├── .gitignore ├── DockerFile ├── LICENSE ├── README.md ├── assets ├── 2b_20b │ ├── .DS_Store │ ├── 1l.jpg │ ├── 1r.jpg │ ├── 2l.jpg │ ├── 2r.jpg │ ├── 3l.jpg │ ├── 3r.jpg │ ├── 4l.jpg │ └── 4r.jpg ├── 2b_8b │ ├── 1l.webp │ ├── 1r.webp │ ├── 2l.webp │ ├── 2r.webp │ ├── 3l.webp │ ├── 3r.webp │ └── 4r.webp ├── framework_row.png ├── scaling_models.png ├── scaling_vocabulary.png └── show_images.jpg ├── cog.yaml ├── conf.py ├── data ├── infinity_toy_data │ ├── .DS_Store │ ├── images │ │ ├── 1220076234599834949.jpg │ │ ├── 1713835988126009050.jpg │ │ ├── 2509925642183738470.jpg │ │ ├── 2732145247443895234.jpg │ │ ├── 3861311014320456446.jpg │ │ ├── 4265467520443567280.jpg │ │ ├── 5134521536907147208.jpg │ │ ├── 5179780969343495162.jpg │ │ ├── 5846166776365405949.jpg │ │ └── 6128985124434332020.jpg │ └── splits │ │ ├── 1.000_000002500.jsonl │ │ └── 1.500_000002500.jsonl └── labels │ └── imagenet │ └── val.txt ├── evaluation ├── README.md ├── gen_eval │ ├── _base_ │ │ ├── datasets │ │ │ └── coco_panoptic.py │ │ └── default_runtime.py │ ├── evaluate_images.py │ ├── infer4eval.py │ ├── mask2former │ │ ├── mask2former_r50_lsj_8x2_50e_coco-panoptic.py │ │ ├── mask2former_r50_lsj_8x2_50e_coco.py │ │ ├── mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py │ │ └── mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py │ ├── prompts │ │ └── create_prompts.py │ ├── rename.py │ └── summary_scores.py ├── hpsv2 │ └── eval_hpsv2.py ├── image_reward │ ├── cal_imagereward.py │ └── infer4eval.py └── validation_loss │ └── validation_loss.py ├── infinity ├── dataset │ ├── build.py │ └── dataset_t2i_iterable.py ├── models │ ├── __init__.py │ ├── basic.py │ ├── bitwise_self_correction.py │ ├── bsq_vae │ │ ├── conv.py │ │ ├── dynamic_resolution.py │ │ ├── flux_vqgan.py │ │ ├── multiscale_bsq.py │ │ └── vae.py │ ├── ema.py │ ├── flex_attn.py │ ├── fused_op.py │ ├── infinity.py │ ├── init_param.py │ └── t5.py └── utils │ ├── amp_opt.py │ ├── arg_util.py │ ├── csv_util.py │ ├── dist.py │ ├── dynamic_resolution.py │ ├── large_file_util.py │ ├── load.py │ ├── lr_control.py │ ├── misc.py │ ├── save_and_load.py │ └── wandb_utils.py ├── predict.py ├── requirements.txt ├── scripts ├── eval.sh ├── infer.sh └── train.sh ├── tools ├── comprehensive_infer.py ├── fid_score.py ├── inception.py ├── interactive_infer.ipynb ├── interactive_infer_8b.ipynb ├── prompt_rewriter.py ├── reproduce.py ├── run_infinity.py └── run_tokenizer.py ├── train.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | **/.ipynb_checkpoints/** 4 | .idea/* 5 | llava/ 6 | _vis_cached/ 7 | _vqgan/ 8 | _vae/ 9 | _vae*/ 10 | ckpt/ 11 | log/ 12 | tb*/ 13 | img*/ 14 | local_output* 15 | _auto_* 16 | sd-vae-ft-mse/ 17 | stable-diffusion-v1-4/ 18 | *.pth 19 | *.pth.tar 20 | *.ckpt 21 | *.log 22 | *.txt 23 | *.ipynb 24 | toscli 25 | *.hydra 26 | wandb 27 | *.jsonl 28 | *.jpg 29 | *.png 30 | *.json 31 | *.csv 32 | *.tar.gz 33 | *.bin 34 | data/ 35 | tmp 36 | output 37 | *.tsv 38 | *.mp4 39 | output/* 40 | results/ 41 | *.JPEG 42 | debug/ 43 | weights 44 | checkpoints 45 | ref.py 46 | wandb 47 | .DS_Store -------------------------------------------------------------------------------- /DockerFile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.5.1-cuda11.8-cudnn9-devel 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive \ 4 | PYTHONUNBUFFERED=1 \ 5 | CUDA_HOME=/usr/local/cuda \ 6 | PATH="$CUDA_HOME/bin:$PATH" 7 | 8 | RUN apt-get update && apt-get install -y --no-install-recommends \ 9 | git \ 10 | curl \ 11 | ffmpeg \ 12 | libsm6 \ 13 | libxext6 \ 14 | && apt-get clean && rm -rf /var/lib/apt/lists/* 15 | 16 | WORKDIR /workspace/ 17 | 18 | COPY requirements.txt /workspace/requirements.txt 19 | 20 | RUN pip install --upgrade pip \ 21 | && pip install ninja \ 22 | && MAX_JOBS=1 pip install flash-attn --no-build-isolation \ 23 | && pip install -r requirements.txt \ 24 | && pip install opencv-fixer==0.2.5 \ 25 | && python -c "from opencv_fixer import AutoFix; AutoFix()" 26 | 27 | CMD ["/bin/bash"] 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 FoundationVision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/2b_20b/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/.DS_Store -------------------------------------------------------------------------------- /assets/2b_20b/1l.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/1l.jpg -------------------------------------------------------------------------------- /assets/2b_20b/1r.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/1r.jpg -------------------------------------------------------------------------------- /assets/2b_20b/2l.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/2l.jpg -------------------------------------------------------------------------------- /assets/2b_20b/2r.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/2r.jpg -------------------------------------------------------------------------------- /assets/2b_20b/3l.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/3l.jpg -------------------------------------------------------------------------------- /assets/2b_20b/3r.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/3r.jpg -------------------------------------------------------------------------------- /assets/2b_20b/4l.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/4l.jpg -------------------------------------------------------------------------------- /assets/2b_20b/4r.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_20b/4r.jpg -------------------------------------------------------------------------------- /assets/2b_8b/1l.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_8b/1l.webp -------------------------------------------------------------------------------- /assets/2b_8b/1r.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_8b/1r.webp -------------------------------------------------------------------------------- /assets/2b_8b/2l.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_8b/2l.webp -------------------------------------------------------------------------------- /assets/2b_8b/2r.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_8b/2r.webp -------------------------------------------------------------------------------- /assets/2b_8b/3l.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_8b/3l.webp -------------------------------------------------------------------------------- /assets/2b_8b/3r.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_8b/3r.webp -------------------------------------------------------------------------------- /assets/2b_8b/4r.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/2b_8b/4r.webp -------------------------------------------------------------------------------- /assets/framework_row.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/framework_row.png -------------------------------------------------------------------------------- /assets/scaling_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/scaling_models.png -------------------------------------------------------------------------------- /assets/scaling_vocabulary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/scaling_vocabulary.png -------------------------------------------------------------------------------- /assets/show_images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/assets/show_images.jpg -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | 8 | # a list of ubuntu apt packages to install 9 | system_packages: 10 | - "libgl1-mesa-glx" 11 | - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | python_packages: 18 | - torch 19 | - transformers 20 | - easydict 21 | - typed-argument-parser 22 | - seaborn 23 | - kornia 24 | - gputil 25 | - colorama 26 | - omegaconf 27 | - pandas 28 | - timm==0.9.6 29 | - decord 30 | - pytz 31 | - pandas 32 | - wandb 33 | - colorama 34 | - imageio 35 | - einops 36 | - openai 37 | - httpx==0.20.0 38 | - opencv-python 39 | - ipython 40 | 41 | # commands run after the environment is setup 42 | run: 43 | - pip install "pydantic<2.0" 44 | - pip install -U flash-attn --no-build-isolation 45 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 46 | predict: "predict.py:Predictor" 47 | -------------------------------------------------------------------------------- /conf.py: -------------------------------------------------------------------------------- 1 | HF_TOKEN = '[YOUR HF_TOKEN]' 2 | HF_HOME = '[YOUR HF_HOME]' 3 | 4 | GPT_AK = '[YOUR GPT_AK]' 5 | -------------------------------------------------------------------------------- /data/infinity_toy_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/.DS_Store -------------------------------------------------------------------------------- /data/infinity_toy_data/images/1220076234599834949.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/1220076234599834949.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/1713835988126009050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/1713835988126009050.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/2509925642183738470.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/2509925642183738470.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/2732145247443895234.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/2732145247443895234.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/3861311014320456446.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/3861311014320456446.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/4265467520443567280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/4265467520443567280.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/5134521536907147208.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/5134521536907147208.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/5179780969343495162.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/5179780969343495162.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/5846166776365405949.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/5846166776365405949.jpg -------------------------------------------------------------------------------- /data/infinity_toy_data/images/6128985124434332020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/Infinity/a857767719feaec9fa27e62f7fa38d5b42462733/data/infinity_toy_data/images/6128985124434332020.jpg -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | We provide [eval.sh](scripts/eval.sh) for evaluation on various benchmarks with only one command. In particular, [eval.sh](scripts/eval.sh) supports evaluation on commonly used metrics such as [GenEval](https://github.com/djghosh13/geneval), [ImageReward](https://github.com/THUDM/ImageReward), [HPSv2.1](https://github.com/tgxs002/HPSv2), FID and Validation Loss. 3 | 4 | # Usage 5 | 6 | 7 | ## Basic Configuration 8 | 9 | ```shell 10 | # set arguments 11 | pn=1M 12 | model_type=infinity_2b 13 | infinity_model_path=[infinity_model_path] 14 | out_dir_root=[out_dir_root] 15 | vae_type=32 16 | vae_path=[vae_path] 17 | cfg=4 18 | tau=1 19 | text_encoder_ckpt=[text_encoder_ckpt] 20 | text_channels=2048 21 | sub_fix=cfg${cfg}_tau${tau} 22 | ``` 23 | 24 | 25 | ## ImageReward 26 | [ImageReward](https://github.com/THUDM/ImageReward) is a metric for evaluating the human preference score of generated images. It learns human preference through fine-tuning CLIP model with 137K human ranked image pairs. 27 | ```shell 28 | out_dir=${out_dir_root}/image_reward_${sub_fix} 29 | infer_eval_image_reward 30 | ``` 31 | 32 | ## HPS v2.1 33 | [HPSv2.1](https://github.com/tgxs002/HPSv2) is a metric for evaluating the human preference score of generated images. It learns human preference through fine-tuning CLIP model with 798K human ranked image pairs. The human ranked image pairs are from human experts. 34 | ```shell 35 | out_dir=${out_dir_root}/hpsv21_${sub_fix} 36 | infer_eval_hpsv21 37 | ``` 38 | 39 | ## GenEval 40 | [GenEval](https://github.com/djghosh13/geneval) is an object-focused framework for evaluating Text-to-Image alignment. 41 | ```shell 42 | rewrite_prompt=0 43 | out_dir=${out_dir_root}/gen_eval_${sub_fix} 44 | test_gen_eval 45 | ``` 46 | 47 | ## FID 48 | For testing FID, you need provide a jsonl file which contains text prompts and ground truth images. We highly recommand the number of examples in the jsonl file is greater than 20000 since testing FID needs abundant of examples. 49 | ```shell 50 | long_caption_fid=1 51 | jsonl_filepath=[jsonl_filepath] 52 | out_dir=${out_dir_root}/val_long_caption_fid_${sub_fix} 53 | rm -rf ${out_dir} 54 | test_fid 55 | ``` 56 | 57 | ## Validation Loss 58 | For testing Validation Loss, you need provide a jsonl folder like the training jsonl folder. Besides, you should specify the noise applying strength for Bitwise Self-Correction to the same strength used in the training phrase. 59 | ```shell 60 | out_dir=${out_dir_root}/val_loss_${sub_fix} 61 | reweight_loss_by_scale=0 62 | jsonl_folder=[jsonl_folder] 63 | noise_apply_strength=0.2 64 | test_val_loss 65 | ``` -------------------------------------------------------------------------------- /evaluation/gen_eval/_base_/datasets/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CocoPanopticDataset' 3 | data_root = 'data/coco/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict( 9 | type='LoadPanopticAnnotations', 10 | with_bbox=True, 11 | with_mask=True, 12 | with_seg=True), 13 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 14 | dict(type='RandomFlip', flip_ratio=0.5), 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='Pad', size_divisor=32), 17 | dict(type='SegRescale', scale_factor=1 / 4), 18 | dict(type='DefaultFormatBundle'), 19 | dict( 20 | type='Collect', 21 | keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), 22 | ] 23 | test_pipeline = [ 24 | dict(type='LoadImageFromFile'), 25 | dict( 26 | type='MultiScaleFlipAug', 27 | img_scale=(1333, 800), 28 | flip=False, 29 | transforms=[ 30 | dict(type='Resize', keep_ratio=True), 31 | dict(type='RandomFlip'), 32 | dict(type='Normalize', **img_norm_cfg), 33 | dict(type='Pad', size_divisor=32), 34 | dict(type='ImageToTensor', keys=['img']), 35 | dict(type='Collect', keys=['img']), 36 | ]) 37 | ] 38 | data = dict( 39 | samples_per_gpu=2, 40 | workers_per_gpu=2, 41 | train=dict( 42 | type=dataset_type, 43 | ann_file=data_root + 'annotations/panoptic_train2017.json', 44 | img_prefix=data_root + 'train2017/', 45 | seg_prefix=data_root + 'annotations/panoptic_train2017/', 46 | pipeline=train_pipeline), 47 | val=dict( 48 | type=dataset_type, 49 | ann_file=data_root + 'annotations/panoptic_val2017.json', 50 | img_prefix=data_root + 'val2017/', 51 | seg_prefix=data_root + 'annotations/panoptic_val2017/', 52 | pipeline=test_pipeline), 53 | test=dict( 54 | type=dataset_type, 55 | ann_file=data_root + 'annotations/panoptic_val2017.json', 56 | img_prefix=data_root + 'val2017/', 57 | seg_prefix=data_root + 'annotations/panoptic_val2017/', 58 | pipeline=test_pipeline)) 59 | evaluation = dict(interval=1, metric=['PQ']) 60 | -------------------------------------------------------------------------------- /evaluation/gen_eval/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | custom_hooks = [dict(type='NumClassCheckHook')] 11 | 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | 18 | # disable opencv multithreading to avoid system being overloaded 19 | opencv_num_threads = 0 20 | # set multi-process start method as `fork` to speed up the training 21 | mp_start_method = 'fork' 22 | 23 | # Default setting for scaling LR automatically 24 | # - `enable` means enable scaling LR automatically 25 | # or not by default. 26 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 27 | auto_scale_lr = dict(enable=False, base_batch_size=16) 28 | -------------------------------------------------------------------------------- /evaluation/gen_eval/infer4eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import hashlib 4 | import time 5 | import argparse 6 | import json 7 | import shutil 8 | import glob 9 | import re 10 | import sys 11 | 12 | import cv2 13 | import tqdm 14 | import torch 15 | import numpy as np 16 | from pytorch_lightning import seed_everything 17 | 18 | from infinity.utils.csv_util import load_csv_as_dicts, write_dicts2csv_file 19 | from tools.run_infinity import * 20 | from conf import HF_TOKEN, HF_HOME 21 | 22 | # set environment variables 23 | os.environ['HF_TOKEN'] = HF_TOKEN 24 | os.environ['HF_HOME'] = HF_HOME 25 | os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1' 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | add_common_arguments(parser) 31 | parser.add_argument('--outdir', type=str, default='') 32 | parser.add_argument('--n_samples', type=int, default=4) 33 | parser.add_argument('--metadata_file', type=str, default='evaluation/gen_eval/prompts/evaluation_metadata.jsonl') 34 | parser.add_argument('--rewrite_prompt', type=int, default=0, choices=[0,1]) 35 | parser.add_argument('--load_rewrite_prompt_cache', type=int, default=1, choices=[0,1]) 36 | args = parser.parse_args() 37 | 38 | # parse cfg 39 | args.cfg = list(map(float, args.cfg.split(','))) 40 | if len(args.cfg) == 1: 41 | args.cfg = args.cfg[0] 42 | 43 | with open(args.metadata_file) as fp: 44 | metadatas = [json.loads(line) for line in fp] 45 | 46 | prompt_rewrite_cache_file = osp.join('evaluation/gen_eval', 'prompt_rewrite_cache.json') 47 | if osp.exists(prompt_rewrite_cache_file): 48 | with open(prompt_rewrite_cache_file, 'r') as f: 49 | prompt_rewrite_cache = json.load(f) 50 | else: 51 | prompt_rewrite_cache = {} 52 | 53 | if args.model_type == 'flux_1_dev': 54 | from diffusers import FluxPipeline 55 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") 56 | elif args.model_type == 'flux_1_dev_schnell': 57 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda") 58 | elif 'infinity' in args.model_type: 59 | # load text encoder 60 | text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt) 61 | # load vae 62 | vae = load_visual_tokenizer(args) 63 | # load infinity 64 | infinity = load_transformer(vae, args) 65 | 66 | if args.rewrite_prompt: 67 | from tools.prompt_rewriter import PromptRewriter 68 | prompt_rewriter = PromptRewriter(system='', few_shot_history=[]) 69 | 70 | for index, metadata in enumerate(metadatas): 71 | seed_everything(args.seed) 72 | outpath = os.path.join(args.outdir, f"{index:0>5}") 73 | os.makedirs(outpath, exist_ok=True) 74 | prompt = metadata['prompt'] 75 | print(f"Prompt ({index: >3}/{len(metadatas)}): '{prompt}'") 76 | 77 | sample_path = os.path.join(outpath, "samples") 78 | os.makedirs(sample_path, exist_ok=True) 79 | with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp: 80 | json.dump(metadata, fp) 81 | 82 | tau = args.tau 83 | cfg = args.cfg 84 | if args.rewrite_prompt: 85 | old_prompt = prompt 86 | if args.load_rewrite_prompt_cache and prompt in prompt_rewrite_cache: 87 | prompt = prompt_rewrite_cache[prompt] 88 | else: 89 | refined_prompt = prompt_rewriter.rewrite(prompt) 90 | input_key_val = extract_key_val(refined_prompt) 91 | prompt = input_key_val['prompt'] 92 | prompt_rewrite_cache[prompt] = prompt 93 | print(f'old_prompt: {old_prompt}, refined_prompt: {prompt}') 94 | 95 | images = [] 96 | for sample_j in range(args.n_samples): 97 | print(f"Generating {sample_j+1} of {args.n_samples}, prompt={prompt}") 98 | t1 = time.time() 99 | if args.model_type == 'flux_1_dev': 100 | image = pipe( 101 | prompt, 102 | height=1024, 103 | width=1024, 104 | guidance_scale=3.5, 105 | num_inference_steps=50, 106 | max_sequence_length=512, 107 | num_images_per_prompt=1, 108 | ).images[0] 109 | elif args.model_type == 'flux_1_dev_schnell': 110 | image = pipe( 111 | prompt, 112 | height=1024, 113 | width=1024, 114 | guidance_scale=0.0, 115 | num_inference_steps=4, 116 | max_sequence_length=256, 117 | generator=torch.Generator("cpu").manual_seed(0) 118 | ).images[0] 119 | elif args.model_type == 'pixart_sigma': 120 | image = pipe(prompt).images[0] 121 | elif 'infinity' in args.model_type: 122 | h_div_w_template = 1.000 123 | scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales'] 124 | scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] 125 | tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][args.pn]['pixel'] 126 | image = gen_one_img(infinity, vae, text_tokenizer, text_encoder, prompt, tau_list=tau, cfg_sc=3, cfg_list=cfg, scale_schedule=scale_schedule, cfg_insertion_layer=[args.cfg_insertion_layer], vae_type=args.vae_type) 127 | else: 128 | raise ValueError 129 | t2 = time.time() 130 | print(f'{args.model_type} infer one image takes {t2-t1:.2f}s') 131 | images.append(image) 132 | for i, image in enumerate(images): 133 | save_file = os.path.join(sample_path, f"{i:05}.jpg") 134 | if 'infinity' in args.model_type: 135 | cv2.imwrite(save_file, image.cpu().numpy()) 136 | else: 137 | image.save(save_file) 138 | 139 | with open(prompt_rewrite_cache_file, 'w') as f: 140 | json.dump(prompt_rewrite_cache, f, indent=2) 141 | -------------------------------------------------------------------------------- /evaluation/gen_eval/mask2former/mask2former_r50_lsj_8x2_50e_coco-panoptic.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py' 3 | ] 4 | num_things_classes = 80 5 | num_stuff_classes = 53 6 | num_classes = num_things_classes + num_stuff_classes 7 | model = dict( 8 | type='Mask2Former', 9 | backbone=dict( 10 | type='ResNet', 11 | depth=50, 12 | num_stages=4, 13 | out_indices=(0, 1, 2, 3), 14 | frozen_stages=-1, 15 | norm_cfg=dict(type='BN', requires_grad=False), 16 | norm_eval=True, 17 | style='pytorch', 18 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), 19 | panoptic_head=dict( 20 | type='Mask2FormerHead', 21 | in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside 22 | strides=[4, 8, 16, 32], 23 | feat_channels=256, 24 | out_channels=256, 25 | num_things_classes=num_things_classes, 26 | num_stuff_classes=num_stuff_classes, 27 | num_queries=100, 28 | num_transformer_feat_level=3, 29 | pixel_decoder=dict( 30 | type='MSDeformAttnPixelDecoder', 31 | num_outs=3, 32 | norm_cfg=dict(type='GN', num_groups=32), 33 | act_cfg=dict(type='ReLU'), 34 | encoder=dict( 35 | type='DetrTransformerEncoder', 36 | num_layers=6, 37 | transformerlayers=dict( 38 | type='BaseTransformerLayer', 39 | attn_cfgs=dict( 40 | type='MultiScaleDeformableAttention', 41 | embed_dims=256, 42 | num_heads=8, 43 | num_levels=3, 44 | num_points=4, 45 | im2col_step=64, 46 | dropout=0.0, 47 | batch_first=False, 48 | norm_cfg=None, 49 | init_cfg=None), 50 | ffn_cfgs=dict( 51 | type='FFN', 52 | embed_dims=256, 53 | feedforward_channels=1024, 54 | num_fcs=2, 55 | ffn_drop=0.0, 56 | act_cfg=dict(type='ReLU', inplace=True)), 57 | operation_order=('self_attn', 'norm', 'ffn', 'norm')), 58 | init_cfg=None), 59 | positional_encoding=dict( 60 | type='SinePositionalEncoding', num_feats=128, normalize=True), 61 | init_cfg=None), 62 | enforce_decoder_input_project=False, 63 | positional_encoding=dict( 64 | type='SinePositionalEncoding', num_feats=128, normalize=True), 65 | transformer_decoder=dict( 66 | type='DetrTransformerDecoder', 67 | return_intermediate=True, 68 | num_layers=9, 69 | transformerlayers=dict( 70 | type='DetrTransformerDecoderLayer', 71 | attn_cfgs=dict( 72 | type='MultiheadAttention', 73 | embed_dims=256, 74 | num_heads=8, 75 | attn_drop=0.0, 76 | proj_drop=0.0, 77 | dropout_layer=None, 78 | batch_first=False), 79 | ffn_cfgs=dict( 80 | embed_dims=256, 81 | feedforward_channels=2048, 82 | num_fcs=2, 83 | act_cfg=dict(type='ReLU', inplace=True), 84 | ffn_drop=0.0, 85 | dropout_layer=None, 86 | add_identity=True), 87 | feedforward_channels=2048, 88 | operation_order=('cross_attn', 'norm', 'self_attn', 'norm', 89 | 'ffn', 'norm')), 90 | init_cfg=None), 91 | loss_cls=dict( 92 | type='CrossEntropyLoss', 93 | use_sigmoid=False, 94 | loss_weight=2.0, 95 | reduction='mean', 96 | class_weight=[1.0] * num_classes + [0.1]), 97 | loss_mask=dict( 98 | type='CrossEntropyLoss', 99 | use_sigmoid=True, 100 | reduction='mean', 101 | loss_weight=5.0), 102 | loss_dice=dict( 103 | type='DiceLoss', 104 | use_sigmoid=True, 105 | activate=True, 106 | reduction='mean', 107 | naive_dice=True, 108 | eps=1.0, 109 | loss_weight=5.0)), 110 | panoptic_fusion_head=dict( 111 | type='MaskFormerFusionHead', 112 | num_things_classes=num_things_classes, 113 | num_stuff_classes=num_stuff_classes, 114 | loss_panoptic=None, 115 | init_cfg=None), 116 | train_cfg=dict( 117 | num_points=12544, 118 | oversample_ratio=3.0, 119 | importance_sample_ratio=0.75, 120 | assigner=dict( 121 | type='MaskHungarianAssigner', 122 | cls_cost=dict(type='ClassificationCost', weight=2.0), 123 | mask_cost=dict( 124 | type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), 125 | dice_cost=dict( 126 | type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), 127 | sampler=dict(type='MaskPseudoSampler')), 128 | test_cfg=dict( 129 | panoptic_on=True, 130 | # For now, the dataset does not support 131 | # evaluating semantic segmentation metric. 132 | semantic_on=False, 133 | instance_on=True, 134 | # max_per_image is for instance segmentation. 135 | max_per_image=100, 136 | iou_thr=0.8, 137 | # In Mask2Former's panoptic postprocessing, 138 | # it will filter mask area where score is less than 0.5 . 139 | filter_low_score=True), 140 | init_cfg=None) 141 | 142 | # dataset settings 143 | image_size = (1024, 1024) 144 | img_norm_cfg = dict( 145 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 146 | train_pipeline = [ 147 | dict(type='LoadImageFromFile', to_float32=True), 148 | dict( 149 | type='LoadPanopticAnnotations', 150 | with_bbox=True, 151 | with_mask=True, 152 | with_seg=True), 153 | dict(type='RandomFlip', flip_ratio=0.5), 154 | # large scale jittering 155 | dict( 156 | type='Resize', 157 | img_scale=image_size, 158 | ratio_range=(0.1, 2.0), 159 | multiscale_mode='range', 160 | keep_ratio=True), 161 | dict( 162 | type='RandomCrop', 163 | crop_size=image_size, 164 | crop_type='absolute', 165 | recompute_bbox=True, 166 | allow_negative_crop=True), 167 | dict(type='Normalize', **img_norm_cfg), 168 | dict(type='Pad', size=image_size), 169 | dict(type='DefaultFormatBundle', img_to_float=True), 170 | dict( 171 | type='Collect', 172 | keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), 173 | ] 174 | test_pipeline = [ 175 | dict(type='LoadImageFromFile'), 176 | dict( 177 | type='MultiScaleFlipAug', 178 | img_scale=(1333, 800), 179 | flip=False, 180 | transforms=[ 181 | dict(type='Resize', keep_ratio=True), 182 | dict(type='RandomFlip'), 183 | dict(type='Normalize', **img_norm_cfg), 184 | dict(type='Pad', size_divisor=32), 185 | dict(type='ImageToTensor', keys=['img']), 186 | dict(type='Collect', keys=['img']), 187 | ]) 188 | ] 189 | data_root = 'data/coco/' 190 | data = dict( 191 | samples_per_gpu=2, 192 | workers_per_gpu=2, 193 | train=dict(pipeline=train_pipeline), 194 | val=dict( 195 | pipeline=test_pipeline, 196 | ins_ann_file=data_root + 'annotations/instances_val2017.json', 197 | ), 198 | test=dict( 199 | pipeline=test_pipeline, 200 | ins_ann_file=data_root + 'annotations/instances_val2017.json', 201 | )) 202 | 203 | embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 204 | # optimizer 205 | optimizer = dict( 206 | type='AdamW', 207 | lr=0.0001, 208 | weight_decay=0.05, 209 | eps=1e-8, 210 | betas=(0.9, 0.999), 211 | paramwise_cfg=dict( 212 | custom_keys={ 213 | 'backbone': dict(lr_mult=0.1, decay_mult=1.0), 214 | 'query_embed': embed_multi, 215 | 'query_feat': embed_multi, 216 | 'level_embed': embed_multi, 217 | }, 218 | norm_decay_mult=0.0)) 219 | optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2)) 220 | 221 | # learning policy 222 | lr_config = dict( 223 | policy='step', 224 | gamma=0.1, 225 | by_epoch=False, 226 | step=[327778, 355092], 227 | warmup='linear', 228 | warmup_by_epoch=False, 229 | warmup_ratio=1.0, # no warmup 230 | warmup_iters=10) 231 | 232 | max_iters = 368750 233 | runner = dict(type='IterBasedRunner', max_iters=max_iters) 234 | 235 | log_config = dict( 236 | interval=50, 237 | hooks=[ 238 | dict(type='TextLoggerHook', by_epoch=False), 239 | dict(type='TensorboardLoggerHook', by_epoch=False) 240 | ]) 241 | interval = 5000 242 | workflow = [('train', interval)] 243 | checkpoint_config = dict( 244 | by_epoch=False, interval=interval, save_last=True, max_keep_ckpts=3) 245 | 246 | # Before 365001th iteration, we do evaluation every 5000 iterations. 247 | # After 365000th iteration, we do evaluation every 368750 iterations, 248 | # which means that we do evaluation at the end of training. 249 | dynamic_intervals = [(max_iters // interval * interval + 1, max_iters)] 250 | evaluation = dict( 251 | interval=interval, 252 | dynamic_intervals=dynamic_intervals, 253 | metric=['PQ', 'bbox', 'segm']) 254 | -------------------------------------------------------------------------------- /evaluation/gen_eval/mask2former/mask2former_r50_lsj_8x2_50e_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./mask2former_r50_lsj_8x2_50e_coco-panoptic.py'] 2 | num_things_classes = 80 3 | num_stuff_classes = 0 4 | num_classes = num_things_classes + num_stuff_classes 5 | model = dict( 6 | panoptic_head=dict( 7 | num_things_classes=num_things_classes, 8 | num_stuff_classes=num_stuff_classes, 9 | loss_cls=dict(class_weight=[1.0] * num_classes + [0.1])), 10 | panoptic_fusion_head=dict( 11 | num_things_classes=num_things_classes, 12 | num_stuff_classes=num_stuff_classes), 13 | test_cfg=dict(panoptic_on=False)) 14 | 15 | # dataset settings 16 | image_size = (1024, 1024) 17 | img_norm_cfg = dict( 18 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 19 | pad_cfg = dict(img=(128, 128, 128), masks=0, seg=255) 20 | train_pipeline = [ 21 | dict(type='LoadImageFromFile', to_float32=True), 22 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 23 | dict(type='RandomFlip', flip_ratio=0.5), 24 | # large scale jittering 25 | dict( 26 | type='Resize', 27 | img_scale=image_size, 28 | ratio_range=(0.1, 2.0), 29 | multiscale_mode='range', 30 | keep_ratio=True), 31 | dict( 32 | type='RandomCrop', 33 | crop_size=image_size, 34 | crop_type='absolute', 35 | recompute_bbox=True, 36 | allow_negative_crop=True), 37 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-5, 1e-5), by_mask=True), 38 | dict(type='Pad', size=image_size, pad_val=pad_cfg), 39 | dict(type='Normalize', **img_norm_cfg), 40 | dict(type='DefaultFormatBundle', img_to_float=True), 41 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 42 | ] 43 | test_pipeline = [ 44 | dict(type='LoadImageFromFile'), 45 | dict( 46 | type='MultiScaleFlipAug', 47 | img_scale=(1333, 800), 48 | flip=False, 49 | transforms=[ 50 | dict(type='Resize', keep_ratio=True), 51 | dict(type='RandomFlip'), 52 | dict(type='Pad', size_divisor=32, pad_val=pad_cfg), 53 | dict(type='Normalize', **img_norm_cfg), 54 | dict(type='ImageToTensor', keys=['img']), 55 | dict(type='Collect', keys=['img']), 56 | ]) 57 | ] 58 | dataset_type = 'CocoDataset' 59 | data_root = 'data/coco/' 60 | data = dict( 61 | _delete_=True, 62 | samples_per_gpu=2, 63 | workers_per_gpu=2, 64 | train=dict( 65 | type=dataset_type, 66 | ann_file=data_root + 'annotations/instances_train2017.json', 67 | img_prefix=data_root + 'train2017/', 68 | pipeline=train_pipeline), 69 | val=dict( 70 | type=dataset_type, 71 | ann_file=data_root + 'annotations/instances_val2017.json', 72 | img_prefix=data_root + 'val2017/', 73 | pipeline=test_pipeline), 74 | test=dict( 75 | type=dataset_type, 76 | ann_file=data_root + 'annotations/instances_val2017.json', 77 | img_prefix=data_root + 'val2017/', 78 | pipeline=test_pipeline)) 79 | evaluation = dict(metric=['bbox', 'segm']) 80 | -------------------------------------------------------------------------------- /evaluation/gen_eval/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py'] 2 | pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth' # noqa 3 | 4 | depths = [2, 2, 18, 2] 5 | model = dict( 6 | backbone=dict( 7 | depths=depths, init_cfg=dict(type='Pretrained', 8 | checkpoint=pretrained))) 9 | 10 | # set all layers in backbone to lr_mult=0.1 11 | # set all norm layers, position_embeding, 12 | # query_embeding, level_embeding to decay_multi=0.0 13 | backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0) 14 | backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0) 15 | embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 16 | custom_keys = { 17 | 'backbone': dict(lr_mult=0.1, decay_mult=1.0), 18 | 'backbone.patch_embed.norm': backbone_norm_multi, 19 | 'backbone.norm': backbone_norm_multi, 20 | 'absolute_pos_embed': backbone_embed_multi, 21 | 'relative_position_bias_table': backbone_embed_multi, 22 | 'query_embed': embed_multi, 23 | 'query_feat': embed_multi, 24 | 'level_embed': embed_multi 25 | } 26 | custom_keys.update({ 27 | f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi 28 | for stage_id, num_blocks in enumerate(depths) 29 | for block_id in range(num_blocks) 30 | }) 31 | custom_keys.update({ 32 | f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi 33 | for stage_id in range(len(depths) - 1) 34 | }) 35 | # optimizer 36 | optimizer = dict( 37 | paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0)) 38 | -------------------------------------------------------------------------------- /evaluation/gen_eval/mask2former/mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./mask2former_r50_lsj_8x2_50e_coco.py'] 2 | pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa 3 | depths = [2, 2, 6, 2] 4 | model = dict( 5 | type='Mask2Former', 6 | backbone=dict( 7 | _delete_=True, 8 | type='SwinTransformer', 9 | embed_dims=96, 10 | depths=depths, 11 | num_heads=[3, 6, 12, 24], 12 | window_size=7, 13 | mlp_ratio=4, 14 | qkv_bias=True, 15 | qk_scale=None, 16 | drop_rate=0., 17 | attn_drop_rate=0., 18 | drop_path_rate=0.3, 19 | patch_norm=True, 20 | out_indices=(0, 1, 2, 3), 21 | with_cp=False, 22 | convert_weights=True, 23 | frozen_stages=-1, 24 | init_cfg=dict(type='Pretrained', checkpoint=pretrained)), 25 | panoptic_head=dict( 26 | type='Mask2FormerHead', in_channels=[96, 192, 384, 768]), 27 | init_cfg=None) 28 | 29 | # set all layers in backbone to lr_mult=0.1 30 | # set all norm layers, position_embeding, 31 | # query_embeding, level_embeding to decay_multi=0.0 32 | backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0) 33 | backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0) 34 | embed_multi = dict(lr_mult=1.0, decay_mult=0.0) 35 | custom_keys = { 36 | 'backbone': dict(lr_mult=0.1, decay_mult=1.0), 37 | 'backbone.patch_embed.norm': backbone_norm_multi, 38 | 'backbone.norm': backbone_norm_multi, 39 | 'absolute_pos_embed': backbone_embed_multi, 40 | 'relative_position_bias_table': backbone_embed_multi, 41 | 'query_embed': embed_multi, 42 | 'query_feat': embed_multi, 43 | 'level_embed': embed_multi 44 | } 45 | custom_keys.update({ 46 | f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi 47 | for stage_id, num_blocks in enumerate(depths) 48 | for block_id in range(num_blocks) 49 | }) 50 | custom_keys.update({ 51 | f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi 52 | for stage_id in range(len(depths) - 1) 53 | }) 54 | # optimizer 55 | optimizer = dict( 56 | type='AdamW', 57 | lr=0.0001, 58 | weight_decay=0.05, 59 | eps=1e-8, 60 | betas=(0.9, 0.999), 61 | paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0)) 62 | -------------------------------------------------------------------------------- /evaluation/gen_eval/prompts/create_prompts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate prompts for evaluation 3 | """ 4 | 5 | import argparse 6 | import json 7 | import os 8 | import yaml 9 | 10 | import numpy as np 11 | 12 | # Load classnames 13 | 14 | with open("object_names.txt") as cls_file: 15 | classnames = [line.strip() for line in cls_file] 16 | 17 | # Proper a vs an 18 | 19 | def with_article(name: str): 20 | if name[0] in "aeiou": 21 | return f"an {name}" 22 | return f"a {name}" 23 | 24 | # Proper plural 25 | 26 | def make_plural(name: str): 27 | if name[-1] in "s": 28 | return f"{name}es" 29 | return f"{name}s" 30 | 31 | # Generates single object samples 32 | 33 | def generate_single_object_sample(rng: np.random.Generator, size: int = None): 34 | TAG = "single_object" 35 | if size > len(classnames): 36 | size = len(classnames) 37 | print(f"Not enough distinct classes, generating only {size} samples") 38 | return_scalar = size is None 39 | size = size or 1 40 | idxs = rng.choice(len(classnames), size=size, replace=False) 41 | samples = [dict( 42 | tag=TAG, 43 | include=[ 44 | {"class": classnames[idx], "count": 1} 45 | ], 46 | prompt=f"a photo of {with_article(classnames[idx])}" 47 | ) for idx in idxs] 48 | if return_scalar: 49 | return samples[0] 50 | return samples 51 | 52 | # Generate two object samples 53 | 54 | def generate_two_object_sample(rng: np.random.Generator): 55 | TAG = "two_object" 56 | idx_a, idx_b = rng.choice(len(classnames), size=2, replace=False) 57 | return dict( 58 | tag=TAG, 59 | include=[ 60 | {"class": classnames[idx_a], "count": 1}, 61 | {"class": classnames[idx_b], "count": 1} 62 | ], 63 | prompt=f"a photo of {with_article(classnames[idx_a])} and {with_article(classnames[idx_b])}" 64 | ) 65 | 66 | # Generate counting samples 67 | 68 | numbers = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"] 69 | 70 | def generate_counting_sample(rng: np.random.Generator, max_count=4): 71 | TAG = "counting" 72 | idx = rng.choice(len(classnames)) 73 | num = int(rng.integers(2, max_count, endpoint=True)) 74 | return dict( 75 | tag=TAG, 76 | include=[ 77 | {"class": classnames[idx], "count": num} 78 | ], 79 | exclude=[ 80 | {"class": classnames[idx], "count": num + 1} 81 | ], 82 | prompt=f"a photo of {numbers[num]} {make_plural(classnames[idx])}" 83 | ) 84 | 85 | # Generate color samples 86 | 87 | colors = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"] 88 | 89 | def generate_color_sample(rng: np.random.Generator): 90 | TAG = "colors" 91 | idx = rng.choice(len(classnames) - 1) + 1 92 | idx = (idx + classnames.index("person")) % len(classnames) # No "[COLOR] person" prompts 93 | color = colors[rng.choice(len(colors))] 94 | return dict( 95 | tag=TAG, 96 | include=[ 97 | {"class": classnames[idx], "count": 1, "color": color} 98 | ], 99 | prompt=f"a photo of {with_article(color)} {classnames[idx]}" 100 | ) 101 | 102 | # Generate position samples 103 | 104 | positions = ["left of", "right of", "above", "below"] 105 | 106 | def generate_position_sample(rng: np.random.Generator): 107 | TAG = "position" 108 | idx_a, idx_b = rng.choice(len(classnames), size=2, replace=False) 109 | position = positions[rng.choice(len(positions))] 110 | return dict( 111 | tag=TAG, 112 | include=[ 113 | {"class": classnames[idx_b], "count": 1}, 114 | {"class": classnames[idx_a], "count": 1, "position": (position, 0)} 115 | ], 116 | prompt=f"a photo of {with_article(classnames[idx_a])} {position} {with_article(classnames[idx_b])}" 117 | ) 118 | 119 | # Generate color attribution samples 120 | 121 | def generate_color_attribution_sample(rng: np.random.Generator): 122 | TAG = "color_attr" 123 | idxs = rng.choice(len(classnames) - 1, size=2, replace=False) + 1 124 | idx_a, idx_b = (idxs + classnames.index("person")) % len(classnames) # No "[COLOR] person" prompts 125 | cidx_a, cidx_b = rng.choice(len(colors), size=2, replace=False) 126 | return dict( 127 | tag=TAG, 128 | include=[ 129 | {"class": classnames[idx_a], "count": 1, "color": colors[cidx_a]}, 130 | {"class": classnames[idx_b], "count": 1, "color": colors[cidx_b]} 131 | ], 132 | prompt=f"a photo of {with_article(colors[cidx_a])} {classnames[idx_a]} and {with_article(colors[cidx_b])} {classnames[idx_b]}" 133 | ) 134 | 135 | 136 | # Generate evaluation suite 137 | 138 | def generate_suite(rng: np.random.Generator, n: int = 100, output_path: str = ""): 139 | samples = [] 140 | # Generate single object samples for all COCO classnames 141 | samples.extend(generate_single_object_sample(rng, size=len(classnames))) 142 | # Generate two object samples (~100) 143 | for _ in range(n): 144 | samples.append(generate_two_object_sample(rng)) 145 | # Generate counting samples 146 | for _ in range(n): 147 | samples.append(generate_counting_sample(rng, max_count=4)) 148 | # Generate color samples 149 | for _ in range(n): 150 | samples.append(generate_color_sample(rng)) 151 | # Generate position samples 152 | for _ in range(n): 153 | samples.append(generate_position_sample(rng)) 154 | # Generate color attribution samples 155 | for _ in range(n): 156 | samples.append(generate_color_attribution_sample(rng)) 157 | # De-duplicate 158 | unique_samples, used_samples = [], set() 159 | for sample in samples: 160 | sample_text = yaml.safe_dump(sample) 161 | if sample_text not in used_samples: 162 | unique_samples.append(sample) 163 | used_samples.add(sample_text) 164 | 165 | # Write to files 166 | os.makedirs(output_path, exist_ok=True) 167 | with open(os.path.join(output_path, "generation_prompts.txt"), "w") as fp: 168 | for sample in unique_samples: 169 | print(sample['prompt'], file=fp) 170 | with open(os.path.join(output_path, "evaluation_metadata.jsonl"), "w") as fp: 171 | for sample in unique_samples: 172 | print(json.dumps(sample), file=fp) 173 | 174 | 175 | if __name__ == "__main__": 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument("--seed", type=int, default=43, help="generation seed (default: 43)") 178 | parser.add_argument("--num-prompts", "-n", type=int, default=100, help="number of prompts per task (default: 100)") 179 | parser.add_argument("--output-path", "-o", type=str, default="prompts", help="output folder for prompts and metadata (default: 'prompts/')") 180 | args = parser.parse_args() 181 | rng = np.random.default_rng(args.seed) 182 | generate_suite(rng, args.num_prompts, args.output_path) 183 | 184 | -------------------------------------------------------------------------------- /evaluation/gen_eval/rename.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import json 5 | 6 | 7 | with open('/Users/bytedance/Desktop/projects/Infinity/evaluation/gen_eval/prompt_rewrite_cache_1.json', 'r') as f: 8 | correct = json.load(f) 9 | 10 | with open('/Users/bytedance/Desktop/projects/Infinity/evaluation/gen_eval/prompt_rewrite_cache_123.json', 'r') as f: 11 | false_key_dict = json.load(f) 12 | 13 | keys1_list = list(correct.keys()) 14 | keys2_list = list(false_key_dict.keys()) 15 | 16 | final_dict = {} 17 | for i in range(len(keys1_list)): 18 | key1 = keys1_list[i] 19 | key2 = keys2_list[i] 20 | final_dict[key1] = false_key_dict[key2] 21 | 22 | with open('prompt_rewrite_cache.json', 'w') as f: 23 | json.dump(final_dict, f, ensure_ascii=False, indent=2) 24 | -------------------------------------------------------------------------------- /evaluation/gen_eval/summary_scores.py: -------------------------------------------------------------------------------- 1 | # Get results of evaluation 2 | 3 | import argparse 4 | import os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("filename", type=str) 12 | args = parser.parse_args() 13 | 14 | # Load classnames 15 | 16 | with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file: 17 | classnames = [line.strip() for line in cls_file] 18 | cls_to_idx = {"_".join(cls.split()):idx for idx, cls in enumerate(classnames)} 19 | 20 | # Load results 21 | 22 | df = pd.read_json(args.filename, orient="records", lines=True) 23 | 24 | # Measure overall success 25 | 26 | print("Summary") 27 | print("=======") 28 | print(f"Total images: {len(df)}") 29 | print(f"Total prompts: {len(df.groupby('metadata'))}") 30 | print(f"% correct images: {df['correct'].mean():.2%}") 31 | print(f"% correct prompts: {df.groupby('metadata')['correct'].any().mean():.2%}") 32 | print() 33 | 34 | # By group 35 | 36 | task_scores = [] 37 | 38 | print("Task breakdown") 39 | print("==============") 40 | for tag, task_df in df.groupby('tag', sort=False): 41 | task_scores.append(task_df['correct'].mean()) 42 | print(f"{tag:<16} = {task_df['correct'].mean():.2%} ({task_df['correct'].sum()} / {len(task_df)})") 43 | print() 44 | 45 | print(f"Overall score (avg. over tasks): {np.mean(task_scores):.5f}") -------------------------------------------------------------------------------- /evaluation/hpsv2/eval_hpsv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import hashlib 4 | import time 5 | import argparse 6 | import json 7 | import shutil 8 | import glob 9 | import re 10 | import sys 11 | 12 | import cv2 13 | import hpsv2 14 | import torch 15 | import numpy as np 16 | from pytorch_lightning import seed_everything 17 | 18 | from infinity.utils.csv_util import load_csv_as_dicts, write_dicts2csv_file 19 | from tools.run_infinity import * 20 | from conf import HF_TOKEN, HF_HOME 21 | 22 | # set environment variables 23 | os.environ['HF_TOKEN'] = HF_TOKEN 24 | os.environ['HF_HOME'] = HF_HOME 25 | os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1' 26 | 27 | def extract_key_val(text): 28 | pattern = r'<(.+?):(.+?)>' 29 | matches = re.findall(pattern, text) 30 | key_val = {} 31 | for match in matches: 32 | key_val[match[0]] = match[1].lstrip() 33 | return key_val 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | add_common_arguments(parser) 39 | parser.add_argument('--outdir', type=str, default='') 40 | parser.add_argument('--n_samples', type=int, default=1) 41 | parser.add_argument('--rewrite_prompt', type=int, default=0, choices=[0,1]) 42 | args = parser.parse_args() 43 | 44 | # parse cfg 45 | args.cfg = list(map(float, args.cfg.split(','))) 46 | if len(args.cfg) == 1: 47 | args.cfg = args.cfg[0] 48 | 49 | all_prompts = hpsv2.benchmark_prompts('all') 50 | seed_everything(args.seed) 51 | 52 | if args.model_type == 'sdxl': 53 | from diffusers import DiffusionPipeline 54 | base = DiffusionPipeline.from_pretrained( 55 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True 56 | ).to("cuda") 57 | 58 | refiner = DiffusionPipeline.from_pretrained( 59 | "stabilityai/stable-diffusion-xl-refiner-1.0", 60 | text_encoder_2=base.text_encoder_2, 61 | vae=base.vae, 62 | torch_dtype=torch.float16, 63 | use_safetensors=True, 64 | variant="fp16", 65 | ).to("cuda") 66 | elif args.model_type == 'sd3': 67 | from diffusers import StableDiffusion3Pipeline 68 | pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) 69 | pipe = pipe.to("cuda") 70 | elif args.model_type == 'pixart_sigma': 71 | from diffusers import PixArtSigmaPipeline 72 | pipe = PixArtSigmaPipeline.from_pretrained( 73 | "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 74 | ).to("cuda") 75 | elif args.model_type == 'flux_1_dev': 76 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") 77 | elif args.model_type == 'flux_1_dev_schnell': 78 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda") 79 | elif 'infinity' in args.model_type: 80 | # load text encoder 81 | text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt) 82 | # load vae 83 | vae = load_visual_tokenizer(args) 84 | # load infinity 85 | infinity = load_transformer(vae, args) 86 | 87 | if args.rewrite_prompt: 88 | from tools.prompt_rewriter import PromptRewriter 89 | prompt_rewriter = PromptRewriter(system='', few_shot_history=[]) 90 | 91 | total = 0 92 | for style, prompts in all_prompts.items(): 93 | for idx, prompt in enumerate(prompts): 94 | total += 1 95 | ptr = 0 96 | for style, prompts in all_prompts.items(): 97 | for idx, prompt in enumerate(prompts): 98 | ptr += 1 99 | if ptr % 10 == 0: 100 | print(f'Generate {ptr}/{total} images...') 101 | 102 | image_save_file_path = os.path.join(args.outdir, style, f"{idx:05d}.jpg") 103 | os.makedirs(osp.dirname(image_save_file_path), exist_ok=True) 104 | 105 | tau = args.tau 106 | cfg = args.cfg 107 | if args.rewrite_prompt: 108 | refined_prompt = prompt_rewriter.rewrite(prompt) 109 | input_key_val = extract_key_val(refined_prompt) 110 | prompt = input_key_val['prompt'] 111 | print(f'prompt: {prompt}, refined_prompt: {refined_prompt}') 112 | 113 | images = [] 114 | for _ in range(args.n_samples): 115 | t1 = time.time() 116 | if args.model_type == 'sdxl': 117 | image = base( 118 | prompt=prompt, 119 | num_inference_steps=40, 120 | denoising_end=0.8, 121 | output_type="latent", 122 | ).images 123 | image = refiner( 124 | prompt=prompt, 125 | num_inference_steps=40, 126 | denoising_start=0.8, 127 | image=image, 128 | ).images[0] 129 | elif args.model_type == 'sd3': 130 | image = pipe( 131 | prompt, 132 | negative_prompt="", 133 | num_inference_steps=28, 134 | guidance_scale=7.0, 135 | num_images_per_prompt=1, 136 | ).images[0] 137 | elif args.model_type == 'flux_1_dev': 138 | image = pipe( 139 | prompt, 140 | height=1024, 141 | width=1024, 142 | guidance_scale=3.5, 143 | num_inference_steps=50, 144 | max_sequence_length=512, 145 | num_images_per_prompt=1, 146 | ).images[0] 147 | elif args.model_type == 'flux_1_dev_schnell': 148 | image = pipe( 149 | prompt, 150 | height=1024, 151 | width=1024, 152 | guidance_scale=0.0, 153 | num_inference_steps=4, 154 | max_sequence_length=256, 155 | generator=torch.Generator("cpu").manual_seed(0) 156 | ).images[0] 157 | elif args.model_type == 'pixart_sigma': 158 | image = pipe(prompt).images[0] 159 | elif 'infinity' in args.model_type: 160 | h_div_w_template = 1.000 161 | scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales'] 162 | scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] 163 | tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][args.pn]['pixel'] 164 | image = gen_one_img(infinity, vae, text_tokenizer, text_encoder, prompt, tau_list=tau, cfg_sc=3, cfg_list=cfg, scale_schedule=scale_schedule, cfg_insertion_layer=[args.cfg_insertion_layer], vae_type=args.vae_type) 165 | else: 166 | raise ValueError 167 | t2 = time.time() 168 | print(f'{args.model_type} infer one image takes {t2-t1:.2f}s') 169 | images.append(image) 170 | 171 | assert len(images) == 1 172 | for i, image in enumerate(images): 173 | if 'infinity' in args.model_type: 174 | cv2.imwrite(image_save_file_path, image.cpu().numpy()) 175 | else: 176 | image.save(image_save_file_path) 177 | 178 | hpsv2.evaluate(args.outdir, hps_version="v2.1") 179 | -------------------------------------------------------------------------------- /evaluation/image_reward/cal_imagereward.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1' 3 | import os.path as osp 4 | import json 5 | import argparse 6 | 7 | import numpy as np 8 | import ImageReward as RM 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--meta_file", type=str, default="") 14 | args = parser.parse_args() 15 | 16 | image_reward_model = RM.load("ImageReward-v1.0") 17 | clip_model = RM.load_score("CLIP") 18 | 19 | with open(args.meta_file, 'r') as f: 20 | meta_infos = json.load(f) 21 | 22 | average_image_reward = [] 23 | average_clip_scores = [] 24 | for meta in meta_infos: 25 | image_paths = meta['gen_image_paths'] 26 | prompt = meta['prompt'] 27 | image_rewards = image_reward_model.score(prompt, image_paths) 28 | _, clip_scores = clip_model.inference_rank(prompt, image_paths) 29 | average_image_reward.extend(image_rewards) 30 | average_clip_scores.extend(clip_scores) 31 | print(f'Average Image Reward of {len(meta_infos)} prompt and {len(average_image_reward)} images is {np.mean(average_image_reward):.4f}, Average CLIP Score is {np.mean(average_clip_scores):.4f}') 32 | print(f'Average Image Reward of {len(meta_infos)} prompt and {len(average_image_reward)} images is {np.mean(average_image_reward):.4f}, Average CLIP Score is {np.mean(average_clip_scores):.4f}') 33 | save_file = osp.join(osp.dirname(args.meta_file), 'image_reward_res.json') 34 | with open(save_file, 'w') as f: 35 | json.dump({ 36 | 'prompts': len(meta_infos), 37 | 'images': len(average_image_reward), 38 | 'average_image_reward': np.mean(average_image_reward), 39 | 'average_clip_scores': np.mean(average_clip_scores) 40 | }, f) 41 | print(f'Save to {save_file}') 42 | -------------------------------------------------------------------------------- /evaluation/image_reward/infer4eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import hashlib 4 | import time 5 | import argparse 6 | import json 7 | import shutil 8 | import glob 9 | import re 10 | import sys 11 | 12 | import cv2 13 | import tqdm 14 | import torch 15 | import numpy as np 16 | from pytorch_lightning import seed_everything 17 | 18 | from infinity.utils.csv_util import load_csv_as_dicts, write_dicts2csv_file 19 | from tools.run_infinity import * 20 | from conf import HF_TOKEN, HF_HOME 21 | 22 | # set environment variables 23 | os.environ['HF_TOKEN'] = HF_TOKEN 24 | os.environ['HF_HOME'] = HF_HOME 25 | os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1' 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | add_common_arguments(parser) 31 | parser.add_argument('--outdir', type=str, default='') 32 | parser.add_argument('--n_samples', type=int, default=10) 33 | parser.add_argument('--metadata_file', type=str, default='evaluation/image_reward/benchmark-prompts.json') 34 | parser.add_argument('--rewrite_prompt', type=int, default=0, choices=[0,1]) 35 | args = parser.parse_args() 36 | 37 | # parse cfg 38 | args.cfg = list(map(float, args.cfg.split(','))) 39 | if len(args.cfg) == 1: 40 | args.cfg = args.cfg[0] 41 | 42 | with open(args.metadata_file) as fp: 43 | metadatas = json.load(fp) 44 | 45 | if args.model_type == 'sdxl': 46 | from diffusers import DiffusionPipeline 47 | base = DiffusionPipeline.from_pretrained( 48 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True 49 | ).to("cuda") 50 | refiner = DiffusionPipeline.from_pretrained( 51 | "stabilityai/stable-diffusion-xl-refiner-1.0", 52 | text_encoder_2=base.text_encoder_2, 53 | vae=base.vae, 54 | torch_dtype=torch.float16, 55 | use_safetensors=True, 56 | variant="fp16", 57 | ).to("cuda") 58 | elif args.model_type == 'sd3': 59 | from diffusers import StableDiffusion3Pipeline 60 | pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) 61 | pipe = pipe.to("cuda") 62 | elif args.model_type == 'pixart_sigma': 63 | from diffusers import PixArtSigmaPipeline 64 | pipe = PixArtSigmaPipeline.from_pretrained( 65 | "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 66 | ).to("cuda") 67 | elif args.model_type == 'flux_1_dev': 68 | from diffusers import FluxPipeline 69 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") 70 | elif args.model_type == 'flux_1_dev_schnell': 71 | from diffusers import FluxPipeline 72 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda") 73 | elif 'infinity' in args.model_type: 74 | # load text encoder 75 | text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt) 76 | # load vae 77 | vae = load_visual_tokenizer(args) 78 | # load infinity 79 | infinity = load_transformer(vae, args) 80 | 81 | if args.rewrite_prompt: 82 | from tools.prompt_rewriter import PromptRewriter 83 | prompt_rewriter = PromptRewriter(system='', few_shot_history=[]) 84 | 85 | save_metadatas = [] 86 | for index, metadata in enumerate(metadatas): 87 | seed_everything(args.seed) 88 | prompt_id = metadata['id'] 89 | prompt = metadata['prompt'] 90 | sample_path = os.path.join(args.outdir, prompt_id) 91 | os.makedirs(sample_path, exist_ok=True) 92 | print(f"Prompt ({index: >3}/{len(metadatas)}): '{prompt}'") 93 | 94 | tau = args.tau 95 | cfg = args.cfg 96 | if args.rewrite_prompt: 97 | refined_prompt = prompt_rewriter.rewrite(prompt) 98 | input_key_val = extract_key_val(refined_prompt) 99 | prompt = input_key_val['prompt'] 100 | print(f'prompt: {prompt}, refined_prompt: {refined_prompt}') 101 | 102 | images = [] 103 | for _ in range(args.n_samples): 104 | t1 = time.time() 105 | if args.model_type == 'sdxl': 106 | image = base( 107 | prompt=prompt, 108 | num_inference_steps=40, 109 | denoising_end=0.8, 110 | output_type="latent", 111 | ).images 112 | image = refiner( 113 | prompt=prompt, 114 | num_inference_steps=40, 115 | denoising_start=0.8, 116 | image=image, 117 | ).images[0] 118 | elif args.model_type == 'sd3': 119 | image = pipe( 120 | prompt, 121 | negative_prompt="", 122 | num_inference_steps=28, 123 | guidance_scale=7.0, 124 | num_images_per_prompt=1, 125 | ).images[0] 126 | elif args.model_type == 'flux_1_dev': 127 | image = pipe( 128 | prompt, 129 | height=1024, 130 | width=1024, 131 | guidance_scale=3.5, 132 | num_inference_steps=50, 133 | max_sequence_length=512, 134 | num_images_per_prompt=1, 135 | ).images[0] 136 | elif args.model_type == 'flux_1_dev_schnell': 137 | image = pipe( 138 | prompt, 139 | height=1024, 140 | width=1024, 141 | guidance_scale=0.0, 142 | num_inference_steps=4, 143 | max_sequence_length=256, 144 | generator=torch.Generator("cpu").manual_seed(0) 145 | ).images[0] 146 | elif args.model_type == 'pixart_sigma': 147 | image = pipe(prompt).images[0] 148 | elif 'infinity' in args.model_type: 149 | h_div_w_template = 1.000 150 | scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales'] 151 | scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] 152 | tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][args.pn]['pixel'] 153 | image = gen_one_img(infinity, vae, text_tokenizer, text_encoder, prompt, tau_list=tau, cfg_sc=3, cfg_list=cfg, scale_schedule=scale_schedule, cfg_insertion_layer=[args.cfg_insertion_layer], vae_type=args.vae_type) 154 | else: 155 | raise ValueError 156 | t2 = time.time() 157 | print(f'{args.model_type} infer one image takes {t2-t1:.2f}s') 158 | images.append(image) 159 | 160 | os.makedirs(sample_path, exist_ok=True) 161 | metadata['gen_image_paths'] = [] 162 | for i, image in enumerate(images): 163 | save_file_path = os.path.join(sample_path, f"{prompt_id}_{i}.jpg") 164 | if 'infinity' in args.model_type: 165 | cv2.imwrite(save_file_path, image.cpu().numpy()) 166 | else: 167 | image.save(save_file_path) 168 | metadata['gen_image_paths'].append(save_file_path) 169 | print(save_file_path) 170 | save_metadatas.append(metadata) 171 | 172 | save_metadata_file_path = os.path.join(args.outdir, "metadata.jsonl") 173 | with open(save_metadata_file_path, "w") as fp: 174 | json.dump(save_metadatas, fp) 175 | -------------------------------------------------------------------------------- /evaluation/validation_loss/validation_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | import gc 5 | import json 6 | import math 7 | import random 8 | import sys 9 | import argparse 10 | import copy 11 | import traceback 12 | import collections 13 | from collections import deque 14 | from contextlib import nullcontext 15 | from functools import partial 16 | from typing import List, Optional, Tuple 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | import numpy as np 20 | import torch 21 | from torch.nn import functional as F 22 | from torch.utils.data import DataLoader 23 | import torch.distributed as tdist 24 | import tqdm 25 | 26 | from tools.run_infinity import * 27 | from infinity.dataset.dataset_t2i_iterable import T2IIterableDataset 28 | from infinity.models.bitwise_self_correction import BitwiseSelfCorrection 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | add_common_arguments(parser) 33 | parser.add_argument('--reweight_loss_by_scale', type=int, default=1, choices=[0,1]) 34 | parser.add_argument('--vis_model_flop_param', type=int, default=0, choices=[0,1]) 35 | parser.add_argument('--meta_folder', type=str, required=True) 36 | parser.add_argument('--save_dir', type=str, default='') 37 | parser.add_argument('--batch_size', type=int, default=2) 38 | parser.add_argument('--dataloader_workers', type=int, default=12) 39 | parser.add_argument('--noise_apply_layers', type=int, default=20) 40 | parser.add_argument('--noise_apply_requant', type=int, default=1, choices=[0,1]) 41 | parser.add_argument('--noise_apply_strength', type=float, default=0.2) 42 | parser.add_argument('--debug_bsc', type=int, default=0, choices=[0,1]) 43 | parser.add_argument('--log_freq', type=int, default=10) 44 | args = parser.parse_args() 45 | 46 | # load text encoder 47 | text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt) 48 | # load vae 49 | vae = load_visual_tokenizer(args) 50 | # load infinity 51 | infinity = load_transformer(vae, args) 52 | 53 | bitwise_self_correction = BitwiseSelfCorrection(vae, args) 54 | 55 | device = torch.device('cuda') 56 | dataset = T2IIterableDataset( 57 | args=None, 58 | meta_folder=args.meta_folder, 59 | data_load_reso=None, 60 | max_caption_len=512, 61 | short_prob=0.0, 62 | load_vae_instead_of_image=False, 63 | buffersize=100, 64 | seed=0, 65 | online_t5=True, 66 | pn=args.pn, 67 | batch_size=args.batch_size, 68 | num_replicas=1, 69 | rank=0, 70 | dataloader_workers=args.dataloader_workers, 71 | ) 72 | dataloader = DataLoader(dataset, batch_size=None, num_workers=args.dataloader_workers) 73 | print(f'len(dataloader): {len(dataloader)}, len(dataset): {len(dataset)}, total_samples: {dataset.total_samples()}') 74 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 75 | t1 = time.time() 76 | dataloader.dataset.set_epoch(0) 77 | pbar = tqdm.tqdm(total=len(dataloader)) 78 | accumulate_res = collections.defaultdict(list) 79 | for i, data in enumerate(iter(dataloader)): 80 | if (i+1) % args.log_freq == 0: 81 | for k, v in accumulate_res.items(): 82 | v = np.array(v).mean(0) 83 | print(f'{k}: {v}') 84 | 85 | pbar.update(1) 86 | inp_B3HW, captions = data 87 | tokens = text_tokenizer(text=captions, max_length=text_tokenizer.model_max_length, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset 88 | input_ids = tokens.input_ids.cuda(non_blocking=True) 89 | mask = tokens.attention_mask.cuda(non_blocking=True) 90 | text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float() 91 | lens: List[int] = mask.sum(dim=-1).tolist() 92 | cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0)) 93 | Ltext = max(lens) 94 | kv_compact = [] 95 | for len_i, feat_i in zip(lens, text_features.unbind(0)): 96 | kv_compact.append(feat_i[:len_i]) 97 | kv_compact = torch.cat(kv_compact, dim=0) 98 | text_cond_tuple: Tuple[torch.FloatTensor, List[int], torch.LongTensor, int] = (kv_compact, lens, cu_seqlens_k, Ltext) 99 | 100 | h_div_w = inp_B3HW.shape[-2] / inp_B3HW.shape[-1] 101 | h_div_w_templates = np.array(list(dynamic_resolution_h_w.keys())) 102 | h_div_w_template = h_div_w_templates[np.argmin(np.abs(h_div_w-h_div_w_templates))] 103 | scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales'] 104 | scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] 105 | raw_last_l = np.array(scale_schedule[-1]).prod() 106 | 107 | # [prepare] 108 | B = inp_B3HW.shape[0] if isinstance(inp_B3HW, torch.Tensor) else inp_B3HW[0].shape[0] 109 | V = vae.vocab_size 110 | 111 | # [forward] 112 | with torch.amp.autocast('cuda', enabled=False): 113 | with torch.no_grad(): 114 | if args.apply_spatial_patchify: 115 | vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] 116 | else: 117 | vae_scale_schedule = scale_schedule 118 | raw_features, _, _ = vae.encode_for_raw_features(inp_B3HW.to(device), scale_schedule=vae_scale_schedule) 119 | 120 | x_BLC_wo_prefix, gt_ms_idx_Bl = bitwise_self_correction.flip_requant(vae_scale_schedule, inp_B3HW, raw_features, device) 121 | training_seq_len = np.array(scale_schedule).prod(axis=1).sum() 122 | x_BLC_wo_prefix = x_BLC_wo_prefix[:, :(training_seq_len-np.array(scale_schedule[0]).prod()), :] 123 | 124 | with torch.no_grad(): 125 | logits_BLV = infinity(text_cond_tuple, x_BLC_wo_prefix, scale_schedule=scale_schedule) # [bs, 1*1+...+64*64, vocab_size or log2(vocab_size)*2] 126 | 127 | if args.vis_model_flop_param: 128 | from torchinfo import summary 129 | res = summary(infinity, input_data=(text_cond_tuple, x_BLC_wo_prefix, scale_schedule)) 130 | print(res) 131 | 132 | batch_size, seq_len = logits_BLV.shape[:2] 133 | seq_len_each = [idx_Bl.shape[1] for idx_Bl in gt_ms_idx_Bl] 134 | 135 | gt_BL = torch.cat(gt_ms_idx_Bl, dim=1)[:,:training_seq_len].contiguous().type(torch.long) # [bs, 1*1+...+64*64, 16] or [bs, 1*1+...+64*64] 136 | tmp_bs, tmp_seq_len, tmp_channel = logits_BLV.shape 137 | assert tmp_channel == vae.codebook_dim * 2 138 | res_loss = torch.nn.functional.cross_entropy(logits_BLV.reshape(tmp_bs, tmp_seq_len, vae.codebook_dim, 2).permute(0,3,1,2), gt_BL, reduction='none') 139 | res_loss = res_loss.mean(dim=-1).mean(0) 140 | 141 | if args.reweight_loss_by_scale: 142 | lw = [] 143 | last_scale_area = np.sqrt(np.array(scale_schedule[-1]).prod()) 144 | for (ph, pw) in scale_schedule: 145 | this_scale_area = np.sqrt(ph * pw) 146 | lw.extend([last_scale_area / this_scale_area for _ in range(ph * pw)]) 147 | lw = torch.tensor(lw, device=device) 148 | lw = lw / lw.sum() 149 | else: 150 | lw = 1. / training_seq_len 151 | loss_reweight_by_scale = res_loss.mul(lw).sum(dim=-1).mean().item() 152 | 153 | bitwise_acc = (logits_BLV.reshape(B, seq_len, vae.codebook_dim, 2).argmax(dim=-1) == gt_BL).float() # shape: [bs, seq_len, codebook_dim] 154 | res_bit_acc = bitwise_acc.mean(-1).mean(0) 155 | res_token_acc = (bitwise_acc.sum(-1) == vae.codebook_dim).float().mean(0) 156 | loss_mean, acc_bit_mean, acc_token_mean = res_loss.mean().item(), res_bit_acc.mean().item() * 100., res_token_acc.mean().item() * 100. 157 | ptr = 0 158 | L_list, acc_bit_list, acc_token_list = [], [], [] 159 | for scale_ind in range(len(scale_schedule)): 160 | start, end = ptr, ptr + np.array(scale_schedule[scale_ind]).prod() 161 | L_list.append(res_loss[start:end].mean().item()) 162 | acc_bit_list.append(res_bit_acc[start:end].mean().item() * 100.) 163 | acc_token_list.append(res_token_acc[start:end].mean().item() * 100.) 164 | ptr = end 165 | accumulate_res['loss_bit_mean'].append(loss_mean) 166 | accumulate_res['acc_bit_mean'].append(acc_bit_mean) 167 | accumulate_res['acc_token_mean'].append(acc_token_mean) 168 | accumulate_res['loss_reweight_by_scale'].append(loss_reweight_by_scale) 169 | accumulate_res['loss_by_scale'].append(L_list) 170 | accumulate_res['acc_bit_list_by_scale'].append(acc_bit_list) 171 | accumulate_res['acc_token_list_by_scale'].append(acc_token_list) 172 | 173 | for k, v in accumulate_res.items(): 174 | if len(np.array(v).shape) == 1: 175 | v = np.array(v).mean(0) 176 | else: 177 | v = np.array(v).mean(0).tolist() 178 | accumulate_res[k] = v 179 | print(f'{k}: {v}') 180 | 181 | save_file = osp.join(args.save_dir, 'val_res.json') 182 | os.makedirs(osp.dirname(save_file), exist_ok=True) 183 | with open(save_file, 'w') as f: 184 | json.dump(accumulate_res, f, indent=2) 185 | print(f'Save val results to {save_file}') 186 | -------------------------------------------------------------------------------- /infinity/dataset/build.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import os.path as osp 4 | import random 5 | import subprocess 6 | from functools import partial 7 | from typing import Optional 8 | import time 9 | 10 | import pytz 11 | 12 | from infinity.dataset.dataset_t2i_iterable import T2IIterableDataset 13 | 14 | try: 15 | from grp import getgrgid 16 | from pwd import getpwuid 17 | except: 18 | pass 19 | import PIL.Image as PImage 20 | from PIL import ImageFile 21 | import numpy as np 22 | from torchvision.transforms import transforms 23 | from torchvision.transforms.functional import resize, to_tensor 24 | import torch.distributed as tdist 25 | 26 | from torchvision.transforms import InterpolationMode 27 | bicubic = InterpolationMode.BICUBIC 28 | lanczos = InterpolationMode.LANCZOS 29 | PImage.MAX_IMAGE_PIXELS = (1024 * 1024 * 1024 // 4 // 3) * 5 30 | ImageFile.LOAD_TRUNCATED_IMAGES = False 31 | 32 | 33 | def time_str(fmt='[%m-%d %H:%M:%S]'): 34 | return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt) 35 | 36 | 37 | def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1 38 | return x.add(x).add_(-1) 39 | 40 | 41 | def denormalize_pm1_into_01(x): # denormalize x from [-1, 1] to [0, 1] 42 | return x.add(1).mul_(0.5) 43 | 44 | 45 | def center_crop_arr(pil_image, image_size): 46 | """ 47 | Center cropping implementation from ADM. 48 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 49 | """ 50 | while min(*pil_image.size) >= 2 * image_size: 51 | pil_image = pil_image.resize( 52 | tuple(x // 2 for x in pil_image.size), resample=PImage.BOX 53 | ) 54 | 55 | scale = image_size / min(*pil_image.size) 56 | pil_image = pil_image.resize( 57 | tuple(round(x * scale) for x in pil_image.size), resample=PImage.LANCZOS 58 | ) 59 | 60 | arr = np.array(pil_image) 61 | crop_y = (arr.shape[0] - image_size) // 2 62 | crop_x = (arr.shape[1] - image_size) // 2 63 | return PImage.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 64 | 65 | 66 | class RandomResize: 67 | def __init__(self, mid_reso, final_reso, interpolation): 68 | ub = max(round((mid_reso + (mid_reso-final_reso) / 8) / 4) * 4, mid_reso) 69 | self.reso_lb, self.reso_ub = final_reso, ub 70 | self.interpolation = interpolation 71 | 72 | def __call__(self, img): 73 | return resize(img, size=random.randint(self.reso_lb, self.reso_ub), interpolation=self.interpolation) 74 | 75 | def __repr__(self): 76 | return f'RandomResize(reso=({self.reso_lb}, {self.reso_ub}), interpolation={self.interpolation})' 77 | 78 | 79 | def load_save(reso=512): 80 | import os 81 | from PIL import Image as PImage 82 | from torchvision.transforms import transforms, InterpolationMode 83 | aug = transforms.Compose([ 84 | transforms.Resize(512, interpolation=InterpolationMode.LANCZOS), 85 | transforms.CenterCrop((512, 512)) 86 | ]) 87 | src_folder = r'C:\Users\16333\Pictures\imgs_to_visual_v2' 88 | ls = [os.path.join(src_folder, x) for x in ('1.jpg', '2.jpg', '3.png', '4.png', '5.png')] 89 | print(ls) 90 | imgs = [] 91 | for i, fname in enumerate(ls): 92 | assert os.path.exists(fname) 93 | with PImage.open(fname) as img: 94 | img = img.convert('RGB') 95 | img = aug(img) 96 | imgs.append(img) 97 | dst_d, dst_f = os.path.split(fname) 98 | dst = os.path.join(dst_d, f'crop{dst_f.replace(".jpg", ".png")}') 99 | img.save(dst) 100 | 101 | W, H = imgs[0].size 102 | WW = W * len(imgs) 103 | new_im = PImage.new('RGB', (WW, H)) 104 | x_offset = 0 105 | for img in imgs: 106 | new_im.paste(img, (x_offset, 0)) 107 | x_offset += W 108 | dst = os.path.join(src_folder, f'junfeng.png') 109 | new_im.save(dst) 110 | 111 | 112 | def print_aug(transform, label): 113 | print(f'Transform {label} = ') 114 | if hasattr(transform, 'transforms'): 115 | for t in transform.transforms: 116 | print(t) 117 | else: 118 | print(transform) 119 | print('---------------------------\n') 120 | 121 | 122 | def build_t2i_dataset( 123 | args, 124 | data_path: str, 125 | data_load_reso: int, 126 | max_caption_len: int, 127 | short_prob=0.2, 128 | load_vae_instead_of_image=False 129 | ): 130 | if args.use_streaming_dataset: 131 | return T2IIterableDataset( 132 | data_path, 133 | max_caption_len=max_caption_len, 134 | short_prob=short_prob, 135 | load_vae_instead_of_image=load_vae_instead_of_image, 136 | buffersize=args.iterable_data_buffersize, 137 | pn=args.pn, 138 | online_t5=args.online_t5, 139 | batch_size=args.batch_size, 140 | num_replicas=tdist.get_world_size(), # 1, 141 | rank=tdist.get_rank(), # 0 142 | dataloader_workers=args.workers, 143 | dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus, 144 | enable_dynamic_length_prompt=args.enable_dynamic_length_prompt, 145 | seed=args.seed if args.seed is not None else int(time.time()), 146 | ) 147 | else: 148 | raise ValueError(f'args.use_streaming_dataset={args.use_streaming_dataset} unsupported') 149 | 150 | 151 | def pil_load(path: str, proposal_size): 152 | with open(path, 'rb') as f: 153 | img: PImage.Image = PImage.open(f) 154 | w: int = img.width 155 | h: int = img.height 156 | sh: int = min(h, w) 157 | if sh > proposal_size: 158 | ratio: float = proposal_size / sh 159 | w = round(ratio * w) 160 | h = round(ratio * h) 161 | img.draft('RGB', (w, h)) 162 | img = img.convert('RGB') 163 | return img 164 | 165 | 166 | def rewrite(im: PImage, file: str, info: str): 167 | kw = dict(quality=100) 168 | if file.lower().endswith('.tif') or file.lower().endswith('.tiff'): 169 | kw['compression'] = 'none' 170 | elif file.lower().endswith('.webp'): 171 | kw['lossless'] = True 172 | 173 | st = os.stat(file) 174 | uname = getpwuid(st.st_uid).pw_name 175 | gname = getgrgid(st.st_gid).gr_name 176 | mode = oct(st.st_mode)[-3:] 177 | 178 | local_file = osp.basename(file) 179 | im.save(local_file, **kw) 180 | print(f'************* ************* @ {file}') 181 | subprocess.call(f'sudo mv {local_file} {file}; sudo chown {uname}:{gname} {file}; sudo chmod {mode} {file}', shell=True) 182 | -------------------------------------------------------------------------------- /infinity/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.loss import SoftTargetCrossEntropy 3 | 4 | from timm.models.layers import DropPath 5 | 6 | from .infinity import Infinity, sample_with_top_k_top_p_also_inplace_modifying_logits_ 7 | 8 | def _ex_repr(self): 9 | return ', '.join( 10 | f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) 11 | for k, v in vars(self).items() 12 | if not k.startswith('_') and k != 'training' 13 | and not isinstance(v, (torch.nn.Module, torch.Tensor)) 14 | ) 15 | for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy): # no longer __repr__ DropPath with drop_prob 16 | if hasattr(clz, 'extra_repr'): 17 | clz.extra_repr = _ex_repr 18 | else: 19 | clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' 20 | 21 | DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)' 22 | 23 | alias_dict = {} 24 | for d in range(6, 40+2, 2): 25 | alias_dict[f'd{d}'] = f'infinity_d{d}' 26 | alias_dict_inv = {v: k for k, v in alias_dict.items()} 27 | -------------------------------------------------------------------------------- /infinity/models/bitwise_self_correction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | def labels2image(all_indices, label_type='int_label', scale_schedule=None): 10 | summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type) 11 | recons_img = recons_imgs[0] 12 | recons_img = (recons_img + 1) / 2 13 | recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1] 14 | return recons_img 15 | 16 | def features2image(raw_features): 17 | recons_imgs = self.vae.decode(raw_features.squeeze(-3)) 18 | recons_img = recons_imgs[0] 19 | recons_img = (recons_img + 1) / 2 20 | recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1] 21 | return recons_img 22 | 23 | class BitwiseSelfCorrection(object): 24 | def __init__(self, vae, args): 25 | self.noise_apply_layers = args.noise_apply_layers 26 | self.noise_apply_requant = args.noise_apply_requant 27 | self.noise_apply_strength = args.noise_apply_strength 28 | self.apply_spatial_patchify = args.apply_spatial_patchify 29 | self.vae = vae 30 | self.debug_bsc = args.debug_bsc 31 | 32 | def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device): 33 | with torch.amp.autocast('cuda', enabled = False): 34 | B = raw_features.shape[0] 35 | if raw_features.dim() == 4: 36 | codes_out = raw_features.unsqueeze(2) 37 | else: 38 | codes_out = raw_features 39 | cum_var_input = 0 40 | gt_all_bit_indices = [] 41 | pred_all_bit_indices = [] 42 | x_BLC_wo_prefix = [] 43 | for si, (pt, ph, pw) in enumerate(vae_scale_schedule): 44 | residual = codes_out - cum_var_input 45 | if si != len(vae_scale_schedule)-1: 46 | residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous() 47 | quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae] 48 | gt_all_bit_indices.append(bit_indices) 49 | if si < self.noise_apply_layers: 50 | noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01 51 | mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength 52 | pred_bit_indices = bit_indices.clone() 53 | pred_bit_indices[mask] = 1 - pred_bit_indices[mask] 54 | pred_all_bit_indices.append(pred_bit_indices) 55 | if self.noise_apply_requant: 56 | quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label') 57 | else: 58 | pred_all_bit_indices.append(bit_indices) 59 | cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous() 60 | if si < len(vae_scale_schedule)-1: 61 | this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous() 62 | if self.apply_spatial_patchify: 63 | # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2) 64 | this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) 65 | x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C) 66 | 67 | if self.apply_spatial_patchify: 68 | gt_ms_idx_Bl = [] 69 | for item in gt_all_bit_indices: 70 | # item shape: (B,1,H,W,d) 71 | item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W) 72 | # (B,d,H,W) -> (B,4d,H/2,W/2) 73 | item = torch.nn.functional.pixel_unshuffle(item, 2) 74 | # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d) 75 | item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim) 76 | gt_ms_idx_Bl.append(item) 77 | else: 78 | gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices] 79 | x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1) 80 | 81 | if self.debug_bsc: 82 | self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices) 83 | 84 | return x_BLC_wo_prefix, gt_ms_idx_Bl 85 | 86 | def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices): 87 | gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255 88 | gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1] 89 | recons_img_2 = labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule) 90 | recons_img_3 = labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule) 91 | cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1) 92 | save_path = osp.abspath('non_teacher_force.jpg') 93 | cv2.imwrite(save_path, cat_image) 94 | print(f'Save to {save_path}') 95 | import pdb; pdb.set_trace() 96 | print(cat_image.shape) 97 | -------------------------------------------------------------------------------- /infinity/models/bsq_vae/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | import torch.nn.functional as F 5 | 6 | 7 | class Conv(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False): 9 | super().__init__() 10 | self.cnn_type = cnn_type 11 | self.slice_seq_len = 17 12 | 13 | if cnn_type == "2d": 14 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) 15 | if cnn_type == "3d": 16 | if temporal_down == False: 17 | stride = (1, stride, stride) 18 | else: 19 | stride = (stride, stride, stride) 20 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0) 21 | if isinstance(kernel_size, int): 22 | kernel_size = (kernel_size, kernel_size, kernel_size) 23 | self.padding = ( 24 | kernel_size[0] - 1 + causal_offset, # Temporal causal padding 25 | padding, # Height padding 26 | padding # Width padding 27 | ) 28 | self.causal_offset = causal_offset 29 | self.stride = stride 30 | self.kernel_size = kernel_size 31 | 32 | def forward(self, x): 33 | if self.cnn_type == "2d": 34 | if x.ndim == 5: 35 | B, C, T, H, W = x.shape 36 | x = rearrange(x, "B C T H W -> (B T) C H W") 37 | x = self.conv(x) 38 | x = rearrange(x, "(B T) C H W -> B C T H W", T=T) 39 | return x 40 | else: 41 | return self.conv(x) 42 | if self.cnn_type == "3d": 43 | assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported" 44 | xs = [] 45 | for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1): 46 | st = i 47 | en = min(i+self.slice_seq_len, x.shape[2]) 48 | _x = x[:,:,st:en,:,:] 49 | if i == 0: 50 | _x = F.pad(_x, (self.padding[2], self.padding[2], # Width 51 | self.padding[1], self.padding[1], # Height 52 | self.padding[0], 0)) # Temporal 53 | else: 54 | padding_0 = self.kernel_size[0] - 1 55 | _x = F.pad(_x, (self.padding[2], self.padding[2], # Width 56 | self.padding[1], self.padding[1], # Height 57 | padding_0, 0)) # Temporal 58 | _x[:,:,:padding_0, 59 | self.padding[1]:_x.shape[-2]-self.padding[1], 60 | self.padding[2]:_x.shape[-1]-self.padding[2]] += x[:,:,i-padding_0:i,:,:] 61 | _x = self.conv(_x) 62 | xs.append(_x) 63 | try: 64 | x = torch.cat(xs, dim=2) 65 | except: 66 | device = x.device 67 | del x 68 | xs = [_x.cpu().pin_memory() for _x in xs] 69 | torch.cuda.empty_cache() 70 | x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device) 71 | return x -------------------------------------------------------------------------------- /infinity/models/bsq_vae/dynamic_resolution.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import tqdm 4 | 5 | vae_stride = 16 6 | ratio2hws = { 7 | 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)], 8 | 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)], 9 | 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)], 10 | 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)], 11 | 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)], 12 | 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)], 13 | 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)], 14 | 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)], 15 | } 16 | full_ratio2hws = {} 17 | for ratio, hws in ratio2hws.items(): 18 | full_ratio2hws[ratio] = hws 19 | full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws] 20 | 21 | dynamic_resolution_h_w = {} 22 | predefined_HW_Scales_dynamic = {} 23 | for ratio in full_ratio2hws: 24 | dynamic_resolution_h_w[ratio] ={} 25 | for ind, leng in enumerate([7, 10, 13]): 26 | h, w = full_ratio2hws[ratio][leng-1][0], full_ratio2hws[ratio][leng-1][1] # feature map size 27 | pixel = (h * vae_stride, w * vae_stride) # The original image (H, W) 28 | dynamic_resolution_h_w[ratio][pixel[1]] = { 29 | 'pixel': pixel, 30 | 'scales': full_ratio2hws[ratio][:leng] 31 | } # W as key 32 | predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng] -------------------------------------------------------------------------------- /infinity/models/bsq_vae/vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from infinity.models.bsq_vae.flux_vqgan import AutoEncoder 5 | 6 | def load_cnn(model, state_dict, prefix, expand=False, use_linear=False): 7 | delete_keys = [] 8 | loaded_keys = [] 9 | for key in state_dict: 10 | if key.startswith(prefix): 11 | _key = key[len(prefix):] 12 | if _key in model.state_dict(): 13 | # load nn.Conv2d or nn.Linear to nn.Linear 14 | if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key): 15 | load_weights = state_dict[key].squeeze() 16 | elif _key.endswith(".conv.weight") and expand: 17 | if model.state_dict()[_key].shape == state_dict[key].shape: 18 | # 2D cnn to 2D cnn 19 | load_weights = state_dict[key] 20 | else: 21 | # 2D cnn to 3D cnn 22 | _expand_dim = model.state_dict()[_key].shape[2] 23 | load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) 24 | else: 25 | load_weights = state_dict[key] 26 | model.state_dict()[_key].copy_(load_weights) 27 | delete_keys.append(key) 28 | loaded_keys.append(prefix+_key) 29 | # load nn.Conv2d to Conv class 30 | conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."] 31 | if any(k in _key for k in conv_list): 32 | if _key.endswith(".weight"): 33 | conv_key = _key.replace(".weight", ".conv.weight") 34 | if conv_key and conv_key in model.state_dict(): 35 | if model.state_dict()[conv_key].shape == state_dict[key].shape: 36 | # 2D cnn to 2D cnn 37 | load_weights = state_dict[key] 38 | else: 39 | # 2D cnn to 3D cnn 40 | _expand_dim = model.state_dict()[conv_key].shape[2] 41 | load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) 42 | model.state_dict()[conv_key].copy_(load_weights) 43 | delete_keys.append(key) 44 | loaded_keys.append(prefix+conv_key) 45 | if _key.endswith(".bias"): 46 | conv_key = _key.replace(".bias", ".conv.bias") 47 | if conv_key and conv_key in model.state_dict(): 48 | model.state_dict()[conv_key].copy_(state_dict[key]) 49 | delete_keys.append(key) 50 | loaded_keys.append(prefix+conv_key) 51 | # load nn.GroupNorm to Normalize class 52 | if "norm" in _key: 53 | if _key.endswith(".weight"): 54 | norm_key = _key.replace(".weight", ".norm.weight") 55 | if norm_key and norm_key in model.state_dict(): 56 | model.state_dict()[norm_key].copy_(state_dict[key]) 57 | delete_keys.append(key) 58 | loaded_keys.append(prefix+norm_key) 59 | if _key.endswith(".bias"): 60 | norm_key = _key.replace(".bias", ".norm.bias") 61 | if norm_key and norm_key in model.state_dict(): 62 | model.state_dict()[norm_key].copy_(state_dict[key]) 63 | delete_keys.append(key) 64 | loaded_keys.append(prefix+norm_key) 65 | 66 | for key in delete_keys: 67 | del state_dict[key] 68 | 69 | return model, state_dict, loaded_keys 70 | 71 | 72 | def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],): 73 | args=argparse.Namespace( 74 | vqgan_ckpt=vqgan_ckpt, 75 | sd_ckpt=None, 76 | inference_type='image', 77 | save='./imagenet_val_bsq', 78 | save_prediction=True, 79 | image_recon4video=False, 80 | junke_old=False, 81 | device='cuda', 82 | max_steps=1000000.0, 83 | log_every=1, 84 | visu_every=1000, 85 | ckpt_every=1000, 86 | default_root_dir='', 87 | compile='no', 88 | ema='no', 89 | lr=0.0001, 90 | beta1=0.9, 91 | beta2=0.95, 92 | warmup_steps=0, 93 | optim_type='Adam', 94 | disc_optim_type=None, 95 | lr_min=0.0, 96 | warmup_lr_init=0.0, 97 | max_grad_norm=1.0, 98 | max_grad_norm_disc=1.0, 99 | disable_sch=False, 100 | patch_size=patch_size, 101 | temporal_patch_size=4, 102 | embedding_dim=256, 103 | codebook_dim=codebook_dim, 104 | num_quantizers=8, 105 | quantizer_type='MultiScaleBSQ', 106 | use_vae=False, 107 | use_freq_enc=False, 108 | use_freq_dec=False, 109 | preserve_norm=False, 110 | ln_before_quant=False, 111 | ln_init_by_sqrt=False, 112 | use_pxsf=False, 113 | new_quant=True, 114 | use_decay_factor=False, 115 | mask_out=False, 116 | use_stochastic_depth=False, 117 | drop_rate=0.0, 118 | schedule_mode=schedule_mode, 119 | lr_drop=None, 120 | lr_drop_rate=0.1, 121 | keep_first_quant=False, 122 | keep_last_quant=False, 123 | remove_residual_detach=False, 124 | use_out_phi=False, 125 | use_out_phi_res=False, 126 | use_lecam_reg=False, 127 | lecam_weight=0.05, 128 | perceptual_model='vgg16', 129 | base_ch_disc=64, 130 | random_flip=False, 131 | flip_prob=0.5, 132 | flip_mode='stochastic', 133 | max_flip_lvl=1, 134 | not_load_optimizer=False, 135 | use_lecam_reg_zero=False, 136 | freeze_encoder=False, 137 | rm_downsample=False, 138 | random_flip_1lvl=False, 139 | flip_lvl_idx=0, 140 | drop_when_test=False, 141 | drop_lvl_idx=0, 142 | drop_lvl_num=1, 143 | disc_version='v1', 144 | magvit_disc=False, 145 | sigmoid_in_disc=False, 146 | activation_in_disc='leaky_relu', 147 | apply_blur=False, 148 | apply_noise=False, 149 | dis_warmup_steps=0, 150 | dis_lr_multiplier=1.0, 151 | dis_minlr_multiplier=False, 152 | disc_channels=64, 153 | disc_layers=3, 154 | discriminator_iter_start=0, 155 | disc_pretrain_iter=0, 156 | disc_optim_steps=1, 157 | disc_warmup=0, 158 | disc_pool='no', 159 | disc_pool_size=1000, 160 | advanced_disc=False, 161 | recon_loss_type='l1', 162 | video_perceptual_weight=0.0, 163 | image_gan_weight=1.0, 164 | video_gan_weight=1.0, 165 | image_disc_weight=0.0, 166 | video_disc_weight=0.0, 167 | l1_weight=4.0, 168 | gan_feat_weight=0.0, 169 | perceptual_weight=0.0, 170 | kl_weight=0.0, 171 | lfq_weight=0.0, 172 | entropy_loss_weight=0.1, 173 | commitment_loss_weight=0.25, 174 | diversity_gamma=1, 175 | norm_type='group', 176 | disc_loss_type='hinge', 177 | use_checkpoint=False, 178 | precision='fp32', 179 | encoder_dtype='fp32', 180 | upcast_attention='', 181 | upcast_tf32=False, 182 | tokenizer='flux', 183 | pretrained=None, 184 | pretrained_mode='full', 185 | inflation_pe=False, 186 | init_vgen='no', 187 | no_init_idis=False, 188 | init_idis='keep', 189 | init_vdis='no', 190 | enable_nan_detector=False, 191 | turn_on_profiler=False, 192 | profiler_scheduler_wait_steps=10, 193 | debug=True, 194 | video_logger=False, 195 | bytenas='', 196 | username='', 197 | seed=1234, 198 | vq_to_vae=False, 199 | load_not_strict=False, 200 | zero=0, 201 | bucket_cap_mb=40, 202 | manual_gc_interval=1000, 203 | data_path=[''], 204 | data_type=[''], 205 | dataset_list=['imagenet'], 206 | fps=-1, 207 | dataaug='resizecrop', 208 | multi_resolution=False, 209 | random_bucket_ratio=0.0, 210 | sequence_length=16, 211 | resolution=[256, 256], 212 | batch_size=[1], 213 | num_workers=0, 214 | image_channels=3, 215 | codebook_size=codebook_size, 216 | codebook_l2_norm=True, 217 | codebook_show_usage=True, 218 | commit_loss_beta=0.25, 219 | entropy_loss_ratio=0.0, 220 | base_ch=128, 221 | num_res_blocks=2, 222 | encoder_ch_mult=encoder_ch_mult, 223 | decoder_ch_mult=decoder_ch_mult, 224 | dropout_p=0.0, 225 | cnn_type='2d', 226 | cnn_version='v1', 227 | conv_in_out_2d='no', 228 | conv_inner_2d='no', 229 | res_conv_2d='no', 230 | cnn_attention='no', 231 | cnn_norm_axis='spatial', 232 | flux_weight=0, 233 | cycle_weight=0, 234 | cycle_feat_weight=0, 235 | cycle_gan_weight=0, 236 | cycle_loop=0, 237 | z_drop=0.0) 238 | 239 | vae = AutoEncoder(args) 240 | use_vae = vae.use_vae 241 | if not use_vae: 242 | num_codes = args.codebook_size 243 | if isinstance(vqgan_ckpt, str): 244 | state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True) 245 | else: 246 | state_dict = args.vqgan_ckpt 247 | if state_dict: 248 | if args.ema == "yes": 249 | vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False) 250 | else: 251 | vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False) 252 | if test_mode: 253 | vae.eval() 254 | [p.requires_grad_(False) for p in vae.parameters()] 255 | return vae -------------------------------------------------------------------------------- /infinity/models/ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from collections import OrderedDict 4 | 5 | 6 | def get_ema_model(model): 7 | ema_model = copy.deepcopy(model) 8 | ema_model.eval() 9 | for param in ema_model.parameters(): 10 | param.requires_grad = False 11 | return ema_model 12 | 13 | @torch.no_grad() 14 | def update_ema(ema_model, model, decay=0.9999): 15 | """ 16 | Step the EMA model towards the current model. 17 | """ 18 | ema_params = OrderedDict(ema_model.named_parameters()) 19 | model_params = OrderedDict(model.named_parameters()) 20 | 21 | for name, param in model_params.items(): 22 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 23 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 24 | -------------------------------------------------------------------------------- /infinity/models/flex_attn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrap torch's flex attention and handle mess info or potentially refactor 3 | """ 4 | from functools import partial 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | try: 10 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask 11 | flex_attention_available = True 12 | except ImportError: 13 | print(f"[Warning] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}") 14 | flex_attention_available = False 15 | 16 | def _causal_mask(b, h, q_idx, kv_idx): 17 | return q_idx >= kv_idx 18 | 19 | def _length_to_offsets(lengths, device): 20 | """Converts a list of lengths to a list of offsets. 21 | 22 | Args: 23 | lengths: A list of lengths. 24 | 25 | """ 26 | offsets = [0] 27 | offsets.extend(lengths) 28 | offsets = torch.tensor(offsets, device=device, dtype=torch.int32) 29 | offsets = torch.cumsum(offsets, dim=-1) 30 | return offsets 31 | 32 | def _generate_var_mask_mod(offsets): 33 | """Generates mask mods that apply to inputs to flex attention in the sequence stacked 34 | format. 35 | 36 | Args: 37 | offsets: This tensor should be of shape(num_documents + 1) 38 | this should contain the cumulative counts of document tokens. 39 | e.g. if you have 3 documents of length 2, 4, 3 then 40 | offsets = [0, 2, 6, 9] 41 | 42 | Note: 43 | What is the sequence stacked format? When assembling batches of inputs, we 44 | take multiple sequences and stack them together to form 1 large sequence. We then 45 | use masking to ensure that the attention scores are only applied to tokens within 46 | the same document. 47 | """ 48 | 49 | def _offsets_to_doc_ids_tensor(offsets): 50 | device = offsets.device 51 | counts = offsets[1:] - offsets[:-1] 52 | return torch.repeat_interleave( 53 | torch.arange(len(counts), device=device, dtype=torch.int32), counts 54 | ) 55 | 56 | document_id = _offsets_to_doc_ids_tensor(offsets) 57 | 58 | def var_mask_mod(b, h, q_idx, kv_idx): 59 | same_doc = document_id[q_idx] == document_id[kv_idx] 60 | causal_mask = _causal_mask(b, h, q_idx, kv_idx) 61 | return same_doc | causal_mask 62 | 63 | return var_mask_mod 64 | 65 | def _generate_var_infer_mask_with_kv_cache(lengths): 66 | kv_len = sum(lengths) 67 | def var_mask_mod(b, h, q_idx, kv_idx): 68 | return kv_idx < kv_len 69 | 70 | return var_mask_mod 71 | 72 | class FlexAttn(nn.Module): 73 | def __init__( 74 | self, block_scales:list, mask_type:str, B, H, L:int, auto_padding=False 75 | ): 76 | """ 77 | :param block_scales: accept VAR's block sizes like [(1,1), (2,2), (3,3)] 78 | :param mask_type: var/causal 79 | :param B: batch size 80 | :param H: heads num 81 | :param L: sequence length 82 | """ 83 | super().__init__() 84 | if not flex_attention_available: 85 | raise NotImplementedError((f"[Error] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}")) 86 | 87 | self.support_mask_type = ["var", "causal", "var_infer_mask_with_kv_cache"] 88 | self.auto_padding = auto_padding 89 | 90 | self.flex_attention = torch.compile(flex_attention) 91 | 92 | self.block_scales = block_scales 93 | self.lengths = [ x * y * z for x,y,z in block_scales] 94 | 95 | self.offsets = _length_to_offsets(self.lengths, device='cuda') 96 | 97 | # if L paded to align 128, block need to cover padding area 98 | if self.offsets[-1] < L: 99 | self.offsets = torch.cat((self.offsets, torch.tensor([L], device='cuda')), dim=0) 100 | 101 | if mask_type == "var": 102 | self.mask_mod = _generate_var_mask_mod(self.offsets) 103 | self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) 104 | elif mask_type == "causal": 105 | self.mask_mod = _causal_mask 106 | self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) 107 | elif mask_type == 'var_infer_mask_with_kv_cache': 108 | self.mask_mod = _generate_var_infer_mask_with_kv_cache(self.lengths) 109 | self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) 110 | else: 111 | raise NotImplementedError(f"{mask_type} not supportted in FlexAttn, support type:{self.support_mask_type}") 112 | 113 | 114 | def forward(self, q, k, v, scale = None): 115 | if self.auto_padding: 116 | q_pad_len = (128 - q.shape[-2] % 128) % 128 117 | kv_pad_len = (128 - k.shape[-2] % 128) % 128 118 | q_pad = F.pad(q, (0, 0, 0, q_pad_len)) 119 | k_pad = F.pad(k, (0, 0, 0, kv_pad_len)) 120 | v_pad = F.pad(v, (0, 0, 0, kv_pad_len)) 121 | oup = self.flex_attention(q_pad.to(v_pad.dtype), k_pad.to(v.dtype), v_pad, block_mask = self.block_mask, scale = scale) 122 | if q_pad_len > 0: 123 | oup = oup[:,:,:-q_pad_len] 124 | else: 125 | oup = self.flex_attention(q.to(v.dtype), k.to(v.dtype), v, block_mask = self.block_mask, scale = scale) 126 | return oup 127 | 128 | def extra_repr(self) -> str: 129 | tail = '' 130 | return f'block size:{self.block_scales} {tail}' 131 | -------------------------------------------------------------------------------- /infinity/models/fused_op.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from copy import deepcopy 3 | from typing import Union 4 | 5 | import torch 6 | from torch import nn as nn 7 | from torch.nn import functional as F 8 | 9 | 10 | @torch.compile(fullgraph=True) 11 | def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float): 12 | x = x.float() 13 | return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight 14 | 15 | 16 | @torch.compile(fullgraph=True) 17 | def fused_ada_layer_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor): 18 | x = x.float() 19 | x = F.layer_norm(input=x, normalized_shape=(C,), weight=None, bias=None, eps=eps) 20 | return x.mul(scale.add(1)).add_(shift) 21 | 22 | 23 | @torch.compile(fullgraph=True) 24 | def fused_ada_rms_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor): 25 | x = x.float() 26 | x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) 27 | return x.mul(scale.add(1)).add_(shift) 28 | -------------------------------------------------------------------------------- /infinity/models/init_param.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_weights(model: nn.Module, conv_std_or_gain: float = 0.02, other_std: float = 0.02): 5 | """ 6 | :param model: the model to be inited 7 | :param conv_std_or_gain: how to init every conv layer `m` 8 | > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) 9 | < 0: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) 10 | :param other_std: how to init every linear layer or embedding layer 11 | use nn.init.trunc_normal_(m.weight.data, std=other_std) 12 | """ 13 | skip = abs(conv_std_or_gain) > 10 14 | if skip: return 15 | print(f'[init_weights] {type(model).__name__} with {"std" if conv_std_or_gain > 0 else "gain"}={abs(conv_std_or_gain):g}') 16 | for m in model.modules(): 17 | if isinstance(m, nn.Linear): 18 | nn.init.trunc_normal_(m.weight.data, std=other_std) 19 | if m.bias is not None: 20 | nn.init.constant_(m.bias.data, 0.) 21 | elif isinstance(m, nn.Embedding): 22 | nn.init.trunc_normal_(m.weight.data, std=other_std) 23 | if m.padding_idx is not None: 24 | m.weight.data[m.padding_idx].zero_() 25 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d)): 26 | nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) if conv_std_or_gain > 0 else nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) # todo: StyleSwin: (..., gain=.02) 27 | if hasattr(m, 'bias') and m.bias is not None: 28 | nn.init.constant_(m.bias.data, 0.) 29 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): 30 | if m.bias is not None: 31 | nn.init.constant_(m.bias.data, 0.) 32 | if m.weight is not None: 33 | nn.init.constant_(m.weight.data, 1.) 34 | -------------------------------------------------------------------------------- /infinity/utils/amp_opt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import signal 4 | import sys 5 | import time 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 10 | # from memory_profiler import profile 11 | 12 | import infinity.utils.dist as dist 13 | from infinity.utils import misc 14 | 15 | class NullCtx: 16 | def __enter__(self): 17 | pass 18 | 19 | def __exit__(self, exc_type, exc_val, exc_tb): 20 | pass 21 | 22 | 23 | def handle_timeout(signum, frame): 24 | raise TimeoutError('took too long') 25 | 26 | 27 | def per_param_clip_grad_norm_(parameters, thresh: float, stable=False, fp=None) -> (float, float): 28 | skipped, max_grad = [], 0 29 | for pi, p in enumerate(parameters): 30 | if p.grad is not None: 31 | g = p.grad.data.norm(2).item() + 1e-7 32 | max_grad = max(max_grad, g) 33 | clip_coef = thresh / g 34 | if clip_coef < 1: 35 | if stable and clip_coef < 0.2: 36 | skipped.append(clip_coef) 37 | p.grad.data.mul_(0) # todo NOTE: inf.mul_(0)==nan will shrink the scale ratio, but inf.zero_()==0 won't 38 | else: 39 | p.grad.data.mul_(clip_coef) 40 | 41 | # if fp is not None: fp.write(f'[per_param_clip_grad_norm_:47] finished.\n'); fp.flush() 42 | return 0 if len(skipped) == 0 else math.log10(max(min(skipped), 1e-7)), max_grad 43 | 44 | 45 | class AmpOptimizer: 46 | def __init__( 47 | self, 48 | model_name_3letters: str, mixed_precision: int, 49 | optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP], 50 | r_accu: float, grad_clip: float, zero: int, 51 | ): 52 | self.enable_amp = mixed_precision > 0 53 | self.zero = zero 54 | if self.enable_amp: 55 | self.using_fp16_rather_bf16 = mixed_precision != 2 56 | self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768) 57 | 58 | # todo: on both V100 and A100, torch.get_autocast_gpu_dtype() returns fp16, not bf16. 59 | self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) # todo: cache_enabled=False 60 | if self.using_fp16_rather_bf16: 61 | self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) 62 | else: 63 | self.scaler = None 64 | else: 65 | self.using_fp16_rather_bf16 = True 66 | self.amp_ctx = NullCtx() 67 | self.scaler = None 68 | 69 | t = torch.zeros(dist.get_world_size()) 70 | t[dist.get_rank()] = float(self.enable_amp) 71 | dist.allreduce(t) 72 | assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}' 73 | 74 | t = torch.zeros(dist.get_world_size()) 75 | t[dist.get_rank()] = float(self.using_fp16_rather_bf16) 76 | dist.allreduce(t) 77 | assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}' 78 | 79 | self.model_name_3letters = model_name_3letters 80 | self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp 81 | self.r_accu = r_accu 82 | 83 | self.paras = self.names = ... # todo: solve EMA-related codes 84 | 85 | self.grad_clip, self.grad_clip_we = grad_clip, 0 # todo: disable wclip 86 | if self.grad_clip > 100: 87 | self.grad_clip %= 100 88 | self.per_param = True 89 | else: 90 | self.per_param = False 91 | self.per_param = False # todo: disable wclip 92 | 93 | self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') 94 | self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') # deepspeed's optimizer 95 | 96 | self.fp = None 97 | self.last_orig_norm: torch.Tensor = torch.tensor(0.1) 98 | 99 | @torch.no_grad() 100 | def log_param(self, ep: int): 101 | if self.zero == 0: 102 | for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items(): 103 | values: List[float] 104 | if len(values) == 1: # e.g., cls token will only have one value 105 | values.append(values[0]) 106 | else: 107 | ... 108 | # todo: log params 109 | 110 | # @profile(precision=4, stream=open('amp_sc.log', 'w+')) 111 | def backward_clip_step( 112 | self, ep: int, it: int, g_it: int, stepping: bool, logging_params: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False, 113 | ) -> Tuple[torch.Tensor, Optional[float]]: 114 | # backward 115 | loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation 116 | orig_norm = scaler_sc = None 117 | # if self.fp is not None: 118 | # if g_it % 20 == 0: self.fp.seek(0); self.fp.truncate(0) 119 | if self.scaler is not None: 120 | self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) # retain_graph=retain_graph, create_graph=create_graph 121 | else: 122 | loss.backward(retain_graph=False, create_graph=False) 123 | # if self.fp is not None: self.fp.write(f'[backward_clip_step:131] [it{it}, g_it{g_it}] after backward\n'); self.fp.flush() 124 | 125 | # clip gradients then step optimizer 126 | if stepping: 127 | if self.scaler is not None: self.scaler.unscale_(self.optimizer) # now the gradient can be correctly got 128 | # if self.fp is not None: self.fp.write(f'[backward_clip_step:137] [it{it}, g_it{g_it}] after scaler.unscale_\n'); self.fp.flush() 129 | 130 | skipped, orig_norm = 0, self.last_orig_norm 131 | # try: 132 | if self.fp is not None: 133 | if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0) 134 | self.fp.write(f'\n'); self.fp.flush() 135 | if self.early_clipping: 136 | c = self.grad_clip * clip_decay_ratio 137 | if self.zero: 138 | orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c) 139 | else: 140 | orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c) 141 | 142 | # if self.fp is not None: self.fp.write(f'[backward_clip_step:175] [it{it}, g_it{g_it}] before opt step\n'); self.fp.flush() 143 | if self.scaler is not None: 144 | self.scaler: torch.cuda.amp.GradScaler 145 | if self.zero: 146 | # synchronize found_inf_per_device before calling step, so that even if only some ranks found inf on their sharded params, all other ranks will know 147 | # otherwise, when saving FSDP optimizer state, it will cause AssertionError saying "Different ranks have different values for step." 148 | for optimizer_state in self.scaler._per_optimizer_states.values(): 149 | for t in optimizer_state['found_inf_per_device'].values(): 150 | dist.allreduce(t) # ideally, each rank only has one single t; so no need to use async allreduce 151 | 152 | self.scaler.step(self.optimizer) 153 | scaler_sc: Optional[float] = self.scaler.get_scale() 154 | if scaler_sc > self.max_sc: # fp16 will overflow when >65536, so multiply 32768 could be dangerous 155 | # print(f'[fp16 scaling] too large loss scale {scaler_sc}! (clip to {self.max_sc:g})') 156 | self.scaler.update(new_scale=self.max_sc) 157 | else: 158 | self.scaler.update() 159 | try: 160 | scaler_sc = float(math.log2(scaler_sc)) 161 | except Exception as e: 162 | print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) 163 | time.sleep(1) 164 | print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) 165 | raise e 166 | else: 167 | self.optimizer.step() 168 | 169 | if self.late_clipping: 170 | orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm 171 | self.last_orig_norm = orig_norm 172 | # no zero_grad calling here, gonna log those gradients! 173 | return orig_norm, scaler_sc 174 | 175 | def state_dict(self): 176 | return { 177 | 'optimizer': self.optimizer.state_dict() 178 | } if self.scaler is None else { 179 | 'scaler': self.scaler.state_dict(), 180 | 'optimizer': self.optimizer.state_dict() 181 | } 182 | 183 | def load_state_dict(self, state, strict=True): 184 | if self.scaler is not None: 185 | try: self.scaler.load_state_dict(state['scaler']) 186 | except Exception as e: print(f'[fp16 load_state_dict err] {e}') 187 | self.optimizer.load_state_dict(state['optimizer']) 188 | -------------------------------------------------------------------------------- /infinity/utils/csv_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import csv 4 | 5 | import numpy as np 6 | 7 | 8 | def write_dicts2csv_file(input_dict_list, csv_filename): 9 | os.makedirs(osp.dirname(csv_filename), exist_ok=True) 10 | with open(csv_filename, mode='w', newline='', encoding='utf-8') as file: 11 | fieldnames = input_dict_list[0].keys() 12 | writer = csv.DictWriter(file, fieldnames=fieldnames) 13 | writer.writeheader() 14 | writer.writerows(input_dict_list) 15 | print(f'"{csv_filename}" has been written.') 16 | 17 | def load_csv_as_dicts(csv_filename): 18 | with open(csv_filename, mode='r', newline='', encoding='utf-8') as csvfile: 19 | reader = csv.DictReader(csvfile) 20 | return list(reader) 21 | -------------------------------------------------------------------------------- /infinity/utils/dist.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import functools 3 | import os 4 | import sys 5 | from typing import List 6 | from typing import Union 7 | 8 | import pytz 9 | import torch 10 | import torch.distributed as tdist 11 | import torch.multiprocessing as mp 12 | 13 | 14 | __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' 15 | __rank_str_zfill = '0' 16 | __initialized = False 17 | 18 | 19 | def initialized(): 20 | return __initialized 21 | 22 | 23 | def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30): 24 | global __device 25 | if not torch.cuda.is_available(): 26 | print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) 27 | return 28 | elif 'RANK' not in os.environ: 29 | torch.cuda.set_device(gpu_id_if_not_distibuted) 30 | __device = torch.empty(1).cuda().device 31 | print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) 32 | return 33 | # then 'RANK' must exist 34 | global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() 35 | local_rank = global_rank % num_gpus 36 | torch.cuda.set_device(local_rank) 37 | 38 | # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 39 | """ 40 | if mp.get_start_method(allow_none=True) is None: 41 | method = 'fork' if fork else 'spawn' 42 | print(f'[dist initialize] mp method={method}') 43 | mp.set_start_method(method) 44 | """ 45 | tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) 46 | 47 | global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill 48 | __local_rank = local_rank 49 | __rank, __world_size = tdist.get_rank(), tdist.get_world_size() 50 | __rank_str_zfill = str(__rank).zfill(len(str(__world_size))) 51 | __device = torch.device(local_rank) 52 | __initialized = True 53 | 54 | assert tdist.is_initialized(), 'torch.distributed is not initialized!' 55 | print(f'[lrk={get_local_rank()}, rk={get_rank()}]') 56 | 57 | 58 | def get_rank(): 59 | return __rank 60 | 61 | 62 | def get_rank_given_group(group: tdist.ProcessGroup): 63 | return tdist.get_rank(group=group) 64 | 65 | 66 | def get_rank_str_zfill(): 67 | return __rank_str_zfill 68 | 69 | 70 | def get_local_rank(): 71 | return __local_rank 72 | 73 | 74 | def get_world_size(): 75 | return __world_size 76 | 77 | 78 | def get_device(): 79 | return __device 80 | 81 | 82 | def set_gpu_id(gpu_id: int): 83 | if gpu_id is None: return 84 | global __device 85 | if isinstance(gpu_id, (str, int)): 86 | torch.cuda.set_device(int(gpu_id)) 87 | __device = torch.empty(1).cuda().device 88 | else: 89 | raise NotImplementedError 90 | 91 | 92 | def is_master(): 93 | return __rank == 0 94 | 95 | 96 | def is_local_master(): 97 | return __local_rank == 0 98 | 99 | 100 | def is_visualizer(): 101 | return __rank == 0 102 | # return __rank == max(__world_size - 8, 0) 103 | 104 | 105 | def parallelize(net, syncbn=False): 106 | if syncbn: 107 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) 108 | net = net.cuda() 109 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) 110 | return net 111 | 112 | 113 | def new_group(ranks: List[int]): 114 | if __initialized: 115 | return tdist.new_group(ranks=ranks) 116 | return None 117 | 118 | 119 | def new_local_machine_group(): 120 | if __initialized: 121 | cur_subgroup, subgroups = tdist.new_subgroups() 122 | return cur_subgroup 123 | return None 124 | 125 | 126 | def barrier(): 127 | if __initialized: 128 | tdist.barrier() 129 | 130 | 131 | def allreduce(t: torch.Tensor, async_op=False): 132 | if __initialized: 133 | if not t.is_cuda: 134 | cu = t.detach().cuda() 135 | ret = tdist.all_reduce(cu, async_op=async_op) 136 | t.copy_(cu.cpu()) 137 | else: 138 | ret = tdist.all_reduce(t, async_op=async_op) 139 | return ret 140 | return None 141 | 142 | 143 | def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: 144 | if __initialized: 145 | if not t.is_cuda: 146 | t = t.cuda() 147 | ls = [torch.empty_like(t) for _ in range(__world_size)] 148 | tdist.all_gather(ls, t) 149 | else: 150 | ls = [t] 151 | if cat: 152 | ls = torch.cat(ls, dim=0) 153 | return ls 154 | 155 | 156 | def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: 157 | if __initialized: 158 | if not t.is_cuda: 159 | t = t.cuda() 160 | 161 | t_size = torch.tensor(t.size(), device=t.device) 162 | ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] 163 | tdist.all_gather(ls_size, t_size) 164 | 165 | max_B = max(size[0].item() for size in ls_size) 166 | pad = max_B - t_size[0].item() 167 | if pad: 168 | pad_size = (pad, *t.size()[1:]) 169 | t = torch.cat((t, t.new_empty(pad_size)), dim=0) 170 | 171 | ls_padded = [torch.empty_like(t) for _ in range(__world_size)] 172 | tdist.all_gather(ls_padded, t) 173 | ls = [] 174 | for t, size in zip(ls_padded, ls_size): 175 | ls.append(t[:size[0].item()]) 176 | else: 177 | ls = [t] 178 | if cat: 179 | ls = torch.cat(ls, dim=0) 180 | return ls 181 | 182 | 183 | def broadcast(t: torch.Tensor, src_rank) -> None: 184 | if __initialized: 185 | if not t.is_cuda: 186 | cu = t.detach().cuda() 187 | tdist.broadcast(cu, src=src_rank) 188 | t.copy_(cu.cpu()) 189 | else: 190 | tdist.broadcast(t, src=src_rank) 191 | 192 | 193 | def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: 194 | if not initialized(): 195 | return torch.tensor([val]) if fmt is None else [fmt % val] 196 | 197 | ts = torch.zeros(__world_size) 198 | ts[__rank] = val 199 | allreduce(ts) 200 | if fmt is None: 201 | return ts 202 | return [fmt % v for v in ts.cpu().numpy().tolist()] 203 | 204 | 205 | def master_only(func): 206 | @functools.wraps(func) 207 | def wrapper(*args, **kwargs): 208 | force = kwargs.pop('force', False) 209 | if force or is_master(): 210 | ret = func(*args, **kwargs) 211 | else: 212 | ret = None 213 | barrier() 214 | return ret 215 | return wrapper 216 | 217 | 218 | def local_master_only(func): 219 | @functools.wraps(func) 220 | def wrapper(*args, **kwargs): 221 | force = kwargs.pop('force', False) 222 | if force or is_local_master(): 223 | ret = func(*args, **kwargs) 224 | else: 225 | ret = None 226 | barrier() 227 | return ret 228 | return wrapper 229 | 230 | 231 | def for_visualize(func): 232 | @functools.wraps(func) 233 | def wrapper(*args, **kwargs): 234 | if is_visualizer(): 235 | # with torch.no_grad(): 236 | ret = func(*args, **kwargs) 237 | else: 238 | ret = None 239 | return ret 240 | return wrapper 241 | 242 | 243 | def finalize(): 244 | if __initialized: 245 | tdist.destroy_process_group() 246 | 247 | 248 | def init_distributed_mode(local_out_path, fork=False, only_sync_master=False, timeout_minutes=30): 249 | try: 250 | __initialize(fork=fork, timeout_minutes=timeout_minutes) 251 | barrier() 252 | except RuntimeError as e: 253 | print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True) 254 | raise e 255 | 256 | if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) 257 | _change_builtin_print(is_local_master()) 258 | if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path): 259 | sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False) 260 | 261 | 262 | def _change_builtin_print(is_master): 263 | import builtins as __builtin__ 264 | 265 | builtin_print = __builtin__.print 266 | if type(builtin_print) != type(open): 267 | return 268 | 269 | def prt(*args, **kwargs): 270 | force = kwargs.pop('force', False) 271 | clean = kwargs.pop('clean', False) 272 | deeper = kwargs.pop('deeper', False) 273 | if is_master or force: 274 | if not clean: 275 | f_back = sys._getframe().f_back 276 | if deeper and f_back.f_back is not None: 277 | f_back = f_back.f_back 278 | file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] 279 | time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') 280 | builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) 281 | else: 282 | builtin_print(*args, **kwargs) 283 | 284 | __builtin__.print = prt 285 | 286 | 287 | class BackupStreamToFile(object): 288 | def __init__(self, local_output_dir, for_stdout=True): 289 | self.for_stdout = for_stdout 290 | self.terminal_stream = sys.stdout if for_stdout else sys.stderr 291 | fname = os.path.join(local_output_dir, 'b1_stdout.txt' if for_stdout else 'b2_stderr.txt') 292 | existing = os.path.exists(fname) 293 | self.file_stream = open(fname, 'a') 294 | if existing: 295 | time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') 296 | self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n') 297 | self.file_stream.flush() 298 | os.system(f'ln -s {fname} /opt/tiger/run_trial/ >/dev/null 2>&1') 299 | self.enabled = True 300 | 301 | def write(self, message): 302 | self.terminal_stream.write(message) 303 | self.file_stream.write(message) 304 | 305 | def flush(self): 306 | self.terminal_stream.flush() 307 | self.file_stream.flush() 308 | 309 | def isatty(self): 310 | return True 311 | 312 | def close(self): 313 | if not self.enabled: 314 | return 315 | self.enabled = False 316 | self.file_stream.flush() 317 | self.file_stream.close() 318 | if self.for_stdout: 319 | sys.stdout = self.terminal_stream 320 | sys.stdout.flush() 321 | else: 322 | sys.stderr = self.terminal_stream 323 | sys.stderr.flush() 324 | 325 | def __del__(self): 326 | self.close() 327 | -------------------------------------------------------------------------------- /infinity/utils/dynamic_resolution.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import tqdm 4 | 5 | vae_stride = 16 6 | ratio2hws = { 7 | 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)], 8 | 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)], 9 | 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)], 10 | 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)], 11 | 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)], 12 | 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)], 13 | 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)], 14 | 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)], 15 | } 16 | predefined_t = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 21] 17 | 18 | full_ratio2hws = {} 19 | for ratio, hws in ratio2hws.items(): 20 | full_ratio2hws[ratio] = hws 21 | if ratio != 1.000: 22 | full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws] 23 | 24 | dynamic_resolution_h_w = {} 25 | for ratio in full_ratio2hws: 26 | dynamic_resolution_h_w[ratio] ={} 27 | for ind, leng in enumerate([7, 10, 12, 13]): 28 | h_div_w = full_ratio2hws[ratio][leng-1][0] / full_ratio2hws[ratio][leng-1][1] 29 | assert np.abs(h_div_w-ratio) < 0.01, f'{full_ratio2hws[ratio][leng-1]}: {h_div_w} != {ratio}' 30 | pixel = (full_ratio2hws[ratio][leng-1][0] * vae_stride, full_ratio2hws[ratio][leng-1][1] * vae_stride) 31 | if ind == 0: 32 | total_pixels = '0.06M' 33 | elif ind == 1: 34 | total_pixels = '0.25M' 35 | elif ind == 2: 36 | total_pixels = '0.60M' 37 | else: 38 | total_pixels = '1M' 39 | 40 | scales = full_ratio2hws[ratio][:leng] 41 | scales = [ (t, h, w) for t, (h, w) in zip(predefined_t, scales) ] 42 | dynamic_resolution_h_w[ratio][total_pixels] = { 43 | 'pixel': pixel, 44 | 'scales': scales 45 | } 46 | 47 | h_div_w_templates = [] 48 | for h_div_w in dynamic_resolution_h_w.keys(): 49 | h_div_w_templates.append(h_div_w) 50 | h_div_w_templates = np.array(h_div_w_templates) 51 | 52 | def get_h_div_w_template2indices(h_div_w_list, h_div_w_templates): 53 | indices = list(range(len(h_div_w_list))) 54 | h_div_w_template2indices = {} 55 | pbar = tqdm.tqdm(total=len(indices), desc='get_h_div_w_template2indices...') 56 | for h_div_w, index in zip(h_div_w_list, indices): 57 | pbar.update(1) 58 | nearest_h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w-h_div_w_templates))] 59 | if nearest_h_div_w_template_ not in h_div_w_template2indices: 60 | h_div_w_template2indices[nearest_h_div_w_template_] = [] 61 | h_div_w_template2indices[nearest_h_div_w_template_].append(index) 62 | for h_div_w_template_, sub_indices in h_div_w_template2indices.items(): 63 | h_div_w_template2indices[h_div_w_template_] = np.array(sub_indices) 64 | return h_div_w_template2indices 65 | 66 | if __name__ == '__main__': 67 | for h_div_w_template in dynamic_resolution_h_w: 68 | for total_pixels in dynamic_resolution_h_w[h_div_w_template]: 69 | scales = np.array(dynamic_resolution_h_w[h_div_w_template][total_pixels]['scales']) 70 | seq_len = np.sum(scales[:,0]*scales[:,1]) 71 | if total_pixels == '1M': 72 | string = f'{h_div_w_template}, {total_pixels}, {dynamic_resolution_h_w[h_div_w_template][total_pixels]}, seq_len: {seq_len}'.replace(', ', ',') 73 | print(string) 74 | -------------------------------------------------------------------------------- /infinity/utils/large_file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | import itertools 5 | import shutil 6 | import glob 7 | import argparse 8 | 9 | import tqdm 10 | import numpy as np 11 | import threading 12 | 13 | def save_lines(lines, filename): 14 | os.makedirs(osp.dirname(filename), exist_ok=True) 15 | with open(filename, 'w') as f: 16 | f.writelines(lines) 17 | del lines 18 | 19 | def get_part_jsonls(filepath, total_line_number, parts=512): 20 | dirname, filename, ext = osp.dirname(filepath), osp.splitext(osp.basename(filepath))[0], osp.splitext(osp.basename(filepath))[1] 21 | if parts == 1: 22 | return False, {1: filepath} 23 | save_dir = osp.join(dirname, f'{parts:04d}_parts') 24 | chunk_id2save_files = {} 25 | missing = False 26 | chunk_size = int(total_line_number/parts) 27 | for chunk_id in range(1, parts+1): 28 | if chunk_id == parts: 29 | num_of_lines = total_line_number - chunk_size * (parts-1) 30 | else: 31 | num_of_lines = chunk_size 32 | chunk_id2save_files[chunk_id] = osp.join(save_dir, f'{filename}_{chunk_id:04d}_{parts:04d}_{num_of_lines:09d}{ext}') 33 | if not osp.exists(chunk_id2save_files[chunk_id]): 34 | missing = True 35 | return missing, chunk_id2save_files 36 | 37 | def split_large_txt_files(filepath, chunk_id2save_files): 38 | thread_list = [] 39 | chunk_id = 1 40 | with open(filepath, 'r') as f: 41 | chunk = [] 42 | pbar = tqdm.tqdm(total=len(chunk_id2save_files)) 43 | for line in f: 44 | chunk.append(line) 45 | cur_chunk_size = int(osp.splitext(osp.basename(chunk_id2save_files[chunk_id]))[0].split('_')[-1]) 46 | if len(chunk) >= cur_chunk_size: 47 | pbar.update(1) 48 | thread_list.append(threading.Thread(target=save_lines, args=(chunk, chunk_id2save_files[chunk_id]))) 49 | thread_list[-1].start() 50 | chunk = [] 51 | chunk_id += 1 52 | if len(chunk): 53 | import ipdb; ipdb.set_trace() 54 | assert not len(chunk) 55 | for thread in thread_list: 56 | thread.join() 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--jsonl_folder', type=str, default='') 61 | parser.add_argument('--parts', type=int, default=600) 62 | args = parser.parse_args() 63 | for jsonl_filepath in sorted(glob.glob(osp.join(args.jsonl_folder, '*.jsonl'))): 64 | print(jsonl_filepath) 65 | t1 = time.time() 66 | line_num = int(jsonl_filepath.split('_')[-1].split('.')[0]) 67 | missing, chunk_id2save_files = get_part_jsonls(jsonl_filepath, line_num, parts=args.parts) 68 | split_large_txt_files(jsonl_filepath, chunk_id2save_files) 69 | t2 = time.time() 70 | print(f'split takes {t2-t1}s') 71 | -------------------------------------------------------------------------------- /infinity/utils/load.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import gc 3 | import os 4 | import os.path as osp 5 | import random 6 | import sys 7 | from copy import deepcopy 8 | from typing import Tuple, Union 9 | 10 | import colorama 11 | import torch 12 | import yaml 13 | 14 | import infinity.utils.dist as dist 15 | 16 | from infinity.models import Infinity 17 | from infinity.models.ema import get_ema_model 18 | from infinity.utils import arg_util, misc 19 | from infinity.utils.misc import os_system 20 | 21 | 22 | def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): 23 | if args.vae_type in [8,16,18,20,24,32,64,128]: 24 | from infinity.models.bsq_vae.vae import vae_model 25 | schedule_mode = "dynamic" 26 | codebook_dim = args.vae_type # 18 27 | codebook_size = 2**codebook_dim 28 | if args.apply_spatial_patchify: 29 | patch_size = 8 30 | encoder_ch_mult=[1, 2, 4, 4] 31 | decoder_ch_mult=[1, 2, 4, 4] 32 | else: 33 | patch_size = 16 34 | encoder_ch_mult=[1, 2, 4, 4, 4] 35 | decoder_ch_mult=[1, 2, 4, 4, 4] 36 | vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, 37 | encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) 38 | if args.fake_vae_input: 39 | vae_local.encoder = None 40 | vae_local.decoder = None 41 | torch.cuda.empty_cache() 42 | else: 43 | raise ValueError(f"vae_type {args.vae_type} not supported") 44 | if force_flash: args.flash = True 45 | gpt_kw = dict( 46 | pretrained=False, global_pool='', 47 | text_channels=args.Ct5, text_maxlen=args.tlen, 48 | norm_eps=args.norm_eps, rms_norm=args.rms, 49 | shared_aln=args.saln, head_aln=args.haln, 50 | cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, 51 | cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, 52 | raw_scale_schedule=args.scale_schedule, 53 | head_depth=args.dec, 54 | top_p=args.tp, top_k=args.tk, 55 | customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, 56 | checkpointing=args.enable_checkpointing, 57 | pad_to_multiplier=args.pad_to_multiplier, 58 | use_flex_attn=args.use_flex_attn, 59 | batch_size=args.batch_size, 60 | add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, 61 | use_bit_label=args.use_bit_label, 62 | rope2d_each_sa_layer=args.rope2d_each_sa_layer, 63 | rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, 64 | pn=args.pn, 65 | train_h_div_w_list=args.train_h_div_w_list, 66 | always_training_scales=args.always_training_scales, 67 | apply_spatial_patchify=args.apply_spatial_patchify, 68 | ) 69 | if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp 70 | if args.hd > 0: gpt_kw['num_heads'] = args.hd 71 | 72 | print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') 73 | gpt_kw['vae_local'] = vae_local 74 | 75 | model_str = args.model.replace('vgpt', 'infinity') # legacy 76 | print(f"{model_str=}") 77 | if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): 78 | model_str, block_chunks = model_str.rsplit('c', maxsplit=1) 79 | block_chunks = int(block_chunks) 80 | else: 81 | block_chunks = 1 82 | gpt_kw['block_chunks'] = block_chunks 83 | 84 | from infinity.models import Infinity 85 | from timm.models import create_model 86 | gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw) 87 | if args.use_fsdp_model_ema: 88 | gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) 89 | else: 90 | gpt_wo_ddp_ema = None 91 | gpt_wo_ddp = gpt_wo_ddp.to(device) 92 | 93 | assert all(not p.requires_grad for p in vae_local.parameters()) 94 | assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) 95 | 96 | return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema 97 | 98 | 99 | if __name__ == '__main__': 100 | ld(sys.argv[1]) 101 | -------------------------------------------------------------------------------- /infinity/utils/lr_control.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pprint import pformat 3 | from typing import Tuple, List, Dict, Union 4 | 5 | import torch.nn 6 | import infinity.utils.dist as dist 7 | 8 | 9 | def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | wp_it = round(wp_it) 12 | 13 | if cur_it < wp_it: 14 | cur_lr = wp0 + (1-wp0) * cur_it / wp_it 15 | else: 16 | pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1] 17 | rest = 1 - pasd # [1, 0] 18 | if sche_type == 'cos': 19 | cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd)) 20 | elif sche_type == 'lin': 21 | T = 0.15; max_rest = 1-T 22 | if pasd < T: cur_lr = 1 23 | else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe 24 | elif sche_type == 'lin0': 25 | T = 0.05; max_rest = 1-T 26 | if pasd < T: cur_lr = 1 27 | else: cur_lr = wpe + (1-wpe) * rest / max_rest 28 | elif sche_type == 'lin00': 29 | cur_lr = wpe + (1-wpe) * rest 30 | elif sche_type.startswith('lin'): 31 | T = float(sche_type[3:]); max_rest = 1-T 32 | wpe_mid = wpe + (1-wpe) * max_rest 33 | wpe_mid = (1 + wpe_mid) / 2 34 | if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T 35 | else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest 36 | elif sche_type == 'exp': 37 | T = 0.15; max_rest = 1-T 38 | if pasd < T: cur_lr = 1 39 | else: 40 | expo = (pasd-T) / max_rest * math.log(wpe) 41 | cur_lr = math.exp(expo) 42 | else: 43 | raise NotImplementedError(f'unknown sche_type {sche_type}') 44 | 45 | cur_lr *= peak_lr 46 | pasd = cur_it / (max_it-1) 47 | cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd)) 48 | 49 | inf = 1e6 50 | min_lr, max_lr = inf, -1 51 | min_wd, max_wd = inf, -1 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned 54 | max_lr = max(max_lr, param_group['lr']) 55 | min_lr = min(min_lr, param_group['lr']) 56 | 57 | param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1) 58 | max_wd = max(max_wd, param_group['weight_decay']) 59 | if param_group['weight_decay'] > 0: 60 | min_wd = min(min_wd, param_group['weight_decay']) 61 | 62 | if min_lr == inf: min_lr = -1 63 | if min_wd == inf: min_wd = -1 64 | return min_lr, max_lr, min_wd, max_wd 65 | 66 | 67 | def filter_params(model, ndim_dict, nowd_keys=(), lr_scale=0.0) -> Tuple[ 68 | List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]] 69 | ]: 70 | with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale <= 1 71 | print(f'[get_param_groups][lr decay] with_lr_scale={with_lr_scale}, lr_scale={lr_scale}') 72 | para_groups, para_groups_dbg = {}, {} 73 | names, paras = [], [] 74 | names_no_grad = [] 75 | count, numel = 0, 0 76 | for name, para in model.named_parameters(): 77 | name = name.replace('_fsdp_wrapped_module.', '') 78 | if not para.requires_grad: 79 | names_no_grad.append(name) 80 | continue # frozen weights 81 | count += 1 82 | numel += para.numel() 83 | names.append(name) 84 | paras.append(para) 85 | 86 | if ndim_dict.get(name, 2) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys): 87 | cur_wd_sc, group_name = 0., 'ND' 88 | # elif any(k in name for k in small_wd_keys): 89 | # cur_wd_sc, group_name = small_wd, 'small_decay' 90 | else: 91 | cur_wd_sc, group_name = 1., 'D' 92 | 93 | if with_lr_scale: 94 | layer_id, scale_exp = model.get_layer_id_and_scale_exp(name) 95 | group_name = f'layer{layer_id}_' + group_name 96 | cur_lr_sc = lr_scale ** scale_exp 97 | dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]' 98 | else: 99 | cur_lr_sc = 1. 100 | dbg = f'[no scale]' 101 | 102 | if group_name not in para_groups: 103 | para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc} 104 | para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': dbg} 105 | para_groups[group_name]['params'].append(para) 106 | para_groups_dbg[group_name]['params'].append(name) 107 | 108 | for g in para_groups_dbg.values(): 109 | g['params'] = pformat(', '.join(g['params']), width=200) 110 | 111 | print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n') 112 | 113 | for rk in range(dist.get_world_size()): 114 | dist.barrier() 115 | if dist.get_rank() == rk: 116 | print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True) 117 | print('') 118 | 119 | assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n' 120 | del ndim_dict 121 | return names, paras, list(para_groups.values()) 122 | 123 | 124 | def plot(): 125 | import matplotlib.pyplot as plt 126 | import torch.nn as nn 127 | from torch.optim import SGD 128 | # for sche in ('lin', 'lin0', 'lin00', 'lin0.5', 'lin0.75'): 129 | for sche in ('lin0', ): 130 | op = SGD(nn.Linear(3, 4).parameters(), lr=1e-3) 131 | it, lr = [], [] 132 | iters = 500 133 | wp_it, max_it = 1 * iters, 10 * iters 134 | for cur_it in range(max_it): 135 | it.append(cur_it) 136 | lr.append(lr_wd_annealing(sche, op, 0.1, 1e-5, 1e-5, cur_it, wp_it, max_it, wpe=0.3)[0]) 137 | 138 | plt.figure() 139 | plt.title(sche) 140 | plt.plot(it, lr, 'b', label=sche) 141 | plt.xlabel('it'), plt.ylabel('lr') 142 | plt.legend() 143 | 144 | plt.savefig('lr.jpg') 145 | 146 | 147 | if __name__ == '__main__': 148 | plot() 149 | -------------------------------------------------------------------------------- /infinity/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import subprocess 4 | import time 5 | import re 6 | from typing import List, Optional, Tuple 7 | 8 | import torch 9 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 10 | 11 | import glob 12 | import shutil 13 | from infinity.utils import arg_util 14 | import infinity.utils.dist as dist 15 | 16 | 17 | def glob_with_epoch_iter(pattern, recursive=False): 18 | def extract_ep_iter(filename): 19 | match = re.search(r'ep(\d+)-iter(\d+)', filename) 20 | if match: 21 | ep = int(match.group(1)) 22 | iter_idx = int(match.group(2)) 23 | return ep, iter_idx 24 | return 0, 0 25 | return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True) 26 | 27 | 28 | def glob_with_global_step(pattern, recursive=False): 29 | def extract_ep_iter(filename): 30 | match = re.search(r'global_step_(\d+)', filename) 31 | if match: 32 | iter_idx = int(match.group(1)) 33 | return iter_idx 34 | return 0 35 | return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True) 36 | 37 | 38 | class CKPTSaver(object): 39 | def __init__(self, is_master: bool, eval_milestone: List[Tuple[float, float]]): 40 | self.is_master = is_master 41 | self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device()) 42 | self.sp_also: subprocess.Popen = None 43 | self.sp_best: subprocess.Popen = None 44 | self.sp_backup: subprocess.Popen = None 45 | self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone 46 | 47 | def sav( 48 | self, args: arg_util.Args, g_it: int, next_ep: int, next_it: int, trainer, 49 | acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None, 50 | also_save_to: str = None, best_save_to: str = None, 51 | ): 52 | self.time_stamp[1] = time.time() 53 | dist.broadcast(self.time_stamp, src_rank=0) 54 | last_save_time, cur_time = self.time_stamp.cpu().tolist() 55 | 56 | auto_save = cur_time - last_save_time > 20 * 60 57 | need_save = also_save_to is not None or best_save_to is not None or next_ep == args.ep or auto_save 58 | if not need_save: 59 | return 60 | 61 | if acc_str is not None: self.acc_str = acc_str 62 | if eval_milestone is not None: self.eval_milestone = eval_milestone 63 | 64 | fname = f'ar-ckpt-giter{g_it//1000:03d}K-ep{next_ep}-iter{next_it}-last.pth' if args.gpt_training else f'ckpt-last.pth' 65 | local_out_ckpt = os.path.join(args.local_out_path, fname) 66 | 67 | # NOTE: all rank should call this state_dict(), not master only! 68 | trainer_state = trainer.state_dict() 69 | 70 | if self.is_master: 71 | stt = time.time() 72 | torch.save({ 73 | 'args': args.state_dict(), 74 | 'gpt_training': args.gpt_training, 75 | 'arch': args.model if args.gpt_training else args.vv, 76 | 'epoch': next_ep, 77 | 'iter': next_it, 78 | 'trainer': trainer_state, 79 | 'acc_str': self.acc_str, 80 | 'milestones': self.eval_milestone, 81 | }, local_out_ckpt) 82 | 83 | print(f'[CKPTSaver][rank00] start: {also_save_to=} {best_save_to=} {(next_ep == args.ep)=} {auto_save=} | see {local_out_ckpt}', flush=True) 84 | print(f'[CKPTSaver][rank00] dbg: {args.bed=}', flush=True) 85 | if auto_save: 86 | if self.sp_backup is not None: 87 | self.sp_backup.wait(timeout=300); self.sp_backup.kill(); self.sp_backup.communicate() 88 | self.time_stamp[0] = time.time() 89 | 90 | def auto_sync(source_filename, target_filename): 91 | cmd = f'cp -r {source_filename} {target_filename}' 92 | self.sp_backup = subprocess.Popen(cmd, shell=True, bufsize=-1) 93 | print(f'[CKPTSaver] auto_save cmd: {cmd}', flush=True) 94 | 95 | local_files = glob.glob(f"{args.local_out_path}/*") 96 | for filename in local_files: 97 | basename = os.path.basename(filename) 98 | target_filename = f'{args.bed}/{basename}' 99 | if basename.endswith('.pth'): 100 | if not os.path.isfile(target_filename): 101 | auto_sync(filename, target_filename) 102 | else: 103 | auto_sync(filename, target_filename) 104 | cost = time.time() - stt 105 | print(f'[CKPTSaver][rank00] cost: {cost:.2f}s', flush=True) 106 | 107 | del trainer_state 108 | time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3) 109 | dist.barrier() 110 | 111 | 112 | def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, str, List[Tuple[float, float]], dict, dict]: 113 | info = [] 114 | resume = '' 115 | if args.auto_resume: 116 | for dd in (args.local_out_path, args.bed): 117 | all_ckpt = glob_with_epoch_iter(os.path.join(dd, pattern)) 118 | if len(all_ckpt): break 119 | if len(all_ckpt) == 0: 120 | info.append(f'[auto_resume] no ckpt found @ {pattern}') 121 | info.append(f'[auto_resume quit]') 122 | else: 123 | resume = all_ckpt[0] 124 | info.append(f'[auto_resume] auto load from @ {resume} ...') 125 | else: 126 | info.append(f'[auto_resume] disabled') 127 | info.append(f'[auto_resume quit]') 128 | 129 | if len(resume) == 0: 130 | return info, 0, 0, '[no acc str]', [], {}, {} 131 | 132 | print(f'auto resume from {resume}') 133 | 134 | try: 135 | ckpt = torch.load(resume, map_location='cpu') 136 | except Exception as e: 137 | info.append(f'[auto_resume] failed, {e} @ {resume}') 138 | if len(all_ckpt) < 2: 139 | return info, 0, 0, '[no acc str]', [], {}, {} 140 | try: # another chance to load from bytenas 141 | ckpt = torch.load(all_ckpt[1], map_location='cpu') 142 | except Exception as e: 143 | info.append(f'[auto_resume] failed, {e} @ {all_ckpt[1]}') 144 | return info, 0, 0, '[no acc str]', [], {}, {} 145 | 146 | dist.barrier() 147 | ep, it = ckpt['epoch'], ckpt['iter'] 148 | eval_milestone = ckpt.get('milestones', []) 149 | info.append(f'[auto_resume success] resume from ep{ep}, it{it}, eval_milestone: {eval_milestone}') 150 | return info, ep, it, ckpt.get('acc_str', '[no acc str]'), eval_milestone, ckpt['trainer'], ckpt['args'] 151 | -------------------------------------------------------------------------------- /infinity/utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import torch 3 | from torchvision.utils import make_grid 4 | import torch.distributed as dist 5 | from PIL import Image 6 | import os 7 | import argparse 8 | import hashlib 9 | import math 10 | 11 | 12 | def is_main_process(): 13 | return dist.get_rank() == 0 14 | 15 | def namespace_to_dict(namespace): 16 | return { 17 | k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v 18 | for k, v in vars(namespace).items() 19 | } 20 | 21 | 22 | def generate_run_id(exp_name): 23 | # https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits 24 | return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8) 25 | 26 | 27 | def initialize(args, entity, exp_name, project_name): 28 | config_dict = namespace_to_dict(args) 29 | wandb.login(key=os.environ["WANDB_KEY"]) 30 | wandb.init( 31 | entity=entity, 32 | project=project_name, 33 | name=exp_name, 34 | config=config_dict, 35 | id=generate_run_id(exp_name), 36 | resume="allow", 37 | ) 38 | 39 | 40 | def log(stats, step=None): 41 | if is_main_process(): 42 | wandb.log({k: v for k, v in stats.items()}, step=step) 43 | 44 | 45 | def log_image(name, sample, step=None): 46 | if is_main_process(): 47 | sample = array2grid(sample) 48 | wandb.log({f"{name}": wandb.Image(sample), "train_step": step}) 49 | 50 | 51 | def array2grid(x): 52 | nrow = round(math.sqrt(x.size(0))) 53 | x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1)) 54 | x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy() 55 | return x -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://cog.run/python 3 | 4 | import os 5 | import argparse 6 | import subprocess 7 | import time 8 | from cog import BasePredictor, Input, Path 9 | import torch 10 | import cv2 11 | import numpy as np 12 | from tools.run_infinity import ( 13 | load_tokenizer, 14 | load_infinity, 15 | load_visual_tokenizer, 16 | gen_one_img, 17 | ) 18 | from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates 19 | 20 | MODEL_CACHE = "model_cache" 21 | MODEL_URL = f"https://weights.replicate.delivery/default/FoundationVision/Infinity/model_cache.tar" 22 | 23 | 24 | def download_weights(url, dest): 25 | start = time.time() 26 | print("downloading url: ", url) 27 | print("downloading to: ", dest) 28 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False) 29 | print("downloading took: ", time.time() - start) 30 | 31 | 32 | def load_transformer(vae, args): 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | model_path = args.model_path 35 | 36 | # Define model configuration based on type 37 | model_configurations = { 38 | "infinity_2b": dict( 39 | depth=32, 40 | embed_dim=2048, 41 | num_heads=2048 // 128, 42 | drop_path_rate=0.1, 43 | mlp_ratio=4, 44 | block_chunks=8, 45 | ), 46 | "infinity_layer12": dict( 47 | depth=12, 48 | embed_dim=768, 49 | num_heads=8, 50 | drop_path_rate=0.1, 51 | mlp_ratio=4, 52 | block_chunks=4, 53 | ), 54 | "infinity_layer16": dict( 55 | depth=16, 56 | embed_dim=1152, 57 | num_heads=12, 58 | drop_path_rate=0.1, 59 | mlp_ratio=4, 60 | block_chunks=4, 61 | ), 62 | "infinity_layer24": dict( 63 | depth=24, 64 | embed_dim=1536, 65 | num_heads=16, 66 | drop_path_rate=0.1, 67 | mlp_ratio=4, 68 | block_chunks=4, 69 | ), 70 | "infinity_layer32": dict( 71 | depth=32, 72 | embed_dim=2080, 73 | num_heads=20, 74 | drop_path_rate=0.1, 75 | mlp_ratio=4, 76 | block_chunks=4, 77 | ), 78 | "infinity_layer40": dict( 79 | depth=40, 80 | embed_dim=2688, 81 | num_heads=24, 82 | drop_path_rate=0.1, 83 | mlp_ratio=4, 84 | block_chunks=4, 85 | ), 86 | "infinity_layer48": dict( 87 | depth=48, 88 | embed_dim=3360, 89 | num_heads=28, 90 | drop_path_rate=0.1, 91 | mlp_ratio=4, 92 | block_chunks=4, 93 | ), 94 | } 95 | 96 | kwargs_model = model_configurations.get(args.model_type, {}) 97 | if not kwargs_model: 98 | raise ValueError(f"Unknown model type: {args.model_type}") 99 | 100 | infinity = load_infinity( 101 | rope2d_each_sa_layer=args.rope2d_each_sa_layer, 102 | rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, 103 | use_scale_schedule_embedding=args.use_scale_schedule_embedding, 104 | pn=args.pn, 105 | use_bit_label=args.use_bit_label, 106 | add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, 107 | model_path=model_path, # Directly use model_path 108 | scale_schedule=None, 109 | vae=vae, 110 | device=device, 111 | model_kwargs=kwargs_model, 112 | text_channels=args.text_channels, 113 | apply_spatial_patchify=args.apply_spatial_patchify, 114 | use_flex_attn=args.use_flex_attn, 115 | bf16=args.bf16, 116 | ) 117 | return infinity 118 | 119 | 120 | class Predictor(BasePredictor): 121 | def setup(self) -> None: 122 | """Load the model into memory to make running multiple predictions efficient""" 123 | 124 | if not os.path.exists(MODEL_CACHE): 125 | print("downloading") 126 | download_weights(MODEL_URL, MODEL_CACHE) 127 | 128 | model_path = f"{MODEL_CACHE}/FoundationVision/Infinity/infinity_2b_reg.pth" 129 | vae_path = f"{MODEL_CACHE}/FoundationVision/Infinity/infinity_vae_d32reg.pth" 130 | text_encoder_ckpt = f"{MODEL_CACHE}/google/flan-t5-xl" 131 | self.args = argparse.Namespace( 132 | pn="1M", 133 | model_path=model_path, 134 | cfg_insertion_layer=0, 135 | vae_type=32, 136 | vae_path=vae_path, 137 | add_lvl_embeding_only_first_block=1, 138 | use_bit_label=1, 139 | model_type="infinity_2b", 140 | rope2d_each_sa_layer=1, 141 | rope2d_normalized_by_hw=2, 142 | use_scale_schedule_embedding=0, 143 | sampling_per_bits=1, 144 | text_encoder_ckpt=text_encoder_ckpt, 145 | text_channels=2048, 146 | apply_spatial_patchify=0, 147 | h_div_w_template=1.000, 148 | use_flex_attn=0, 149 | cache_dir="/tmp/cache", 150 | checkpoint_type="torch", 151 | bf16=1, 152 | ) 153 | 154 | self.text_tokenizer, self.text_encoder = load_tokenizer( 155 | t5_path=text_encoder_ckpt 156 | ) 157 | # load vae 158 | self.vae = load_visual_tokenizer(self.args) 159 | # load infinity 160 | self.infinity = load_transformer(self.vae, self.args) 161 | 162 | def predict( 163 | self, 164 | prompt: str = Input( 165 | description="Input prompt", 166 | default="alien spaceship enterprise", 167 | ), 168 | guidance_scale: float = Input( 169 | description="Scale for classifier-free guidance", ge=1, le=10, default=3 170 | ), 171 | tau: float = Input(description="tau in self attention", default=0.5), 172 | seed: int = Input( 173 | description="Random seed. Leave blank to randomize the seed", default=None 174 | ), 175 | ) -> Path: 176 | """Run a single prediction on the model""" 177 | if seed is None: 178 | seed = int.from_bytes(os.urandom(2), "big") 179 | print(f"Using seed: {seed}") 180 | 181 | h_div_w = 1 / 1 # aspect ratio, height:width 182 | h_div_w_template_ = h_div_w_templates[ 183 | np.argmin(np.abs(h_div_w_templates - h_div_w)) 184 | ] 185 | scale_schedule = dynamic_resolution_h_w[h_div_w_template_][self.args.pn][ 186 | "scales" 187 | ] 188 | scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] 189 | generated_image = gen_one_img( 190 | self.infinity, 191 | self.vae, 192 | self.text_tokenizer, 193 | self.text_encoder, 194 | prompt, 195 | g_seed=seed, 196 | gt_leak=0, 197 | gt_ls_Bl=None, 198 | cfg_list=guidance_scale, 199 | tau_list=tau, 200 | scale_schedule=scale_schedule, 201 | cfg_insertion_layer=[self.args.cfg_insertion_layer], 202 | vae_type=self.args.vae_type, 203 | sampling_per_bits=self.args.sampling_per_bits, 204 | enable_positive_prompt=0, 205 | ) 206 | output_path = "/tmp/out.png" 207 | cv2.imwrite(output_path, generated_image.cpu().numpy()) 208 | return Path(output_path) 209 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | typed-argument-parser 3 | seaborn 4 | kornia 5 | gputil 6 | colorama 7 | omegaconf 8 | pandas 9 | timm==0.9.6 10 | decord 11 | transformers 12 | torch==2.5.1 13 | pytz 14 | pandas 15 | wandb 16 | colorama 17 | imageio 18 | einops 19 | openai 20 | httpx==0.20.0 21 | opencv-python 22 | flash_attn 23 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | infer_eval_image_reward() { 4 | ${pip_ext} install image-reward pytorch_lightning 5 | ${pip_ext} install -U timm diffusers 6 | ${pip_ext} install openai==1.34.0 7 | ${pip_ext} install httpx==0.20.0 8 | 9 | # step 1, infer images 10 | ${python_ext} evaluation/image_reward/infer4eval.py \ 11 | --cfg ${cfg} \ 12 | --tau ${tau} \ 13 | --pn ${pn} \ 14 | --model_path ${infinity_model_path} \ 15 | --vae_type ${vae_type} \ 16 | --vae_path ${vae_path} \ 17 | --add_lvl_embeding_only_first_block ${add_lvl_embeding_only_first_block} \ 18 | --use_bit_label ${use_bit_label} \ 19 | --model_type ${model_type} \ 20 | --rope2d_each_sa_layer ${rope2d_each_sa_layer} \ 21 | --rope2d_normalized_by_hw ${rope2d_normalized_by_hw} \ 22 | --use_scale_schedule_embedding ${use_scale_schedule_embedding} \ 23 | --cfg ${cfg} \ 24 | --tau ${tau} \ 25 | --checkpoint_type ${checkpoint_type} \ 26 | --text_encoder_ckpt ${text_encoder_ckpt} \ 27 | --text_channels ${text_channels} \ 28 | --apply_spatial_patchify ${apply_spatial_patchify} \ 29 | --cfg_insertion_layer ${cfg_insertion_layer} \ 30 | --outdir ${out_dir} 31 | 32 | # step 2, compute image reward 33 | ${pip_ext} install diffusers==0.16.0 34 | ${pip_ext} install git+https://github.com/openai/CLIP.git ftfy 35 | ${python_ext} evaluation/image_reward/cal_imagereward.py \ 36 | --meta_file ${out_dir}/metadata.jsonl 37 | } 38 | 39 | infer_eval_hpsv21() { 40 | ${pip_ext} install hpsv2 41 | ${pip_ext}install -U diffusers 42 | sudo apt install python3-tk 43 | wget https://dl.fbaipublicfiles.com/mmf/clip/bpe_simple_vocab_16e6.txt.gz 44 | mv bpe_simple_vocab_16e6.txt.gz /home/tiger/.local/lib/python3.9/site-packages/hpsv2/src/open_clip 45 | 46 | mkdir -p ${out_dir} 47 | ${python_ext} evaluation/hpsv2/eval_hpsv2.py \ 48 | --cfg ${cfg} \ 49 | --tau ${tau} \ 50 | --pn ${pn} \ 51 | --model_path ${infinity_model_path} \ 52 | --vae_type ${vae_type} \ 53 | --vae_path ${vae_path} \ 54 | --add_lvl_embeding_only_first_block ${add_lvl_embeding_only_first_block} \ 55 | --use_bit_label ${use_bit_label} \ 56 | --model_type ${model_type} \ 57 | --rope2d_each_sa_layer ${rope2d_each_sa_layer} \ 58 | --rope2d_normalized_by_hw ${rope2d_normalized_by_hw} \ 59 | --use_scale_schedule_embedding ${use_scale_schedule_embedding} \ 60 | --cfg ${cfg} \ 61 | --tau ${tau} \ 62 | --checkpoint_type ${checkpoint_type} \ 63 | --text_encoder_ckpt ${text_encoder_ckpt} \ 64 | --text_channels ${text_channels} \ 65 | --apply_spatial_patchify ${apply_spatial_patchify} \ 66 | --cfg_insertion_layer ${cfg_insertion_layer} \ 67 | --outdir ${out_dir}/images | tee ${out_dir}/log.txt 68 | } 69 | 70 | test_gen_eval() { 71 | ${pip_ext} install -U openmim 72 | mim install mmengine mmcv-full==1.7.2 73 | ${pip_ext} install mmdet==2.28.2 pytorch_lightning clip_benchmark open-clip-torch==2.20.0 74 | ${pip_ext} install -U diffusers 75 | sudo apt install libgl1 76 | ${pip_ext} install openai 77 | ${pip_ext} install httpx==0.20.0 78 | 79 | # run inference 80 | ${python_ext} evaluation/gen_eval/infer4eval.py \ 81 | --cfg ${cfg} \ 82 | --tau ${tau} \ 83 | --pn ${pn} \ 84 | --model_path ${infinity_model_path} \ 85 | --vae_type ${vae_type} \ 86 | --vae_path ${vae_path} \ 87 | --add_lvl_embeding_only_first_block ${add_lvl_embeding_only_first_block} \ 88 | --use_bit_label ${use_bit_label} \ 89 | --model_type ${model_type} \ 90 | --rope2d_each_sa_layer ${rope2d_each_sa_layer} \ 91 | --rope2d_normalized_by_hw ${rope2d_normalized_by_hw} \ 92 | --use_scale_schedule_embedding ${use_scale_schedule_embedding} \ 93 | --cfg ${cfg} \ 94 | --tau ${tau} \ 95 | --checkpoint_type ${checkpoint_type} \ 96 | --text_encoder_ckpt ${text_encoder_ckpt} \ 97 | --text_channels ${text_channels} \ 98 | --apply_spatial_patchify ${apply_spatial_patchify} \ 99 | --cfg_insertion_layer ${cfg_insertion_layer} \ 100 | --outdir ${out_dir}/images \ 101 | --rewrite_prompt ${rewrite_prompt} 102 | 103 | # detect objects 104 | ${python_ext} evaluation/gen_eval/evaluate_images.py ${out_dir}/images \ 105 | --outfile ${out_dir}/results/det.jsonl \ 106 | --model-config evaluation/gen_eval/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py \ 107 | --model-path weights/mask2former 108 | 109 | # accumulate results 110 | ${python_ext} evaluation/gen_eval/summary_scores.py ${out_dir}/results/det.jsonl > ${out_dir}/results/res.txt 111 | cat ${out_dir}/results/res.txt 112 | } 113 | 114 | test_fid() { 115 | ${pip_ext} install pytorch_fid 116 | 117 | # step 1, infer images 118 | ${python_ext} tools/comprehensive_infer.py \ 119 | --cfg ${cfg} \ 120 | --tau ${tau} \ 121 | --pn ${pn} \ 122 | --model_path ${infinity_model_path} \ 123 | --vae_type ${vae_type} \ 124 | --vae_path ${vae_path} \ 125 | --add_lvl_embeding_only_first_block ${add_lvl_embeding_only_first_block} \ 126 | --use_bit_label ${use_bit_label} \ 127 | --model_type ${model_type} \ 128 | --rope2d_each_sa_layer ${rope2d_each_sa_layer} \ 129 | --rope2d_normalized_by_hw ${rope2d_normalized_by_hw} \ 130 | --use_scale_schedule_embedding ${use_scale_schedule_embedding} \ 131 | --cfg ${cfg} \ 132 | --tau ${tau} \ 133 | --checkpoint_type ${checkpoint_type} \ 134 | --text_encoder_ckpt ${text_encoder_ckpt} \ 135 | --text_channels ${text_channels} \ 136 | --apply_spatial_patchify ${apply_spatial_patchify} \ 137 | --cfg_insertion_layer ${cfg_insertion_layer} \ 138 | --coco30k_prompts 0 \ 139 | --save4fid_eval 1 \ 140 | --jsonl_filepath ${jsonl_filepath} \ 141 | --long_caption_fid ${long_caption_fid} \ 142 | --out_dir ${out_dir} \ 143 | 144 | # step 2, compute fid 145 | ${python_ext} tools/fid_score.py \ 146 | ${out_dir}/pred \ 147 | ${out_dir}/gt | tee ${out_dir}/log.txt 148 | } 149 | 150 | test_val_loss() { 151 | ${python_ext} evaluation/validation_loss/validation_loss.py \ 152 | --cfg ${cfg} \ 153 | --tau ${tau} \ 154 | --pn ${pn} \ 155 | --model_path ${infinity_model_path} \ 156 | --vae_type ${vae_type} \ 157 | --vae_path ${vae_path} \ 158 | --add_lvl_embeding_only_first_block ${add_lvl_embeding_only_first_block} \ 159 | --use_bit_label ${use_bit_label} \ 160 | --model_type ${model_type} \ 161 | --rope2d_each_sa_layer ${rope2d_each_sa_layer} \ 162 | --rope2d_normalized_by_hw ${rope2d_normalized_by_hw} \ 163 | --use_scale_schedule_embedding ${use_scale_schedule_embedding} \ 164 | --cfg ${cfg} \ 165 | --tau ${tau} \ 166 | --checkpoint_type ${checkpoint_type} \ 167 | --text_encoder_ckpt ${text_encoder_ckpt} \ 168 | --text_channels ${text_channels} \ 169 | --apply_spatial_patchify ${apply_spatial_patchify} \ 170 | --cfg_insertion_layer ${cfg_insertion_layer} \ 171 | --save_dir ${out_dir} \ 172 | --reweight_loss_by_scale ${reweight_loss_by_scale} \ 173 | --meta_folder ${jsonl_folder} \ 174 | --noise_apply_strength ${noise_apply_strength} \ 175 | --bf16 0 \ 176 | --log_freq 10 177 | } 178 | 179 | 180 | python_ext=python3 181 | pip_ext=pip3 182 | 183 | # set arguments for inference 184 | pn=1M 185 | model_type=infinity_2b 186 | use_scale_schedule_embedding=0 187 | use_bit_label=1 188 | checkpoint_type='torch' 189 | infinity_model_path=weights/infinity_2b_reg.pth 190 | out_dir_root=output/infinity_2b_evaluation 191 | vae_type=32 192 | vae_path=weights/infinity_vae_d32_reg.pth 193 | cfg=4 194 | tau=1 195 | rope2d_normalized_by_hw=2 196 | add_lvl_embeding_only_first_block=1 197 | rope2d_each_sa_layer=1 198 | text_encoder_ckpt=weights/flan-t5-xl 199 | text_channels=2048 200 | apply_spatial_patchify=0 201 | cfg_insertion_layer=0 202 | sub_fix=cfg${cfg}_tau${tau}_cfg_insertion_layer${cfg_insertion_layer} 203 | 204 | # ImageReward 205 | out_dir=${out_dir_root}/image_reward_${sub_fix} 206 | # infer_eval_image_reward 207 | 208 | # HPS v2.1 209 | out_dir=${out_dir_root}/hpsv21_${sub_fix} 210 | # infer_eval_hpsv21 211 | 212 | # GenEval 213 | rewrite_prompt=1 214 | out_dir=${out_dir_root}/gen_eval_${sub_fix}_rewrite_prompt${rewrite_prompt}_round2_real_rewrite 215 | test_gen_eval 216 | 217 | # long caption fid 218 | long_caption_fid=1 219 | jsonl_filepath='[YOUR VAL JSONL FILEPATH]' 220 | out_dir=${out_dir_root}/val_long_caption_fid_${sub_fix} 221 | rm -rf ${out_dir} 222 | # test_fid 223 | 224 | # test val loss 225 | out_dir=${out_dir_root}/val_loss_${sub_fix} 226 | reweight_loss_by_scale=0 227 | jsonl_folder='[YOUR VAL JSONL FILEPATH]' 228 | noise_apply_strength=0.2 229 | # test_val_loss 230 | -------------------------------------------------------------------------------- /scripts/infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set arguments for inference 4 | pn=1M 5 | model_type=infinity_2b 6 | use_scale_schedule_embedding=0 7 | use_bit_label=1 8 | checkpoint_type='torch' 9 | infinity_model_path=weights/infinity_2b_reg.pth 10 | vae_type=32 11 | vae_path=weights/infinity_vae_d32_reg.pth 12 | cfg=4 13 | tau=0.5 14 | rope2d_normalized_by_hw=2 15 | add_lvl_embeding_only_first_block=1 16 | rope2d_each_sa_layer=1 17 | text_encoder_ckpt=weights/flan-t5-xl 18 | text_channels=2048 19 | apply_spatial_patchify=0 20 | 21 | # run inference 22 | python3 tools/run_infinity.py \ 23 | --cfg ${cfg} \ 24 | --tau ${tau} \ 25 | --pn ${pn} \ 26 | --model_path ${infinity_model_path} \ 27 | --vae_type ${vae_type} \ 28 | --vae_path ${vae_path} \ 29 | --add_lvl_embeding_only_first_block ${add_lvl_embeding_only_first_block} \ 30 | --use_bit_label ${use_bit_label} \ 31 | --model_type ${model_type} \ 32 | --rope2d_each_sa_layer ${rope2d_each_sa_layer} \ 33 | --rope2d_normalized_by_hw ${rope2d_normalized_by_hw} \ 34 | --use_scale_schedule_embedding ${use_scale_schedule_embedding} \ 35 | --cfg ${cfg} \ 36 | --tau ${tau} \ 37 | --checkpoint_type ${checkpoint_type} \ 38 | --text_encoder_ckpt ${text_encoder_ckpt} \ 39 | --text_channels ${text_channels} \ 40 | --apply_spatial_patchify ${apply_spatial_patchify} \ 41 | --prompt "a beautifual Chinese woman in her late 30s, wearing a suit and tie, looking at the camera" \ 42 | --seed 1 \ 43 | --save_file tmp.jpg 44 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | # set dist args 6 | # SINGLE=1 7 | nproc_per_node=${ARNOLD_WORKER_GPU} 8 | 9 | if [ ! -z "$SINGLE" ] && [ "$SINGLE" != "0" ]; then 10 | echo "[single node alone] SINGLE=$SINGLE" 11 | nnodes=1 12 | node_rank=0 13 | nproc_per_node=1 14 | master_addr=127.0.0.1 15 | master_port=12345 16 | else 17 | MASTER_NODE_ID=0 18 | nnodes=${ARNOLD_WORKER_NUM} 19 | node_rank=${ARNOLD_ID} 20 | master_addr="METIS_WORKER_${MASTER_NODE_ID}_HOST" 21 | master_addr=${!master_addr} 22 | master_port="METIS_WORKER_${MASTER_NODE_ID}_PORT" 23 | master_port=${!master_port} 24 | ports=(`echo $master_port | tr ',' ' '`) 25 | master_port=${ports[0]} 26 | fi 27 | 28 | echo "[nproc_per_node: ${nproc_per_node}]" 29 | echo "[nnodes: ${nnodes}]" 30 | echo "[node_rank: ${node_rank}]" 31 | echo "[master_addr: ${master_addr}]" 32 | echo "[master_port: ${master_port}]" 33 | 34 | # set up envs 35 | export OMP_NUM_THREADS=8 36 | export NCCL_IB_DISABLE=0 37 | export NCCL_IB_GID_INDEX=3 38 | export NCCL_SOCKET_IFNAME=eth0 39 | 40 | 41 | BED=checkpoints 42 | LOCAL_OUT=local_output 43 | mkdir -p $BED 44 | mkdir -p $LOCAL_OUT 45 | 46 | 47 | export COMPILE_GAN=0 48 | export USE_TIMELINE_SDK=1 49 | export CUDA_TIMER_STREAM_KAFKA_CLUSTER=bmq_data_va 50 | export CUDA_TIMER_STREAM_KAFKA_TOPIC=megatron_cuda_timer_tracing_original_v2 51 | export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" 52 | 53 | wandb offline 54 | exp_name=debug 55 | bed_path=checkpoints/${exp_name}/ 56 | data_path='data/infinity_toy_data/splits' 57 | video_data_path='' 58 | local_out_path=$LOCAL_OUT/${exp_name} 59 | 60 | rm -rf ${bed_path} 61 | rm -rf ${local_out_path} 62 | 63 | torchrun \ 64 | --nproc_per_node=${nproc_per_node} \ 65 | --nnodes=${nnodes} \ 66 | --node_rank=${node_rank} \ 67 | --master_addr=${master_addr} \ 68 | --master_port=${master_port} \ 69 | train.py \ 70 | --ep=100 \ 71 | --opt=adamw \ 72 | --cum=3 \ 73 | --sche=lin0 \ 74 | --fp16=2 \ 75 | --ada=0.9_0.97 \ 76 | --tini=-1 \ 77 | --tclip=5 \ 78 | --flash=0 \ 79 | --alng=5e-06 \ 80 | --saln=1 \ 81 | --cos=1 \ 82 | --enable_checkpointing=full-block \ 83 | --local_out_path ${local_out_path} \ 84 | --task_type='t2i' \ 85 | --bed=${bed_path} \ 86 | --data_path=${data_path} \ 87 | --video_data_path=${video_data_path} \ 88 | --exp_name=${exp_name} \ 89 | --tblr=6e-3 \ 90 | --pn 0.06M \ 91 | --model=2bc8 \ 92 | --lbs=4 \ 93 | --workers=8 \ 94 | --short_cap_prob 0.5 \ 95 | --online_t5=1 \ 96 | --use_streaming_dataset 1 \ 97 | --iterable_data_buffersize 30000 \ 98 | --Ct5=2048 \ 99 | --t5_path=weights/flan-t5-xl \ 100 | --vae_type 32 \ 101 | --vae_ckpt=weights/infinity_vae_d32_rdn_short.pth \ 102 | --wp 0.00000001 \ 103 | --wpe=1 \ 104 | --dynamic_resolution_across_gpus 1 \ 105 | --enable_dynamic_length_prompt 1 \ 106 | --reweight_loss_by_scale 1 \ 107 | --add_lvl_embeding_only_first_block 1 \ 108 | --rope2d_each_sa_layer 1 \ 109 | --rope2d_normalized_by_hw 2 \ 110 | --use_fsdp_model_ema 0 \ 111 | --always_training_scales 100 \ 112 | --use_bit_label 1 \ 113 | --zero=2 \ 114 | --save_model_iters_freq 100 \ 115 | --log_freq=50 \ 116 | --checkpoint_type='torch' \ 117 | --prefetch_factor=16 \ 118 | --noise_apply_strength 0.3 \ 119 | --noise_apply_layers 13 \ 120 | --apply_spatial_patchify 0 \ 121 | --use_flex_attn=True \ 122 | --pad=128 -------------------------------------------------------------------------------- /tools/interactive_infer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import random\n", 10 | "import torch\n", 11 | "torch.cuda.set_device(0)\n", 12 | "import cv2\n", 13 | "import numpy as np\n", 14 | "from tools.run_infinity import *\n", 15 | "\n", 16 | "model_path='weights/infinity_2b_reg.pth'\n", 17 | "vae_path='weights/infinity_vae_d32_reg.pth'\n", 18 | "text_encoder_ckpt = 'weights/flan-t5-xl'\n", 19 | "args=argparse.Namespace(\n", 20 | " pn='1M',\n", 21 | " model_path=model_path,\n", 22 | " cfg_insertion_layer=0,\n", 23 | " vae_type=32,\n", 24 | " vae_path=vae_path,\n", 25 | " add_lvl_embeding_only_first_block=1,\n", 26 | " use_bit_label=1,\n", 27 | " model_type='infinity_2b',\n", 28 | " rope2d_each_sa_layer=1,\n", 29 | " rope2d_normalized_by_hw=2,\n", 30 | " use_scale_schedule_embedding=0,\n", 31 | " sampling_per_bits=1,\n", 32 | " text_encoder_ckpt=text_encoder_ckpt,\n", 33 | " text_channels=2048,\n", 34 | " apply_spatial_patchify=0,\n", 35 | " h_div_w_template=1.000,\n", 36 | " use_flex_attn=0,\n", 37 | " cache_dir='/dev/shm',\n", 38 | " checkpoint_type='torch',\n", 39 | " seed=0,\n", 40 | " bf16=1,\n", 41 | " save_file='tmp.jpg'\n", 42 | ")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# load text encoder\n", 52 | "text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)\n", 53 | "# load vae\n", 54 | "vae = load_visual_tokenizer(args)\n", 55 | "# load infinity\n", 56 | "infinity = load_transformer(vae, args)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "prompt = \"\"\"alien spaceship enterprise\"\"\"\n", 66 | "cfg = 3\n", 67 | "tau = 0.5\n", 68 | "h_div_w = 1/1 # aspect ratio, height:width\n", 69 | "seed = random.randint(0, 10000)\n", 70 | "enable_positive_prompt=0\n", 71 | "\n", 72 | "h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]\n", 73 | "scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']\n", 74 | "scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]\n", 75 | "generated_image = gen_one_img(\n", 76 | " infinity,\n", 77 | " vae,\n", 78 | " text_tokenizer,\n", 79 | " text_encoder,\n", 80 | " prompt,\n", 81 | " g_seed=seed,\n", 82 | " gt_leak=0,\n", 83 | " gt_ls_Bl=None,\n", 84 | " cfg_list=cfg,\n", 85 | " tau_list=tau,\n", 86 | " scale_schedule=scale_schedule,\n", 87 | " cfg_insertion_layer=[args.cfg_insertion_layer],\n", 88 | " vae_type=args.vae_type,\n", 89 | " sampling_per_bits=args.sampling_per_bits,\n", 90 | " enable_positive_prompt=enable_positive_prompt,\n", 91 | ")\n", 92 | "args.save_file = 'ipynb_tmp.jpg'\n", 93 | "os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)\n", 94 | "cv2.imwrite(args.save_file, generated_image.cpu().numpy())\n", 95 | "print(f'Save to {osp.abspath(args.save_file)}')" 96 | ] 97 | } 98 | ], 99 | "metadata": { 100 | "fileId": "8ac263ab-b18c-41dc-b409-0fb0f32525f0", 101 | "filePath": "/mnt/bn/foundation-vision/hanjian.thu123/infinity/infinity/tools/interactive_infer.ipynb", 102 | "kernelspec": { 103 | "display_name": "Python 3", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.9.2" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /tools/interactive_infer_8b.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import random\n", 10 | "import torch\n", 11 | "torch.cuda.set_device(2)\n", 12 | "import cv2\n", 13 | "import numpy as np\n", 14 | "from tools.run_infinity import *\n", 15 | "\n", 16 | "model_path='weights/infinity_8b_weights'\n", 17 | "vae_path='weights/infinity_vae_d56_f8_14_patchify.pth'\n", 18 | "text_encoder_ckpt = 'weights/flan-t5-xl-official'\n", 19 | "args=argparse.Namespace(\n", 20 | " pn='1M',\n", 21 | " model_path=model_path,\n", 22 | " cfg_insertion_layer=0,\n", 23 | " vae_type=14,\n", 24 | " vae_path=vae_path,\n", 25 | " add_lvl_embeding_only_first_block=1,\n", 26 | " use_bit_label=1,\n", 27 | " model_type='infinity_8b',\n", 28 | " rope2d_each_sa_layer=1,\n", 29 | " rope2d_normalized_by_hw=2,\n", 30 | " use_scale_schedule_embedding=0,\n", 31 | " sampling_per_bits=1,\n", 32 | " text_encoder_ckpt=text_encoder_ckpt,\n", 33 | " text_channels=2048,\n", 34 | " apply_spatial_patchify=1,\n", 35 | " h_div_w_template=1.000,\n", 36 | " use_flex_attn=0,\n", 37 | " cache_dir='/dev/shm',\n", 38 | " checkpoint_type='torch_shard',\n", 39 | " seed=0,\n", 40 | " bf16=1,\n", 41 | " save_file='tmp.jpg'\n", 42 | ")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 10, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "[Loading tokenizer and text encoder]\n" 55 | ] 56 | }, 57 | { 58 | "data": { 59 | "application/vnd.jupyter.widget-view+json": { 60 | "model_id": "3f68ce998b1546f185e6263884b382ef", 61 | "version_major": 2, 62 | "version_minor": 0 63 | }, 64 | "text/plain": [ 65 | "Loading checkpoint shards: 0%| | 0/2 [00:00" will trigger your partner bot to output an image of a forest morning, as described. 47 | You will be prompted by users looking to create detailed, amazing images. The way to accomplish this is to refine their short prompts and make them extremely detailed and descriptive. 48 | - You will only ever output a single image description sentence per user request. 49 | - Each image description sentence should be consist of "", where is the image description, is the parameter that control the image generation. 50 | Here are the guidelines to generate image description : 51 | - Refine users' prompts and make them extremely detailed and descriptive but keep the meaning unchanged (very important). 52 | - For particularly long users' prompts (>50 words), they can be outputted directly without refining. Image descriptions must be between 8-512 words. Extra words will be ignored. 53 | - If the user's prompt requires rendering text, enclose the text with single quotation marks and prefix it with "the text". 54 | Here are the guidelines to set : 55 | - Please first determine whether the image to be generated based on the user prompt is likely to contain a clear face. If it does, set ; if not, set . 56 | """ 57 | 58 | FEW_SHOT_HISTORY = [ 59 | {"role": "user", "content": "a tree"}, 60 | {"role": "assistant", "content": ""}, 61 | {"role": "user", "content": "a young girl with red hair"}, 62 | {"role": "assistant", "content": ""}, 63 | {"role": "user", "content": "a man, close-up"}, 64 | {"role": "assistant", "content": ""}, 65 | {"role": "user", "content": "Generate Never Stop Learning"}, 66 | {"role": "assistant", "content": ""}, 67 | ] 68 | 69 | class PromptRewriter(object): 70 | def __init__(self, system, few_shot_history): 71 | if not system: 72 | system = SYSTEM 73 | if not len(few_shot_history): 74 | few_shot_history = FEW_SHOT_HISTORY 75 | self.system = [{"role": "system", "content": system}] 76 | self.few_shot_history = few_shot_history 77 | 78 | def rewrite(self, prompt): 79 | messages = self.system + self.few_shot_history + [{"role": "user", "content": prompt}] 80 | result, _ = get_gpt_result(model_name='gpt-4o-2024-08-06', messages=messages, retry=5, ak=GPT_AK, return_json=False) 81 | assert result 82 | return result 83 | 84 | 85 | def get_gpt_result(model_name='gpt-4o-2024-05-13', messages=None, retry=5, ak=None, return_json=False): 86 | """ 87 | Retrieves a chat response using the GPT-4 model. 88 | Args: 89 | model_name (str, optional): The name of the GPT model to use. Defaults to 'gpt-4'. [gpt-3.5-turbo, gpt-4] 90 | retry (int, optional): The number of times to retry the chat API if there is an error. Defaults to 5. 91 | Returns: 92 | tuple: A tuple containing the chat response content (str) and the API usage (dict). 93 | Raises: 94 | Exception: If there is an error retrieving the chat response. 95 | """ 96 | openai_ak = ak 97 | client = openai.AzureOpenAI( 98 | azure_endpoint="https://search-va.byteintl.net/gpt/openapi/online/multimodal/crawl", 99 | api_version="2023-07-01-preview", 100 | api_key=openai_ak 101 | ) 102 | for i in range(retry): 103 | try: 104 | if return_json: 105 | completion = client.chat.completions.create( 106 | model=model_name, 107 | messages=messages, 108 | response_format={ "type": "json_object" }, 109 | ) 110 | else: 111 | completion = client.chat.completions.create( 112 | model=model_name, 113 | messages=messages, 114 | ) 115 | result = json.loads(completion.model_dump_json())['choices'][0]['message']['content'] 116 | return result,None 117 | except Exception as e: 118 | traceback.print_exc() 119 | if isinstance(e,KeyboardInterrupt): 120 | exit(0) 121 | sleep_time = 10 + random.randint(2,5)**(i+1) 122 | time.sleep(sleep_time) 123 | return None, -1 124 | 125 | if __name__ == '__main__': 126 | times = 0 127 | prompt_list = [] 128 | 129 | var_t2i_prompt_rewriter = PromptRewriter(system='', few_shot_history=[]) 130 | 131 | prompt_list = [ 132 | 'a tree', 133 | 'two dogs', 134 | 'an oil painting of a house', 135 | 'a Chinese model sits in the train. Magazine style', 136 | 'two girls', 137 | 'countryside', 138 | 'a rabbit fights with a tiger', 139 | 'a beach in Hawaii', 140 | ] 141 | 142 | for prompt in prompt_list: 143 | times += 1 144 | result = var_t2i_prompt_rewriter.rewrite(prompt) 145 | print(f'prompt: {prompt}, result: {result}') 146 | -------------------------------------------------------------------------------- /tools/reproduce.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import os 4 | import os.path as osp 5 | import cv2 6 | import numpy as np 7 | from run_infinity import * 8 | 9 | torch.cuda.set_device(0) 10 | model_path = '/workspace/Infinity/weights/infinity_2b_reg.pth' 11 | vae_path = '/workspace/Infinity/weights/infinity_vae_d32reg.pth' 12 | text_encoder_ckpt = '/workspace/Infinity/weights/flan-t5-xl' 13 | 14 | # SET 15 | args = argparse.Namespace( 16 | pn='1M', 17 | model_path=model_path, 18 | cfg_insertion_layer=0, 19 | vae_type=32, 20 | vae_path=vae_path, 21 | add_lvl_embeding_only_first_block=1, 22 | use_bit_label=1, 23 | model_type='infinity_2b', 24 | rope2d_each_sa_layer=1, 25 | rope2d_normalized_by_hw=2, 26 | use_scale_schedule_embedding=0, 27 | sampling_per_bits=1, 28 | text_encoder_ckpt=text_encoder_ckpt, 29 | text_channels=2048, 30 | apply_spatial_patchify=0, 31 | h_div_w_template=1.000, 32 | use_flex_attn=0, 33 | cache_dir='/dev/shm', 34 | checkpoint_type='torch', 35 | seed=0, 36 | bf16=1, 37 | save_file='tmp.jpg', 38 | enable_model_cache=0 39 | ) 40 | 41 | # LOAD 42 | text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt) 43 | vae = load_visual_tokenizer(args) 44 | infinity = load_transformer(vae, args) 45 | 46 | # PROMPT 47 | prompts = { 48 | "vintage_insect": "Insect made from vintage 1960s electronic components, capacitors, resistors, transistors, wires, diodes, solder, circuitboard.", 49 | "macro_closeup": "Denis Villeneuve's extreme macro cinematographic close-up in water.", 50 | "3d_school": "A creative 3D image to be placed at the bottom of a mobile application's homepage, depicting a miniature school and children carrying backpacks.", 51 | "explore_more": "Create an image with 'Explore More' in an adventurous font over a picturesque hiking trail.", 52 | "toy_car": "Close-up shot of a diecast toy car, diorama, night, lights from windows, bokeh, snow.", 53 | "fairy_house": "House: white; pink tinted windows; surrounded by flowers; cute; scenic; garden; fairy-like; epic; photography; photorealistic; insanely detailed and intricate; textures; grain; ultra-realistic.", 54 | "cat_fashion": "Hyperrealistic black and white photography of cats fashion show in style of Helmut Newton.", 55 | "spacefrog_astroduck": "Two superheroes called Spacefrog (a dashing green cartoon-like frog with a red cape) and Astroduck (a yellow fuzzy duck, part-robot, with blue/grey armor), near a garden pond, next to their spaceship, a classic flying saucer, called the Tadpole 3000. Photorealistic.", 56 | "miniature_village": "An enchanted miniature village bustling with activity, featuring tiny houses, markets, and residents.", 57 | "corgi_dog": "A close-up photograph of a Corgi dog. The dog is wearing a black hat and round, dark sunglasses. The Corgi has a joyful expression, with its mouth open and tongue sticking out, giving an impression of happiness or excitement.", 58 | "robot_eggplant": "a robot holding a huge eggplant, sunny nature background", 59 | "perfume_product": "Product photography, a perfume placed on a white marble table with pineapple, coconut, lime next to it as decoration, white curtains, full of intricate details, realistic, minimalist, layered gestures in a bright and concise atmosphere, minimalist style.", 60 | "mountain_landscape": "The image presents a picturesque mountainous landscape under a cloudy sky. The mountains, blanketed in lush greenery, rise majestically, their slopes dotted with clusters of trees and shrubs. The sky above is a canvas of blue, adorned with fluffy white clouds that add a sense of tranquility to the scene. In the foreground, a valley unfolds, nestled between the towering mountains. It appears to be a rural area, with a few buildings and structures visible, suggesting the presence of a small settlement. The buildings are scattered, blending harmoniously with the natural surroundings. The image is captured from a high vantage point, providing a sweeping view of the valley and the mountains." 61 | } 62 | 63 | # OUTPUT 64 | output_dir = "outputs" 65 | os.makedirs(output_dir, exist_ok=True) 66 | 67 | # GEN IMG 68 | for category, prompt in prompts.items(): 69 | cfg = 3 70 | tau = 0.5 71 | h_div_w = 1/1 # Aspect Ratio 72 | seed = random.randint(0, 10000) 73 | enable_positive_prompt = 0 74 | 75 | h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))] 76 | scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales'] 77 | scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] 78 | 79 | # GEN 80 | generated_image = gen_one_img( 81 | infinity, 82 | vae, 83 | text_tokenizer, 84 | text_encoder, 85 | prompt, 86 | g_seed=seed, 87 | gt_leak=0, 88 | gt_ls_Bl=None, 89 | cfg_list=cfg, 90 | tau_list=tau, 91 | scale_schedule=scale_schedule, 92 | cfg_insertion_layer=[args.cfg_insertion_layer], 93 | vae_type=args.vae_type, 94 | sampling_per_bits=args.sampling_per_bits, 95 | enable_positive_prompt=enable_positive_prompt, 96 | ) 97 | 98 | # SAVE 99 | save_path = osp.join(output_dir, f"re_{category}_test.jpg") 100 | cv2.imwrite(save_path, generated_image.cpu().numpy()) 101 | print(f"{category} image saved to {save_path}") 102 | -------------------------------------------------------------------------------- /tools/run_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 3 | import os.path as osp 4 | from typing import List 5 | import math 6 | import time 7 | import hashlib 8 | import yaml 9 | import argparse 10 | 11 | import numpy as np 12 | import torch 13 | import pandas as pd 14 | from tqdm import tqdm 15 | from PIL import Image, ImageEnhance 16 | import torch.nn.functional as F 17 | import torchvision 18 | 19 | 20 | # for distributed evaluation 21 | import torch.distributed as dist 22 | from torch.multiprocessing import spawn 23 | # for metrics 24 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 25 | from skimage.metrics import structural_similarity as ssim_loss 26 | from tools.fid_score import calculate_frechet_distance 27 | from tools.inception import InceptionV3 28 | import lpips 29 | import warnings 30 | 31 | warnings.filterwarnings("ignore") 32 | from infinity.models.bsq_vae.vae import vae_model 33 | from torchvision import transforms 34 | from torchvision.transforms import InterpolationMode 35 | 36 | 37 | def _pil_interp(method): 38 | if method == 'bicubic': 39 | return InterpolationMode.BICUBIC 40 | elif method == 'lanczos': 41 | return InterpolationMode.LANCZOS 42 | elif method == 'hamming': 43 | return InterpolationMode.HAMMING 44 | else: 45 | # default bilinear, do we want to allow nearest? 46 | return InterpolationMode.BILINEAR 47 | 48 | def vae_encode_decode_norm(vae, image_path, tgt_h, tgt_w, device, augmentations): 49 | # get normalized gt_img and recons_img in [-1, 1] 50 | pil_image = Image.open(image_path).convert('RGB') 51 | # inp = crop_to_tensor(pil_image, tgt_h, tgt_w) 52 | inp = augmentations(pil_image) 53 | inp = inp * 2 - 1 54 | 55 | inp = inp.unsqueeze(0).to(device) 56 | 57 | # decode by vae 58 | # Both inputs and outputs are in [-1, 1] 59 | recons_img, vq_output = vae(inp) 60 | gt_img = inp 61 | 62 | return gt_img, recons_img # (1, 3, H, W) 63 | 64 | def inference_eval(rank, world_size, args, vae, return_dict, val_txt, tgt_h, tgt_w, augmentations): 65 | # Don't remove this setup!!! dist.init_process_group is important for building loader (data.distributed.DistributedSampler) 66 | setup(rank, world_size) 67 | 68 | device = torch.device(f"cuda:{rank}") 69 | 70 | for param in vae.parameters(): 71 | param.requires_grad = False 72 | vae.to(device).eval() 73 | 74 | save_dir = 'results/%s'%(args.save) 75 | print('generating and saving video to %s...'%save_dir) 76 | os.makedirs(save_dir, exist_ok=True) 77 | 78 | # data = VideoData(args) 79 | 80 | # loader = data.val_dataloader() 81 | 82 | dims = 2048 83 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 84 | inception_model = InceptionV3([block_idx]).to(device) 85 | inception_model.eval() 86 | 87 | # loader_iter = iter(loader) 88 | 89 | pred_xs = [] 90 | pred_recs = [] 91 | # LPIPS score related 92 | loss_fn_alex = lpips.LPIPS(net='alex').to(device) # best forward scores 93 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) # closer to "traditional" perceptual loss, when used for optimization 94 | lpips_alex = 0.0 95 | lpips_vgg = 0.0 96 | 97 | # SSIM score related 98 | ssim_value = 0.0 99 | 100 | # PSNR score related 101 | psnr_value = 0.0 102 | 103 | # num_images = len(loader) 104 | assert len(val_txt) % world_size == 0 105 | num_images = len(val_txt) // world_size 106 | start_idx, end_idx = num_images * rank, num_images * (rank + 1) 107 | print(f"Testing {num_images} files") 108 | num_iter = 0 109 | 110 | for idx in tqdm(range(start_idx, end_idx)): 111 | rel_path = val_txt[idx] 112 | image_path = rel_path 113 | with torch.no_grad(): 114 | torch.cuda.empty_cache() 115 | # x: [-1, 1] 116 | # x_recons, vq_output = vae(x.to(device), 2, 0, is_train=False) 117 | # x_recons = x_recons.cpu() 118 | x, x_recons = vae_encode_decode_norm(vae, image_path, tgt_h, tgt_w, device, augmentations) 119 | x_recons = x_recons.cpu() 120 | 121 | # paths = batch["path"] 122 | # assert len(paths) == x.shape[0] 123 | paths = [rel_path] 124 | 125 | for p, input_, recon_ in zip(paths, x, x_recons): 126 | if os.path.isabs(p): 127 | p = "/".join(p.split("/")[6:]) 128 | assert not os.path.isabs(p), f"{p} should not be abspath" 129 | path = os.path.join(save_dir, "input_recon", os.path.basename(p)) 130 | os.makedirs(os.path.split(path)[0], exist_ok=True) 131 | input_ = ((input_ + 1) / 2).unsqueeze(0).to(device) # [-1, 1] -> [0, 1] 132 | 133 | pred_x = inception_model(input_)[0] 134 | pred_x = pred_x.squeeze(3).squeeze(2).cpu().numpy() 135 | 136 | recon_ = ((recon_ + 1) / 2).unsqueeze(0).to(device) # [-1, 1] -> [0, 1] 137 | # recon_ = recon_.permute(1, 2, 0).detach().cpu() 138 | with torch.no_grad(): 139 | pred_rec = inception_model(recon_)[0] 140 | pred_rec = pred_rec.squeeze(3).squeeze(2).cpu().numpy() 141 | if args.save_prediction: 142 | if input_.dim() == 4: 143 | input_image = input_.squeeze(0) 144 | if recon_.dim() == 4: 145 | recon_image = recon_.squeeze(0) 146 | input_recon = torch.cat([input_image, recon_image], dim=-1) 147 | input_recon = Image.fromarray((torch.clamp(input_recon.permute(1, 2, 0).detach().cpu(), 0, 1).numpy() * 255).astype(np.uint8)) 148 | input_recon.save(path) 149 | 150 | pred_xs.append(pred_x) 151 | pred_recs.append(pred_rec) 152 | 153 | # calculate lpips 154 | with torch.no_grad(): 155 | lpips_alex += loss_fn_alex(input_, recon_, normalize=True).sum() 156 | lpips_vgg += loss_fn_vgg(input_, recon_, normalize=True).sum() 157 | 158 | #calculate PSNR and SSIM 159 | rgb_restored = (recon_ * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 160 | rgb_gt = (input_ * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 161 | rgb_restored = rgb_restored.astype(np.float32) / 255. 162 | rgb_gt = rgb_gt.astype(np.float32) / 255. 163 | ssim_temp = 0 164 | psnr_temp = 0 165 | B, _, _, _ = rgb_restored.shape 166 | for i in range(B): 167 | rgb_restored_s, rgb_gt_s = rgb_restored[i], rgb_gt[i] 168 | with torch.no_grad(): 169 | ssim_temp += ssim_loss(rgb_restored_s, rgb_gt_s, data_range=1.0, channel_axis=-1) 170 | psnr_temp += psnr_loss(rgb_gt, rgb_restored) 171 | ssim_value += ssim_temp / B 172 | psnr_value += psnr_temp / B 173 | num_iter += 1 174 | 175 | pred_xs = np.concatenate(pred_xs, axis=0) 176 | pred_recs = np.concatenate(pred_recs, axis=0) 177 | temp_dict = { 178 | 'pred_xs':pred_xs, 179 | 'pred_recs':pred_recs, 180 | 'lpips_alex':lpips_alex.cpu(), 181 | 'lpips_vgg':lpips_vgg.cpu(), 182 | 'ssim_value': ssim_value, 183 | 'psnr_value': psnr_value, 184 | 'num_iter': num_iter, 185 | } 186 | return_dict[rank] = temp_dict 187 | 188 | if dist.is_initialized(): 189 | dist.barrier() 190 | cleanup() 191 | 192 | def image_eval(pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter): 193 | mu_x = np.mean(pred_xs, axis=0) 194 | sigma_x = np.cov(pred_xs, rowvar=False) 195 | mu_rec = np.mean(pred_recs, axis=0) 196 | sigma_rec = np.cov(pred_recs, rowvar=False) 197 | 198 | fid_value = calculate_frechet_distance(mu_x, sigma_x, mu_rec, sigma_rec) 199 | lpips_alex_value = lpips_alex / num_iter 200 | lpips_vgg_value = lpips_vgg / num_iter 201 | ssim_value = ssim_value / num_iter 202 | psnr_value = psnr_value / num_iter 203 | 204 | 205 | result_str = f""" 206 | FID = {fid_value:.4f} 207 | LPIPS_VGG: {lpips_vgg_value.item():.4f} 208 | LPIPS_ALEX: {lpips_alex_value.item():.4f} 209 | SSIM: {ssim_value:.4f} 210 | PSNR: {psnr_value:.3f} 211 | """ 212 | return result_str 213 | 214 | def get_args(): 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument('--vqgan_ckpt', type=str, default="infinity_vae_d32.pth") 217 | parser.add_argument('--codebook_dim', type=int, default=32) 218 | parser.add_argument('--save_prediction', action='store_true') 219 | parser.add_argument('--save', type=str, default='imageNet_val') 220 | parser.add_argument('--tgt_size', type=int, default=256, help="input size during inference") 221 | parser.add_argument('--image_path', type=str, default=None) # "data/infinity_toy_data/images/5134521536907147208.jpg" 222 | args = parser.parse_args() 223 | return args 224 | 225 | def setup(rank, world_size): 226 | os.environ['MASTER_ADDR'] = 'localhost' 227 | os.environ['MASTER_PORT'] = '12356' 228 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 229 | 230 | def cleanup(): 231 | dist.destroy_process_group() 232 | 233 | if __name__ == '__main__': 234 | args = get_args() 235 | 236 | # load bsq vae 237 | vqgan_ckpt = args.vqgan_ckpt 238 | schedule_mode = "dynamic" 239 | codebook_dim = args.codebook_dim 240 | codebook_size = 2**codebook_dim 241 | vae = vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size) 242 | vae.eval() 243 | 244 | # read images 245 | if args.image_path is not None: # read a single image 246 | val_txt = [args.image_path] 247 | world_size = 1 248 | else: # test on benchmark 249 | val_txt_path = "data/labels/imagenet/val.txt" 250 | val_txt = open(val_txt_path, 'r').readlines() 251 | val_txt = [x.split("\t")[0] for x in val_txt if x.strip()] 252 | world_size = torch.cuda.device_count() 253 | 254 | tgt_h, tgt_w = args.tgt_size, args.tgt_size 255 | resolution = (tgt_h, tgt_w) 256 | augmentations = transforms.Compose([ 257 | transforms.Resize(min(resolution), interpolation=_pil_interp("bicubic")), 258 | transforms.CenterCrop(resolution), 259 | transforms.ToTensor(), 260 | ]) 261 | # get evaluation metrics 262 | manager = torch.multiprocessing.Manager() 263 | return_dict = manager.dict() 264 | 265 | spawn(inference_eval, args=(world_size, args, vae, return_dict, val_txt, tgt_h, tgt_w, augmentations), nprocs=world_size, join=True) 266 | 267 | pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter = [], [], 0, 0, 0, 0, 0 268 | for rank in range(world_size): 269 | pred_xs.append(return_dict[rank]['pred_xs']) 270 | pred_recs.append(return_dict[rank]['pred_recs']) 271 | lpips_alex += return_dict[rank]['lpips_alex'] 272 | lpips_vgg += return_dict[rank]['lpips_vgg'] 273 | ssim_value += return_dict[rank]['ssim_value'] 274 | psnr_value += return_dict[rank]['psnr_value'] 275 | num_iter += return_dict[rank]['num_iter'] 276 | pred_xs = np.concatenate(pred_xs, 0) 277 | pred_recs = np.concatenate(pred_recs, 0) 278 | result_str = image_eval(pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter) 279 | print(result_str) --------------------------------------------------------------------------------