├── training ├── train_vqgan.py ├── __init__.py └── optimizer.py ├── evaluations ├── requirements.txt ├── sample_sg_512.py ├── sample_sg_256.py ├── maskgit_toast.ipynb └── README.md ├── muse ├── data.py ├── __init__.py ├── sampling.py ├── modeling_paella_vq.py ├── logging.py ├── modeling_ema.py ├── lr_schedulers.py ├── pipeline_muse_toast.py ├── modeling_maskgit_vqgan.py └── training_utils.py ├── .gitignore ├── setup.cfg ├── acc_config.yaml ├── acc_config_mult.yaml ├── scripts ├── convert_coco_to_wds.py ├── compute_offline_ema.py ├── benchmark_models.py ├── convert_imagenet_to_wds.py ├── convert_imagenet_local.py ├── log_generations_wandb.py ├── makeshards.py ├── log_inpainting_images.py ├── convert_maskgit_transformer.py ├── convert_maskgit_vqgan.py └── pre_encode.py ├── setup.py ├── docker └── Dockerfile ├── README.md └── configs ├── ft_256_toast_cls_b256_corr.yaml └── ft_512_toast_cls_b256_corr.yaml /training/train_vqgan.py: -------------------------------------------------------------------------------- 1 | """Training script for VQGAN.""" 2 | -------------------------------------------------------------------------------- /evaluations/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0 2 | scipy 3 | requests 4 | tqdm -------------------------------------------------------------------------------- /muse/data.py: -------------------------------------------------------------------------------- 1 | """All data related utilities and loaders are defined here.""" 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | results_corr/ 3 | muse.egg-info/ 4 | build/ 5 | README.md 6 | muse/__pycache__/ 7 | training/__pycache__/ 8 | scripts/tokenizer_imagenet* 9 | scripts/maskgit_imagenet* -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = accelerate 7 | known_third_party = 8 | numpy 9 | torch 10 | torch_xla 11 | 12 | line_length = 119 13 | lines_after_imports = 2 14 | multi_line_output = 3 15 | use_parentheses = True 16 | 17 | [flake8] 18 | ignore = E203, E722, E501, E741, W503, W605 19 | max-line-length = 119 -------------------------------------------------------------------------------- /acc_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | # debug: false 3 | distributed_type: 'NO' # 'NO' | 'MULTI_GPU' (set num_processes > 1) | 'TPU' (set tpu_use_cluster = True) 4 | downcast_bf16: 'NO' 5 | gpu_ids: '0 ' 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /acc_config_mult.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | # debug: false 3 | distributed_type: 'MULTI_GPU' # 'NO' | 'MULTI_GPU' (set num_processes > 1) | 'TPU' (set tpu_use_cluster = True) 4 | main_process_port: 29501 5 | downcast_bf16: 'no' 6 | gpu_ids: '0,1,2,3' 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 4 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /muse/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "0.0.1" 17 | 18 | from .modeling_ema import EMAModel 19 | from .modeling_maskgit_vqgan import MaskGitVQGAN 20 | from .modeling_movq import MOVQ 21 | from .modeling_paella_vq import PaellaVQModel 22 | from .modeling_taming_vqgan import VQGANModel 23 | from .modeling_transformer import MaskGitTransformer 24 | from .pipeline_muse import PipelineMuse, PipelineMuseInpainting 25 | # from .pipeline_muse_org import PipelineMuseOrg 26 | from .sampling import get_mask_chedule 27 | from .modeling_transformer_toast import MaskGitTransformerTOAST 28 | -------------------------------------------------------------------------------- /scripts/convert_coco_to_wds.py: -------------------------------------------------------------------------------- 1 | # To download the 2017 train split of coco 2 | # $ wget http://images.cocodataset.org/zips/train2017.zip 3 | # $ wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 4 | # 5 | # script assumes they're downloaded and unzipped into ../data 6 | # 7 | # we write into ../data/coco-2017-train/ because shard writer doesn't support piped 8 | # uploads 9 | # 10 | # after writing to disk, run 11 | # $ aws s3 cp ../data/coco-2017-train/ s3://muse-datasets/coco/2017/train/ --recursive 12 | 13 | # Needed for `PIL.Image` to work in wds :/ 14 | 15 | import json 16 | 17 | import webdataset as wds 18 | from cv2 import COLOR_BGR2RGB, cvtColor, imread 19 | 20 | 21 | def main(): 22 | with open("../data/annotations/captions_train2017.json") as f: 23 | annotations = json.load(f)["annotations"] 24 | 25 | annotations_by_image_id = {} 26 | 27 | for annotation in annotations: 28 | image_id = annotation["image_id"] 29 | 30 | if image_id not in annotations_by_image_id: 31 | annotations_by_image_id[image_id] = [] 32 | 33 | annotations_by_image_id[image_id].append(annotation) 34 | 35 | shard_writer = wds.ShardWriter("../data/coco-2017-train/%05d.tar", maxsize=5e8) 36 | 37 | for image_id, annotations in annotations_by_image_id.items(): 38 | print(f"writing {image_id}") 39 | image = imread("../data/train2017/%012d.jpg" % image_id) 40 | image = cvtColor(image, COLOR_BGR2RGB) 41 | 42 | annotations_metadata = [] 43 | 44 | for annotation in annotations: 45 | annotations_metadata.append({"id": annotation["id"], "caption": annotation["caption"]}) 46 | 47 | metadata = {"annotations": json.dumps(annotations_metadata)} 48 | 49 | shard_writer.write({"__key__": str(image_id), "json": metadata, "jpg": image}) 50 | 51 | shard_writer.close() 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | # To use a consistent encoding 19 | from codecs import open 20 | 21 | import setuptools 22 | 23 | _deps = [ 24 | "transformers==4.26.1", 25 | "accelerate==0.17.1", 26 | "einops==0.6.0", 27 | "omegaconf==2.3.0", 28 | "webdataset>=0.2.39", 29 | "datasets", 30 | "wandb", 31 | "sentencepiece", # for T5 tokenizer 32 | "plotly", 33 | "pandas", 34 | ] 35 | 36 | _extras_dev_deps = [ 37 | "black[jupyter]~=23.1", 38 | "isort>=5.5.4", 39 | "flake8>=3.8.3", 40 | ] 41 | 42 | 43 | here = os.path.abspath(os.path.dirname(__file__)) 44 | 45 | with open(os.path.join(here, "README.md"), encoding="utf-8") as f: 46 | long_description = f.read() 47 | 48 | # read version 49 | with open(os.path.join(here, "muse", "__init__.py"), encoding="utf-8") as f: 50 | for line in f: 51 | if line.startswith("__version__"): 52 | version = line.split("=")[1].strip().strip('"') 53 | break 54 | else: 55 | raise RuntimeError("Unable to find version string.") 56 | 57 | setuptools.setup( 58 | name="muse", 59 | version=version, 60 | description="The best generative model in PyTorch", 61 | long_description=long_description, 62 | long_description_content_type="text/markdown", 63 | packages=setuptools.find_packages(), 64 | install_requires=_deps, 65 | extras_require={ 66 | "dev": [_extras_dev_deps], 67 | }, 68 | ) 69 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Builds GPU docker image of PyTorch 2 | # Uses multi-staged approach to reduce size 3 | # Stage 1 4 | # Use base conda image to reduce time 5 | FROM continuumio/miniconda3:latest AS compile-image 6 | # Specify py version 7 | ENV PYTHON_VERSION=3.10 8 | # Install apt libs 9 | RUN apt-get update && \ 10 | apt-get install -y curl git wget nano && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists* 13 | 14 | # Create our conda env 15 | RUN conda create --name accelerate python=${PYTHON_VERSION} ipython jupyter pip 16 | # We don't install pytorch here yet since CUDA isn't available 17 | # instead we use the direct torch wheel 18 | ENV PATH /opt/conda/envs/accelerate/bin:$PATH 19 | # Activate our bash shell 20 | RUN chsh -s /bin/bash 21 | SHELL ["/bin/bash", "-c"] 22 | # Activate the conda env and install torch + accelerate 23 | RUN source activate accelerate && \ 24 | python3 -m pip install --no-cache-dir \ 25 | torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 \ 26 | --index-url https://download.pytorch.org/whl/cu116 27 | 28 | RUN python3 -m pip install --no-cache-dir bitsandbytes 29 | 30 | # Stage 2 31 | FROM nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04 AS build-image 32 | COPY --from=compile-image /opt/conda /opt/conda 33 | ENV PATH /opt/conda/bin:$PATH 34 | 35 | # Install apt libs 36 | RUN apt-get update && \ 37 | apt-get install -y curl git wget tmux htop && \ 38 | apt-get clean && \ 39 | rm -rf /var/lib/apt/lists* 40 | 41 | RUN echo "source activate accelerate" >> ~/.profile 42 | 43 | #RUN /bin/bash -c "source activate accelerate && pip install xformers" 44 | #RUN echo "Installing Apex..." 45 | #WORKDIR /tmp/unique_for_apex 46 | #RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git 47 | #WORKDIR /tmp/unique_for_apex/apex 48 | 49 | #RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 50 | #WORKDIR /muse # move to git page 51 | #RUN pip install -e ".[extra]" 52 | #WORKDIR / 53 | # Activate the virtualenv 54 | CMD ["/bin/bash"] 55 | -------------------------------------------------------------------------------- /scripts/compute_offline_ema.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from muse import EMAModel 8 | from muse import MaskGitCorrTransformerOrg 9 | 10 | 11 | def offline_ema(args): 12 | checkpoint_dir_path = args.checkpoint_dir_path 13 | ema_save_path = args.ema_save_path 14 | ema_decay = args.ema_decay 15 | checkpoint_interval = args.checkpoint_interval 16 | 17 | dirs = os.listdir(checkpoint_dir_path) 18 | dirs = [d for d in dirs if d.startswith("checkpoint")] 19 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 20 | dirs = [Path(checkpoint_dir_path) / dir_ for dir_ in dirs] 21 | print(Path(checkpoint_dir_path) / dirs[0] / "unwrapped_model") 22 | print(dirs[0]) 23 | transformer_config = MaskGitCorrTransformerOrg.load_config(dirs[0] / "unwrapped_model") 24 | # if transformer_config["_class_name"] == "MaskGitTransformer": 25 | # model_cls = MaskGitCorrTransformerOrg 26 | # elif transformer_config["_class_name"] == "MaskGiTUViT": 27 | # model_cls = MaskGiTUViT 28 | model_cls = MaskGitCorrTransformerOrg 29 | device = "cpu" 30 | if torch.cuda.is_available(): 31 | device = "cuda" 32 | 33 | model = model_cls.from_pretrained(dirs[0] / "unwrapped_model").to(device) 34 | ema_model = EMAModel(parameters=model.parameters(), decay=ema_decay, update_every=checkpoint_interval) 35 | ema_model.to(device) 36 | 37 | end_step = int(str(dirs[-1]).split("-")[-1]) 38 | for step in range(0, end_step): 39 | if (step + 1) % checkpoint_interval == 0: 40 | print(f"Loading checkpoint {step + 1}...") 41 | model = model_cls.from_pretrained(Path(checkpoint_dir_path) / f"checkpoint-{step + 1}" / "unwrapped_model") 42 | model.to(device) 43 | 44 | ema_model.step(model.parameters()) 45 | 46 | ema_model.copy_to(model.parameters()) 47 | model.save_pretrained(ema_save_path) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--checkpoint_dir_path", type=str, default=None, required=True) 53 | parser.add_argument("--ema_save_path", type=str, default=None, required=True) 54 | parser.add_argument("--ema_decay", type=float, default=0.9999) 55 | parser.add_argument("--checkpoint_interval", type=int, default=1000) 56 | 57 | args = parser.parse_args() 58 | offline_ema(args) 59 | -------------------------------------------------------------------------------- /scripts/benchmark_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | 4 | import torch 5 | import torch.utils.benchmark as benchmark 6 | 7 | from muse import MaskGitTransformer, MaskGiTUViT 8 | 9 | 10 | def benchmark_torch_function(f, *args, **kwargs): 11 | t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) 12 | return round(t0.blocked_autorange(min_run_time=1).mean, 2) 13 | 14 | 15 | def create_model_and_benchmark(args): 16 | if args.model_type == "transformer": 17 | config = MaskGitTransformer.load_config(args.config_path) 18 | model = MaskGitTransformer.from_config(config).to(args.device) 19 | elif args.model_type == "uvit": 20 | config = MaskGiTUViT.load_config(args.config_path) 21 | model = MaskGiTUViT.from_config(config).to(args.device) 22 | 23 | model.eval() 24 | 25 | print("Running benchmark for vanilla attention in FP32 ...") 26 | encoder_hidden_states = torch.randn( 27 | args.batch_size, args.text_length, model.config.encoder_hidden_size, device=args.device, dtype=torch.float32 28 | ) 29 | f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps) 30 | time_vanilla = benchmark_torch_function(f) 31 | 32 | print("Running benchmark for vanilla attention in FP16 ...") 33 | encoder_hidden_states = encoder_hidden_states.half() 34 | model = model.half() 35 | f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps) 36 | time_vanilla_fp16 = benchmark_torch_function(f) 37 | 38 | print("Running benchmark for efficient attention in FP16 ...") 39 | model.enable_xformers_memory_efficient_attention() 40 | f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps) 41 | time_efficient_fp16 = benchmark_torch_function(f) 42 | 43 | # print results with nice formatting 44 | print(f"Vanilla attention in FP32: {time_vanilla} ms") 45 | print(f"Vanilla attention in FP16: {time_vanilla_fp16} ms") 46 | print(f"Efficient attention in FP16: {time_efficient_fp16} ms") 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("--config_path", type=str, required=True) 52 | parser.add_argument("--batch_size", type=int, default=32) 53 | parser.add_argument("--model_type", type=str, default="transformer", choices=["transformer", "uvit"]) 54 | parser.add_argument("--text_length", type=int, default=96) 55 | parser.add_argument("--time_steps", type=int, default=12) 56 | parser.add_argument("--device", type=str, default="cuda") 57 | 58 | args = parser.parse_args() 59 | create_model_and_benchmark(args) 60 | -------------------------------------------------------------------------------- /evaluations/sample_sg_512.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from tqdm import tqdm 4 | 5 | sys.path.append('.') 6 | import torch 7 | 8 | import numpy as np 9 | device="cuda:1" 10 | from muse.pipeline_muse_toast import PipelineMuse 11 | 12 | 13 | model_paths = [ 14 | "results_corr/ft_512_toast_cls_b256_corr/checkpoint-50000/ema_model", 15 | ] 16 | for model_path in model_paths: 17 | pipe = PipelineMuse.from_pretrained(transformer_path=model_path, 18 | is_class_conditioned=True, 19 | vae_path="./scripts/tokenizer_imagenet512_torch/", 20 | use_toast=True).to(device) 21 | pipe.transformer.eval() 22 | pipe.vae.eval() 23 | 24 | sample_root = f"." 25 | 26 | temperatures = [45] 27 | sampling_types = ['self_guidance'] # maskgit, self_guidance 28 | schedule = ["cosine"] 29 | sampling_steps = [18] 30 | guidance_scale = [1] 31 | 32 | for gs in guidance_scale: 33 | for schedule_type in schedule: 34 | for sampling_type in sampling_types: 35 | for t in temperatures: 36 | for i in range(len(sampling_steps)): 37 | step = sampling_steps[i] 38 | batch_size = 32 39 | # num_images = 50000 40 | num_images = 500 41 | num_iter = num_images // batch_size + 1 42 | 43 | model_name = model_path.split('/')[1] 44 | checkpoint_name = model_path.split('/')[2] 45 | 46 | save_dir = os.path.join(sample_root, f"{model_name}_{checkpoint_name}_{num_images//1000}k_{sampling_type}_{step}_{schedule_type}_{t}_{schedule_type}_gs_{gs}.npz") 47 | print(save_dir) 48 | 49 | all_images = [] 50 | all_labels = [] 51 | for o in tqdm(range(1, num_iter+1)): 52 | class_ids = torch.randint(0, 1000, (batch_size,)) 53 | images = pipe(class_ids=class_ids, num_images_per_prompt=batch_size, 54 | return_intermediate=False, timesteps=step, temperature=t, 55 | sampling_type=sampling_type, schedule=schedule_type, guidance_scale=gs, min_c_ratio=0.5) 56 | 57 | all_images += images 58 | all_labels.append(class_ids.numpy()) 59 | 60 | arr = np.array([np.array(image) for image in all_images]) 61 | arr = arr[:num_images] 62 | label_arr = np.concatenate(all_labels) 63 | label_arr = label_arr[:num_images] 64 | 65 | np.savez(save_dir, arr, label_arr) 66 | -------------------------------------------------------------------------------- /evaluations/sample_sg_256.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from tqdm import tqdm 4 | 5 | sys.path.append('.') 6 | import torch 7 | 8 | import numpy as np 9 | device="cuda:0" 10 | from muse.pipeline_muse_toast import PipelineMuse 11 | 12 | 13 | model_paths = [ 14 | "results_corr/ft_256_toast_cls_b256_corr/checkpoint-50000/ema_model", 15 | ] 16 | for model_path in model_paths: 17 | pipe = PipelineMuse.from_pretrained(transformer_path=model_path, 18 | is_class_conditioned=True, 19 | vae_path="./scripts/tokenizer_imagenet256_torch/", 20 | use_toast=True).to(device) 21 | pipe.transformer.eval() 22 | pipe.vae.eval() 23 | 24 | sample_root = f"." 25 | 26 | temperatures = [25] 27 | sampling_types = ['self_guidance'] # maskgit, self_guidance 28 | schedule = ["cosine"] 29 | sampling_steps = [18] 30 | guidance_scale = [1] 31 | 32 | for gs in guidance_scale: 33 | for schedule_type in schedule: 34 | for sampling_type in sampling_types: 35 | for t in temperatures: 36 | for i in range(len(sampling_steps)): 37 | step = sampling_steps[i] 38 | batch_size = 128 39 | # num_images = 50000 40 | num_images = 500 41 | num_iter = num_images // batch_size + 1 42 | 43 | model_name = model_path.split('/')[1] 44 | checkpoint_name = model_path.split('/')[2] 45 | 46 | save_dir = os.path.join(sample_root, f"{model_name}_{checkpoint_name}_{num_images//1000}k_{sampling_type}_{step}_{schedule_type}_{t}_{schedule_type}_gs_{gs}.npz") 47 | print(save_dir) 48 | 49 | all_images = [] 50 | all_labels = [] 51 | for o in tqdm(range(1, num_iter+1)): 52 | class_ids = torch.randint(0, 1000, (batch_size,)) 53 | images = pipe(class_ids=class_ids, num_images_per_prompt=batch_size, 54 | return_intermediate=False, timesteps=step, temperature=t, 55 | sampling_type=sampling_type, schedule=schedule_type, guidance_scale=gs, min_c_ratio=0.5) 56 | 57 | all_images += images 58 | all_labels.append(class_ids.numpy()) 59 | 60 | arr = np.array([np.array(image) for image in all_images]) 61 | arr = arr[:num_images] 62 | label_arr = np.concatenate(all_labels) 63 | label_arr = label_arr[:num_images] 64 | 65 | np.savez(save_dir, arr, label_arr) 66 | -------------------------------------------------------------------------------- /scripts/convert_imagenet_to_wds.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py 17 | 18 | import argparse 19 | import os 20 | import sys 21 | import time 22 | 23 | import webdataset as wds 24 | from datasets import load_dataset 25 | 26 | 27 | def convert_imagenet_to_wds(output_dir, max_train_samples_per_shard, max_val_samples_per_shard): 28 | assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar")) 29 | assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar")) 30 | 31 | opat = os.path.join(output_dir, "imagenet-train-%06d.tar") 32 | output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard) 33 | dataset = load_dataset("imagenet-1k", streaming=True, split="train", use_auth_token=True) 34 | now = time.time() 35 | for i, example in enumerate(dataset): 36 | if i % max_train_samples_per_shard == 0: 37 | print(i, file=sys.stderr) 38 | img, label = example["image"], example["label"] 39 | output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label}) 40 | output.close() 41 | time_taken = time.time() - now 42 | print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.") 43 | 44 | opat = os.path.join(output_dir, "imagenet-val-%06d.tar") 45 | output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard) 46 | dataset = load_dataset("imagenet-1k", streaming=True, split="validation", use_auth_token=True) 47 | now = time.time() 48 | for i, example in enumerate(dataset): 49 | if i % max_val_samples_per_shard == 0: 50 | print(i, file=sys.stderr) 51 | img, label = example["image"], example["label"] 52 | output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label}) 53 | output.close() 54 | time_taken = time.time() - now 55 | print(f"Wrote {i+1} val examples in {time_taken // 60} min.") 56 | 57 | 58 | if __name__ == "__main__": 59 | # create parase object 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory.") 62 | parser.add_argument("--max_train_samples_per_shard", type=int, default=4000, help="Path to the output directory.") 63 | parser.add_argument("--max_val_samples_per_shard", type=int, default=1000, help="Path to the output directory.") 64 | args = parser.parse_args() 65 | 66 | # create output directory 67 | os.makedirs(args.output_dir, exist_ok=True) 68 | convert_imagenet_to_wds(args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard) 69 | -------------------------------------------------------------------------------- /scripts/convert_imagenet_local.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py 17 | 18 | import argparse 19 | import os 20 | import sys 21 | import time 22 | 23 | import webdataset as wds 24 | from datasets import load_dataset 25 | 26 | 27 | def convert_imagenet_to_wds(output_dir, max_train_samples_per_shard, max_val_samples_per_shard): 28 | assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar")) 29 | assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar")) 30 | 31 | opat = os.path.join(output_dir, "imagenet-train-%06d.tar") 32 | output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard) 33 | # dataset = load_dataset("imagenet-1k", split="train", data_dir='/mnt/sdb/datset/ILSVRC2012') 34 | dataset = load_dataset("imagenet-1k", split="train", data_dir='/mnt/sdb/ILSVRC2012/') 35 | print('load data') 36 | now = time.time() 37 | for i, example in enumerate(dataset): 38 | if i % max_train_samples_per_shard == 0: 39 | print(i, file=sys.stderr) 40 | img, label = example["image"], example["label"] 41 | output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label}) 42 | output.close() 43 | time_taken = time.time() - now 44 | print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.") 45 | 46 | opat = os.path.join(output_dir, "imagenet-val-%06d.tar") 47 | output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard) 48 | dataset = load_dataset("imagenet-1k", data_dir='/mnt/sdb/ILSVRC2012/', split="validation") 49 | now = time.time() 50 | for i, example in enumerate(dataset): 51 | if i % max_val_samples_per_shard == 0: 52 | print(i, file=sys.stderr) 53 | img, label = example["image"], example["label"] 54 | output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label}) 55 | output.close() 56 | time_taken = time.time() - now 57 | print(f"Wrote {i+1} val examples in {time_taken // 60} min.") 58 | 59 | 60 | if __name__ == "__main__": 61 | # create parase object 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory.") 64 | parser.add_argument("--max_train_samples_per_shard", type=int, default=4000, help="Path to the output directory.") 65 | parser.add_argument("--max_val_samples_per_shard", type=int, default=1000, help="Path to the output directory.") 66 | args = parser.parse_args() 67 | 68 | # create output directory 69 | os.makedirs(args.output_dir, exist_ok=True) 70 | convert_imagenet_to_wds(args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard) 71 | -------------------------------------------------------------------------------- /scripts/log_generations_wandb.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from itertools import islice 4 | 5 | import torch 6 | import wandb 7 | 8 | from muse import PipelineMuse 9 | 10 | 11 | def chunk(it, size): 12 | it = iter(it) 13 | return iter(lambda: tuple(islice(it, size)), ()) 14 | 15 | 16 | def generate_and_log(args): 17 | run_name = f"{args.transformer} samples at checkpoint {args.checkpoint}" 18 | wandb.init( 19 | project=args.project, 20 | entity=args.entity, 21 | name=run_name, 22 | notes=( 23 | f"Samples from {args.run_id} at checkpoint {args.checkpoint} with timesteps={args.timesteps}," 24 | f" guidance_scale={args.guidance_scale}, temperature={args.temperature}" 25 | ), 26 | ) 27 | 28 | pipe = PipelineMuse.from_pretrained( 29 | text_encoder_path=args.text_encoder, 30 | vae_path=args.vae, 31 | transformer_path=args.transformer, 32 | ).to(device=args.device) 33 | pipe.transformer.enable_xformers_memory_efficient_attention() 34 | 35 | # open args.prompts_file_path and read prompts in a list 36 | with open(args.prompts_file_path, "r") as f: 37 | prompts = f.readlines() 38 | 39 | # divide the prompts into batches of size args.batch_size 40 | prompts = list(chunk(prompts, args.batch_size)) 41 | 42 | # generate images and log in wandb table 43 | table = wandb.Table(columns=["prompt"] + [f"image {i}" for i in range(args.num_generations)]) 44 | for batch in prompts: 45 | images = pipe( 46 | batch, 47 | timesteps=args.timesteps, 48 | guidance_scale=args.guidance_scale, 49 | temperature=args.temperature, 50 | num_images_per_prompt=args.num_generations, 51 | use_maskgit_generate=True, 52 | use_fp16=True, 53 | ) 54 | 55 | # create rows like this: [prompt, image 1, image 2, ...] 56 | # where each image is a wandb.Image 57 | # and log in wandb table 58 | images = list(chunk(images, args.num_generations)) 59 | for prompt, gen_images in zip(batch, images): 60 | row = [prompt] 61 | for image in gen_images: 62 | row.append(wandb.Image(image)) 63 | table.add_data(*row) 64 | 65 | wandb.log({"samples": table}) 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = ArgumentParser() 70 | parser.add_argument("--project", type=str, default="muse") 71 | parser.add_argument("--entity", type=str, default="psuraj") 72 | parser.add_argument("--run_id", type=str, required=True) 73 | parser.add_argument("--timesteps", type=int, default=12) 74 | parser.add_argument("--temperature", type=float, default=1.0) 75 | parser.add_argument("--guidance_scale", type=float, default=8) 76 | parser.add_argument("--num_generations", type=int, default=8) 77 | parser.add_argument("--checkpoint", type=str, required=True) 78 | parser.add_argument("--text_encoder", type=str, default="google/t5-v1_1-large") 79 | parser.add_argument("--vae", type=str, default="openMUSE/maskgit-vqgan-imagenet-f16-256") 80 | parser.add_argument("--transformer", type=str, required=True) 81 | parser.add_argument("--prompts_file_path", type=str, required=True) 82 | parser.add_argument("--device", type=str, default="cuda") 83 | parser.add_argument("--batch_size", type=int, default=64) 84 | 85 | args = parser.parse_args() 86 | generate_and_log(args) 87 | -------------------------------------------------------------------------------- /muse/sampling.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/lucidrains/muse-maskgit-pytorch 2 | 3 | import math 4 | from functools import partial 5 | 6 | import torch 7 | 8 | 9 | def log(t, eps=1e-20): 10 | return torch.log(t.clamp(min=eps)) 11 | 12 | 13 | def gumbel_noise(t, generator=None): 14 | noise = torch.zeros_like(t).uniform_(0, 1, generator=generator) 15 | return -log(-log(noise)) 16 | 17 | 18 | def gumbel_sample(t, temperature=1.0, dim=-1, generator=None): 19 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim) 20 | 21 | 22 | def gumbel_sample_max(t, temperature=1.0, dim=-1, generator=None): 23 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).max(dim=dim) 24 | 25 | def top_k(logits, thres=0.9): 26 | k = math.ceil((1 - thres) * logits.shape[-1]) 27 | val, ind = logits.topk(k, dim=-1) 28 | probs = torch.full_like(logits, float("-inf")) 29 | probs.scatter_(2, ind, val) 30 | return probs 31 | 32 | def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): 33 | confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator) 34 | sorted_confidence = torch.sort(confidence, dim=-1).values 35 | cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) 36 | masking = confidence < cut_off 37 | return masking 38 | 39 | 40 | def cosine_schedule(t): 41 | return torch.cos(t * math.pi * 0.5) 42 | 43 | def linear_schedule(t): 44 | mask_ratio = 1 - t 45 | mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0) 46 | return mask_ratio 47 | 48 | def exponential_schedule(t, total_unkown=256): 49 | mask_ratio = 1 - torch.exp(-torch.log(torch.tensor(total_unkown)) * (1-t)) 50 | return mask_ratio 51 | 52 | def sqrt_schedule(t): 53 | return (1 - torch.sqrt(t)).clamp(min=1e-6, max=1.0) 54 | 55 | def pow(t, method): 56 | exponent = float(method.replace("pow", "")) 57 | mask_ratio = 1.0 - t**exponent 58 | mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0) 59 | return mask_ratio 60 | 61 | 62 | def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6): 63 | for item in [t, start, end, tau]: 64 | item = torch.tensor(item) if not torch.is_tensor(item) else item 65 | 66 | # A gamma function based on sigmoid function. 67 | v_start = torch.sigmoid(torch.tensor(start / tau)) 68 | v_end = torch.sigmoid(torch.tensor(end / tau)) 69 | output = torch.sigmoid((t * (end - start) + start) / tau) 70 | output = (v_end - output) / (v_end - v_start) 71 | return torch.clip(output, clip_min, 1.0) 72 | 73 | 74 | def get_mask_chedule(method, **schedule_kwargs): 75 | if method == "cosine": 76 | return cosine_schedule 77 | elif method == "linear": 78 | return linear_schedule 79 | elif "pow" in method: 80 | return partial(pow, method=method) 81 | elif method == "sigmoid": 82 | return partial(sigmoid_schedule, **schedule_kwargs) 83 | else: 84 | raise ValueError("Unknown schedule method: {}".format(method)) 85 | 86 | 87 | def new_arange(x, *size): 88 | """ 89 | Return a Tensor of `size` filled with a range function on the device of x. 90 | If size is empty, using the size of the variable x. 91 | """ 92 | if len(size) == 0: 93 | size = x.size() 94 | return torch.arange(size[-1], device=x.device).expand(*size).contiguous() 95 | 96 | def uniform(shape, min=0, max=1, device=None): 97 | return torch.zeros(shape, device=device).float().uniform_(0, 1) -------------------------------------------------------------------------------- /scripts/makeshards.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path 4 | import random 5 | import argparse 6 | 7 | from torchvision import datasets 8 | from PIL import Image 9 | import webdataset as wds 10 | 11 | 12 | parser = argparse.ArgumentParser("""Generate sharded dataset from original ImageNet data.""") 13 | parser.add_argument("--splits", default="train,val", help="which splits to write") 14 | parser.add_argument( 15 | "--filekey", action="store_true", help="use file as key (default: index)" 16 | ) 17 | parser.add_argument("--maxsize", type=float, default=1e9) 18 | parser.add_argument("--maxcount", type=float, default=100000) 19 | parser.add_argument( 20 | "--shards", default="./shards", help="directory where shards are written" 21 | ) 22 | parser.add_argument( 23 | "--data", 24 | default="./data", 25 | help="directory containing ImageNet data distribution suitable for torchvision.datasets", 26 | ) 27 | args = parser.parse_args() 28 | 29 | 30 | assert args.maxsize > 10000000 31 | assert args.maxcount < 1000000 32 | 33 | 34 | if not os.path.isdir(os.path.join(args.data, "train")): 35 | print(f"{args.data}: should be directory containing ImageNet", file=sys.stderr) 36 | print(f"suitable as argument for torchvision.datasets.ImageNet(...)", file=sys.stderr) 37 | sys.exit(1) 38 | 39 | 40 | if not os.path.isdir(os.path.join(args.shards, ".")): 41 | print(f"{args.shards}: should be a writable destination directory for shards", file=sys.stderr) 42 | sys.exit(1) 43 | 44 | 45 | splits = args.splits.split(",") 46 | 47 | 48 | def readfile(fname): 49 | "Read a binary file from disk." 50 | with open(fname, "rb") as stream: 51 | return stream.read() 52 | 53 | 54 | all_keys = set() 55 | 56 | 57 | def write_dataset(imagenet, base="./shards", split="train"): 58 | 59 | # We're using the torchvision ImageNet dataset 60 | # to parse the metadata; however, we will read 61 | # the compressed images directly from disk (to 62 | # avoid having to reencode them) 63 | ds = datasets.ImageNet(imagenet, split=split) 64 | nimages = len(ds.imgs) 65 | print("# nimages", nimages) 66 | 67 | # We shuffle the indexes to make sure that we 68 | # don't get any large sequences of a single class 69 | # in the dataset. 70 | indexes = list(range(nimages)) 71 | #random.shuffle(indexes) 72 | 73 | # This is the output pattern under which we write shards. 74 | pattern = os.path.join(base, f"imagenet-{split}-%06d.tar") 75 | 76 | with wds.ShardWriter(pattern, maxsize=int(args.maxsize), maxcount=int(args.maxcount)) as sink: 77 | for i in indexes: 78 | 79 | # Internal information from the ImageNet dataset 80 | # instance: the file name and the numerical class. 81 | fname, cls = ds.imgs[i] 82 | assert cls == ds.targets[i] 83 | 84 | # Read the JPEG-compressed image file contents. 85 | #image = readfile(fname) 86 | image = Image.open(fname).convert("RGB") 87 | # Construct a uniqu keye from the filename. 88 | key = os.path.splitext(os.path.basename(fname))[0] 89 | 90 | # Useful check. 91 | assert key not in all_keys 92 | all_keys.add(key) 93 | 94 | # Construct a sample. 95 | xkey = key if args.filekey else "%08d" % i 96 | sample = {"__key__": xkey, "jpg": image, "cls": cls} 97 | 98 | # Write the sample to the sharded tar archives. 99 | sink.write(sample) 100 | 101 | 102 | for split in splits: 103 | print("# split", split) 104 | write_dataset(args.data, base=args.shards, split=split) 105 | -------------------------------------------------------------------------------- /scripts/log_inpainting_images.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from argparse import ArgumentParser 5 | from itertools import islice 6 | 7 | import numpy as np 8 | import torch 9 | import wandb 10 | from PIL import Image 11 | 12 | from muse import PipelineMuseInpainting 13 | 14 | 15 | def chunk(it, size): 16 | it = iter(it) 17 | return iter(lambda: tuple(islice(it, size)), ()) 18 | 19 | 20 | def generate_and_log(args): 21 | os.makedirs(args.output_dir, exist_ok=True) 22 | vae_scaling_factor = args.vae_scaling_factor 23 | pipe = PipelineMuseInpainting.from_pretrained( 24 | model_name_or_path=args.model_name_or_path, 25 | is_class_conditioned=args.is_class_conditioned, 26 | ).to(device=args.device) 27 | pipe.transformer.enable_xformers_memory_efficient_attention() 28 | 29 | if args.is_class_conditioned: 30 | imagenet_class_ids = [args.imagenet_class_id] 31 | class_ids = torch.tensor(imagenet_class_ids).to(device=args.device, dtype=torch.long) 32 | inputs = {"class_ids": class_ids} 33 | else: 34 | inputs = {"text": args.text} 35 | 36 | mask = np.zeros((args.image_size // vae_scaling_factor, args.image_size // vae_scaling_factor)) 37 | mask[args.mask_start_x : args.mask_end_x, args.mask_start_y : args.mask_end_y] = 1 38 | mask = mask.reshape(-1) 39 | mask = torch.tensor(mask).to(args.device, dtype=torch.bool) 40 | 41 | image = Image.open(args.input_image).resize((args.image_size, args.image_size)) 42 | 43 | masked_image = copy.deepcopy(np.array(image)) 44 | masked_image[ 45 | args.mask_start_x * vae_scaling_factor : args.mask_end_x * vae_scaling_factor, 46 | args.mask_start_y * vae_scaling_factor : args.mask_end_y * vae_scaling_factor, 47 | ] = 0 48 | masked_image = Image.fromarray(masked_image) 49 | masked_image.save(os.path.join(args.output_dir, "segmented.jpg")) 50 | images = pipe( 51 | image=image, 52 | mask=mask, 53 | **inputs, 54 | timesteps=args.timesteps, 55 | guidance_scale=args.guidance_scale, 56 | temperature=args.temperature, 57 | use_maskgit_generate=not args.not_maskgit_generate, 58 | num_images_per_prompt=args.num_generations, 59 | image_size=args.image_size, 60 | ) 61 | 62 | if args.is_class_conditioned: 63 | images = list(chunk(images, args.num_generations)) 64 | for class_id, class_images in zip(imagenet_class_ids, images): 65 | for i, image in enumerate(class_images): 66 | image.save(os.path.join(args.output_dir, f"output_{class_id}_{i}.jpg")) 67 | else: 68 | for i, image in enumerate(images): 69 | image.save(os.path.join(args.output_dir, f"output_{i}.jpg")) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = ArgumentParser() 74 | parser.add_argument("--is_class_conditioned", action="store_true") 75 | parser.add_argument("--timesteps", type=int, default=18) 76 | parser.add_argument("--temperature", type=float, default=1.0) 77 | parser.add_argument("--guidance_scale", type=float, default=2.0) 78 | parser.add_argument("--not_maskgit_generate", action="store_true") 79 | parser.add_argument("--num_generations", type=int, default=8) 80 | parser.add_argument("--model_name_or_path", type=str, default="openMUSE/muse-laiona6-uvit-clip-220k") 81 | parser.add_argument("--device", type=str, default="cuda") 82 | parser.add_argument("--imagenet_class_id", type=int, default=248) 83 | parser.add_argument("--text", type=str, default="a picture of a dog") 84 | parser.add_argument("--input_image", type=str, required=True) 85 | parser.add_argument("--image_size", type=int, default=256) 86 | parser.add_argument("--mask_start_x", type=int, default=4) 87 | parser.add_argument("--mask_start_y", type=int, default=4) 88 | parser.add_argument("--mask_end_x", type=int, default=12) 89 | parser.add_argument("--mask_end_y", type=int, default=12) 90 | parser.add_argument("--vae_scaling_factor", type=int, default=16) 91 | parser.add_argument("--output_dir", type=str, default="generated") 92 | args = parser.parse_args() 93 | generate_and_log(args) 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unlocking the Capabilities of Masked Generative Models for Image Synthesis via Self-Guidance: Official PyTorch Implementation (NeurIPS 2024) 2 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2404.02905-b31b1b.svg)](https://arxiv.org/abs/2410.13136)  3 | 4 | [NeurIPS 2024] Official implementation of the paper "Unlocking the Capabilities of Masked Generative Models for Image Synthesis via Self-Guidance". 5 | 6 | ![git](https://github.com/user-attachments/assets/dac17103-f85c-43eb-abd9-0eff738e58ec) 7 | 8 | ## Get Started 9 | Our code is based on the [open-muse](https://github.com/huggingface/open-muse), an open pytorch reproduction of masked generative models such as MaskGIT and MUSE. Please refer the codebase for the more information and source code. 10 | 11 | ### Installation 12 | We provide a `docker/Dockerfile` to simplify the setup of our repository. Once the Docker container is running, follow the scripts below to get started 13 | 14 | ``` 15 | source activate accelerate && pip install xformers 16 | 17 | ### only for fine-tuning ### 18 | mkdir /tmp/unique_for_apex 19 | cd /tmp/unique_for_apex 20 | SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git 21 | cd /tmp/unique_for_apex/apex 22 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 23 | 24 | ### install project dependencies ### 25 | cd /UnlockMGM 26 | pip install -e ".[extra]" 27 | ``` 28 | 29 | Depending on your settings, [apex](https://github.com/NVIDIA/apex) could not be installed successfully. Please carefully check the CUDA version for the Dockerfile or you can fine-tune without fused_adam optimizer. You can skip installing Apex for the sampling only. 30 | 31 | ### Prepare Dataset 32 | Please refer to [open-muse](https://github.com/huggingface/open-muse) to prepare `webdataset` for ImageNet. 33 | 34 | ## Converting VQGAN and MaskGIT Weights 35 | 36 | 1. **Download JAX Checkpoints** 37 | Download the JAX checkpoints for the VQGAN tokenizer and MaskGIT from the [MaskGIT repository](https://github.com/google-research/maskgit). Place the downloaded files in the `scripts/` directory. 38 | 39 | 2. **Convert JAX Checkpoints to PyTorch** 40 | Run the following scripts to convert the JAX checkpoints to PyTorch format: 41 | ```bash 42 | python scripts/convert_maskgit_vqgan.py 43 | python scripts/convert_maskgit_transformer.py 44 | ``` 45 | 46 | 3. **Adjust Resolution** 47 | Update the `resolution` parameter in the converting files to convert 256x256 and 512x512 resolutions, respectively. 48 | 49 | ## Fine-tuning TOAST 50 | For ImageNet256, 51 | ``` 52 | accelerate launch --config_file acc_config_mult.yaml training/ft_maskgit_org_toast.py config=configs/ft_256_toast_cls_b256_corr.yaml 53 | ``` 54 | For ImageNet512, 55 | ``` 56 | accelerate launch --config_file acc_config_mult.yaml training/ft_maskgit_org_toast.py config=configs/ft_512_toast_cls_b256_corr.yaml 57 | ``` 58 | Please adjust the `batch_size` in the config files and the `num_processes` in the `acc_config_mult.yaml` file to ensure that the total batch size matches 256. 59 | 60 | 61 | ## Sample Images with Self-Guidance 62 | 63 | We provide [jupyter notebook](https://github.com/JiwanHur/UnlockMGM/blob/main/evaluations/maskgit_toast.ipynb) to sample and visualize the images using self-guidance. 64 | 65 | To sample images for the evaluations, use `./evaluations/sample_sg_256.py` and `./evaluations/sample_sg_512.py`. 66 | 67 | ## Fine-tuning checkpoints 68 | | model | weights | 69 | |:----------:|:---------:| 70 | |ImageNet-256|[checkpoint](https://huggingface.co/HURJIWAN/UnlockMGM/resolve/main/UnlockMGM_imagenet_256.zip)| 71 | |ImageNet-512|[checkpoint](https://huggingface.co/HURJIWAN/UnlockMGM/resolve/main/UnlockMGM_imagenet_512.zip)| 72 | 73 | 74 | ## Acknowledgements 75 | This code is heavily based on the following repositories. Thanks for all authors for their amazing works! 76 | - [open-muse](https://github.com/huggingface/open-muse) 77 | - [maskgit](https://github.com/google-research/maskgit) 78 | - [TOAST](https://github.com/bfshi/TOAST) 79 | - [DiffFit](https://github.com/mkshing/DiffFit-pytorch) 80 | - [guided-diffusion](https://github.com/openai/guided-diffusion) 81 | - [CMLMC](https://github.com/layer6ai-labs/CMLMC) 82 | - [webdataset](https://github.com/webdataset/webdataset) 83 | - [apex](https://github.com/NVIDIA/apex) 84 | 85 | ## Citation 86 | ``` 87 | @Article{hur2024unlocking, 88 | title={Unlocking the Capabilities of Masked Generative Models for Image Synthesis via Self-Guidance}, 89 | author={Hur, Jiwan and Lee, Dong-Jae and Han, Gyojin and Choi, Jaehyun and Jeon, Yunho and Kim, Junmo}, 90 | journal={arXiv preprint arXiv:2410.13136}, 91 | year={2024} 92 | } 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /configs/ft_256_toast_cls_b256_corr.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: "siit-maskgm" 3 | 4 | experiment: 5 | project: "git_muse" 6 | name: "ft_256_toast_cls_b256_corr" 7 | output_dir: "results_corr/ft_256_toast_cls_b256_corr" 8 | max_train_examples: 1281167 # total number of imagenet examples 9 | max_eval_examples: 12800 10 | save_every: 10000 # 5000 11 | eval_every: 2000 # 1000 12 | generate_every: 2000 # 1000 13 | log_every: 50 14 | log_grad_norm_every: 100 15 | resume_from_checkpoint: False 16 | resume_lr_scheduler: True 17 | fine_tune_path: './scripts/maskgit_imagenet256_torch' 18 | 19 | model: 20 | vq_model: 21 | type: "maskgit_vqgan" 22 | pretrained: "./scripts/tokenizer_imagenet256_torch" 23 | 24 | transformer: 25 | vocab_size: 2025 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + ) 26 | max_position_embeddings: 257 # (256 + 1 for class id) 27 | hidden_size: 768 28 | num_hidden_layers: 24 29 | num_attention_heads: 16 30 | intermediate_size: 3072 31 | codebook_size: 1024 32 | num_vq_tokens: 256 33 | num_classes: 1000 34 | initializer_range: 0.02 35 | norm_type: "layernorm" 36 | layer_norm_eps: 1e-12 37 | layer_norm_embeddings: True 38 | use_normformer: False 39 | use_encoder_layernorm: False 40 | use_mlm_layer: True 41 | use_mlm_layernorm: True 42 | use_maskgit_mlp: True 43 | use_bias: True 44 | hidden_dropout: 0.1 45 | attention_dropout: 0.1 46 | use_embed_fusion: False 47 | 48 | fine_tune: 49 | # use_prompt_token: False # | False | simple | token_generator | late_prompt | 50 | # num_prompt_token: 1 51 | tg_factor: 16 # token generator factor 52 | # prompt_init: 'image_token' # | trunc_norm | mask_token | image_token | 53 | # use_shallow_cpe: False # | False | dense | sparse | 54 | # use_deep_cpe: False # | False | dense | sparse | 55 | # cpe_init: 'zero' # | zero | trunc_norm | 56 | # shallow_cpe_len: 6 57 | # deep_cpe_len: 5 58 | # deep_layer_index: 18 # 13 + 5 (=deep_cpe_len) 59 | # use_bitfit: False 60 | # use_ln: False 61 | # use_aux_loss: False 62 | # use_token_generator: False 63 | # use_cpe: False 64 | # aux_loss_weight: 0.00 65 | # use_ttur: False 66 | # use_pe: False 67 | # ttur_lr_scalefactor: 10 68 | use_toast: True 69 | # use_cls_hidden: False 70 | train_head: False 71 | use_blank_second: True 72 | predict_all: False 73 | lambda_var: 0.1 74 | 75 | gradient_checkpointing: True 76 | enable_xformers_memory_efficient_attention: True 77 | 78 | 79 | dataset: 80 | params: 81 | train_shards_path_or_url: "/mnt/sdb/imagenet_shards/shards/imagenet-train-{000000..000320}.tar" 82 | eval_shards_path_or_url: "/mnt/sdb/imagenet_shards/shards/imagenet-val-{000000..000012}.tar" 83 | batch_size: ${training.batch_size} 84 | shuffle_buffer_size: 1000 85 | num_workers: 32 86 | resolution: 256 87 | pin_memory: True 88 | persistent_workers: True 89 | preprocessing: 90 | resolution: 256 91 | center_crop: False 92 | random_flip: False 93 | random_resize_and_crop: False 94 | 95 | 96 | optimizer: 97 | name: fused_adam 98 | params: # default adamw params 99 | learning_rate: 1e-4 100 | scale_lr: False # scale learning rate by total batch size 101 | beta1: 0.9 102 | beta2: 0.96 103 | weight_decay: 0 104 | epsilon: 1e-8 105 | 106 | 107 | lr_scheduler: 108 | scheduler: "cosine_with_restarts" 109 | params: 110 | learning_rate: ${optimizer.params.learning_rate} 111 | warmup_steps: 5000 112 | num_cycles: 6 113 | 114 | 115 | training: 116 | # freeze_cpe: False 117 | gradient_accumulation_steps: 1 118 | batch_size: 64 # per GPU 119 | mixed_precision: "bf16" 120 | enable_tf32: True 121 | use_ema: True 122 | ema_rate: 0.9999 123 | ema_update_every: 1 124 | seed: 208 125 | max_train_steps: 50000 # 5000 approximates 1 epochs on 256 batch size 126 | overfit_one_batch: False 127 | min_masking_rate: 0.0 128 | label_smoothing: 0.1 129 | max_grad_norm: null 130 | correction_loss_weight: 1.0 131 | masking_type: "no_masking" 132 | # related to vae code sampling 133 | use_soft_code_target: False 134 | use_stochastic_code: False 135 | soft_code_temp: 1.0 136 | use_generation_loss: False 137 | use_correction_loss: True 138 | use_dynamic_substitution: False 139 | substitution_rate: 0.3 140 | 141 | -------------------------------------------------------------------------------- /configs/ft_512_toast_cls_b256_corr.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: "siit-maskgm" 3 | 4 | experiment: 5 | project: "git_muse" 6 | name: "ft_512_toast_cls_b256_corr_blank_corr_ids_min0_varloss" 7 | output_dir: "results_corr/ft_512_toast_cls_b256_corr_blank_corr_ids_min0_varloss" 8 | max_train_examples: 1281167 # total number of imagenet examples 9 | max_eval_examples: 12800 10 | save_every: 10000 # 5000 11 | eval_every: 2000 # 1000 12 | generate_every: 2000 # 1000 13 | log_every: 50 14 | log_grad_norm_every: 100 15 | resume_from_checkpoint: False 16 | resume_lr_scheduler: True 17 | fine_tune_path: './scripts/maskgit_imagenet512_torch' 18 | 19 | model: 20 | vq_model: 21 | type: "maskgit_vqgan" 22 | pretrained: "./scripts/tokenizer_imagenet512_torch/" 23 | 24 | transformer: 25 | vocab_size: 2025 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + ) 26 | max_position_embeddings: 1025 # (1024 + 1 + 1 for class id and prompt token) 27 | hidden_size: 768 28 | num_hidden_layers: 24 29 | num_attention_heads: 16 30 | intermediate_size: 3072 31 | codebook_size: 1024 32 | num_vq_tokens: 1024 33 | num_classes: 1000 34 | initializer_range: 0.02 35 | norm_type: "layernorm" 36 | layer_norm_eps: 1e-12 37 | layer_norm_embeddings: True 38 | use_normformer: False 39 | use_encoder_layernorm: False 40 | use_mlm_layer: True 41 | use_mlm_layernorm: True 42 | use_maskgit_mlp: True 43 | use_bias: True 44 | hidden_dropout: 0.1 45 | attention_dropout: 0.1 46 | use_embed_fusion: False 47 | 48 | fine_tune: 49 | use_prompt_token: False # | False | simple | token_generator | late_prompt | 50 | num_prompt_token: 1 51 | tg_factor: 16 # token generator factor 52 | prompt_init: 'image_token' # | trunc_norm | mask_token | image_token | 53 | use_shallow_cpe: False # | False | dense | sparse | 54 | use_deep_cpe: False # | False | dense | sparse | 55 | cpe_init: 'zero' # | zero | trunc_norm | 56 | shallow_cpe_len: 6 57 | deep_cpe_len: 5 58 | deep_layer_index: 18 # 13 + 5 (=deep_cpe_len) 59 | use_bitfit: False 60 | use_ln: False 61 | use_aux_loss: False 62 | use_token_generator: False 63 | use_cpe: False 64 | aux_loss_weight: 0.00 65 | use_ttur: False 66 | use_pe: False 67 | ttur_lr_scalefactor: 10 68 | use_toast: True 69 | use_cls_hidden: False 70 | train_head: False 71 | use_blank_second: True 72 | predict_all: False 73 | lambda_var: 0.1 74 | 75 | gradient_checkpointing: True 76 | enable_xformers_memory_efficient_attention: True 77 | 78 | 79 | dataset: 80 | params: 81 | train_shards_path_or_url: "/mnt/sdb/imagenet_shards/shards/imagenet-train-{000000..000320}.tar" 82 | eval_shards_path_or_url: "/mnt/sdb/imagenet_shards/shards/imagenet-val-{000000..000012}.tar" 83 | batch_size: ${training.batch_size} 84 | shuffle_buffer_size: 1000 85 | num_workers: 32 86 | resolution: 512 87 | pin_memory: True 88 | persistent_workers: True 89 | preprocessing: 90 | resolution: 512 91 | center_crop: False 92 | random_flip: False 93 | random_resize_and_crop: False 94 | 95 | 96 | optimizer: 97 | name: fused_adam 98 | params: # default adamw params 99 | learning_rate: 1e-4 100 | scale_lr: False # scale learning rate by total batch size 101 | beta1: 0.9 102 | beta2: 0.96 103 | weight_decay: 0 104 | epsilon: 1e-8 105 | 106 | 107 | lr_scheduler: 108 | scheduler: "cosine_with_restarts" 109 | params: 110 | learning_rate: ${optimizer.params.learning_rate} 111 | warmup_steps: 5000 112 | num_cycles: 6 113 | 114 | 115 | training: 116 | freeze_cpe: False 117 | gradient_accumulation_steps: 2 118 | batch_size: 32 # per GPU 119 | mixed_precision: "bf16" 120 | enable_tf32: True 121 | use_ema: True 122 | ema_rate: 0.9999 123 | ema_update_every: 1 124 | seed: 208 125 | max_train_steps: 50000 # 5000 approximates 1 epochs on 256 batch size 126 | overfit_one_batch: False 127 | min_masking_rate: 0.0 128 | label_smoothing: 0.1 129 | max_grad_norm: null 130 | correction_loss_weight: 1.0 131 | masking_type: "no_masking" 132 | # related to vae code sampling 133 | use_soft_code_target: False 134 | use_stochastic_code: False 135 | soft_code_temp: 1.0 136 | use_generation_loss: False 137 | use_correction_loss: True 138 | use_dynamic_substitution: False 139 | substitution_rate: 0.3 140 | 141 | -------------------------------------------------------------------------------- /evaluations/maskgit_toast.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "import sys\n", 12 | "sys.path.append('../')\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "device=\"cuda:0\"\n", 16 | "from imagenet_dict import imagenet_dict\n", 17 | "from muse.pipeline_muse_toast import PipelineMuse" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "def display_intermediate_img_and_mask(i_batch, timestep, images, intermediate_images, mask_index, res=256):\n", 27 | " seq_len = 16 if res == 256 else 32\n", 28 | " vq_len = 256 if res == 256 else 1024\n", 29 | " inter_arr = [np.array(intermediate_images[i][i_batch]) for i in range(timestep)]\n", 30 | " mask_arr = [np.array(mask_index[i][i_batch].reshape(-1, seq_len, seq_len).squeeze().detach().cpu()) for i in range(timestep)]\n", 31 | "\n", 32 | " fig, axes = plt.subplots(timestep+1, 3, figsize=(10, (timestep+2)*2))\n", 33 | "\n", 34 | " mask_arr_np = np.concatenate([np.zeros((1,seq_len,seq_len)), 1-np.array(mask_arr)])\n", 35 | " for i in range(timestep):\n", 36 | " previous = mask_arr_np[i]\n", 37 | " now = mask_arr_np[i+1]\n", 38 | " approved = previous * now\n", 39 | " created = now - approved\n", 40 | " denied = previous - approved\n", 41 | " image_rgb = np.zeros((seq_len, seq_len, 3))\n", 42 | " image_rgb[created == 1] = [0, 1, 1] # Sky Blue for created\n", 43 | " image_rgb[approved == 1] = [0, 1, 0] # Green for approved\n", 44 | " image_rgb[denied == 1] = [1, 0, 0] # Red for denied\n", 45 | " \n", 46 | " axes[i][0].imshow(inter_arr[i], interpolation='nearest', aspect='equal')\n", 47 | " axes[i][1].matshow(1-mask_arr[i], cmap='gray')\n", 48 | " axes[i][2].imshow(image_rgb, interpolation='nearest', aspect='equal')\n", 49 | " \n", 50 | " axes[i][0].axis(\"off\")\n", 51 | " axes[i][1].axis(\"off\")\n", 52 | " axes[i][2].axis(\"off\")\n", 53 | " \n", 54 | " # Insert the count of approved pixels between the subplots\n", 55 | " unmasked = now.sum() / vq_len\n", 56 | " prev_unmasked = previous.sum() / vq_len\n", 57 | " approved = approved.sum() / vq_len\n", 58 | " denied = denied.sum() / vq_len\n", 59 | " created = created.sum() / vq_len\n", 60 | " txt = f'prev_unmasked: {prev_unmasked:.2f} || unmasked: {unmasked:.2f} || approved: {approved:.2f} || denied: {denied:.2f} || created: {created:.2f}'\n", 61 | " \n", 62 | " axes[i][1].text(0.5, 1.05, txt, \n", 63 | " horizontalalignment='center', \n", 64 | " verticalalignment='center', \n", 65 | " transform=axes[i][1].transAxes)\n", 66 | "\n", 67 | " axes[timestep][0].imshow(images[i_batch], interpolation='nearest', aspect='equal')\n", 68 | " axes[timestep][0].axis(\"off\")\n", 69 | " axes[timestep][1].axis(\"off\")\n", 70 | " axes[timestep][2].axis(\"off\")\n", 71 | " \n", 72 | " fig.tight_layout() \n", 73 | " fig.show()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "pipe = PipelineMuse.from_pretrained(transformer_path=\"../results_corr/ft_256_toast_cls_b256_corr/checkpoint-50000/ema_model\", \n", 83 | " is_class_conditioned=True, use_toast=True, vae_path=\"../scripts/tokenizer_imagenet256_torch/\").to(device)\n", 84 | "# pipe = PipelineMuse.from_pretrained(transformer_path=\"../results_corr/ft_512_toast_cls_b256_corr/checkpoint-50000/ema_model\", \n", 85 | "# is_class_conditioned=True, use_toast=True, vae_path=\"../scripts/tokenizer_imagenet512_torch/\").to(device)\n", 86 | "pipe.transformer.eval()\n", 87 | "pipe.vae.eval()\n", 88 | "print(\"Loaded model\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# Scene class id of Imagenet\n", 98 | "\n", 99 | "class_ids = 105 # 105: koala 454: bookstore\n", 100 | "\n", 101 | "timesteps = 18\n", 102 | "images, intermediate_images, intermediate, mask_index = pipe(class_ids=class_ids, num_images_per_prompt=4, \n", 103 | " timesteps=timesteps, temperature=10, sampling_type='self_guidance', #maskgit or self_guidance\n", 104 | " return_intermediate=True, guidance_scale=1.0)\n", 105 | "print(imagenet_dict[class_ids])\n", 106 | "for i in range(4):\n", 107 | " display(images[i])" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "i_batch = 0\n", 117 | "display_intermediate_img_and_mask(i_batch, timesteps, images, intermediate_images, mask_index, res=256)" 118 | ] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "Python 3 (ipykernel)", 124 | "language": "python", 125 | "name": "python3" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.10.13" 138 | }, 139 | "orig_nbformat": 4, 140 | "vscode": { 141 | "interpreter": { 142 | "hash": "50dd96137dc1217dbb5a4b77b01cf18314368902f96ea7b6b16b7f34afe8268a" 143 | } 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 2 148 | } 149 | -------------------------------------------------------------------------------- /evaluations/README.md: -------------------------------------------------------------------------------- 1 | # Evaluations 2 | 3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files. 4 | 5 | # Download batches 6 | 7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format. 8 | 9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall. 10 | 11 | Here are links to download all of the sample and reference batches: 12 | 13 | * LSUN 14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz) 15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz) 16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz) 17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz) 18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz) 19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz) 20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz) 21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz) 22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz) 23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz) 24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz) 25 | 26 | * ImageNet 27 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz) 28 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz) 29 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz) 30 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz) 31 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz) 32 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz) 33 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz) 34 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz) 35 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz) 36 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) 37 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz) 38 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz) 39 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz) 40 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz) 41 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz) 42 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz) 43 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) 44 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz) 45 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz) 46 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz) 47 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz) 48 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz) 49 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz) 50 | 51 | # Run evaluations 52 | 53 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`. 54 | 55 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB. 56 | 57 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging: 58 | 59 | ``` 60 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz 61 | ... 62 | computing reference batch activations... 63 | computing/reading reference batch statistics... 64 | computing sample batch activations... 65 | computing/reading sample batch statistics... 66 | Computing evaluations... 67 | Inception Score: 215.8370361328125 68 | FID: 3.9425574129223264 69 | sFID: 6.140433703346162 70 | Precision: 0.8265 71 | Recall: 0.5309 72 | ``` 73 | -------------------------------------------------------------------------------- /scripts/convert_maskgit_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # !pip install ml_collections 18 | # !wget https://storage.googleapis.com/maskgit-public/checkpoints/tokenizer_imagenet256_checkpoint 19 | import logging 20 | 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from flax.traverse_util import flatten_dict 24 | import argparse 25 | 26 | import sys 27 | sys.path.append('.') 28 | from muse.modeling_transformer import MaskGitTransformer 29 | # breakpoint() 30 | import tensorflow as tf 31 | import flax 32 | 33 | def restore_from_path(path): 34 | with tf.io.gfile.GFile(path, "rb") as f: 35 | state = flax.serialization.from_bytes(None, f.read()) 36 | return state 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | def rename_flax_dict(params): 42 | keys = list(params.keys()) 43 | 44 | for key in keys: 45 | new_key = ".".join(key) 46 | params[new_key] = params.pop(key) 47 | keys = list(params.keys()) 48 | 49 | for key in keys: 50 | new_key = key.replace("Embed_0", "embed") 51 | for i in range(24): 52 | new_key = new_key.replace(f"TransformerLayer_{i}.", f"transformer_layers.{i}.") 53 | new_key = new_key.replace("Attention_0.attention_output_ln", "attn_layer_norm") 54 | new_key = new_key.replace("Attention_0.self_attention", "attention") 55 | new_key = new_key.replace("MlmLayer_0", "mlm_layer") 56 | new_key = new_key.replace("mlm_bias", "to_logits") 57 | new_key = new_key.replace("Mlp_0", "ffn") 58 | new_key = new_key.replace("layer_output_ln", "layer_norm") 59 | new_key = new_key.replace("scale", "weight") 60 | new_key = new_key.replace("embeddings.embedding", "embeddings.weight") 61 | new_key = new_key.replace("kernel", "weight") 62 | params[new_key] = params.pop(key) 63 | params['mlm_layer.to_logits.weight'] = params['embed.word_embeddings.weight'] 64 | return params 65 | 66 | 67 | def load_flax_weights_in_pytorch_model(pt_model, flax_state): 68 | """Load flax checkpoints in a PyTorch model""" 69 | 70 | try: 71 | import torch # noqa: F401 72 | except ImportError: 73 | logger.error( 74 | "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see" 75 | " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" 76 | " instructions." 77 | ) 78 | raise 79 | 80 | pt_model_dict = pt_model.state_dict() 81 | 82 | # keep track of unexpected & missing keys 83 | unexpected_keys = [] 84 | missing_keys = set(pt_model_dict.keys()) 85 | 86 | for flax_key, flax_tensor in flax_state.items(): 87 | flax_key_tuple = tuple(flax_key.split(".")) 88 | flax_tensor = torch.from_numpy(np.array(flax_tensor)) 89 | if flax_key_tuple[0] == "transformer_layers" and len(flax_tensor.shape) == 3: 90 | if flax_tensor.shape[0] == 768: 91 | flax_tensor = flax_tensor.permute(1, 2, 0) 92 | elif flax_tensor.shape[-1] == 768: 93 | flax_tensor = flax_tensor.permute(2, 0, 1) 94 | if flax_key_tuple[0] == "transformer_layers" and len(flax_tensor.shape) == 2: 95 | if not 'bias' == flax_key_tuple[-1]: 96 | flax_tensor = flax_tensor.permute(1, 0) 97 | if flax_key_tuple[0] == "mlm_layer" and len(flax_tensor.shape) == 2: 98 | if not flax_key_tuple[1] == "to_logits": 99 | flax_tensor = flax_tensor.permute(1, 0) 100 | 101 | if flax_key_tuple[0] == "transformer_layers" and 'bias' == flax_key_tuple[-1] and len(flax_tensor.shape) == 2: 102 | flax_tensor = flax_tensor.reshape(-1) 103 | if flax_key_tuple[0] == "transformer_layers" and 'weight' == flax_key_tuple[-1] and len(flax_tensor.shape) == 3: 104 | if flax_tensor.shape[0] == 768: 105 | flax_tensor = flax_tensor.reshape(flax_tensor.shape[0], -1) 106 | elif flax_tensor.shape[-1] == 768: 107 | flax_tensor = flax_tensor.reshape(-1, flax_tensor.shape[-1]) 108 | 109 | flax_key = ".".join(flax_key_tuple) 110 | if flax_key in pt_model_dict: 111 | if flax_tensor.shape != pt_model_dict[flax_key].shape: 112 | raise ValueError( 113 | f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " 114 | f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." 115 | ) 116 | else: 117 | # add weight to pytorch dict 118 | flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor 119 | pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) 120 | # remove from missing keys 121 | missing_keys.remove(flax_key) 122 | else: 123 | # weight is not expected by PyTorch model 124 | unexpected_keys.append(flax_key) 125 | 126 | pt_model.load_state_dict(pt_model_dict) 127 | 128 | # re-transform missing_keys to list 129 | missing_keys = list(missing_keys) 130 | 131 | if len(unexpected_keys) > 0: 132 | logger.warning( 133 | "Some weights of the Flax model were not used when initializing the PyTorch model" 134 | f" {pt_model.__class__.__name__}: {unexpected_keys}." 135 | ) 136 | else: 137 | logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n") 138 | if len(missing_keys) > 0: 139 | logger.warning( 140 | f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" 141 | f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" 142 | " use it for predictions and inference." 143 | ) 144 | else: 145 | logger.warning(f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.") 146 | 147 | return pt_model 148 | 149 | 150 | def convert(flax_model_path, pytorch_dump_folder_path, num_vq_tokens): 151 | params = restore_from_path(flax_model_path)["params"] 152 | params = flatten_dict(params) 153 | params = rename_flax_dict(params) 154 | 155 | pt_model = MaskGitTransformer( 156 | vocab_size=2025, #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + ) 157 | max_position_embeddings=num_vq_tokens+1, # num_vq_tokens + 1 for class token 158 | hidden_size=768, 159 | num_hidden_layers=24, 160 | num_attention_heads=16, 161 | intermediate_size=3072, 162 | codebook_size=1024, 163 | num_vq_tokens=num_vq_tokens, 164 | num_classes=1000, 165 | layer_norm_eps=1e-12, 166 | layer_norm_embeddings=True, 167 | use_bias=True, 168 | use_encoder_layernorm=False, 169 | use_maskgit_mlp=True, 170 | use_normformer=False 171 | ) 172 | 173 | pt_model = load_flax_weights_in_pytorch_model(pt_model, params) 174 | pt_model.save_pretrained(pytorch_dump_folder_path) 175 | return pt_model 176 | 177 | def parse_args(): 178 | parser = argparse.ArgumentParser(description="Simple loading script for maskgit.") 179 | parser.add_argument( 180 | "--flax_model_path", 181 | type=str, 182 | default=None, 183 | required=True, 184 | help="Path to flax maskgit transformer.", 185 | ) 186 | parser.add_argument( 187 | "--pytorch_dump_folder_path", 188 | type=str, 189 | default=None, 190 | required=True, 191 | help="Path to dump pytorch model.", 192 | ) 193 | return parser.parse_args() 194 | 195 | if __name__ == "__main__": 196 | # args = parse_args() 197 | resolution = 256 # 256 or 512 198 | flax_model_path = f'./scripts/maskgit_imagenet{resolution}_checkpoint' 199 | pytorch_dump_folder_path = f'./scripts/maskgit_imagenet{resolution}_torch' 200 | num_vq_tokens = 256 if resolution == 256 else 1024 201 | convert(flax_model_path, pytorch_dump_folder_path, num_vq_tokens) -------------------------------------------------------------------------------- /muse/modeling_paella_vq.py: -------------------------------------------------------------------------------- 1 | # VQGAN taken from https://github.com/dome272/Paella/ 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from .modeling_utils import ConfigMixin, ModelMixin, register_to_config 10 | 11 | # TODO: This model only supports inference, not training. Make it trainable. 12 | 13 | 14 | class VectorQuantizer(nn.Module): 15 | """ 16 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 17 | Discretization bottleneck part of the VQ-VAE. 18 | """ 19 | 20 | def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25): 21 | r""" 22 | Args: 23 | num_embeddings: number of vectors in the quantized space. 24 | embedding_dim: dimensionality of the tensors in the quantized space. 25 | Inputs to the modules must be in this format as well. 26 | commitment_cost: scalar which controls the weighting of the loss terms 27 | (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta). 28 | """ 29 | super().__init__() 30 | 31 | self.num_embeddings = num_embeddings 32 | self.codebook_dim = embedding_dim 33 | self.commitment_cost = commitment_cost 34 | 35 | self.codebook = nn.Embedding(num_embeddings, embedding_dim) 36 | self.codebook.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings) 37 | 38 | def forward(self, hidden_states, return_loss=False): 39 | """ 40 | Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the 41 | closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width) 42 | quantization pipeline: 43 | 1. get encoder input (B,C,H,W) 44 | 2. flatten input to (B*H*W,C) 45 | """ 46 | # reshape z -> (batch, height, width, channel) and flatten 47 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() 48 | 49 | distances = self.compute_distances(hidden_states) 50 | min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1) 51 | min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states) 52 | min_encodings.scatter_(1, min_encoding_indices, 1) 53 | 54 | # get quantized latent vectors 55 | z_q = torch.matmul(min_encodings, self.codebook.weight).view(hidden_states.shape) 56 | 57 | # reshape to (batch, num_tokens) 58 | min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) 59 | 60 | # compute loss for embedding 61 | loss = None 62 | if return_loss: 63 | loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean( 64 | (z_q - hidden_states.detach()) ** 2 65 | ) 66 | # preserve gradients 67 | z_q = hidden_states + (z_q - hidden_states).detach() 68 | 69 | # reshape back to match original input shape 70 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 71 | 72 | return z_q, min_encoding_indices, loss 73 | 74 | def compute_distances(self, hidden_states): 75 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 76 | hidden_states_flattended = hidden_states.reshape((-1, self.codebook_dim)) 77 | return torch.cdist(hidden_states_flattended, self.codebook.weight) 78 | 79 | def get_codebook_entry(self, indices): 80 | # indices are expected to be of shape (batch, num_tokens) 81 | # get quantized latent vectors 82 | batch, num_tokens = indices.shape 83 | z_q = self.codebook(indices) 84 | z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2) 85 | return z_q 86 | 87 | # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372 88 | def get_soft_code(self, hidden_states, temp=1.0, stochastic=False): 89 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel) 90 | distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings) 91 | 92 | soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings) 93 | if stochastic: 94 | code = torch.multinomial(soft_code, 1) # (batch * height * width, 1) 95 | else: 96 | code = distances.argmin(dim=-1) # (batch * height * width) 97 | 98 | code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width) 99 | batch, num_tokens = code.shape 100 | soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings) 101 | return soft_code, code 102 | 103 | def get_code(self, hidden_states): 104 | # reshape z -> (batch, height, width, channel) 105 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() 106 | distances = self.compute_distances(hidden_states) 107 | indices = torch.argmin(distances, axis=1).unsqueeze(1) 108 | indices = indices.reshape(hidden_states.shape[0], -1) 109 | return indices 110 | 111 | 112 | class ResBlock(nn.Module): 113 | def __init__(self, c, c_hidden): 114 | super().__init__() 115 | # depthwise/attention 116 | self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) 117 | self.depthwise = nn.Sequential(nn.ReplicationPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c)) 118 | 119 | self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) 120 | self.channelwise = nn.Sequential( 121 | nn.Linear(c, c_hidden), 122 | nn.GELU(), 123 | nn.Linear(c_hidden, c), 124 | ) 125 | 126 | self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) 127 | 128 | def _basic_init(module): 129 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 130 | torch.nn.init.xavier_uniform_(module.weight) 131 | if module.bias is not None: 132 | nn.init.constant_(module.bias, 0) 133 | 134 | self.apply(_basic_init) 135 | 136 | def _norm(self, x, norm): 137 | return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 138 | 139 | def forward(self, x): 140 | mods = self.gammas 141 | x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] 142 | x = x + self.depthwise(x_temp) * mods[2] 143 | x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] 144 | x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] 145 | return x 146 | 147 | 148 | class PaellaVQModel(ModelMixin, ConfigMixin): 149 | @register_to_config 150 | def __init__( 151 | self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.3764 152 | ): # 1.0 153 | super().__init__() 154 | self.c_latent = c_latent 155 | self.scale_factor = scale_factor 156 | c_levels = [c_hidden // (2**i) for i in reversed(range(levels))] 157 | 158 | # Encoder blocks 159 | self.in_block = nn.Sequential(nn.PixelUnshuffle(2), nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)) 160 | down_blocks = [] 161 | for i in range(levels): 162 | if i > 0: 163 | down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) 164 | block = ResBlock(c_levels[i], c_levels[i] * 4) 165 | down_blocks.append(block) 166 | down_blocks.append( 167 | nn.Sequential( 168 | nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), 169 | nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 170 | ) 171 | ) 172 | self.down_blocks = nn.Sequential(*down_blocks) 173 | 174 | self.codebook_size = codebook_size 175 | self.vquantizer = VectorQuantizer(codebook_size, c_latent) 176 | 177 | # Decoder blocks 178 | up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))] 179 | for i in range(levels): 180 | for j in range(bottleneck_blocks if i == 0 else 1): 181 | block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) 182 | up_blocks.append(block) 183 | if i < levels - 1: 184 | up_blocks.append( 185 | nn.ConvTranspose2d( 186 | c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 187 | ) 188 | ) 189 | self.up_blocks = nn.Sequential(*up_blocks) 190 | self.out_block = nn.Sequential( 191 | nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), 192 | nn.PixelShuffle(2), 193 | ) 194 | 195 | def encode(self, x): 196 | x = self.in_block(x) 197 | x = self.down_blocks(x) 198 | # qe, (vq_loss, commit_loss), indices = self.vquantizer(x, dim=1) 199 | # return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 200 | quantized_states, codebook_indices, codebook_loss = self.vquantizer(x) 201 | quantized_states = quantized_states / self.scale_factor 202 | output = (quantized_states, codebook_indices, codebook_loss) 203 | return output 204 | 205 | def decode(self, x): 206 | x = x * self.scale_factor 207 | x = self.up_blocks(x) 208 | x = self.out_block(x) 209 | return x 210 | 211 | def decode_code(self, codebook_indices): 212 | x = self.vquantizer.get_codebook_entry(codebook_indices) 213 | x = self.up_blocks(x) 214 | x = self.out_block(x) 215 | return x 216 | 217 | def get_code(self, pixel_values): 218 | x = self.in_block(pixel_values) 219 | x = self.down_blocks(x) 220 | return self.vquantizer.get_code(x) 221 | 222 | def forward(self, x, quantize=False): 223 | qe = self.encode(x)[0] 224 | x = self.decode(qe) 225 | return x 226 | -------------------------------------------------------------------------------- /scripts/convert_maskgit_vqgan.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # !pip install ml_collections 18 | # !wget https://storage.googleapis.com/maskgit-public/checkpoints/tokenizer_imagenet256_checkpoint 19 | import logging 20 | import sys 21 | sys.path.append('.') 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from flax.traverse_util import flatten_dict 25 | from muse import MaskGitVQGAN 26 | import tensorflow as tf 27 | import flax 28 | 29 | def restore_from_path(path): 30 | with tf.io.gfile.GFile(path, "rb") as f: 31 | state = flax.serialization.from_bytes(None, f.read()) 32 | return state 33 | 34 | 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | def rename_flax_dict(params): 40 | keys = list(params.keys()) 41 | 42 | for key in keys: 43 | new_key = ".".join(key) 44 | params[new_key] = params.pop(key) 45 | keys = list(params.keys()) 46 | 47 | block_map = { 48 | 0: (0, 0), 49 | 1: (0, 1), 50 | 2: (1, 0), 51 | 3: (1, 1), 52 | 4: (2, 0), 53 | 5: (2, 1), 54 | 6: (3, 0), 55 | 7: (3, 1), 56 | 8: (4, 0), 57 | 9: (4, 1), 58 | } 59 | 60 | encoder_keys = [key for key in keys if "encoder.ResBlock" in key] 61 | for key in encoder_keys: 62 | if "ResBlock_10" in key: 63 | new_key = key.replace("ResBlock_10", "mid.0") 64 | new_key = new_key.replace("Conv_0", "conv1") 65 | new_key = new_key.replace("Conv_1", "conv2") 66 | new_key = new_key.replace("GroupNorm_0", "norm1") 67 | new_key = new_key.replace("GroupNorm_1", "norm2") 68 | params[new_key] = params.pop(key) 69 | elif "ResBlock_11" in key: 70 | new_key = key.replace("ResBlock_11", "mid.1") 71 | new_key = new_key.replace("Conv_0", "conv1") 72 | new_key = new_key.replace("Conv_1", "conv2") 73 | new_key = new_key.replace("GroupNorm_0", "norm1") 74 | new_key = new_key.replace("GroupNorm_1", "norm2") 75 | params[new_key] = params.pop(key) 76 | keys = list(params.keys()) 77 | 78 | encoder_keys = [key for key in keys if "encoder.ResBlock" in key] 79 | for key in encoder_keys: 80 | name = key.split(".")[1] 81 | res_name, idx = name.split("_") 82 | idx1, idx2 = block_map[int(idx)] 83 | new_key = key.replace(name, f"down.{idx1}.block.{idx2}") 84 | new_key = new_key.replace("Conv_0", "conv1") 85 | new_key = new_key.replace("Conv_1", "conv2") 86 | new_key = new_key.replace("Conv_2", "nin_shortcut") 87 | new_key = new_key.replace("GroupNorm_0", "norm1") 88 | new_key = new_key.replace("GroupNorm_1", "norm2") 89 | params[new_key] = params.pop(key) 90 | keys = list(params.keys()) 91 | 92 | decoder_keys = [key for key in keys if "decoder.ResBlock" in key] 93 | for key in decoder_keys: 94 | if "ResBlock_0" in key: 95 | new_key = key.replace("ResBlock_0", "mid.0") 96 | new_key = new_key.replace("Conv_0", "conv1") 97 | new_key = new_key.replace("Conv_1", "conv2") 98 | new_key = new_key.replace("GroupNorm_0", "norm1") 99 | new_key = new_key.replace("GroupNorm_1", "norm2") 100 | params[new_key] = params.pop(key) 101 | elif "ResBlock_1." in key: 102 | new_key = key.replace("ResBlock_1", "mid.1") 103 | new_key = new_key.replace("Conv_0", "conv1") 104 | new_key = new_key.replace("Conv_1", "conv2") 105 | new_key = new_key.replace("GroupNorm_0", "norm1") 106 | new_key = new_key.replace("GroupNorm_1", "norm2") 107 | params[new_key] = params.pop(key) 108 | keys = list(params.keys()) 109 | 110 | decoder_keys = [key for key in keys if "decoder.ResBlock" in key] 111 | for key in decoder_keys: 112 | name = key.split(".")[1] 113 | res_name, idx = name.split("_") 114 | idx = int(idx) - 2 115 | idx1, idx2 = block_map[int(idx)] 116 | idx1 = 4 - idx1 117 | new_key = key.replace(name, f"up.{idx1}.block.{idx2}") 118 | new_key = new_key.replace("Conv_0", "conv1") 119 | new_key = new_key.replace("Conv_1", "conv2") 120 | new_key = new_key.replace("Conv_2", "nin_shortcut") 121 | new_key = new_key.replace("GroupNorm_0", "norm1") 122 | new_key = new_key.replace("GroupNorm_1", "norm2") 123 | params[new_key] = params.pop(key) 124 | keys = list(params.keys()) 125 | 126 | for i in range(1, 5): 127 | w = f"decoder.Conv_{i}.kernel" 128 | b = f"decoder.Conv_{i}.bias" 129 | new_w = f"decoder.up.{5 - i}.upsample_conv.kernel" 130 | new_b = f"decoder.up.{5 - i}.upsample_conv.bias" 131 | params[new_w] = params.pop(w) 132 | params[new_b] = params.pop(b) 133 | keys = list(params.keys()) 134 | 135 | for key in keys: 136 | if "Conv_" in key: 137 | new_key = key.replace("Conv_0", "conv_in") 138 | new_key = new_key.replace("Conv_5", "conv_out") 139 | new_key = new_key.replace("Conv_1", "conv_out") 140 | params[new_key] = params.pop(key) 141 | elif "GroupNorm" in key: 142 | new_key = key.replace("GroupNorm_0", "norm_out") 143 | params[new_key] = params.pop(key) 144 | params["quantize.embedding.embedding"] = params.pop("quantizer.codebook") 145 | 146 | return params 147 | 148 | 149 | def load_flax_weights_in_pytorch_model(pt_model, flax_state): 150 | """Load flax checkpoints in a PyTorch model""" 151 | 152 | try: 153 | import torch # noqa: F401 154 | except ImportError: 155 | logger.error( 156 | "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see" 157 | " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" 158 | " instructions." 159 | ) 160 | raise 161 | 162 | pt_model_dict = pt_model.state_dict() 163 | 164 | # keep track of unexpected & missing keys 165 | unexpected_keys = [] 166 | missing_keys = set(pt_model_dict.keys()) 167 | 168 | for flax_key, flax_tensor in flax_state.items(): 169 | flax_key_tuple = tuple(flax_key.split(".")) 170 | 171 | # rename flax weights to PyTorch format 172 | if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict: 173 | # conv layer 174 | flax_key_tuple = flax_key_tuple[:-1] + ("weight",) 175 | flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) 176 | elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict: 177 | # linear layer 178 | flax_key_tuple = flax_key_tuple[:-1] + ("weight",) 179 | flax_tensor = flax_tensor.T 180 | elif flax_key_tuple[-1] in ["scale", "embedding"]: 181 | flax_key_tuple = flax_key_tuple[:-1] + ("weight",) 182 | 183 | flax_key = ".".join(flax_key_tuple) 184 | 185 | if "in_proj.weight" in flax_key: 186 | flax_key = flax_key.replace("in_proj.weight", "in_proj_weight") 187 | 188 | if flax_key in pt_model_dict: 189 | if flax_tensor.shape != pt_model_dict[flax_key].shape: 190 | raise ValueError( 191 | f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " 192 | f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." 193 | ) 194 | else: 195 | # add weight to pytorch dict 196 | flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor 197 | pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) 198 | # remove from missing keys 199 | missing_keys.remove(flax_key) 200 | else: 201 | # weight is not expected by PyTorch model 202 | unexpected_keys.append(flax_key) 203 | 204 | pt_model.load_state_dict(pt_model_dict) 205 | 206 | # re-transform missing_keys to list 207 | missing_keys = list(missing_keys) 208 | 209 | if len(unexpected_keys) > 0: 210 | logger.warning( 211 | "Some weights of the Flax model were not used when initializing the PyTorch model" 212 | f" {pt_model.__class__.__name__}: {unexpected_keys}." 213 | ) 214 | else: 215 | logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n") 216 | if len(missing_keys) > 0: 217 | logger.warning( 218 | f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" 219 | f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" 220 | " use it for predictions and inference." 221 | ) 222 | else: 223 | logger.warning(f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.") 224 | 225 | return pt_model 226 | 227 | 228 | def convert(flax_model_path, pytorch_dump_folder_path): 229 | params = restore_from_path(flax_model_path)["params"] 230 | params = flatten_dict(params) 231 | params = rename_flax_dict(params) 232 | 233 | pt_model = MaskGitVQGAN() 234 | 235 | pt_model = load_flax_weights_in_pytorch_model(pt_model, params) 236 | pt_model.save_pretrained(pytorch_dump_folder_path) 237 | return pt_model 238 | 239 | 240 | if __name__ == "__main__": 241 | resolution = 256 # 256 or 512 242 | flax_model_path = f'./scripts/tokenizer_imagenet{resolution}_checkpoint' 243 | pytorch_dump_folder_path = f'./scripts/tokenizer_imagenet{resolution}_torch' 244 | convert(flax_model_path, pytorch_dump_folder_path) -------------------------------------------------------------------------------- /muse/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities.""" 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from logging import CRITICAL # NOQA 22 | from logging import DEBUG # NOQA 23 | from logging import ERROR # NOQA 24 | from logging import FATAL # NOQA 25 | from logging import INFO # NOQA 26 | from logging import NOTSET # NOQA 27 | from logging import WARN # NOQA 28 | from logging import WARNING # NOQA 29 | from typing import Optional 30 | 31 | from tqdm import auto as tqdm_lib 32 | 33 | _lock = threading.Lock() 34 | _default_handler: Optional[logging.Handler] = None 35 | 36 | log_levels = { 37 | "debug": logging.DEBUG, 38 | "info": logging.INFO, 39 | "warning": logging.WARNING, 40 | "error": logging.ERROR, 41 | "critical": logging.CRITICAL, 42 | } 43 | 44 | _default_log_level = logging.WARNING 45 | 46 | _tqdm_active = True 47 | 48 | 49 | def _get_default_logging_level(): 50 | """ 51 | If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 52 | not - fall back to `_default_log_level` 53 | """ 54 | env_level_str = os.getenv("muse_VERBOSITY", None) 55 | if env_level_str: 56 | if env_level_str in log_levels: 57 | return log_levels[env_level_str] 58 | else: 59 | logging.getLogger().warning( 60 | f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }" 61 | ) 62 | return _default_log_level 63 | 64 | 65 | def _get_library_name() -> str: 66 | return __name__.split(".")[0] 67 | 68 | 69 | def _get_library_root_logger() -> logging.Logger: 70 | return logging.getLogger(_get_library_name()) 71 | 72 | 73 | def _configure_library_root_logger() -> None: 74 | global _default_handler 75 | 76 | with _lock: 77 | if _default_handler: 78 | # This library has already configured the library root logger. 79 | return 80 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 81 | _default_handler.flush = sys.stderr.flush 82 | 83 | # Apply our default configuration to the library root logger. 84 | library_root_logger = _get_library_root_logger() 85 | library_root_logger.addHandler(_default_handler) 86 | library_root_logger.setLevel(_get_default_logging_level()) 87 | library_root_logger.propagate = False 88 | 89 | 90 | def _reset_library_root_logger() -> None: 91 | global _default_handler 92 | 93 | with _lock: 94 | if not _default_handler: 95 | return 96 | 97 | library_root_logger = _get_library_root_logger() 98 | library_root_logger.removeHandler(_default_handler) 99 | library_root_logger.setLevel(logging.NOTSET) 100 | _default_handler = None 101 | 102 | 103 | def get_log_levels_dict(): 104 | return log_levels 105 | 106 | 107 | def get_logger(name: Optional[str] = None) -> logging.Logger: 108 | """ 109 | Return a logger with the specified name. 110 | 111 | This function is not supposed to be directly accessed unless you are writing a custom muse module. 112 | """ 113 | 114 | if name is None: 115 | name = _get_library_name() 116 | 117 | _configure_library_root_logger() 118 | return logging.getLogger(name) 119 | 120 | 121 | def get_verbosity() -> int: 122 | """ 123 | Return the current level for the 🤗 muse' root logger as an int. 124 | 125 | Returns: 126 | `int`: The logging level. 127 | 128 | 129 | 130 | 🤗 muse has following logging levels: 131 | 132 | - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL` 133 | - 40: `muse.logging.ERROR` 134 | - 30: `muse.logging.WARNING` or `muse.logging.WARN` 135 | - 20: `muse.logging.INFO` 136 | - 10: `muse.logging.DEBUG` 137 | 138 | """ 139 | 140 | _configure_library_root_logger() 141 | return _get_library_root_logger().getEffectiveLevel() 142 | 143 | 144 | def set_verbosity(verbosity: int) -> None: 145 | """ 146 | Set the verbosity level for the 🤗 muse' root logger. 147 | 148 | Args: 149 | verbosity (`int`): 150 | Logging level, e.g., one of: 151 | 152 | - `muse.logging.CRITICAL` or `muse.logging.FATAL` 153 | - `muse.logging.ERROR` 154 | - `muse.logging.WARNING` or `muse.logging.WARN` 155 | - `muse.logging.INFO` 156 | - `muse.logging.DEBUG` 157 | """ 158 | 159 | _configure_library_root_logger() 160 | _get_library_root_logger().setLevel(verbosity) 161 | 162 | 163 | def set_verbosity_info(): 164 | """Set the verbosity to the `INFO` level.""" 165 | return set_verbosity(INFO) 166 | 167 | 168 | def set_verbosity_warning(): 169 | """Set the verbosity to the `WARNING` level.""" 170 | return set_verbosity(WARNING) 171 | 172 | 173 | def set_verbosity_debug(): 174 | """Set the verbosity to the `DEBUG` level.""" 175 | return set_verbosity(DEBUG) 176 | 177 | 178 | def set_verbosity_error(): 179 | """Set the verbosity to the `ERROR` level.""" 180 | return set_verbosity(ERROR) 181 | 182 | 183 | def disable_default_handler() -> None: 184 | """Disable the default handler of the HuggingFace muse' root logger.""" 185 | 186 | _configure_library_root_logger() 187 | 188 | assert _default_handler is not None 189 | _get_library_root_logger().removeHandler(_default_handler) 190 | 191 | 192 | def enable_default_handler() -> None: 193 | """Enable the default handler of the HuggingFace muse' root logger.""" 194 | 195 | _configure_library_root_logger() 196 | 197 | assert _default_handler is not None 198 | _get_library_root_logger().addHandler(_default_handler) 199 | 200 | 201 | def add_handler(handler: logging.Handler) -> None: 202 | """adds a handler to the HuggingFace muse' root logger.""" 203 | 204 | _configure_library_root_logger() 205 | 206 | assert handler is not None 207 | _get_library_root_logger().addHandler(handler) 208 | 209 | 210 | def remove_handler(handler: logging.Handler) -> None: 211 | """removes given handler from the HuggingFace muse' root logger.""" 212 | 213 | _configure_library_root_logger() 214 | 215 | assert handler is not None and handler not in _get_library_root_logger().handlers 216 | _get_library_root_logger().removeHandler(handler) 217 | 218 | 219 | def disable_propagation() -> None: 220 | """ 221 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 222 | """ 223 | 224 | _configure_library_root_logger() 225 | _get_library_root_logger().propagate = False 226 | 227 | 228 | def enable_propagation() -> None: 229 | """ 230 | Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent 231 | double logging if the root logger has been configured. 232 | """ 233 | 234 | _configure_library_root_logger() 235 | _get_library_root_logger().propagate = True 236 | 237 | 238 | def enable_explicit_format() -> None: 239 | """ 240 | Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows: 241 | ``` 242 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 243 | ``` 244 | All handlers currently bound to the root logger are affected by this method. 245 | """ 246 | handlers = _get_library_root_logger().handlers 247 | 248 | for handler in handlers: 249 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") 250 | handler.setFormatter(formatter) 251 | 252 | 253 | def reset_format() -> None: 254 | """ 255 | Resets the formatting for HuggingFace muse' loggers. 256 | 257 | All handlers currently bound to the root logger are affected by this method. 258 | """ 259 | handlers = _get_library_root_logger().handlers 260 | 261 | for handler in handlers: 262 | handler.setFormatter(None) 263 | 264 | 265 | def warning_advice(self, *args, **kwargs): 266 | """ 267 | This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this 268 | warning will not be printed 269 | """ 270 | no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False) 271 | if no_advisory_warnings: 272 | return 273 | self.warning(*args, **kwargs) 274 | 275 | 276 | logging.Logger.warning_advice = warning_advice 277 | 278 | 279 | class EmptyTqdm: 280 | """Dummy tqdm which doesn't do anything.""" 281 | 282 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument 283 | self._iterator = args[0] if args else None 284 | 285 | def __iter__(self): 286 | return iter(self._iterator) 287 | 288 | def __getattr__(self, _): 289 | """Return empty function.""" 290 | 291 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument 292 | return 293 | 294 | return empty_fn 295 | 296 | def __enter__(self): 297 | return self 298 | 299 | def __exit__(self, type_, value, traceback): 300 | return 301 | 302 | 303 | class _tqdm_cls: 304 | def __call__(self, *args, **kwargs): 305 | if _tqdm_active: 306 | return tqdm_lib.tqdm(*args, **kwargs) 307 | else: 308 | return EmptyTqdm(*args, **kwargs) 309 | 310 | def set_lock(self, *args, **kwargs): 311 | self._lock = None 312 | if _tqdm_active: 313 | return tqdm_lib.tqdm.set_lock(*args, **kwargs) 314 | 315 | def get_lock(self): 316 | if _tqdm_active: 317 | return tqdm_lib.tqdm.get_lock() 318 | 319 | 320 | tqdm = _tqdm_cls() 321 | 322 | 323 | def is_progress_bar_enabled() -> bool: 324 | """Return a boolean indicating whether tqdm progress bars are enabled.""" 325 | global _tqdm_active 326 | return bool(_tqdm_active) 327 | 328 | 329 | def enable_progress_bar(): 330 | """Enable tqdm progress bar.""" 331 | global _tqdm_active 332 | _tqdm_active = True 333 | 334 | 335 | def disable_progress_bar(): 336 | """Disable tqdm progress bar.""" 337 | global _tqdm_active 338 | _tqdm_active = False 339 | -------------------------------------------------------------------------------- /muse/modeling_ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Dict, Iterable, Optional, Union 3 | 4 | import torch 5 | 6 | 7 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 8 | class EMAModel: 9 | """ 10 | Exponential Moving Average of models weights 11 | """ 12 | 13 | def __init__( 14 | self, 15 | parameters: Iterable[torch.nn.Parameter], 16 | decay: float = 0.9999, 17 | min_decay: float = 0.0, 18 | update_after_step: int = 0, 19 | update_every: int = 1, 20 | use_ema_warmup: bool = False, 21 | inv_gamma: Union[float, int] = 1.0, 22 | power: Union[float, int] = 2 / 3, 23 | model_cls: Optional[Any] = None, 24 | model_config: Dict[str, Any] = None, 25 | ): 26 | """ 27 | Args: 28 | parameters (Iterable[torch.nn.Parameter]): The parameters to track. 29 | decay (float): The decay factor for the exponential moving average. 30 | min_decay (float): The minimum decay factor for the exponential moving average. 31 | update_after_step (int): The number of steps to wait before starting to update the EMA weights. 32 | update_every (int): The number of steps between each EMA update. 33 | use_ema_warmup (bool): Whether to use EMA warmup. 34 | inv_gamma (float): 35 | Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. 36 | power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. 37 | device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA 38 | weights will be stored on CPU. 39 | 40 | @crowsonkb's notes on EMA Warmup: 41 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 42 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 43 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 44 | at 215.4k steps). 45 | """ 46 | 47 | parameters = list(parameters) 48 | self.shadow_params = [p.clone().detach() for p in parameters] 49 | self.temp_stored_params = None 50 | 51 | self.decay = decay 52 | self.min_decay = min_decay 53 | self.update_after_step = update_after_step 54 | self.update_every = update_every 55 | self.use_ema_warmup = use_ema_warmup 56 | self.inv_gamma = inv_gamma 57 | self.power = power 58 | self.optimization_step = 0 59 | self.cur_decay_value = None # set in `step()` 60 | 61 | self.model_cls = model_cls 62 | self.model_config = model_config 63 | 64 | @classmethod 65 | def from_pretrained(cls, path, model_cls) -> "EMAModel": 66 | _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) 67 | model = model_cls.from_pretrained(path) 68 | 69 | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) 70 | 71 | ema_model.load_state_dict(ema_kwargs) 72 | return ema_model 73 | 74 | def save_pretrained(self, path): 75 | if self.model_cls is None: 76 | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") 77 | 78 | if self.model_config is None: 79 | raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") 80 | 81 | model = self.model_cls.from_config(self.model_config) 82 | state_dict = self.state_dict() 83 | state_dict.pop("shadow_params", None) 84 | 85 | model.register_to_config(**state_dict) 86 | self.copy_to(model.parameters()) 87 | model.save_pretrained(path) 88 | 89 | def get_decay(self, optimization_step: int) -> float: 90 | """ 91 | Compute the decay factor for the exponential moving average. 92 | """ 93 | step = max(0, optimization_step - self.update_after_step - 1) 94 | 95 | if step <= 0: 96 | return 0.0 97 | 98 | if self.use_ema_warmup: 99 | cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power 100 | else: 101 | cur_decay_value = (1 + step) / (10 + step) 102 | 103 | cur_decay_value = min(cur_decay_value, self.decay) 104 | # make sure decay is not smaller than min_decay 105 | cur_decay_value = max(cur_decay_value, self.min_decay) 106 | return cur_decay_value 107 | 108 | @torch.no_grad() 109 | def step(self, parameters: Iterable[torch.nn.Parameter]): 110 | parameters = list(parameters) 111 | 112 | self.optimization_step += 1 113 | 114 | if (self.optimization_step - 1) % self.update_every != 0: 115 | return 116 | 117 | # Compute the decay factor for the exponential moving average. 118 | decay = self.get_decay(self.optimization_step) 119 | self.cur_decay_value = decay 120 | one_minus_decay = 1 - decay 121 | 122 | for s_param, param in zip(self.shadow_params, parameters): 123 | if param.requires_grad: 124 | s_param.sub_(one_minus_decay * (s_param - param)) 125 | else: 126 | s_param.copy_(param) 127 | 128 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 129 | """ 130 | Copy current averaged parameters into given collection of parameters. 131 | 132 | Args: 133 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 134 | updated with the stored moving averages. If `None`, the parameters with which this 135 | `ExponentialMovingAverage` was initialized will be used. 136 | """ 137 | parameters = list(parameters) 138 | for s_param, param in zip(self.shadow_params, parameters): 139 | param.data.copy_(s_param.to(param.device).data) 140 | 141 | def to(self, device=None, dtype=None) -> None: 142 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 143 | 144 | Args: 145 | device: like `device` argument to `torch.Tensor.to` 146 | """ 147 | # .to() on the tensors handles None correctly 148 | self.shadow_params = [ 149 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 150 | for p in self.shadow_params 151 | ] 152 | 153 | def state_dict(self) -> dict: 154 | r""" 155 | Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during 156 | checkpointing to save the ema state dict. 157 | """ 158 | # Following PyTorch conventions, references to tensors are returned: 159 | # "returns a reference to the state and not its copy!" - 160 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict 161 | return { 162 | "decay": self.decay, 163 | "min_decay": self.min_decay, 164 | "optimization_step": self.optimization_step, 165 | "update_after_step": self.update_after_step, 166 | "use_ema_warmup": self.use_ema_warmup, 167 | "inv_gamma": self.inv_gamma, 168 | "power": self.power, 169 | "shadow_params": self.shadow_params, 170 | } 171 | 172 | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: 173 | r""" 174 | Args: 175 | Save the current parameters for restoring later. 176 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 177 | temporarily stored. 178 | """ 179 | self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] 180 | 181 | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: 182 | r""" 183 | Args: 184 | Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: 185 | affecting the original optimization process. Store the parameters before the `copy_to()` method. After 186 | validation (or model saving), use this to restore the former parameters. 187 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 188 | updated with the stored parameters. If `None`, the parameters with which this 189 | `ExponentialMovingAverage` was initialized will be used. 190 | """ 191 | if self.temp_stored_params is None: 192 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") 193 | for c_param, param in zip(self.temp_stored_params, parameters): 194 | param.data.copy_(c_param.data) 195 | 196 | # Better memory-wise. 197 | self.temp_stored_params = None 198 | 199 | def load_state_dict(self, state_dict: dict) -> None: 200 | r""" 201 | Args: 202 | Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the 203 | ema state dict. 204 | state_dict (dict): EMA state. Should be an object returned 205 | from a call to :meth:`state_dict`. 206 | """ 207 | # deepcopy, to be consistent with module API 208 | state_dict = copy.deepcopy(state_dict) 209 | 210 | self.decay = state_dict.get("decay", self.decay) 211 | if self.decay < 0.0 or self.decay > 1.0: 212 | raise ValueError("Decay must be between 0 and 1") 213 | 214 | self.min_decay = state_dict.get("min_decay", self.min_decay) 215 | if not isinstance(self.min_decay, float): 216 | raise ValueError("Invalid min_decay") 217 | 218 | self.optimization_step = state_dict.get("optimization_step", self.optimization_step) 219 | if not isinstance(self.optimization_step, int): 220 | raise ValueError("Invalid optimization_step") 221 | 222 | self.update_after_step = state_dict.get("update_after_step", self.update_after_step) 223 | if not isinstance(self.update_after_step, int): 224 | raise ValueError("Invalid update_after_step") 225 | 226 | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) 227 | if not isinstance(self.use_ema_warmup, bool): 228 | raise ValueError("Invalid use_ema_warmup") 229 | 230 | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) 231 | if not isinstance(self.inv_gamma, (float, int)): 232 | raise ValueError("Invalid inv_gamma") 233 | 234 | self.power = state_dict.get("power", self.power) 235 | if not isinstance(self.power, (float, int)): 236 | raise ValueError("Invalid power") 237 | 238 | shadow_params = state_dict.get("shadow_params", None) 239 | if shadow_params is not None: 240 | self.shadow_params = shadow_params 241 | if not isinstance(self.shadow_params, list): 242 | raise ValueError("shadow_params must be a list") 243 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): 244 | raise ValueError("shadow_params must all be Tensors") 245 | -------------------------------------------------------------------------------- /training/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """PyTorch implementation of the Lion optimizer.""" 16 | import torch 17 | from torch.optim.optimizer import Optimizer 18 | import math 19 | 20 | class Lion(Optimizer): 21 | r"""Implements Lion algorithm.""" 22 | 23 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0, **kwargs): 24 | """Initialize the hyperparameters. 25 | Args: 26 | params (iterable): iterable of parameters to optimize or dicts defining 27 | parameter groups 28 | lr (float, optional): learning rate (default: 1e-4) 29 | betas (Tuple[float, float], optional): coefficients used for computing 30 | running averages of gradient and its square (default: (0.9, 0.99)) 31 | weight_decay (float, optional): weight decay coefficient (default: 0) 32 | """ 33 | 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= betas[0] < 1.0: 37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 40 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 41 | super().__init__(params, defaults) 42 | 43 | @torch.no_grad() 44 | def step(self, closure=None): 45 | """Performs a single optimization step. 46 | Args: 47 | closure (callable, optional): A closure that reevaluates the model 48 | and returns the loss. 49 | Returns: 50 | the loss. 51 | """ 52 | loss = None 53 | if closure is not None: 54 | with torch.enable_grad(): 55 | loss = closure() 56 | 57 | for group in self.param_groups: 58 | for p in group["params"]: 59 | if p.grad is None: 60 | continue 61 | 62 | # Perform stepweight decay 63 | p.data.mul_(1 - group["lr"] * group["weight_decay"]) 64 | 65 | grad = p.grad 66 | state = self.state[p] 67 | # State initialization 68 | if len(state) == 0: 69 | # Exponential moving average of gradient values 70 | state["exp_avg"] = torch.zeros_like(p) 71 | 72 | exp_avg = state["exp_avg"] 73 | beta1, beta2 = group["betas"] 74 | 75 | # Weight update 76 | update = exp_avg * beta1 + grad * (1 - beta1) 77 | p.add_(torch.sign(update), alpha=-group["lr"]) 78 | # Decay the momentum running average coefficient 79 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 80 | 81 | return loss 82 | 83 | class Adafactor(Optimizer): 84 | """Implements Adafactor algorithm. 85 | 86 | This implementation is based on: 87 | `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` 88 | (see https://arxiv.org/abs/1804.04235) 89 | 90 | Note that this optimizer internally adjusts the learning rate 91 | depending on the *scale_parameter*, *relative_step* and 92 | *warmup_init* options. To use a manual (external) learning rate 93 | schedule you should set `scale_parameter=False` and 94 | `relative_step=False`. 95 | 96 | Arguments: 97 | params (iterable): iterable of parameters to optimize or dicts defining 98 | parameter groups 99 | lr (float, optional): external learning rate (default: None) 100 | eps (tuple[float, float]): regularization constans for square gradient 101 | and parameter scale respectively (default: (1e-30, 1e-3)) 102 | clip_threshold (float): threshold of root mean square of 103 | final gradient update (default: 1.0) 104 | decay_rate (float): coefficient used to compute running averages of square 105 | gradient (default: -0.8) 106 | beta1 (float): coefficient used for computing running averages of gradient 107 | (default: None) 108 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 109 | scale_parameter (bool): if True, learning rate is scaled by root mean square of 110 | parameter (default: True) 111 | relative_step (bool): if True, time-dependent learning rate is computed 112 | instead of external learning rate (default: True) 113 | warmup_init (bool): time-dependent learning rate computation depends on 114 | whether warm-up initialization is being used (default: False) 115 | """ 116 | 117 | def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0, 118 | decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, 119 | relative_step=True, warmup_init=False): 120 | if lr is not None and relative_step: 121 | raise ValueError('Cannot combine manual lr and relative_step options') 122 | if warmup_init and not relative_step: 123 | raise ValueError('warmup_init requires relative_step=True') 124 | 125 | defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate, 126 | beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, 127 | relative_step=relative_step, warmup_init=warmup_init) 128 | super(Adafactor, self).__init__(params, defaults) 129 | 130 | @property 131 | def supports_memory_efficient_fp16(self): 132 | return True 133 | 134 | @property 135 | def supports_flat_params(self): 136 | return False 137 | 138 | def _get_lr(self, param_group, param_state): 139 | rel_step_sz = param_group['lr'] 140 | if param_group['relative_step']: 141 | min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 142 | rel_step_sz = min(min_step, 1.0/math.sqrt(param_state['step'])) 143 | param_scale = 1.0 144 | if param_group['scale_parameter']: 145 | param_scale = max(param_group['eps'][1], param_state['RMS']) 146 | return param_scale * rel_step_sz 147 | 148 | def _get_options(self, param_group, param_shape): 149 | factored = len(param_shape) >= 2 150 | use_first_moment = param_group['beta1'] is not None 151 | return factored, use_first_moment 152 | 153 | def _rms(self, tensor): 154 | return tensor.norm(2) / (tensor.numel() ** 0.5) 155 | 156 | def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): 157 | r_factor = ( 158 | exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True) 159 | ).rsqrt_() 160 | c_factor = exp_avg_sq_col.rsqrt() 161 | return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0)) 162 | 163 | def step(self, closure=None): 164 | """Performs a single optimization step. 165 | 166 | Arguments: 167 | closure (callable, optional): A closure that reevaluates the model 168 | and returns the loss. 169 | """ 170 | loss = None 171 | if closure is not None: 172 | loss = closure() 173 | 174 | for group in self.param_groups: 175 | for p in group['params']: 176 | if p.grad is None: 177 | continue 178 | grad = p.grad.data 179 | if grad.dtype in {torch.float16, torch.bfloat16}: 180 | grad = grad.float() 181 | if grad.is_sparse: 182 | raise RuntimeError('Adafactor does not support sparse gradients.') 183 | 184 | state = self.state[p] 185 | grad_shape = grad.shape 186 | 187 | factored, use_first_moment = self._get_options(group, grad_shape) 188 | # State Initialization 189 | if len(state) == 0: 190 | state['step'] = 0 191 | 192 | if use_first_moment: 193 | # Exponential moving average of gradient values 194 | state['exp_avg'] = torch.zeros_like(grad) 195 | if factored: 196 | state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) 197 | state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 198 | else: 199 | state['exp_avg_sq'] = torch.zeros_like(grad) 200 | 201 | state['RMS'] = 0 202 | else: 203 | if use_first_moment: 204 | state['exp_avg'] = state['exp_avg'].to(grad) 205 | if factored: 206 | state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) 207 | state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) 208 | else: 209 | state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) 210 | 211 | p_data_fp32 = p.data 212 | if p.data.dtype in {torch.float16, torch.bfloat16}: 213 | p_data_fp32 = p_data_fp32.float() 214 | 215 | state['step'] += 1 216 | state['RMS'] = self._rms(p_data_fp32) 217 | group['lr'] = self._get_lr(group, state) 218 | 219 | beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) 220 | update = (grad**2) + group['eps'][0] 221 | if factored: 222 | exp_avg_sq_row = state['exp_avg_sq_row'] 223 | exp_avg_sq_col = state['exp_avg_sq_col'] 224 | 225 | exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) 226 | exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) 227 | 228 | # Approximation of exponential moving average of square of gradient 229 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 230 | update.mul_(grad) 231 | else: 232 | exp_avg_sq = state['exp_avg_sq'] 233 | 234 | exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) 235 | update = exp_avg_sq.rsqrt().mul_(grad) 236 | 237 | update.div_( 238 | (self._rms(update) / group['clip_threshold']).clamp_(min=1.0) 239 | ) 240 | update.mul_(group['lr']) 241 | 242 | if use_first_moment: 243 | exp_avg = state['exp_avg'] 244 | exp_avg.mul_(group['beta1']).add_(1 - group['beta1'], update) 245 | update = exp_avg 246 | 247 | if group['weight_decay'] != 0: 248 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 249 | 250 | p_data_fp32.add_(-update) 251 | 252 | if p.data.dtype in {torch.float16, torch.bfloat16}: 253 | p.data.copy_(p_data_fp32) 254 | 255 | return loss -------------------------------------------------------------------------------- /muse/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for diffusion models.""" 16 | 17 | import math 18 | from enum import Enum 19 | from typing import Optional, Union 20 | 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | from .logging import get_logger 25 | 26 | logger = get_logger(__name__) 27 | 28 | 29 | class SchedulerType(Enum): 30 | LINEAR = "linear" 31 | COSINE = "cosine" 32 | COSINE_WITH_RESTARTS = "cosine_with_restarts" 33 | POLYNOMIAL = "polynomial" 34 | CONSTANT = "constant" 35 | CONSTANT_WITH_WARMUP = "constant_with_warmup" 36 | 37 | 38 | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): 39 | """ 40 | Create a schedule with a constant learning rate, using the learning rate set in optimizer. 41 | 42 | Args: 43 | optimizer ([`~torch.optim.Optimizer`]): 44 | The optimizer for which to schedule the learning rate. 45 | last_epoch (`int`, *optional*, defaults to -1): 46 | The index of the last epoch when resuming training. 47 | 48 | Return: 49 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 50 | """ 51 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 52 | 53 | 54 | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): 55 | """ 56 | Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate 57 | increases linearly between 0 and the initial lr set in the optimizer. 58 | 59 | Args: 60 | optimizer ([`~torch.optim.Optimizer`]): 61 | The optimizer for which to schedule the learning rate. 62 | num_warmup_steps (`int`): 63 | The number of steps for the warmup phase. 64 | last_epoch (`int`, *optional*, defaults to -1): 65 | The index of the last epoch when resuming training. 66 | 67 | Return: 68 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 69 | """ 70 | 71 | def lr_lambda(current_step: int): 72 | if current_step < num_warmup_steps: 73 | return float(current_step) / float(max(1.0, num_warmup_steps)) 74 | return 1.0 75 | 76 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 77 | 78 | 79 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 80 | """ 81 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 82 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 83 | 84 | Args: 85 | optimizer ([`~torch.optim.Optimizer`]): 86 | The optimizer for which to schedule the learning rate. 87 | num_warmup_steps (`int`): 88 | The number of steps for the warmup phase. 89 | num_training_steps (`int`): 90 | The total number of training steps. 91 | last_epoch (`int`, *optional*, defaults to -1): 92 | The index of the last epoch when resuming training. 93 | 94 | Return: 95 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 96 | """ 97 | 98 | def lr_lambda(current_step: int): 99 | if current_step < num_warmup_steps: 100 | return float(current_step) / float(max(1, num_warmup_steps)) 101 | return max( 102 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 103 | ) 104 | 105 | return LambdaLR(optimizer, lr_lambda, last_epoch) 106 | 107 | 108 | def get_cosine_schedule_with_warmup( 109 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 110 | ): 111 | """ 112 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 113 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 114 | initial lr set in the optimizer. 115 | 116 | Args: 117 | optimizer ([`~torch.optim.Optimizer`]): 118 | The optimizer for which to schedule the learning rate. 119 | num_warmup_steps (`int`): 120 | The number of steps for the warmup phase. 121 | num_training_steps (`int`): 122 | The total number of training steps. 123 | num_periods (`float`, *optional*, defaults to 0.5): 124 | The number of periods of the cosine function in a schedule (the default is to just decrease from the max 125 | value to 0 following a half-cosine). 126 | last_epoch (`int`, *optional*, defaults to -1): 127 | The index of the last epoch when resuming training. 128 | 129 | Return: 130 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 131 | """ 132 | 133 | def lr_lambda(current_step): 134 | if current_step < num_warmup_steps: 135 | return float(current_step) / float(max(1, num_warmup_steps)) 136 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 137 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 138 | 139 | return LambdaLR(optimizer, lr_lambda, last_epoch) 140 | 141 | 142 | def get_cosine_with_hard_restarts_schedule_with_warmup( 143 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 144 | ): 145 | """ 146 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 147 | initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases 148 | linearly between 0 and the initial lr set in the optimizer. 149 | 150 | Args: 151 | optimizer ([`~torch.optim.Optimizer`]): 152 | The optimizer for which to schedule the learning rate. 153 | num_warmup_steps (`int`): 154 | The number of steps for the warmup phase. 155 | num_training_steps (`int`): 156 | The total number of training steps. 157 | num_cycles (`int`, *optional*, defaults to 1): 158 | The number of hard restarts to use. 159 | last_epoch (`int`, *optional*, defaults to -1): 160 | The index of the last epoch when resuming training. 161 | 162 | Return: 163 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 164 | """ 165 | 166 | def lr_lambda(current_step): 167 | if current_step < num_warmup_steps: 168 | return float(current_step) / float(max(1, num_warmup_steps)) 169 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 170 | if progress >= 1.0: 171 | return 0.0 172 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 173 | 174 | return LambdaLR(optimizer, lr_lambda, last_epoch) 175 | 176 | 177 | def get_polynomial_decay_schedule_with_warmup( 178 | optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 179 | ): 180 | """ 181 | Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the 182 | optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the 183 | initial lr set in the optimizer. 184 | 185 | Args: 186 | optimizer ([`~torch.optim.Optimizer`]): 187 | The optimizer for which to schedule the learning rate. 188 | num_warmup_steps (`int`): 189 | The number of steps for the warmup phase. 190 | num_training_steps (`int`): 191 | The total number of training steps. 192 | lr_end (`float`, *optional*, defaults to 1e-7): 193 | The end LR. 194 | power (`float`, *optional*, defaults to 1.0): 195 | Power factor. 196 | last_epoch (`int`, *optional*, defaults to -1): 197 | The index of the last epoch when resuming training. 198 | 199 | Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT 200 | implementation at 201 | https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 202 | 203 | Return: 204 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 205 | 206 | """ 207 | 208 | lr_init = optimizer.defaults["lr"] 209 | if not (lr_init > lr_end): 210 | raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") 211 | 212 | def lr_lambda(current_step: int): 213 | if current_step < num_warmup_steps: 214 | return float(current_step) / float(max(1, num_warmup_steps)) 215 | elif current_step > num_training_steps: 216 | return lr_end / lr_init # as LambdaLR multiplies by lr_init 217 | else: 218 | lr_range = lr_init - lr_end 219 | decay_steps = num_training_steps - num_warmup_steps 220 | pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps 221 | decay = lr_range * pct_remaining**power + lr_end 222 | return decay / lr_init # as LambdaLR multiplies by lr_init 223 | 224 | return LambdaLR(optimizer, lr_lambda, last_epoch) 225 | 226 | 227 | TYPE_TO_SCHEDULER_FUNCTION = { 228 | SchedulerType.LINEAR: get_linear_schedule_with_warmup, 229 | SchedulerType.COSINE: get_cosine_schedule_with_warmup, 230 | SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, 231 | SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, 232 | SchedulerType.CONSTANT: get_constant_schedule, 233 | SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, 234 | } 235 | 236 | 237 | def get_scheduler( 238 | name: Union[str, SchedulerType], 239 | optimizer: Optimizer, 240 | num_warmup_steps: Optional[int] = None, 241 | num_training_steps: Optional[int] = None, 242 | num_cycles: int = 1, 243 | power: float = 1.0, 244 | ): 245 | """ 246 | Unified API to get any scheduler from its name. 247 | 248 | Args: 249 | name (`str` or `SchedulerType`): 250 | The name of the scheduler to use. 251 | optimizer (`torch.optim.Optimizer`): 252 | The optimizer that will be used during training. 253 | num_warmup_steps (`int`, *optional*): 254 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 255 | optional), the function will raise an error if it's unset and the scheduler type requires it. 256 | num_training_steps (`int``, *optional*): 257 | The number of training steps to do. This is not required by all schedulers (hence the argument being 258 | optional), the function will raise an error if it's unset and the scheduler type requires it. 259 | num_cycles (`int`, *optional*): 260 | The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. 261 | power (`float`, *optional*, defaults to 1.0): 262 | Power factor. See `POLYNOMIAL` scheduler 263 | last_epoch (`int`, *optional*, defaults to -1): 264 | The index of the last epoch when resuming training. 265 | """ 266 | name = SchedulerType(name) 267 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 268 | if name == SchedulerType.CONSTANT: 269 | return schedule_func(optimizer) 270 | 271 | # All other schedulers require `num_warmup_steps` 272 | if num_warmup_steps is None: 273 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 274 | 275 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 276 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) 277 | 278 | # All other schedulers require `num_training_steps` 279 | if num_training_steps is None: 280 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 281 | 282 | if name == SchedulerType.COSINE_WITH_RESTARTS: 283 | return schedule_func( 284 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles 285 | ) 286 | 287 | if name == SchedulerType.POLYNOMIAL: 288 | return schedule_func( 289 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power 290 | ) 291 | 292 | 293 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) 294 | -------------------------------------------------------------------------------- /muse/pipeline_muse_toast.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Union 16 | 17 | import numpy as np 18 | import torch 19 | from PIL import Image 20 | from transformers import ( 21 | AutoTokenizer, 22 | CLIPTextModel, 23 | PreTrainedTokenizer, 24 | T5EncoderModel, 25 | ) 26 | import pdb 27 | 28 | from .modeling_maskgit_vqgan import MaskGitVQGAN 29 | from .modeling_taming_vqgan import VQGANModel 30 | from .modeling_transformer import MaskGitTransformer 31 | from .modeling_transformer_toast import MaskGitTransformerTOAST 32 | from .sampling import cosine_schedule, linear_schedule, sqrt_schedule 33 | 34 | def schedule_func(t, mu=0.8, v=2): 35 | eps = 1e-4 36 | t = t+eps 37 | some = (t*(1-mu)/(mu*(1-t)))**(-v) 38 | return 1-1/(1+some) + eps 39 | 40 | class PipelineMuse: 41 | def __init__( 42 | self, 43 | vae: Union[MaskGitVQGAN, VQGANModel], 44 | transformer: Union[MaskGitTransformer, MaskGitTransformerTOAST], 45 | is_class_conditioned: bool = False, 46 | text_encoder: Optional[Union[T5EncoderModel, CLIPTextModel]] = None, 47 | tokenizer: Optional[PreTrainedTokenizer] = None, 48 | ) -> None: 49 | self.text_encoder = text_encoder 50 | self.tokenizer = tokenizer 51 | self.vae = vae 52 | self.transformer = transformer 53 | self.is_class_conditioned = is_class_conditioned 54 | self.device = "cpu" 55 | 56 | def to(self, device="cpu", dtype=torch.float32) -> None: 57 | if not self.is_class_conditioned: 58 | self.text_encoder.to(device, dtype=dtype) 59 | self.vae.to(device, dtype=dtype) 60 | self.transformer.to(device, dtype=dtype) 61 | self.device = device 62 | self.dtype = dtype 63 | return self 64 | 65 | @torch.no_grad() 66 | def __call__( 67 | self, 68 | text: Optional[Union[str, List[str]]] = None, 69 | class_ids: torch.LongTensor = None, 70 | timesteps: int = 8, 71 | guidance_scale: float = 8.0, 72 | temperature: float = 1.0, 73 | temperature2: float = 0.0, 74 | topk_filter_thres: float = 0.9, 75 | num_images_per_prompt: int = 1, 76 | sampling_type: str = "self_guidance", 77 | return_intermediate: bool = False, 78 | min_c_step:int = 0, 79 | max_c_step:int = 999, 80 | momentum: float = 0.0, 81 | correct_every: int = -1, 82 | threshold: float = 0.0, 83 | use_toast_correct_only: bool = False, 84 | no_temp_for_correct: bool = False, 85 | substitution_rate: float = 0.2, 86 | schedule: str = "cosine", 87 | stop_tdib_step: int = 100, 88 | min_c_ratio=0.0, 89 | guidance_anneal=False, 90 | degub: bool = False, 91 | ): 92 | if text is None and class_ids is None: 93 | raise ValueError("Either text or class_ids must be provided.") 94 | 95 | if text is not None and class_ids is not None: 96 | raise ValueError("Only one of text or class_ids may be provided.") 97 | 98 | if class_ids is not None: 99 | if isinstance(class_ids, int): 100 | class_ids = [class_ids] 101 | 102 | class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long) 103 | # duplicate class ids for each generation per prompt 104 | class_ids = class_ids.repeat_interleave(num_images_per_prompt, dim=0) 105 | elif isinstance(class_ids, torch.Tensor): 106 | class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long) 107 | 108 | model_inputs = {"class_ids": class_ids} 109 | else: 110 | if isinstance(text, str): 111 | text = [text] 112 | 113 | input_ids = self.tokenizer( 114 | text, return_tensors="pt", padding="max_length", truncation=True, max_length=16 115 | ).input_ids # TODO: remove hardcode 116 | input_ids = input_ids.to(self.device) 117 | encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state 118 | 119 | # duplicate text embeddings for each generation per prompt, using mps friendly method 120 | bs_embed, seq_len, _ = encoder_hidden_states.shape 121 | encoder_hidden_states = encoder_hidden_states.repeat(1, num_images_per_prompt, 1) 122 | encoder_hidden_states = encoder_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1) 123 | model_inputs = {"encoder_hidden_states": encoder_hidden_states} 124 | 125 | if sampling_type == 'maskgit': 126 | generate = self.transformer.generate2 127 | elif sampling_type == 'self_guidance': 128 | generate = self.transformer.generate_sg 129 | else: 130 | raise NotImplementedError 131 | 132 | if schedule == "cosine": 133 | noise_schedule = cosine_schedule 134 | elif schedule == "linear": 135 | noise_schedule = linear_schedule 136 | elif schedule == "sqrt": 137 | noise_schedule = sqrt_schedule 138 | elif schedule == "custom": 139 | noise_schedule = schedule_func 140 | else: 141 | raise NotImplementedError 142 | 143 | outputs = generate( 144 | **model_inputs, 145 | timesteps=timesteps, 146 | guidance_scale=guidance_scale, 147 | temperature=temperature, 148 | return_intermediate=return_intermediate, 149 | noise_schedule=noise_schedule, 150 | ) 151 | 152 | if return_intermediate: 153 | generated_tokens, intermediate, mask_index = outputs 154 | else: 155 | generated_tokens = outputs 156 | 157 | images = self.vae.decode_code(generated_tokens) 158 | if return_intermediate: 159 | intermediate_images = [self.vae.decode_code(tokens) for tokens in intermediate] 160 | 161 | # Convert to PIL images 162 | images = [self.to_pil_image(image) for image in images] 163 | if return_intermediate: 164 | intermediate_images = [[self.to_pil_image(image) for image in images] for images in intermediate_images] 165 | # return images, intermediate_images, mask_index 166 | return images, intermediate_images, intermediate, mask_index 167 | return images 168 | 169 | def to_pil_image(self, image: torch.Tensor): 170 | image = image.permute(1, 2, 0).cpu().numpy() 171 | image = 2.0 * image - 1.0 172 | image = np.clip(image, -1.0, 1.0) 173 | image = (image + 1.0) / 2.0 174 | image = (255 * image).astype(np.uint8) 175 | image = Image.fromarray(image).convert("RGB") 176 | return image 177 | 178 | @classmethod 179 | def from_pretrained( 180 | cls, 181 | model_name_or_path: str = None, 182 | text_encoder_path: Optional[str] = None, 183 | vae_path: Optional[str] = None, 184 | transformer_path: Optional[str] = None, 185 | is_class_conditioned: bool = False, 186 | use_toast: bool = False, 187 | use_taming: bool = False, 188 | debug: bool = False, 189 | ) -> None: 190 | """ 191 | Instantiate a PipelineMuse from a pretrained model. Either model_name_or_path or all of text_encoder_path, vae_path, and 192 | transformer_path must be provided. 193 | """ 194 | MaskGitTransformer = MaskGitTransformerTOAST if use_toast else MaskGitTransformer 195 | # if debug: 196 | # from .modeling_transformer_org_correct_debugging import MaskGitCorrTransformerOrg 197 | # MaskGitTransformer = MaskGitCorrTransformerOrg 198 | # print("set debug mode!") 199 | 200 | if model_name_or_path is None: 201 | if vae_path is None or transformer_path is None: 202 | raise ValueError( 203 | "If model_name_or_path is None, then text_encoder_path, vae_path, and transformer_path must be" 204 | " provided." 205 | ) 206 | 207 | text_encoder = None 208 | tokenizer = None 209 | 210 | if not is_class_conditioned: 211 | text_encoder = T5EncoderModel.from_pretrained(text_encoder_path) 212 | tokenizer = AutoTokenizer.from_pretrained(text_encoder_path) 213 | 214 | if use_taming: 215 | vae = VQGANModel.from_pretrained(vae_path) 216 | else: 217 | vae = MaskGitVQGAN.from_pretrained(vae_path) 218 | transformer = MaskGitTransformer.from_pretrained(transformer_path) 219 | else: 220 | text_encoder = None 221 | tokenizer = None 222 | 223 | if not is_class_conditioned: 224 | text_encoder = T5EncoderModel.from_pretrained(model_name_or_path, subfolder="text_encoder") 225 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, subfolder="text_encoder") 226 | if use_taming: 227 | vae = VQGANModel.from_pretrained(model_name_or_path, subfolder="vae") 228 | else: 229 | vae = MaskGitVQGAN.from_pretrained(model_name_or_path, subfolder="vae") 230 | transformer = MaskGitTransformer.from_pretrained(model_name_or_path, subfolder="transformer") 231 | 232 | return cls( 233 | vae=vae, 234 | transformer=transformer, 235 | text_encoder=text_encoder, 236 | tokenizer=tokenizer, 237 | is_class_conditioned=is_class_conditioned, 238 | ) 239 | 240 | @torch.no_grad() 241 | def correct_vq_code( 242 | self, 243 | vq_tokens: torch.Tensor = None, 244 | class_ids: torch.LongTensor = None, 245 | correct_step: int = 1, 246 | ): 247 | """ 248 | # refine the given vq_code with MaskGIT Corrector 249 | # vq_codes: [batch_size, seq_len] 250 | """ 251 | 252 | batch_size, seq_len = vq_tokens.shape 253 | if class_ids is not None: 254 | if isinstance(class_ids, int): 255 | class_ids = [class_ids] 256 | 257 | class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long) 258 | # duplicate class ids for each generation per prompt 259 | class_ids = class_ids.repeat_interleave(batch_size, dim=0) 260 | elif isinstance(class_ids, torch.Tensor): 261 | class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long) 262 | 263 | input_ids = torch.cat([class_ids.unsqueeze(1), vq_tokens], dim=1) 264 | 265 | for i in correct_step: 266 | # classifier free guidance 267 | logits = self(vq_tokens, encoder_hidden_states=encoder_hidden_states) 268 | logits = logits[..., : self.config.codebook_size] 269 | 270 | # remove class token 271 | if class_ids is not None: 272 | input_ids = input_ids[:, 1:] 273 | logits = logits[:, 1:] 274 | 275 | # add gumbel noise for correction for more diversity 276 | # logits = logits + temperature * gumbel_noise(logits) 277 | # logits = logits + 1 * gumbel_noise(logits) 278 | unknown_map = input_ids == mask_token_id 279 | 280 | # Samples the ids using categorical sampling: [batch_size, seq_length]. 281 | # sampled_ids = torch.stack([torch.multinomial(l.softmax(dim=-1), 1).squeeze(1) for l in logits]) 282 | sampled_ids = gumbel_sample(logits, temperature=1.0 * (1.0 - ratio)) 283 | 284 | # Replace the input_ids with highest probability tokens 285 | selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) 286 | selected_probs = selected_probs.squeeze(-1) 287 | # Mark the masked region so that it is not replaced 288 | selected_probs = torch.where(unknown_map, torch.finfo(selected_probs.dtype).min, selected_probs) 289 | 290 | num_unmasked = (~unknown_map).sum(dim=1) 291 | # num_substitute = (num_unmasked * substitution_rate * (1.0 - ratio)).round().clamp(min=1)[..., None] 292 | num_substitute = (num_unmasked * substitution_rate).round().clamp(min=1)[..., None] 293 | 294 | # Negation of num_substitute means select only num_substitute tokens using mask_by_random_topk 295 | substitute_mask = mask_by_random_topk(seq_len-num_substitute, selected_probs, temperature=0.0) 296 | 297 | input_ids = torch.where(substitute_mask, input_ids, sampled_ids) 298 | 299 | if return_intermediate: 300 | return sampled_ids, intermediate, mask_index 301 | return sampled_ids 302 | 303 | @torch.no_grad() 304 | def transfer( 305 | self, 306 | text: Optional[Union[str, List[str]]] = None, 307 | class_ids: torch.LongTensor = None, 308 | timesteps: int = 8, 309 | guidance_scale: float = 8.0, 310 | temperature: float = 1.0, 311 | num_images_per_prompt: int = 1, 312 | sampling_type: str = "maskgit", 313 | return_intermediate: bool = False, 314 | schedule: str = "cosine", 315 | ): 316 | if text is None and class_ids is None: 317 | raise ValueError("Either text or class_ids must be provided.") 318 | 319 | if text is not None and class_ids is not None: 320 | raise ValueError("Only one of text or class_ids may be provided.") 321 | 322 | if class_ids is not None: 323 | if isinstance(class_ids, int): 324 | class_ids = [class_ids] 325 | 326 | class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long) 327 | # duplicate class ids for each generation per prompt 328 | class_ids = class_ids.repeat_interleave(num_images_per_prompt, dim=0) 329 | elif isinstance(class_ids, torch.Tensor): 330 | class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long) 331 | 332 | model_inputs = {"class_ids": class_ids} 333 | else: 334 | if isinstance(text, str): 335 | text = [text] 336 | 337 | input_ids = self.tokenizer( 338 | text, return_tensors="pt", padding="max_length", truncation=True, max_length=16 339 | ).input_ids # TODO: remove hardcode 340 | input_ids = input_ids.to(self.device) 341 | encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state 342 | 343 | # duplicate text embeddings for each generation per prompt, using mps friendly method 344 | bs_embed, seq_len, _ = encoder_hidden_states.shape 345 | encoder_hidden_states = encoder_hidden_states.repeat(1, num_images_per_prompt, 1) 346 | encoder_hidden_states = encoder_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1) 347 | model_inputs = {"encoder_hidden_states": encoder_hidden_states} 348 | 349 | if sampling_type == 'maskgit': 350 | generate = self.transformer.generate2 351 | elif sampling_type == 'sg': 352 | generate = self.transformer.generate_sg 353 | else: 354 | raise NotImplementedError 355 | 356 | if schedule == "cosine": 357 | noise_schedule = cosine_schedule 358 | elif schedule == "linear": 359 | noise_schedule = linear_schedule 360 | elif schedule == "sqrt": 361 | noise_schedule = sqrt_schedule 362 | elif schedule == "custom": 363 | noise_schedule = schedule_func 364 | else: 365 | raise NotImplementedError 366 | 367 | outputs = generate( 368 | **model_inputs, 369 | timesteps=timesteps, 370 | guidance_scale=guidance_scale, 371 | temperature=temperature, 372 | return_intermediate=return_intermediate, 373 | noise_schedule=noise_schedule, 374 | ) 375 | # import pdb 376 | # pdb.set_trace() 377 | 378 | if return_intermediate: 379 | generated_tokens, intermediate, mask_index = outputs 380 | else: 381 | generated_tokens = outputs 382 | 383 | images = self.vae.decode_code(generated_tokens) 384 | if return_intermediate: 385 | intermediate_images = [self.vae.decode_code(tokens) for tokens in intermediate] 386 | 387 | # Convert to PIL images 388 | images = [self.to_pil_image(image) for image in images] 389 | if return_intermediate: 390 | intermediate_images = [[self.to_pil_image(image) for image in images] for images in intermediate_images] 391 | # return images, intermediate_images, mask_index 392 | return images, intermediate_images, intermediate, mask_index 393 | return images -------------------------------------------------------------------------------- /muse/modeling_maskgit_vqgan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC and The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""MaskGIT Tokenizer based on VQGAN. 16 | 17 | This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841] 18 | with several modifications. The non-local layers are removed from VQGAN for 19 | faster speed. 20 | """ 21 | 22 | import math 23 | from typing import Tuple 24 | 25 | import torch 26 | import torch.nn.functional as F 27 | from torch import nn 28 | 29 | from .modeling_utils import ConfigMixin, ModelMixin, register_to_config 30 | 31 | 32 | # Conv2D with same padding 33 | class Conv2dSame(nn.Conv2d): 34 | def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: 35 | return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | ih, iw = x.size()[-2:] 39 | 40 | pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) 41 | pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) 42 | 43 | if pad_h > 0 or pad_w > 0: 44 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 45 | return super().forward(x) 46 | 47 | 48 | class ResnetBlock(nn.Module): 49 | def __init__( 50 | self, 51 | in_channels: int, 52 | out_channels: int = None, 53 | dropout_prob: float = 0.0, 54 | ): 55 | super().__init__() 56 | 57 | self.in_channels = in_channels 58 | self.out_channels = out_channels 59 | self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels 60 | 61 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 62 | self.conv1 = Conv2dSame(self.in_channels, self.out_channels_, kernel_size=3, bias=False) 63 | 64 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True) 65 | self.dropout = nn.Dropout(dropout_prob) 66 | self.conv2 = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=3, bias=False) 67 | 68 | if self.in_channels != self.out_channels_: 69 | self.nin_shortcut = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=1, bias=False) 70 | 71 | def forward(self, hidden_states): 72 | residual = hidden_states 73 | hidden_states = self.norm1(hidden_states) 74 | hidden_states = F.silu(hidden_states) 75 | hidden_states = self.conv1(hidden_states) 76 | 77 | hidden_states = self.norm2(hidden_states) 78 | hidden_states = F.silu(hidden_states) 79 | hidden_states = self.dropout(hidden_states) 80 | hidden_states = self.conv2(hidden_states) 81 | 82 | if self.in_channels != self.out_channels_: 83 | residual = self.nin_shortcut(hidden_states) 84 | 85 | return hidden_states + residual 86 | 87 | 88 | class DownsamplingBlock(nn.Module): 89 | def __init__(self, config, block_idx: int): 90 | super().__init__() 91 | 92 | self.config = config 93 | self.block_idx = block_idx 94 | 95 | in_channel_mult = (1,) + tuple(self.config.channel_mult) 96 | block_in = self.config.hidden_channels * in_channel_mult[self.block_idx] 97 | block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] 98 | 99 | res_blocks = nn.ModuleList() 100 | for _ in range(self.config.num_res_blocks): 101 | res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)) 102 | block_in = block_out 103 | self.block = res_blocks 104 | 105 | self.downsample = self.block_idx != self.config.num_resolutions - 1 106 | 107 | def forward(self, hidden_states): 108 | for res_block in self.block: 109 | hidden_states = res_block(hidden_states) 110 | 111 | if self.downsample: 112 | hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2) 113 | 114 | return hidden_states 115 | 116 | 117 | class UpsamplingBlock(nn.Module): 118 | def __init__(self, config, block_idx: int): 119 | super().__init__() 120 | 121 | self.config = config 122 | self.block_idx = block_idx 123 | 124 | if self.block_idx == self.config.num_resolutions - 1: 125 | block_in = self.config.hidden_channels * self.config.channel_mult[-1] 126 | else: 127 | block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1] 128 | 129 | block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] 130 | 131 | res_blocks = [] 132 | for _ in range(self.config.num_res_blocks): 133 | res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)) 134 | block_in = block_out 135 | self.block = nn.ModuleList(res_blocks) 136 | 137 | self.add_upsample = self.block_idx != 0 138 | if self.add_upsample: 139 | self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3) 140 | 141 | def forward(self, hidden_states): 142 | for res_block in self.block: 143 | hidden_states = res_block(hidden_states) 144 | 145 | if self.add_upsample: 146 | hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") 147 | hidden_states = self.upsample_conv(hidden_states) 148 | 149 | return hidden_states 150 | 151 | 152 | class Encoder(nn.Module): 153 | def __init__(self, config): 154 | super().__init__() 155 | self.config = config 156 | # downsampling 157 | self.conv_in = Conv2dSame(self.config.num_channels, self.config.hidden_channels, kernel_size=3, bias=False) 158 | 159 | downsample_blocks = [] 160 | for i_level in range(self.config.num_resolutions): 161 | downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level)) 162 | self.down = nn.ModuleList(downsample_blocks) 163 | 164 | # middle 165 | mid_channels = self.config.hidden_channels * self.config.channel_mult[-1] 166 | res_blocks = nn.ModuleList() 167 | for _ in range(self.config.num_res_blocks): 168 | res_blocks.append(ResnetBlock(mid_channels, mid_channels, dropout_prob=self.config.dropout)) 169 | self.mid = res_blocks 170 | 171 | # end 172 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True) 173 | self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1) 174 | 175 | def forward(self, pixel_values): 176 | # downsampling 177 | hidden_states = self.conv_in(pixel_values) 178 | for block in self.down: 179 | hidden_states = block(hidden_states) 180 | 181 | # middle 182 | for block in self.mid: 183 | hidden_states = block(hidden_states) 184 | 185 | # end 186 | hidden_states = self.norm_out(hidden_states) 187 | hidden_states = F.silu(hidden_states) 188 | hidden_states = self.conv_out(hidden_states) 189 | return hidden_states 190 | 191 | 192 | class Decoder(nn.Module): 193 | def __init__(self, config): 194 | super().__init__() 195 | 196 | self.config = config 197 | 198 | # compute in_channel_mult, block_in and curr_res at lowest res 199 | block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1] 200 | curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1) 201 | self.z_shape = (1, self.config.z_channels, curr_res, curr_res) 202 | 203 | # z to block_in 204 | self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3) 205 | 206 | # middle 207 | res_blocks = nn.ModuleList() 208 | for _ in range(self.config.num_res_blocks): 209 | res_blocks.append(ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout)) 210 | self.mid = res_blocks 211 | 212 | # upsampling 213 | upsample_blocks = [] 214 | for i_level in reversed(range(self.config.num_resolutions)): 215 | upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level)) 216 | self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order 217 | 218 | # end 219 | block_out = self.config.hidden_channels * self.config.channel_mult[0] 220 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True) 221 | self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3) 222 | 223 | def forward(self, hidden_states): 224 | # z to block_in 225 | hidden_states = self.conv_in(hidden_states) 226 | 227 | # middle 228 | for block in self.mid: 229 | hidden_states = block(hidden_states) 230 | 231 | # upsampling 232 | for block in reversed(self.up): 233 | hidden_states = block(hidden_states) 234 | 235 | # end 236 | hidden_states = self.norm_out(hidden_states) 237 | hidden_states = F.silu(hidden_states) 238 | hidden_states = self.conv_out(hidden_states) 239 | 240 | return hidden_states 241 | 242 | 243 | class VectorQuantizer(nn.Module): 244 | """ 245 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 246 | Discretization bottleneck part of the VQ-VAE. 247 | """ 248 | 249 | def __init__(self, num_embeddings, embedding_dim, commitment_cost): 250 | r""" 251 | Args: 252 | num_embeddings: number of vectors in the quantized space. 253 | embedding_dim: dimensionality of the tensors in the quantized space. 254 | Inputs to the modules must be in this format as well. 255 | commitment_cost: scalar which controls the weighting of the loss terms 256 | (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta). 257 | """ 258 | super().__init__() 259 | 260 | self.num_embeddings = num_embeddings 261 | self.embedding_dim = embedding_dim 262 | self.commitment_cost = commitment_cost 263 | 264 | self.embedding = nn.Embedding(num_embeddings, embedding_dim) 265 | self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings) 266 | 267 | def forward(self, hidden_states, return_loss=False): 268 | """ 269 | Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the 270 | closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width) 271 | quantization pipeline: 272 | 1. get encoder input (B,C,H,W) 273 | 2. flatten input to (B*H*W,C) 274 | """ 275 | # reshape z -> (batch, height, width, channel) and flatten 276 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() 277 | 278 | distances = self.compute_distances(hidden_states) 279 | min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1) 280 | min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states) 281 | min_encodings.scatter_(1, min_encoding_indices, 1) 282 | 283 | # get quantized latent vectors 284 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape) 285 | 286 | # reshape to (batch, num_tokens) 287 | min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) 288 | 289 | # compute loss for embedding 290 | loss = None 291 | if return_loss: 292 | loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean( 293 | (z_q - hidden_states.detach()) ** 2 294 | ) 295 | # preserve gradients 296 | z_q = hidden_states + (z_q - hidden_states).detach() 297 | 298 | # reshape back to match original input shape 299 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 300 | 301 | return z_q, min_encoding_indices, loss 302 | 303 | def compute_distances(self, hidden_states): 304 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 305 | hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim)) 306 | emb_weights = self.embedding.weight.t() 307 | 308 | inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True) 309 | codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True) 310 | distances = torch.addmm( 311 | inputs_norm_sq + codebook_t_norm_sq, 312 | hidden_states_flattended, 313 | emb_weights, 314 | alpha=-2.0, 315 | ) 316 | return distances 317 | 318 | def get_codebook_entry(self, indices): 319 | # indices are expected to be of shape (batch, num_tokens) 320 | # get quantized latent vectors 321 | batch, num_tokens = indices.shape 322 | z_q = self.embedding(indices) 323 | z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2) 324 | return z_q 325 | 326 | # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372 327 | def get_soft_code(self, hidden_states, temp=1.0, stochastic=False): 328 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel) 329 | distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings) 330 | 331 | soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings) 332 | if stochastic: 333 | code = torch.multinomial(soft_code, 1) # (batch * height * width, 1) 334 | else: 335 | code = distances.argmin(dim=-1) # (batch * height * width) 336 | 337 | code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width) 338 | batch, num_tokens = code.shape 339 | soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings) 340 | return soft_code, code 341 | 342 | def get_code(self, hidden_states): 343 | # reshape z -> (batch, height, width, channel) 344 | hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() 345 | distances = self.compute_distances(hidden_states) 346 | indices = torch.argmin(distances, axis=1).unsqueeze(1) 347 | indices = indices.reshape(hidden_states.shape[0], -1) 348 | return indices 349 | 350 | 351 | class MaskGitVQGAN(ModelMixin, ConfigMixin): 352 | @register_to_config 353 | def __init__( 354 | self, 355 | resolution: int = 256, 356 | num_channels: int = 3, 357 | hidden_channels: int = 128, 358 | channel_mult: Tuple = (1, 1, 2, 2, 4), 359 | num_res_blocks: int = 2, 360 | attn_resolutions: int = (16,), 361 | z_channels: int = 256, 362 | num_embeddings: int = 1024, 363 | quantized_embed_dim: int = 256, 364 | dropout: float = 0.0, 365 | resample_with_conv: bool = True, 366 | commitment_cost: float = 0.25, 367 | ): 368 | super().__init__() 369 | 370 | self.config.num_resolutions = len(channel_mult) 371 | self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1) 372 | self.config.latent_size = resolution // self.config.reduction_factor 373 | 374 | self.encoder = Encoder(self.config) 375 | self.decoder = Decoder(self.config) 376 | self.quantize = VectorQuantizer( 377 | self.config.num_embeddings, self.config.quantized_embed_dim, self.config.commitment_cost 378 | ) 379 | 380 | def encode(self, pixel_values, return_loss=False): 381 | hidden_states = self.encoder(pixel_values) 382 | quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss) 383 | output = (quantized_states, codebook_indices) 384 | if return_loss: 385 | output = output + (codebook_loss,) 386 | return output 387 | 388 | def decode(self, quantized_states): 389 | reconstructed_pixel_values = self.decoder(quantized_states) 390 | return reconstructed_pixel_values 391 | 392 | def decode_code(self, codebook_indices): 393 | quantized_states = self.quantize.get_codebook_entry(codebook_indices) 394 | reconstructed_pixel_values = self.decode(quantized_states) 395 | return reconstructed_pixel_values 396 | 397 | def get_soft_code(self, pixel_values, temp=1.0, stochastic=False): 398 | hidden_states = self.encoder(pixel_values) 399 | soft_code, codebook_indices = self.quantize.get_soft_code(hidden_states, temp=temp, stochastic=stochastic) 400 | return soft_code, codebook_indices 401 | 402 | def get_code(self, pixel_values): 403 | hidden_states = self.encoder(pixel_values) 404 | codebook_indices = self.quantize.get_code(hidden_states) 405 | return codebook_indices 406 | 407 | def forward(self, pixel_values, return_loss=False): 408 | hidden_states = self.encoder(pixel_values) 409 | quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss) 410 | reconstructed_pixel_values = self.decode(quantized_states) 411 | outputs = (reconstructed_pixel_values, quantized_states, codebook_indices) 412 | if return_loss: 413 | outputs = outputs + (codebook_loss,) 414 | return outputs 415 | -------------------------------------------------------------------------------- /muse/training_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | import os 18 | import random 19 | from typing import Any, Dict, Iterable, Optional, Union 20 | 21 | import numpy as np 22 | import pandas as pd 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | 27 | def enable_full_determinism(seed: int): 28 | """ 29 | Helper function for reproducible behavior during distributed training. See 30 | - https://pytorch.org/docs/stable/notes/randomness.html for pytorch 31 | """ 32 | # set seed first 33 | set_seed(seed) 34 | 35 | # Enable PyTorch deterministic mode. This potentially requires either the environment 36 | # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, 37 | # depending on the CUDA version, so we set them both here 38 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 39 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" 40 | torch.use_deterministic_algorithms(True) 41 | 42 | # Enable CUDNN deterministic mode 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | 47 | def set_seed(seed: int): 48 | """ 49 | Args: 50 | Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. 51 | seed (`int`): The seed to set. 52 | """ 53 | random.seed(seed) 54 | np.random.seed(seed) 55 | torch.manual_seed(seed) 56 | torch.cuda.manual_seed_all(seed) 57 | # ^^ safe to call this function even if cuda is not available 58 | 59 | 60 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 61 | class EMA: 62 | """ 63 | Exponential Moving Average of models weights 64 | """ 65 | 66 | def __init__( 67 | self, 68 | parameters: Iterable[torch.nn.Parameter], 69 | decay: float = 0.9999, 70 | min_decay: float = 0.0, 71 | update_after_step: int = 0, 72 | use_ema_warmup: bool = False, 73 | inv_gamma: Union[float, int] = 1.0, 74 | power: Union[float, int] = 2 / 3, 75 | model_cls: Optional[Any] = None, 76 | model_config: Dict[str, Any] = None, 77 | **kwargs, 78 | ): 79 | """ 80 | Args: 81 | parameters (Iterable[torch.nn.Parameter]): The parameters to track. 82 | decay (float): The decay factor for the exponential moving average. 83 | min_decay (float): The minimum decay factor for the exponential moving average. 84 | update_after_step (int): The number of steps to wait before starting to update the EMA weights. 85 | use_ema_warmup (bool): Whether to use EMA warmup. 86 | inv_gamma (float): 87 | Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. 88 | power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. 89 | device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA 90 | weights will be stored on CPU. 91 | 92 | @crowsonkb's notes on EMA Warmup: 93 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 94 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 95 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 96 | at 215.4k steps). 97 | """ 98 | 99 | parameters = list(parameters) 100 | self.shadow_params = [p.clone().detach() for p in parameters] 101 | 102 | self.temp_stored_params = None 103 | 104 | self.decay = decay 105 | self.min_decay = min_decay 106 | self.update_after_step = update_after_step 107 | self.use_ema_warmup = use_ema_warmup 108 | self.inv_gamma = inv_gamma 109 | self.power = power 110 | self.optimization_step = 0 111 | self.cur_decay_value = None # set in `step()` 112 | 113 | self.model_cls = model_cls 114 | self.model_config = model_config 115 | 116 | @classmethod 117 | def from_pretrained(cls, path, model_cls) -> "EMA": 118 | _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) 119 | model = model_cls.from_pretrained(path) 120 | 121 | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) 122 | 123 | ema_model.load_state_dict(ema_kwargs) 124 | return ema_model 125 | 126 | def save_pretrained(self, path): 127 | if self.model_cls is None: 128 | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") 129 | 130 | if self.model_config is None: 131 | raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") 132 | 133 | model = self.model_cls.from_config(self.model_config) 134 | state_dict = self.state_dict() 135 | state_dict.pop("shadow_params", None) 136 | 137 | model.register_to_config(**state_dict) 138 | self.copy_to(model.parameters()) 139 | model.save_pretrained(path) 140 | 141 | def get_decay(self, optimization_step: int) -> float: 142 | """ 143 | Compute the decay factor for the exponential moving average. 144 | """ 145 | step = max(0, optimization_step - self.update_after_step - 1) 146 | 147 | if step <= 0: 148 | return 0.0 149 | 150 | if self.use_ema_warmup: 151 | cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power 152 | else: 153 | cur_decay_value = (1 + step) / (10 + step) 154 | 155 | cur_decay_value = min(cur_decay_value, self.decay) 156 | # make sure decay is not smaller than min_decay 157 | cur_decay_value = max(cur_decay_value, self.min_decay) 158 | return cur_decay_value 159 | 160 | @torch.no_grad() 161 | def step(self, parameters: Iterable[torch.nn.Parameter]): 162 | parameters = list(parameters) 163 | 164 | self.optimization_step += 1 165 | 166 | # Compute the decay factor for the exponential moving average. 167 | decay = self.get_decay(self.optimization_step) 168 | self.cur_decay_value = decay 169 | one_minus_decay = 1 - decay 170 | 171 | for s_param, param in zip(self.shadow_params, parameters): 172 | if param.requires_grad: 173 | s_param.sub_(one_minus_decay * (s_param - param)) 174 | else: 175 | s_param.copy_(param) 176 | 177 | torch.cuda.empty_cache() 178 | 179 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 180 | """ 181 | Copy current averaged parameters into given collection of parameters. 182 | 183 | Args: 184 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 185 | updated with the stored moving averages. If `None`, the parameters with which this 186 | `ExponentialMovingAverage` was initialized will be used. 187 | """ 188 | parameters = list(parameters) 189 | for s_param, param in zip(self.shadow_params, parameters): 190 | param.data.copy_(s_param.to(param.device).data) 191 | 192 | def to(self, device=None, dtype=None) -> None: 193 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 194 | 195 | Args: 196 | device: like `device` argument to `torch.Tensor.to` 197 | """ 198 | # .to() on the tensors handles None correctly 199 | self.shadow_params = [ 200 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 201 | for p in self.shadow_params 202 | ] 203 | 204 | def state_dict(self) -> dict: 205 | r""" 206 | Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during 207 | checkpointing to save the ema state dict. 208 | """ 209 | # Following PyTorch conventions, references to tensors are returned: 210 | # "returns a reference to the state and not its copy!" - 211 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict 212 | return { 213 | "decay": self.decay, 214 | "min_decay": self.min_decay, 215 | "optimization_step": self.optimization_step, 216 | "update_after_step": self.update_after_step, 217 | "use_ema_warmup": self.use_ema_warmup, 218 | "inv_gamma": self.inv_gamma, 219 | "power": self.power, 220 | "shadow_params": self.shadow_params, 221 | } 222 | 223 | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: 224 | r""" 225 | Args: 226 | Save the current parameters for restoring later. 227 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 228 | temporarily stored. 229 | """ 230 | self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] 231 | 232 | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: 233 | r""" 234 | Args: 235 | Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: 236 | affecting the original optimization process. Store the parameters before the `copy_to()` method. After 237 | validation (or model saving), use this to restore the former parameters. 238 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 239 | updated with the stored parameters. If `None`, the parameters with which this 240 | `ExponentialMovingAverage` was initialized will be used. 241 | """ 242 | if self.temp_stored_params is None: 243 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") 244 | for c_param, param in zip(self.temp_stored_params, parameters): 245 | param.data.copy_(c_param.data) 246 | 247 | # Better memory-wise. 248 | self.temp_stored_params = None 249 | 250 | def load_state_dict(self, state_dict: dict) -> None: 251 | r""" 252 | Args: 253 | Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the 254 | ema state dict. 255 | state_dict (dict): EMA state. Should be an object returned 256 | from a call to :meth:`state_dict`. 257 | """ 258 | # deepcopy, to be consistent with module API 259 | state_dict = copy.deepcopy(state_dict) 260 | 261 | self.decay = state_dict.get("decay", self.decay) 262 | if self.decay < 0.0 or self.decay > 1.0: 263 | raise ValueError("Decay must be between 0 and 1") 264 | 265 | self.min_decay = state_dict.get("min_decay", self.min_decay) 266 | if not isinstance(self.min_decay, float): 267 | raise ValueError("Invalid min_decay") 268 | 269 | self.optimization_step = state_dict.get("optimization_step", self.optimization_step) 270 | if not isinstance(self.optimization_step, int): 271 | raise ValueError("Invalid optimization_step") 272 | 273 | self.update_after_step = state_dict.get("update_after_step", self.update_after_step) 274 | if not isinstance(self.update_after_step, int): 275 | raise ValueError("Invalid update_after_step") 276 | 277 | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) 278 | if not isinstance(self.use_ema_warmup, bool): 279 | raise ValueError("Invalid use_ema_warmup") 280 | 281 | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) 282 | if not isinstance(self.inv_gamma, (float, int)): 283 | raise ValueError("Invalid inv_gamma") 284 | 285 | self.power = state_dict.get("power", self.power) 286 | if not isinstance(self.power, (float, int)): 287 | raise ValueError("Invalid power") 288 | 289 | shadow_params = state_dict.get("shadow_params", None) 290 | if shadow_params is not None: 291 | self.shadow_params = shadow_params 292 | if not isinstance(self.shadow_params, list): 293 | raise ValueError("shadow_params must be a list") 294 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): 295 | raise ValueError("shadow_params must all be Tensors") 296 | 297 | 298 | # calculates entropy over each pixel distribution 299 | def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): 300 | # only calculated entropy over image tokens that were masked in the original image 301 | masked_tokens = input_ids == mask_id 302 | num_masked_pixels = masked_tokens.sum(-1) 303 | 304 | probs = F.softmax(logits, dim=-1) 305 | log_probs = F.log_softmax(logits, dim=-1) 306 | 307 | entropy_per_pixel = -((probs * log_probs).sum(-1)) 308 | 309 | # the predictions for non-masked aren't used, so set their entropies to zero 310 | entropy_per_pixel[~masked_tokens] = 0 311 | 312 | entropy_per_image_numerator = entropy_per_pixel.sum(-1) 313 | entropy_per_image = entropy_per_image_numerator / num_masked_pixels 314 | 315 | total_buckets = 10 316 | masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) 317 | 318 | entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets) 319 | 320 | return entropy_by_masked_bucket 321 | 322 | 323 | # calculates entropy over the averaged distribution of pixels for the whole image 324 | def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): 325 | # only calculated entropy over image tokens that were masked in the original image 326 | masked_tokens = input_ids == mask_id 327 | num_masked_pixels = masked_tokens.sum(-1, keepdim=True) 328 | 329 | pixel_probs = F.softmax(logits, dim=-1) 330 | pixel_probs[~masked_tokens] = 0 331 | image_probs_numerator = pixel_probs.sum(-2) 332 | image_probs = image_probs_numerator / num_masked_pixels 333 | 334 | image_log_probs = image_probs.log() 335 | 336 | entropy_per_image = -((image_probs * image_log_probs).sum(-1)) 337 | 338 | total_buckets = 10 339 | masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) 340 | 341 | entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets) 342 | 343 | return entropy_by_masked_bucket 344 | 345 | 346 | def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing): 347 | cross_entropy_per_image = F.cross_entropy( 348 | logits.view(-1, output_size), 349 | labels.view(-1), 350 | ignore_index=-100, 351 | label_smoothing=label_smoothing, 352 | reduction="none", 353 | ) 354 | 355 | total_buckets = 10 356 | masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) 357 | 358 | cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets) 359 | 360 | return cross_entropy_by_percent_masked_bucket 361 | 362 | 363 | def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id): 364 | probs = F.softmax(logits, dim=-1) 365 | 366 | total_buckets = 10 367 | masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) 368 | 369 | data = [] 370 | 371 | for bucket_idx in range(total_buckets): 372 | indices_for_bucket = masked_buckets[masked_buckets == bucket_idx] 373 | 374 | # It's ok if none were noised in the range of this bucket. This 375 | # function will be called for a later training step where it's likely 376 | # there will be an element noised in the range. 377 | if indices_for_bucket.shape[0] == 0: 378 | continue 379 | 380 | index_for_bucket = indices_for_bucket[0] 381 | 382 | image_probs = probs[index_for_bucket] 383 | 384 | # find the index of a masked pixel for the image 385 | input_ids_for_image = input_ids[index_for_bucket] 386 | masked_pixels_probs = image_probs[input_ids_for_image == mask_id] 387 | 388 | masked_pixel_probs = masked_pixels_probs[0] 389 | 390 | masked_pixel_probs = masked_pixel_probs.cpu().numpy() 391 | 392 | for masked_pixel_prob in masked_pixel_probs: 393 | data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob}) 394 | 395 | df = pd.DataFrame(data) 396 | 397 | return df 398 | 399 | 400 | def average_by_buckets(values, masked_buckets, total_buckets): 401 | unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True) 402 | 403 | numerator = torch.zeros(total_buckets, device=values.device) 404 | 405 | numerator.scatter_add_(0, masked_buckets, values) 406 | 407 | # default value is one because the buckets for which there aren't 408 | # any values will have a numerator of zero. So we just need to not divide 409 | # by zero. 410 | denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long) 411 | denominator[unique_buckets] = bucket_counts 412 | 413 | averaged_by_buckets = numerator / denominator 414 | 415 | return averaged_by_buckets 416 | 417 | 418 | def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10): 419 | assert total_buckets == 10 420 | 421 | masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1] 422 | 423 | # we do not formally use timesteps to noise images. Instead, we mask a percent 424 | # of the pixels. We don't want to log entropy for every mask percent between 0 and 1, 425 | # and we also want to track how the entropy evolves over time w/in a range of mask 426 | # percents that should have similar entropy. So we bucket the masked percents into a 427 | # fixed number of buckets 428 | 429 | # we could generalize this later if needed but for now, let's just assume a fixed 430 | # number of 10 buckets. 431 | 432 | # How this maps to a bucket index: 433 | # (mask) * bucket_index + 434 | # (mask_1) * bucket_index_1 435 | # 436 | # -> Where the mask is true will be set to the expected bucket index, 437 | # where the mask is false will be set to 0. 438 | # 439 | # Given the probabilities are between 0 and 1, each masked_percent will get mapped 440 | # to a timestep by one and only one of the masks. 441 | 442 | masked_buckets = ( 443 | ((0 < masked_percent) & (masked_percent <= 0.1)) * 0 444 | + ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1 445 | + ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2 446 | + ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3 447 | + ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4 448 | + ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5 449 | + ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6 450 | + ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7 451 | + ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8 452 | + ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9 453 | ) 454 | 455 | return masked_buckets 456 | -------------------------------------------------------------------------------- /scripts/pre_encode.py: -------------------------------------------------------------------------------- 1 | # This script is used to pre encode coyo, laion 6a, and laion 5a. 2 | # 3 | # It can be run as both a standalone job or via slurm. When run via slurm, be 4 | # sure to pass `--slurm` so the script can split shards amongst workers based on 5 | # the env vars `$SLURM_NTASKS` and `$SLURM_PROCID`. It is intended that one copy 6 | # of the script is launched per gpu, and cpu access is controlled implicitly 7 | # through slurm setting `$CUDA_VISIBLE_DEVICES`. See 8 | # ../slurm_scrips/{pre_encoded_laion_6, pre_encode_laion_5, 9 | # pre_encode_coyo}.slurm for example sbatch scripts. 10 | # 11 | # Benchmarks: 12 | # COYO) 64.1 GPU * sec / shard 13 | # laion) 75 GPU * sec / shard 14 | # 15 | # To convert a time per shard into a time to convert the 16 | # whole dataset, use 17 | # X (GPU * sec / shard) * Y shards * 1/8 (nodes/GPU) * 1/Z nodes = seconds to encode Y shards 18 | # 19 | # Shard counts: 20 | # COYO) 74,752 shards (0-74,751) 21 | # laion 6a) 1,211 shards (0 - 1,210) 22 | # laion 5a) 60,581 shards (0 - 60,580) 23 | # 24 | # Encoding times using 8 nodes: 25 | # COYO) 20h48m 26 | # laion 6a) 23.4 minutes 27 | # laion 5a) 19h43m 28 | 29 | import argparse 30 | import concurrent.futures 31 | import logging 32 | import os 33 | import re 34 | from collections import OrderedDict 35 | from threading import Lock 36 | 37 | import numpy as np 38 | import torch 39 | import torchvision.transforms.functional as TF 40 | import webdataset as wds 41 | from torch.utils.data import DataLoader 42 | from torchvision.transforms import InterpolationMode 43 | from transformers import CLIPTextModel, CLIPTokenizerFast 44 | 45 | from muse import PaellaVQModel, VQGANModel 46 | 47 | torch.set_float32_matmul_precision("high") 48 | torch.set_grad_enabled(False) 49 | 50 | PAELLA_F8_VQVAE = "openMUSE/paellavq-f8-8192-laion" 51 | VQGAN_F16_VQVAE = "openMUSE/vqgan-f16-8192-laion" 52 | CLIP = "openMUSE/CLIP-ViT-L-14-DataComp.XL-s13B-b90K-penultimate" 53 | 54 | PAELLA_F8_VQVAE_EXT = f"{'.'.join(PAELLA_F8_VQVAE.split('/'))}.pth" 55 | VQGAN_F16_VQVAE_EXT = f"{'.'.join(VQGAN_F16_VQVAE.split('/'))}.pth" 56 | CLIP_EXT = f"{'.'.join(CLIP.split('/'))}.pth" 57 | 58 | LAION_AESTHETICS_V2_5_PLUS = "s3://hf-datasets-laion-5b-us-west-2/glacier/laion-data/laion-aesthetics-v2-5-plus-data" 59 | LAION_AESTHETICS_V2_6_PLUS = "s3://muse-datasets/laion-aesthetic6plus-data" 60 | COYO = "s3://hf-datasets-coyo-700m-us-west-2/data" 61 | 62 | LAION_AESTHETICS_V2_5_PLUS_PRE_ENCODED = "s3://muse-datasets/hf-datasets-laion-aesthetics-v2-5-plus-data-pre-encoded" 63 | LAION_AESTHETICS_V2_6_PLUS_PRE_ENCODED = "s3://muse-datasets/hf-datasets-laion-aesthetic6plus-data-pre-encoded" 64 | COYO_PRE_ENCODED = "s3://muse-datasets/hf-datasets-coyo-700m-pre-encoded" 65 | 66 | logger = logging.getLogger(__name__) 67 | 68 | tar_regex = r"\/([^\/]+\.tar)" 69 | 70 | 71 | def get_tar_file_name(url): 72 | match = re.search(tar_regex, url) 73 | assert match is not None, url 74 | tar_file_name = match.group(1) 75 | return tar_file_name 76 | 77 | 78 | def format_shard_number(shard_n: int): 79 | return "{:0>{}}".format(shard_n, 5) 80 | 81 | 82 | class Uploads: 83 | """ 84 | Uploads manages the post encoding steps, both CUDA -> cpu and the s3 upload. 85 | 86 | In order to avoid an expensive cuda sync event of the encode for every batch, 87 | instead "submit" the entirety of the post processing to a thread pool. Once the 88 | thread pool is full, we hand over the entirety of the thread pool to the python 89 | interpreter. This effectively allows multiple encoding batches to execute at once. 90 | At a 160 batch size, this uses <40 GB VRAM. 91 | 92 | TODO - probably would be better to wait until the thread pool is full and then 93 | execute just the least recent post processing? This could even be done without a 94 | thread pool or with a single thread, since it's executing one job at a time. Hmmm. 95 | 96 | The class must manage 97 | 1) the thread pool 98 | 2) the list of pending futures that have been submitted 99 | 3) a list of tar writers to upload results 100 | 101 | For the list of tar writers, we keep at most 5 open at a time. When we need to 102 | open an additional writer, we close the earliest opened one assuming that we have 103 | finished writing to it as the archives are read sequentially. This is an assumption 104 | but 5 is a safe buffer as we realistically will never be writing to more than 2 at a time 105 | for a reasonably sized thread pool. 106 | 107 | The list of tar writers is managed with a global lock because it opens a sub process and 108 | iirc Popen is not thread safe. Additionally each tar writer is managed with its own lock 109 | because writes are not thread safe and can corrupt the archive. 110 | """ 111 | 112 | def __init__(self, skip_upload, upload_to, num_writing_threads): 113 | self.open_lock = Lock() 114 | self.uploads = OrderedDict() 115 | self.skip_upload = skip_upload 116 | self.upload_to = upload_to 117 | self.futures = [] 118 | self.num_writing_threads = num_writing_threads 119 | self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.num_writing_threads) 120 | 121 | def __enter__(self): 122 | return self 123 | 124 | def __exit__(self, exc_type, exc_val, exc_tb): 125 | # Finish all pending encodings 126 | [x.result() for x in concurrent.futures.as_completed(self.futures)] 127 | 128 | self.executor.shutdown(wait=True) 129 | 130 | # Close all unclosed file writes 131 | for tar_file_name, tar_writer in self.uploads.items(): 132 | tar_writer["writer"].close() 133 | 134 | return False 135 | 136 | def submit( 137 | self, 138 | __key__, 139 | __url__, 140 | encoder_hidden_states, 141 | attention_mask_lengths, 142 | encoded_image_f8, 143 | encoded_image_f16, 144 | metadata, 145 | ): 146 | future = self.executor.submit( 147 | self._upload_thread_entrypoint, 148 | __key__, 149 | __url__, 150 | encoder_hidden_states, 151 | attention_mask_lengths, 152 | encoded_image_f8, 153 | encoded_image_f16, 154 | metadata, 155 | ) 156 | 157 | self.futures.append(future) 158 | 159 | # Give cuda some time to complete the encodings before moving to cpu and uploading 160 | if len(self.futures) == self.num_writing_threads: 161 | [x.result() for x in concurrent.futures.as_completed(self.futures)] 162 | self.futures = [] 163 | 164 | def _upload_thread_entrypoint( 165 | self, 166 | __key__, 167 | __url__, 168 | encoder_hidden_states, 169 | attention_mask_lengths, 170 | encoded_image_f8, 171 | encoded_image_f16, 172 | metadata, 173 | ): 174 | encoder_hidden_states = torch.unbind(encoder_hidden_states) 175 | encoded_image_f8 = torch.unbind(encoded_image_f8) 176 | encoded_image_f16 = torch.unbind(encoded_image_f16) 177 | 178 | for ( 179 | __key__, 180 | __url__, 181 | encoded_image_f8, 182 | encoded_image_f16, 183 | encoder_hidden_states, 184 | attention_mask_length, 185 | metadata, 186 | ) in zip( 187 | __key__, 188 | __url__, 189 | encoded_image_f8, 190 | encoded_image_f16, 191 | encoder_hidden_states, 192 | attention_mask_lengths, 193 | metadata, 194 | ): 195 | encoded_image_f8 = encoded_image_f8.clone().to("cpu") 196 | encoded_image_f16 = encoded_image_f16.clone().to("cpu") 197 | encoder_hidden_states = encoder_hidden_states.clone().to("cpu") 198 | 199 | if self.skip_upload: 200 | continue 201 | 202 | tar_file_name = get_tar_file_name(__url__) 203 | 204 | # It is not strictly clear to me if it is necessary to lock this whole block or 205 | # just part(s) of the kickout/create new writer. Just lock the whole function to be 206 | # safe. 207 | self.open_lock.acquire() 208 | 209 | if tar_file_name not in self.uploads: 210 | if len(self.uploads) == 5: 211 | # kick out the earliest one 212 | key = next(iter(self.uploads.keys())) 213 | self.uploads[key]["writer"].close() 214 | del self.uploads[key] 215 | 216 | upload_command = f"pipe:aws s3 cp - {self.upload_to}/{tar_file_name}" 217 | logger.warning(f"opening new writer for {upload_command}") 218 | 219 | self.uploads[tar_file_name] = { 220 | "writer": wds.TarWriter(upload_command), 221 | "lock": Lock(), 222 | } 223 | 224 | upload = self.uploads[tar_file_name] 225 | 226 | self.open_lock.release() 227 | 228 | metadata = dict(metadata) 229 | metadata["attention_mask_length"] = attention_mask_length 230 | 231 | sample = { 232 | "__key__": __key__, 233 | PAELLA_F8_VQVAE_EXT: encoded_image_f8, 234 | VQGAN_F16_VQVAE_EXT: encoded_image_f16, 235 | CLIP_EXT: encoder_hidden_states, 236 | "json": metadata, 237 | } 238 | 239 | # Not locking around the write will corrupt the tar file 240 | upload["lock"].acquire() 241 | upload["writer"].write(sample) 242 | upload["lock"].release() 243 | 244 | 245 | def distribute_shards(start_shard_all, end_shard_all, slurm_ntasks): 246 | total_shards = end_shard_all - start_shard_all + 1 247 | shards_per_task = total_shards // slurm_ntasks 248 | shards_per_task = [shards_per_task] * slurm_ntasks 249 | 250 | # to distribute the remainder of tasks for non-evenly divisible number of shards 251 | left_over_shards = total_shards % slurm_ntasks 252 | 253 | for slurm_procid in range(left_over_shards): 254 | shards_per_task[slurm_procid] += 1 255 | 256 | assert sum(shards_per_task) == total_shards 257 | 258 | distributed_shards = [] 259 | 260 | for slurm_procid in range(len(shards_per_task)): 261 | if slurm_procid == 0: 262 | start_shard = start_shard_all 263 | else: 264 | start_shard = distributed_shards[slurm_procid - 1][1] + 1 265 | 266 | end_shard = start_shard + shards_per_task[slurm_procid] - 1 267 | distributed_shards.append((start_shard, end_shard)) 268 | 269 | assert sum([end_shard - start_shard + 1 for start_shard, end_shard in distributed_shards]) == total_shards 270 | 271 | return distributed_shards 272 | 273 | 274 | def main(): 275 | parser = argparse.ArgumentParser() 276 | parser.add_argument( 277 | "--dataset", 278 | type=str, 279 | help="The dataset to pre-encode", 280 | choices=["laion_5", "laion_6", "coyo"], 281 | required=True, 282 | ) 283 | parser.add_argument( 284 | "--start_shard", 285 | type=int, 286 | help="The starting shard to pre-encode.", 287 | required=True, 288 | ) 289 | parser.add_argument( 290 | "--end_shard", 291 | type=int, 292 | help="The ending shard to pre-encode, inclusive. If not given, defaults to `--start_shard`.", 293 | required=False, 294 | ) 295 | parser.add_argument( 296 | "--slurm", 297 | action="store_true", 298 | help=( 299 | "If set, this process is running under a batch of slurm tasks." 300 | "`--start_shard` and `--end_shard` must be set for the entirety of shards over all slurm tasks." 301 | " The shards that will be encoded in each instance of the task will be determined via" 302 | " the env vars `$SLURM_NTASKS` and `$SLURM_PROCID`." 303 | ), 304 | ) 305 | parser.add_argument( 306 | "--batch_size", type=int, help="The batch size to encode at a time", required=False, default=160 307 | ) 308 | parser.add_argument( 309 | "--resolution", type=int, help="The resolution to convert the image to.", required=False, default=256 310 | ) 311 | parser.add_argument( 312 | "--skip_upload", 313 | action="store_true", 314 | help="Set to not actually upload results, helpful for only testing encoding.", 315 | ) 316 | parser.add_argument( 317 | "--num_writing_threads", 318 | type=int, 319 | required=False, 320 | default=40, 321 | ) 322 | 323 | args = parser.parse_args() 324 | 325 | if args.slurm and args.end_shard is None: 326 | raise ValueError("`--end_shard` must be set when `--slurm` is set") 327 | 328 | if args.end_shard is None: 329 | args.end_shard = args.start_shard 330 | 331 | if args.end_shard < args.start_shard: 332 | raise ValueError("`--end_shard` must be >= `--start_shard`") 333 | 334 | if args.batch_size < 1: 335 | raise ValueError("`--batch_size` must be >= 1") 336 | 337 | if args.resolution < 1: 338 | raise ValueError("`--resolution` must be >= 1") 339 | 340 | if args.dataset == "laion_5": 341 | args.dataset = LAION_AESTHETICS_V2_5_PLUS 342 | elif args.dataset == "laion_6": 343 | args.dataset = LAION_AESTHETICS_V2_6_PLUS 344 | elif args.dataset == "coyo": 345 | args.dataset = COYO 346 | else: 347 | assert False 348 | 349 | if args.dataset == LAION_AESTHETICS_V2_5_PLUS: 350 | upload_to = LAION_AESTHETICS_V2_5_PLUS_PRE_ENCODED 351 | elif args.dataset == LAION_AESTHETICS_V2_6_PLUS: 352 | upload_to = LAION_AESTHETICS_V2_6_PLUS_PRE_ENCODED 353 | elif args.dataset == COYO: 354 | upload_to = COYO_PRE_ENCODED 355 | else: 356 | assert False 357 | 358 | logger.warning("********************") 359 | logger.warning("Pre-encoding dataset") 360 | logger.warning(f"dataset: {args.dataset}") 361 | logger.warning(f"start_shard: {args.start_shard}") 362 | logger.warning(f"end_shard: {args.end_shard}") 363 | logger.warning(f"upload_to: {upload_to}") 364 | logger.warning(f"batch_size: {args.batch_size}") 365 | logger.warning("********************") 366 | 367 | if args.slurm: 368 | slurm_procid = int(os.environ["SLURM_PROCID"]) 369 | slurm_ntasks = int(os.environ["SLURM_NTASKS"]) 370 | 371 | distributed_shards = distribute_shards(args.start_shard, args.end_shard, slurm_ntasks) 372 | 373 | start_shard_task, end_shard_task = distributed_shards[slurm_procid] 374 | 375 | args.start_shard = start_shard_task 376 | args.end_shard = end_shard_task 377 | 378 | logger.warning("************") 379 | logger.warning("Running as slurm task") 380 | logger.warning(f"SLURM_NTASKS: {slurm_ntasks}") 381 | logger.warning(f"SLURM_PROCID: {slurm_procid}") 382 | logger.warning(f"start_shard: {start_shard_task}, end_shard: {end_shard_task}") 383 | logger.warning("************") 384 | logger.warning(f"all slurm processes") 385 | for slurm_proc_id_, (start_shard, end_shard) in enumerate(distributed_shards): 386 | logger.warning(f"slurm process: {slurm_proc_id_}, start_shard: {start_shard}, end_shard: {end_shard}") 387 | logger.warning("************") 388 | 389 | vae_f8 = PaellaVQModel.from_pretrained(PAELLA_F8_VQVAE) 390 | vae_f8.to("cuda") 391 | vae_f8.requires_grad_(False) 392 | 393 | vae_f16 = VQGANModel.from_pretrained(VQGAN_F16_VQVAE) 394 | vae_f16.to("cuda") 395 | vae_f16.requires_grad_(False) 396 | 397 | tokenizer = CLIPTokenizerFast.from_pretrained(CLIP) 398 | text_encoder = CLIPTextModel.from_pretrained(CLIP) 399 | text_encoder.to_bettertransformer() 400 | text_encoder.to("cuda") 401 | 402 | shard_range = "{" + format_shard_number(args.start_shard) + ".." + format_shard_number(args.end_shard) + "}" 403 | download_shards = f"pipe:aws s3 cp {args.dataset}/{shard_range}.tar -" 404 | 405 | logger.warning(f"downloading shards {download_shards}") 406 | 407 | src = ( 408 | wds.WebDataset( 409 | download_shards, 410 | ) 411 | .decode("pil", handler=wds.warn_and_continue) 412 | .rename(image="jpg;png;jpeg;webp", prompt="text;txt;caption", metadata="json") 413 | .map( 414 | lambda dict: { 415 | "__key__": dict["__key__"], 416 | "__url__": dict["__url__"], 417 | "image": dict["image"], 418 | "prompt": dict["prompt"], 419 | "metadata": dict["metadata"], 420 | } 421 | ) 422 | .to_tuple("__key__", "__url__", "image", "prompt", "metadata") 423 | .batched(args.batch_size) 424 | ) 425 | src = DataLoader( 426 | src, 427 | batch_size=None, 428 | shuffle=False, 429 | num_workers=0, 430 | ) 431 | 432 | with Uploads(args.skip_upload, upload_to, args.num_writing_threads) as uploads: 433 | for __key__, __url__, image, prompt, metadata in src: 434 | logger.warning(f"Encoding {len(__key__)} examples: {__key__[0]} to {__key__[-1]}.") 435 | 436 | encoded_prompts = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt") 437 | 438 | attention_masks = encoded_prompts.attention_mask 439 | # attention masks are [1, 1, 1, 1, 0, ....., 0] so summing gives us the 440 | # index of last non-zero element. 441 | attention_mask_lengths = attention_masks.sum(-1) 442 | # Will be stored as a part of json metadata 443 | attention_mask_lengths = attention_mask_lengths.tolist() 444 | 445 | input_ids = encoded_prompts.input_ids.to("cuda") 446 | 447 | all_images = [] 448 | 449 | for image_ in image: 450 | # The following is minorly more efficient than the default 451 | # torchvision to_tensor and lets use move to cuda earlier :P 452 | mode = image_.mode 453 | 454 | height = image_.height 455 | width = image_.width 456 | 457 | if hasattr(image_, "getbands"): 458 | channels = len(image_.getbands()) 459 | else: 460 | channels = image_.channels 461 | 462 | if mode == "I": 463 | nptype = np.int32 464 | elif mode == "I;16": 465 | nptype = np.int16 466 | elif mode == "F": 467 | nptype = np.float32 468 | else: 469 | nptype = np.uint8 470 | 471 | image_ = np.array(image_, nptype) 472 | image_ = torch.from_numpy(image_) 473 | image_: torch.Tensor = image_.to("cuda") 474 | 475 | image_ = image_.view(height, width, channels) 476 | image_ = image_.permute((2, 0, 1)).contiguous() 477 | 478 | if mode != "1" and image_.dtype == torch.uint8: 479 | image_ = image_.to(dtype=torch.float32).div(255) 480 | 481 | image_ = TF.resize( 482 | image_, size=args.resolution, interpolation=InterpolationMode.BILINEAR, antialias=True 483 | ) 484 | 485 | image_ = TF.center_crop(image_, args.resolution) 486 | 487 | all_images.append(image_) 488 | 489 | image = torch.stack(all_images) 490 | 491 | encoder_hidden_states = text_encoder(input_ids)[0] 492 | 493 | with torch.cuda.amp.autocast(): 494 | encoded_image_f8 = vae_f8.get_code(image) 495 | 496 | with torch.cuda.amp.autocast(): 497 | encoded_image_f16 = vae_f16.get_code(image) 498 | 499 | uploads.submit( 500 | __key__, 501 | __url__, 502 | encoder_hidden_states, 503 | attention_mask_lengths, 504 | encoded_image_f8, 505 | encoded_image_f16, 506 | metadata, 507 | ) 508 | 509 | 510 | if __name__ == "__main__": 511 | main() 512 | --------------------------------------------------------------------------------