├── .gitignore ├── LICENSE ├── README.md ├── assets ├── docs │ ├── installation_xOS.md │ └── v1_models.md ├── gradio.png ├── gradio_error_img.png ├── logo.png ├── pipeline.png ├── teaser.png └── visual_results │ ├── bfr1.png │ ├── bfr2.png │ ├── bfr4.png │ ├── bid1.png │ ├── bid2.png │ ├── bid3.png │ ├── bsr1.png │ ├── bsr2.png │ ├── bsr3.png │ ├── bsr4.png │ ├── bsr5.png │ ├── bsr6.png │ ├── bsr7.png │ ├── tiled_sampling.png │ ├── whole_image1.png │ └── whole_image2.png ├── configs ├── inference │ ├── bsrnet.yaml │ ├── cldm.yaml │ ├── diffusion.yaml │ ├── diffusion_v2.1.yaml │ ├── scunet.yaml │ └── swinir.yaml └── train │ ├── train_stage1.yaml │ ├── train_stage2.yaml │ └── train_stage2_v2.1.yaml ├── diffbir ├── dataset │ ├── batch_transform.py │ ├── codeformer.py │ ├── degradation.py │ ├── diffjpeg.py │ ├── file_backend.py │ ├── realesrgan.py │ └── utils.py ├── inference │ ├── __init__.py │ ├── bfr_loop.py │ ├── bid_loop.py │ ├── bsr_loop.py │ ├── custom_loop.py │ ├── loop.py │ ├── pretrained_models.py │ └── unaligned_bfr_loop.py ├── model │ ├── __init__.py │ ├── attention.py │ ├── bsrnet.py │ ├── cldm.py │ ├── clip.py │ ├── config.py │ ├── controlnet.py │ ├── distributions.py │ ├── gaussian_diffusion.py │ ├── open_clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── model.py │ │ ├── tokenizer.py │ │ └── transformer.py │ ├── scunet.py │ ├── swinir.py │ ├── unet.py │ ├── util.py │ └── vae.py ├── pipeline.py ├── sampler │ ├── __init__.py │ ├── ddim_sampler.py │ ├── dpm_solver_pytorch.py │ ├── dpms_sampler.py │ ├── edm_sampler.py │ ├── k_diffusion.py │ ├── sampler.py │ └── spaced_sampler.py └── utils │ ├── caption.py │ ├── common.py │ ├── cond_fn.py │ ├── face.py │ └── tilevae │ ├── __init__.py │ ├── attn.py │ └── tilevae.py ├── inference.py ├── inputs ├── demo │ ├── bfr │ │ ├── aligned │ │ │ ├── 0229.png │ │ │ ├── 0427.png │ │ │ ├── 0722.png │ │ │ └── hermione.jpg │ │ └── whole_img │ │ │ ├── 01.jpg │ │ │ ├── 02.png │ │ │ ├── Audrey_Hepburn.jpg │ │ │ ├── Blake_Lively.jpg │ │ │ ├── Harry_Potter.jpg │ │ │ ├── Queen.jpeg │ │ │ └── real47_1.jpg │ ├── bid │ │ ├── Audrey_Hepburn.jpg │ │ ├── Bears.png │ │ ├── Flowers.png │ │ ├── Movie.png │ │ ├── Postcards.png │ │ ├── cty_fnb_0047.png │ │ ├── kf_fnb_0058.png │ │ └── palace.png │ └── bsr │ │ ├── 14.jpg │ │ ├── 29.jpg │ │ ├── 49.jpg │ │ ├── 53.jpeg │ │ └── comic3.png └── real47 │ ├── 1.jpg │ ├── 11.jpg │ ├── 12.jpg │ ├── 13.jpg │ ├── 14.jpg │ ├── 15.jpg │ ├── 16.jpg │ ├── 17.jpg │ ├── 19.jpg │ ├── 2.jpg │ ├── 20.jpg │ ├── 21.jpg │ ├── 22.jpg │ ├── 23.jpg │ ├── 24.jpg │ ├── 26.jpg │ ├── 27.jpg │ ├── 29.jpg │ ├── 3.jpg │ ├── 32.jpg │ ├── 33.jpg │ ├── 34.jpg │ ├── 35.jpg │ ├── 36.png │ ├── 38.jpg │ ├── 4.jpg │ ├── 40.jpg │ ├── 41.jpg │ ├── 42.jpg │ ├── 43.jpg │ ├── 44.jpg │ ├── 45.jpg │ ├── 46.jpg │ ├── 47.jpg │ ├── 48.jpg │ ├── 49.jpg │ ├── 5.jpg │ ├── 50.jpg │ ├── 51.jpg │ ├── 52.jpg │ ├── 53.jpeg │ ├── 54.jpeg │ ├── 55.jpg │ ├── 56.jpg │ ├── 6.jpg │ ├── 7.jpg │ └── 9.jpg ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_pope.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── eval_textvqa.py │ ├── generate_webpage_data_from_table.py │ ├── m4c_evaluator.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_loader.py │ ├── model_vqa_mmbench.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ └── llava_mpt.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ └── train_xformers.py └── utils.py ├── ram ├── __init__.py ├── configs │ ├── finetune.yaml │ ├── finetune_tag2text.yaml │ ├── med_config.json │ ├── pretrain.yaml │ ├── pretrain_tag2text.yaml │ ├── q2l_config.json │ └── swin │ │ ├── config_swinB_224.json │ │ ├── config_swinB_384.json │ │ ├── config_swinL_224.json │ │ └── config_swinL_384.json ├── data │ ├── __init__.py │ ├── dataset.py │ ├── ram_tag_list.txt │ ├── ram_tag_list_chinese.txt │ ├── ram_tag_list_threshold.txt │ ├── randaugment.py │ ├── tag2text_ori_tag_list.txt │ ├── tag_list.txt │ └── utils.py ├── inference.py ├── models │ ├── __init__.py │ ├── bert.py │ ├── ram.py │ ├── ram_plus.py │ ├── swin_transformer.py │ ├── tag2text.py │ ├── utils.py │ └── vit.py ├── transform.py └── utils │ ├── __init__.py │ ├── metrics.py │ └── openset_utils.py ├── requirements.txt ├── run_gradio.py ├── scripts └── convert_diffusers_to_sd.py ├── train_stage1.py └── train_stage2.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.pth 4 | /data 5 | /exps 6 | *.sh 7 | !install_env.sh 8 | /weights 9 | /temp 10 | /results 11 | .ipynb_checkpoints/ 12 | /TODO.txt 13 | /deprecated 14 | /temp_scripts 15 | /.vscode 16 | /runs 17 | /tests 18 | /meta_files 19 | -------------------------------------------------------------------------------- /assets/docs/installation_xOS.md: -------------------------------------------------------------------------------- 1 | # Linux 2 | Please follow the primary README.md of this repo. 3 | 4 | # Windows 5 | Windows users may stumble when installing the package `triton`. 6 | You can choose to run on **CPU** without `xformers` and `triton` installed. 7 | 8 | To use **CUDA**, please refer to [issue#24](https://github.com/XPixelGroup/DiffBIR/issues/24) to try solve the problem of `triton` installation. 9 | 10 | # MacOS 11 | <!-- Currenly only CPU device is supported to run DiffBIR on Apple Silicon since most GPU acceleration packages are compatible with CUDA only. 12 | 13 | We are still trying to support MPS device. Stay tuned for our progress! --> 14 | 15 | You can try to set up according to the following steps to use CPU or MPS device. 16 | 17 | 1. Install **torch (Preview/Nighly version)**. 18 | 19 | ```bash 20 | # MPS acceleration is available on MacOS 12.3+ 21 | pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu 22 | ``` 23 | Check more details in [official document](https://pytorch.org/get-started/locally/). 24 | 25 | 2. Package `triton` and `xformers` is not needed since they work with CUDA. Remove the related packages. 26 | 27 | Your requirements.txt should look like: 28 | ```bash 29 | # requirements.txt 30 | pytorch_lightning==1.4.2 31 | einops 32 | open-clip-torch 33 | omegaconf 34 | torchmetrics==0.6.0 35 | opencv-python-headless 36 | scipy 37 | matplotlib 38 | lpips 39 | gradio 40 | chardet 41 | transformers 42 | facexlib 43 | ``` 44 | 45 | ```bash 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | 3. [Run the inference script](https://github.com/XPixelGroup/DiffBIR#general_image_inference) and specify `--device cpu` or `--device mps`. Using MPS can accelarate your inference. 50 | 51 | You can specify `--tiled` and related arguments to avoid OOM. -------------------------------------------------------------------------------- /assets/docs/v1_models.md: -------------------------------------------------------------------------------- 1 | | Model Name | Description | HuggingFace | BaiduNetdisk | OpenXLab | 2 | | :--------- | :---------- | :---------- | :---------- | :---------- | 3 | | general_swinir_v1.ckpt | Stage1 model (SwinIR) for general image restoration. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) | [download](https://pan.baidu.com/s/1uvSvJgcoL_Knj0h22-9TvA?pwd=v3v6) (pwd: v3v6) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_general_swinir_v1) | 4 | | general_full_v1.ckpt | Full model for general image restoration. "Full" means it contains both the stage1 and stage2 model. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) | [download](https://pan.baidu.com/s/1gLvW1nvkJStdVAKROqaYaA?pwd=86zi) (pwd: 86zi) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_general_full_v1) | 5 | | face_swinir_v1.ckpt | Stage1 model (SwinIR) for face restoration. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) | [download](https://pan.baidu.com/s/1cnBBC8437BJiM3q6suaK8g?pwd=xk5u) (pwd: xk5u) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_face_swinir_v1) | 6 | | face_full_v1.ckpt | Full model for face restoration. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) | [download](https://pan.baidu.com/s/1pc04xvQybkynRfzK5Y8K0Q?pwd=ov8i) (pwd: ov8i) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_face_full_v1) | -------------------------------------------------------------------------------- /assets/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/gradio.png -------------------------------------------------------------------------------- /assets/gradio_error_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/gradio_error_img.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/logo.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/pipeline.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/teaser.png -------------------------------------------------------------------------------- /assets/visual_results/bfr1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bfr1.png -------------------------------------------------------------------------------- /assets/visual_results/bfr2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bfr2.png -------------------------------------------------------------------------------- /assets/visual_results/bfr4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bfr4.png -------------------------------------------------------------------------------- /assets/visual_results/bid1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bid1.png -------------------------------------------------------------------------------- /assets/visual_results/bid2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bid2.png -------------------------------------------------------------------------------- /assets/visual_results/bid3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bid3.png -------------------------------------------------------------------------------- /assets/visual_results/bsr1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr1.png -------------------------------------------------------------------------------- /assets/visual_results/bsr2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr2.png -------------------------------------------------------------------------------- /assets/visual_results/bsr3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr3.png -------------------------------------------------------------------------------- /assets/visual_results/bsr4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr4.png -------------------------------------------------------------------------------- /assets/visual_results/bsr5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr5.png -------------------------------------------------------------------------------- /assets/visual_results/bsr6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr6.png -------------------------------------------------------------------------------- /assets/visual_results/bsr7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr7.png -------------------------------------------------------------------------------- /assets/visual_results/tiled_sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/tiled_sampling.png -------------------------------------------------------------------------------- /assets/visual_results/whole_image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/whole_image1.png -------------------------------------------------------------------------------- /assets/visual_results/whole_image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/whole_image2.png -------------------------------------------------------------------------------- /configs/inference/bsrnet.yaml: -------------------------------------------------------------------------------- 1 | target: diffbir.model.RRDBNet 2 | params: 3 | in_nc: 3 4 | out_nc: 3 5 | nf: 64 6 | nb: 23 7 | gc: 32 8 | sf: 4 9 | -------------------------------------------------------------------------------- /configs/inference/cldm.yaml: -------------------------------------------------------------------------------- 1 | target: diffbir.model.ControlLDM 2 | params: 3 | latent_scale_factor: 0.18215 4 | unet_cfg: 5 | use_checkpoint: True 6 | image_size: 32 # unused 7 | in_channels: 4 8 | out_channels: 4 9 | model_channels: 320 10 | attention_resolutions: [ 4, 2, 1 ] 11 | num_res_blocks: 2 12 | channel_mult: [ 1, 2, 4, 4 ] 13 | num_head_channels: 64 # need to fix for flash-attn 14 | use_spatial_transformer: True 15 | use_linear_in_transformer: True 16 | transformer_depth: 1 17 | context_dim: 1024 18 | legacy: False 19 | vae_cfg: 20 | embed_dim: 4 21 | ddconfig: 22 | double_z: true 23 | z_channels: 4 24 | resolution: 256 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: 29 | - 1 30 | - 2 31 | - 4 32 | - 4 33 | num_res_blocks: 2 34 | attn_resolutions: [] 35 | dropout: 0.0 36 | clip_cfg: 37 | embed_dim: 1024 38 | vision_cfg: 39 | image_size: 224 40 | layers: 32 41 | width: 1280 42 | head_width: 80 43 | patch_size: 14 44 | text_cfg: 45 | context_length: 77 46 | vocab_size: 49408 47 | width: 1024 48 | heads: 16 49 | layers: 24 50 | layer: "penultimate" 51 | controlnet_cfg: 52 | use_checkpoint: True 53 | image_size: 32 # unused 54 | in_channels: 4 55 | hint_channels: 4 56 | model_channels: 320 57 | attention_resolutions: [ 4, 2, 1 ] 58 | num_res_blocks: 2 59 | channel_mult: [ 1, 2, 4, 4 ] 60 | num_head_channels: 64 # need to fix for flash-attn 61 | use_spatial_transformer: True 62 | use_linear_in_transformer: True 63 | transformer_depth: 1 64 | context_dim: 1024 65 | legacy: False 66 | -------------------------------------------------------------------------------- /configs/inference/diffusion.yaml: -------------------------------------------------------------------------------- 1 | target: diffbir.model.Diffusion 2 | params: 3 | linear_start: 0.00085 4 | linear_end: 0.0120 5 | timesteps: 1000 6 | -------------------------------------------------------------------------------- /configs/inference/diffusion_v2.1.yaml: -------------------------------------------------------------------------------- 1 | target: diffbir.model.Diffusion 2 | params: 3 | linear_start: 0.00085 4 | linear_end: 0.0120 5 | timesteps: 1000 6 | zero_snr: True 7 | parameterization: v 8 | -------------------------------------------------------------------------------- /configs/inference/scunet.yaml: -------------------------------------------------------------------------------- 1 | target: diffbir.model.SCUNet 2 | params: 3 | in_nc: 3 4 | config: [4,4,4,4,4,4,4] 5 | dim: 64 6 | -------------------------------------------------------------------------------- /configs/inference/swinir.yaml: -------------------------------------------------------------------------------- 1 | target: diffbir.model.SwinIR 2 | params: 3 | img_size: 64 4 | patch_size: 1 5 | in_chans: 3 6 | embed_dim: 180 7 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 8 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 9 | window_size: 8 10 | mlp_ratio: 2 11 | sf: 8 12 | img_range: 1.0 13 | upsampler: "nearest+conv" 14 | resi_connection: "1conv" 15 | unshuffle: True 16 | unshuffle_scale: 8 17 | -------------------------------------------------------------------------------- /configs/train/train_stage1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | swinir: 3 | target: diffbir.model.swinir.SwinIR 4 | params: 5 | img_size: 64 6 | patch_size: 1 7 | in_chans: 3 8 | embed_dim: 180 9 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 10 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 11 | window_size: 8 12 | mlp_ratio: 2 13 | sf: 8 14 | img_range: 1.0 15 | upsampler: "nearest+conv" 16 | resi_connection: "1conv" 17 | unshuffle: True 18 | unshuffle_scale: 8 19 | 20 | dataset: 21 | train: 22 | target: diffbir.dataset.codeformer.CodeformerDataset 23 | params: 24 | # training file list path 25 | file_list: 26 | file_backend_cfg: 27 | target: diffbir.dataset.file_backend.HardDiskBackend 28 | out_size: 512 29 | crop_type: center 30 | blur_kernel_size: 41 31 | kernel_list: ['iso', 'aniso'] 32 | kernel_prob: [0.5, 0.5] 33 | blur_sigma: [0.1, 12] 34 | downsample_range: [1, 12] 35 | noise_range: [0, 15] 36 | jpeg_range: [30, 100] 37 | val: 38 | target: diffbir.dataset.codeformer.CodeformerDataset 39 | params: 40 | # validation file list path 41 | file_list: 42 | file_backend_cfg: 43 | target: diffbir.dataset.file_backend.HardDiskBackend 44 | out_size: 512 45 | crop_type: center 46 | blur_kernel_size: 41 47 | kernel_list: ['iso', 'aniso'] 48 | kernel_prob: [0.5, 0.5] 49 | blur_sigma: [0.1, 12] 50 | downsample_range: [1, 12] 51 | noise_range: [0, 15] 52 | jpeg_range: [30, 100] 53 | 54 | batch_transform: 55 | target: diffbir.dataset.batch_transform.IdentityBatchTransform 56 | 57 | train: 58 | # experiment directory path 59 | exp_dir: 60 | learning_rate: 1e-4 61 | # total batch size 62 | batch_size: 96 63 | num_workers: 64 | train_steps: 150000 65 | log_every: 50 66 | ckpt_every: 10000 67 | image_every: 1000 68 | val_every: 1000 69 | resume: ~ 70 | -------------------------------------------------------------------------------- /configs/train/train_stage2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | cldm: 3 | target: diffbir.model.cldm.ControlLDM 4 | params: 5 | latent_scale_factor: 0.18215 6 | unet_cfg: 7 | use_checkpoint: True 8 | image_size: 32 # unused 9 | in_channels: 4 10 | out_channels: 4 11 | model_channels: 320 12 | attention_resolutions: [ 4, 2, 1 ] 13 | num_res_blocks: 2 14 | channel_mult: [ 1, 2, 4, 4 ] 15 | num_head_channels: 64 # need to fix for flash-attn 16 | use_spatial_transformer: True 17 | use_linear_in_transformer: True 18 | transformer_depth: 1 19 | context_dim: 1024 20 | legacy: False 21 | vae_cfg: 22 | embed_dim: 4 23 | ddconfig: 24 | double_z: true 25 | z_channels: 4 26 | resolution: 256 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: 31 | - 1 32 | - 2 33 | - 4 34 | - 4 35 | num_res_blocks: 2 36 | attn_resolutions: [] 37 | dropout: 0.0 38 | clip_cfg: 39 | embed_dim: 1024 40 | vision_cfg: 41 | image_size: 224 42 | layers: 32 43 | width: 1280 44 | head_width: 80 45 | patch_size: 14 46 | text_cfg: 47 | context_length: 77 48 | vocab_size: 49408 49 | width: 1024 50 | heads: 16 51 | layers: 24 52 | layer: "penultimate" 53 | controlnet_cfg: 54 | use_checkpoint: True 55 | image_size: 32 # unused 56 | in_channels: 4 57 | hint_channels: 4 58 | model_channels: 320 59 | attention_resolutions: [ 4, 2, 1 ] 60 | num_res_blocks: 2 61 | channel_mult: [ 1, 2, 4, 4 ] 62 | num_head_channels: 64 # need to fix for flash-attn 63 | use_spatial_transformer: True 64 | use_linear_in_transformer: True 65 | transformer_depth: 1 66 | context_dim: 1024 67 | legacy: False 68 | 69 | swinir: 70 | target: diffbir.model.swinir.SwinIR 71 | params: 72 | img_size: 64 73 | patch_size: 1 74 | in_chans: 3 75 | embed_dim: 180 76 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 77 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 78 | window_size: 8 79 | mlp_ratio: 2 80 | sf: 8 81 | img_range: 1.0 82 | upsampler: "nearest+conv" 83 | resi_connection: "1conv" 84 | unshuffle: True 85 | unshuffle_scale: 8 86 | 87 | diffusion: 88 | target: diffbir.model.gaussian_diffusion.Diffusion 89 | params: 90 | linear_start: 0.00085 91 | linear_end: 0.0120 92 | timesteps: 1000 93 | zero_snr: False 94 | parameterization: eps 95 | 96 | dataset: 97 | train: 98 | target: diffbir.dataset.codeformer.CodeformerDataset 99 | params: 100 | # training file list path 101 | file_list: 102 | file_backend_cfg: 103 | target: diffbir.dataset.file_backend.HardDiskBackend 104 | out_size: 512 105 | crop_type: center 106 | blur_kernel_size: 41 107 | kernel_list: ['iso', 'aniso'] 108 | kernel_prob: [0.5, 0.5] 109 | blur_sigma: [0.1, 12] 110 | downsample_range: [1, 12] 111 | noise_range: [0, 15] 112 | jpeg_range: [30, 100] 113 | 114 | batch_transform: 115 | target: diffbir.dataset.batch_transform.IdentityBatchTransform 116 | 117 | train: 118 | # pretrained sd v2.1 path 119 | sd_path: 120 | # experiment directory path 121 | exp_dir: 122 | # stage 1 swinir path. 123 | # In our paper, we use SwinIR trained on ImageNet-1k with codeformer degradation. 124 | swinir_path: 125 | learning_rate: 1e-4 126 | # ImageNet 1k (1.3M images) 127 | # batch size = 192, lr = 1e-4, total training steps = 25k 128 | # Our filtered laion2b-en (15M images) 129 | # batch size = 256, lr = 1e-4 (first 30k), 1e-5 (next 50k), total training steps = 80k 130 | batch_size: 256 131 | num_workers: 132 | train_steps: 30000 133 | log_every: 50 134 | ckpt_every: 10000 135 | image_every: 1000 136 | resume: ~ 137 | noise_aug_timestep: 0 138 | -------------------------------------------------------------------------------- /configs/train/train_stage2_v2.1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | cldm: 3 | target: diffbir.model.cldm.ControlLDM 4 | params: 5 | latent_scale_factor: 0.18215 6 | unet_cfg: 7 | use_checkpoint: True 8 | image_size: 32 # unused 9 | in_channels: 4 10 | out_channels: 4 11 | model_channels: 320 12 | attention_resolutions: [ 4, 2, 1 ] 13 | num_res_blocks: 2 14 | channel_mult: [ 1, 2, 4, 4 ] 15 | num_head_channels: 64 # need to fix for flash-attn 16 | use_spatial_transformer: True 17 | use_linear_in_transformer: True 18 | transformer_depth: 1 19 | context_dim: 1024 20 | legacy: False 21 | vae_cfg: 22 | embed_dim: 4 23 | ddconfig: 24 | double_z: true 25 | z_channels: 4 26 | resolution: 256 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: 31 | - 1 32 | - 2 33 | - 4 34 | - 4 35 | num_res_blocks: 2 36 | attn_resolutions: [] 37 | dropout: 0.0 38 | clip_cfg: 39 | embed_dim: 1024 40 | vision_cfg: 41 | image_size: 224 42 | layers: 32 43 | width: 1280 44 | head_width: 80 45 | patch_size: 14 46 | text_cfg: 47 | context_length: 77 48 | vocab_size: 49408 49 | width: 1024 50 | heads: 16 51 | layers: 24 52 | layer: "penultimate" 53 | controlnet_cfg: 54 | use_checkpoint: True 55 | image_size: 32 # unused 56 | in_channels: 4 57 | hint_channels: 4 58 | model_channels: 320 59 | attention_resolutions: [ 4, 2, 1 ] 60 | num_res_blocks: 2 61 | channel_mult: [ 1, 2, 4, 4 ] 62 | num_head_channels: 64 # need to fix for flash-attn 63 | use_spatial_transformer: True 64 | use_linear_in_transformer: True 65 | transformer_depth: 1 66 | context_dim: 1024 67 | legacy: False 68 | 69 | swinir: 70 | target: diffbir.model.swinir.SwinIR 71 | params: 72 | img_size: 64 73 | patch_size: 1 74 | in_chans: 3 75 | embed_dim: 180 76 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 77 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 78 | window_size: 8 79 | mlp_ratio: 2 80 | sf: 8 81 | img_range: 1.0 82 | upsampler: "nearest+conv" 83 | resi_connection: "1conv" 84 | unshuffle: True 85 | unshuffle_scale: 8 86 | 87 | diffusion: 88 | target: diffbir.model.gaussian_diffusion.Diffusion 89 | params: 90 | linear_start: 0.00085 91 | linear_end: 0.0120 92 | timesteps: 1000 93 | zero_snr: True 94 | parameterization: v 95 | 96 | dataset: 97 | train: 98 | target: diffbir.dataset.realesrgan.RealESRGANDataset 99 | params: 100 | # Path to the file list. 101 | file_metas: 102 | # The training set is formatted as a parquet file. 103 | # Each row contains file path, long caption and short caption of a high-quality image. 104 | - file_list: 105 | image_path_key: image_path 106 | short_prompt_key: llava_short 107 | long_prompt_key: llava_long 108 | 109 | p_long_prompt: 0.2 110 | 111 | file_backend_cfg: 112 | target: diffbir.dataset.file_backend.HardDiskBackend 113 | 114 | out_size: 512 115 | crop_type: none 116 | 117 | use_hflip: false 118 | use_rot: false 119 | 120 | blur_kernel_size: 21 121 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 122 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 123 | sinc_prob: 0.1 124 | blur_sigma: [0.2, 3] 125 | betag_range: [0.5, 4] 126 | betap_range: [1, 2] 127 | 128 | blur_kernel_size2: 21 129 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 130 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 131 | sinc_prob2: 0.1 132 | blur_sigma2: [0.2, 1.5] 133 | betag_range2: [0.5, 4] 134 | betap_range2: [1, 2] 135 | 136 | final_sinc_prob: 0.8 137 | 138 | p_empty_prompt: 0.2 139 | 140 | batch_transform: 141 | target: diffbir.dataset.batch_transform.RealESRGANBatchTransform 142 | params: 143 | use_sharpener: true 144 | # Queue size of training pool, this should be multiples of batch_size (per GPU). 145 | queue_size: 256 146 | # the first degradation process 147 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 148 | resize_range: [0.15, 1.5] 149 | gaussian_noise_prob: 0.5 150 | noise_range: [1, 30] 151 | poisson_scale_range: [0.05, 3] 152 | gray_noise_prob: 0.4 153 | jpeg_range: [30, 95] 154 | 155 | # the second degradation process 156 | stage2_scale: 4 157 | second_blur_prob: 0.8 158 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 159 | resize_range2: [0.3, 1.2] 160 | gaussian_noise_prob2: 0.5 161 | noise_range2: [1, 25] 162 | poisson_scale_range2: [0.05, 2.5] 163 | gray_noise_prob2: 0.4 164 | jpeg_range2: [30, 95] 165 | 166 | train: 167 | # pretrained sd v2.1-zsnr path 168 | sd_path: 169 | # experiment directory path 170 | exp_dir: 171 | # stage 1 swinir path. 172 | # For DiffBIR v2.1, we use SwinIR trained on ImageNet-1k with RealESRGAN degradation. 173 | swinir_path: 174 | learning_rate: 1e-5 175 | batch_size: 512 176 | num_workers: 16 177 | train_steps: 1000000 178 | log_every: 100 179 | ckpt_every: 10000 180 | image_every: 1000 181 | resume: 182 | noise_aug_timestep: 200 183 | -------------------------------------------------------------------------------- /diffbir/dataset/codeformer.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Dict, Union, List, Mapping, Any, Optional 2 | import math 3 | import time 4 | import io 5 | import random 6 | 7 | import numpy as np 8 | import cv2 9 | from PIL import Image 10 | import torch.utils.data as data 11 | 12 | from .degradation import ( 13 | random_mixed_kernels, 14 | random_add_gaussian_noise, 15 | random_add_jpg_compression, 16 | ) 17 | from .utils import load_file_list, center_crop_arr, random_crop_arr 18 | from ..utils.common import instantiate_from_config 19 | 20 | 21 | class CodeformerDataset(data.Dataset): 22 | 23 | def __init__( 24 | self, 25 | file_list: str, 26 | file_backend_cfg: Mapping[str, Any], 27 | out_size: int, 28 | crop_type: str, 29 | blur_kernel_size: int, 30 | kernel_list: Sequence[str], 31 | kernel_prob: Sequence[float], 32 | blur_sigma: Sequence[float], 33 | downsample_range: Sequence[float], 34 | noise_range: Sequence[float], 35 | jpeg_range: Sequence[int], 36 | ) -> "CodeformerDataset": 37 | super(CodeformerDataset, self).__init__() 38 | self.file_list = file_list 39 | self.image_files = load_file_list(file_list) 40 | self.file_backend = instantiate_from_config(file_backend_cfg) 41 | self.out_size = out_size 42 | self.crop_type = crop_type 43 | assert self.crop_type in ["none", "center", "random"] 44 | # degradation configurations 45 | self.blur_kernel_size = blur_kernel_size 46 | self.kernel_list = kernel_list 47 | self.kernel_prob = kernel_prob 48 | self.blur_sigma = blur_sigma 49 | self.downsample_range = downsample_range 50 | self.noise_range = noise_range 51 | self.jpeg_range = jpeg_range 52 | 53 | def load_gt_image( 54 | self, image_path: str, max_retry: int = 5 55 | ) -> Optional[np.ndarray]: 56 | image_bytes = None 57 | while image_bytes is None: 58 | if max_retry == 0: 59 | return None 60 | image_bytes = self.file_backend.get(image_path) 61 | max_retry -= 1 62 | if image_bytes is None: 63 | time.sleep(0.5) 64 | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") 65 | if self.crop_type != "none": 66 | if image.height == self.out_size and image.width == self.out_size: 67 | image = np.array(image) 68 | else: 69 | if self.crop_type == "center": 70 | image = center_crop_arr(image, self.out_size) 71 | elif self.crop_type == "random": 72 | image = random_crop_arr(image, self.out_size, min_crop_frac=0.7) 73 | else: 74 | assert image.height == self.out_size and image.width == self.out_size 75 | image = np.array(image) 76 | # hwc, rgb, 0,255, uint8 77 | return image 78 | 79 | def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]: 80 | # load gt image 81 | img_gt = None 82 | while img_gt is None: 83 | # load meta file 84 | image_file = self.image_files[index] 85 | gt_path = image_file["image_path"] 86 | prompt = image_file["prompt"] 87 | img_gt = self.load_gt_image(gt_path) 88 | if img_gt is None: 89 | print(f"filed to load {gt_path}, try another image") 90 | index = random.randint(0, len(self) - 1) 91 | 92 | # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. 93 | img_gt = (img_gt[..., ::-1] / 255.0).astype(np.float32) 94 | h, w, _ = img_gt.shape 95 | if np.random.uniform() < 0.5: 96 | prompt = "" 97 | 98 | # ------------------------ generate lq image ------------------------ # 99 | # blur 100 | kernel = random_mixed_kernels( 101 | self.kernel_list, 102 | self.kernel_prob, 103 | self.blur_kernel_size, 104 | self.blur_sigma, 105 | self.blur_sigma, 106 | [-math.pi, math.pi], 107 | noise_range=None, 108 | ) 109 | img_lq = cv2.filter2D(img_gt, -1, kernel) 110 | # downsample 111 | scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) 112 | img_lq = cv2.resize( 113 | img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR 114 | ) 115 | # noise 116 | if self.noise_range is not None: 117 | img_lq = random_add_gaussian_noise(img_lq, self.noise_range) 118 | # jpeg compression 119 | if self.jpeg_range is not None: 120 | img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) 121 | 122 | # resize to original size 123 | img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) 124 | 125 | # BGR to RGB, [-1, 1] 126 | gt = (img_gt[..., ::-1] * 2 - 1).astype(np.float32) 127 | # BGR to RGB, [0, 1] 128 | lq = img_lq[..., ::-1].astype(np.float32) 129 | 130 | return gt, lq, prompt 131 | 132 | def __len__(self) -> int: 133 | return len(self.image_files) 134 | -------------------------------------------------------------------------------- /diffbir/dataset/file_backend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py 3 | import re 4 | from abc import ABCMeta, abstractmethod 5 | from pathlib import Path 6 | from typing import Optional, Union 7 | 8 | 9 | class BaseStorageBackend(metaclass=ABCMeta): 10 | """Abstract class of storage backends. 11 | 12 | All backends need to implement two apis: ``get()`` and ``get_text()``. 13 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 14 | as texts. 15 | """ 16 | 17 | @property 18 | def name(self) -> str: 19 | return self.__class__.__name__ 20 | 21 | @abstractmethod 22 | def get(self, filepath: str) -> bytes: 23 | pass 24 | 25 | 26 | class PetrelBackend(BaseStorageBackend): 27 | """Petrel storage backend (for internal use). 28 | 29 | PetrelBackend supports reading and writing data to multiple clusters. 30 | If the file path contains the cluster name, PetrelBackend will read data 31 | from specified cluster or write data to it. Otherwise, PetrelBackend will 32 | access the default cluster. 33 | 34 | Args: 35 | path_mapping (dict, optional): Path mapping dict from local path to 36 | Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in 37 | ``filepath`` will be replaced by ``dst``. Default: None. 38 | enable_mc (bool, optional): Whether to enable memcached support. 39 | Default: True. 40 | conf_path (str, optional): Config path of Petrel client. Default: None. 41 | `New in version 1.7.1`. 42 | 43 | Examples: 44 | >>> filepath1 = 's3://path/of/file' 45 | >>> filepath2 = 'cluster-name:s3://path/of/file' 46 | >>> client = PetrelBackend() 47 | >>> client.get(filepath1) # get data from default cluster 48 | >>> client.get(filepath2) # get data from 'cluster-name' cluster 49 | """ 50 | 51 | def __init__(self, 52 | path_mapping: Optional[dict] = None, 53 | enable_mc: bool = False, 54 | conf_path: str = None): 55 | try: 56 | from petrel_client import client 57 | except ImportError: 58 | raise ImportError('Please install petrel_client to enable ' 59 | 'PetrelBackend.') 60 | 61 | self._client = client.Client(conf_path=conf_path, enable_mc=enable_mc) 62 | assert isinstance(path_mapping, dict) or path_mapping is None 63 | self.path_mapping = path_mapping 64 | 65 | def _map_path(self, filepath: Union[str, Path]) -> str: 66 | """Map ``filepath`` to a string path whose prefix will be replaced by 67 | :attr:`self.path_mapping`. 68 | 69 | Args: 70 | filepath (str): Path to be mapped. 71 | """ 72 | filepath = str(filepath) 73 | if self.path_mapping is not None: 74 | for k, v in self.path_mapping.items(): 75 | filepath = filepath.replace(k, v, 1) 76 | return filepath 77 | 78 | def _format_path(self, filepath: str) -> str: 79 | """Convert a ``filepath`` to standard format of petrel oss. 80 | 81 | If the ``filepath`` is concatenated by ``os.path.join``, in a Windows 82 | environment, the ``filepath`` will be the format of 83 | 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the 84 | above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. 85 | 86 | Args: 87 | filepath (str): Path to be formatted. 88 | """ 89 | return re.sub(r'\\+', '/', filepath) 90 | 91 | def get(self, filepath: Union[str, Path]) -> bytes: 92 | """Read data from a given ``filepath`` with 'rb' mode. 93 | 94 | Args: 95 | filepath (str or Path): Path to read data. 96 | 97 | Returns: 98 | bytes: The loaded bytes. 99 | """ 100 | filepath = self._map_path(filepath) 101 | filepath = self._format_path(filepath) 102 | value = self._client.Get(filepath) 103 | return value 104 | 105 | 106 | class HardDiskBackend(BaseStorageBackend): 107 | """Raw hard disks storage backend.""" 108 | 109 | def get(self, filepath: Union[str, Path]) -> bytes: 110 | """Read data from a given ``filepath`` with 'rb' mode. 111 | 112 | Args: 113 | filepath (str or Path): Path to read data. 114 | 115 | Returns: 116 | bytes: Expected bytes object. 117 | """ 118 | with open(filepath, 'rb') as f: 119 | value_buf = f.read() 120 | return value_buf 121 | -------------------------------------------------------------------------------- /diffbir/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .bsr_loop import BSRInferenceLoop 2 | from .bfr_loop import BFRInferenceLoop 3 | from .bid_loop import BIDInferenceLoop 4 | from .unaligned_bfr_loop import UnAlignedBFRInferenceLoop 5 | from .custom_loop import CustomInferenceLoop 6 | -------------------------------------------------------------------------------- /diffbir/inference/bfr_loop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from omegaconf import OmegaConf 4 | 5 | from .loop import InferenceLoop, MODELS 6 | from ..utils.common import ( 7 | instantiate_from_config, 8 | load_model_from_url, 9 | trace_vram_usage, 10 | ) 11 | from ..pipeline import SwinIRPipeline 12 | from ..model import SwinIR 13 | 14 | 15 | class BFRInferenceLoop(InferenceLoop): 16 | 17 | def load_cleaner(self) -> None: 18 | self.cleaner: SwinIR = instantiate_from_config( 19 | OmegaConf.load("configs/inference/swinir.yaml") 20 | ) 21 | weight = load_model_from_url(MODELS["swinir_face"]) 22 | self.cleaner.load_state_dict(weight, strict=True) 23 | self.cleaner.eval().to(self.args.device) 24 | 25 | def load_pipeline(self) -> None: 26 | self.pipeline = SwinIRPipeline( 27 | self.cleaner, self.cldm, self.diffusion, self.cond_fn, self.args.device 28 | ) 29 | 30 | def after_load_lq(self, lq: Image.Image) -> np.ndarray: 31 | lq = lq.resize( 32 | tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC 33 | ) 34 | return super().after_load_lq(lq) 35 | -------------------------------------------------------------------------------- /diffbir/inference/bid_loop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from omegaconf import OmegaConf 4 | 5 | from .loop import InferenceLoop, MODELS 6 | from ..utils.common import ( 7 | instantiate_from_config, 8 | load_model_from_url, 9 | trace_vram_usage, 10 | ) 11 | from ..pipeline import ( 12 | SwinIRPipeline, 13 | SCUNetPipeline, 14 | ) 15 | from ..model import SwinIR, SCUNet 16 | 17 | 18 | class BIDInferenceLoop(InferenceLoop): 19 | 20 | def load_cleaner(self) -> None: 21 | if self.args.version == "v1": 22 | config = "configs/inference/swinir.yaml" 23 | weight = MODELS["swinir_general"] 24 | elif self.args.version == "v2": 25 | config = "configs/inference/scunet.yaml" 26 | weight = MODELS["scunet_psnr"] 27 | else: 28 | config = "configs/inference/swinir.yaml" 29 | weight = MODELS["swinir_realesrgan"] 30 | self.cleaner: SCUNet | SwinIR = instantiate_from_config(OmegaConf.load(config)) 31 | model_weight = load_model_from_url(weight) 32 | self.cleaner.load_state_dict(model_weight, strict=True) 33 | self.cleaner.eval().to(self.args.device) 34 | 35 | def load_pipeline(self) -> None: 36 | if self.args.version == "v1" or self.args.version == "v2.1": 37 | pipeline_class = SwinIRPipeline 38 | else: 39 | pipeline_class = SCUNetPipeline 40 | self.pipeline = pipeline_class( 41 | self.cleaner, 42 | self.cldm, 43 | self.diffusion, 44 | self.cond_fn, 45 | self.args.device, 46 | ) 47 | 48 | def after_load_lq(self, lq: Image.Image) -> np.ndarray: 49 | lq = lq.resize( 50 | tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC 51 | ) 52 | return super().after_load_lq(lq) 53 | -------------------------------------------------------------------------------- /diffbir/inference/bsr_loop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from omegaconf import OmegaConf 4 | 5 | from .loop import InferenceLoop, MODELS 6 | from ..utils.common import ( 7 | instantiate_from_config, 8 | load_model_from_url, 9 | trace_vram_usage, 10 | ) 11 | from ..pipeline import ( 12 | BSRNetPipeline, 13 | SwinIRPipeline, 14 | ) 15 | from ..model import RRDBNet, SwinIR 16 | 17 | 18 | class BSRInferenceLoop(InferenceLoop): 19 | 20 | def load_cleaner(self) -> None: 21 | if self.args.version == "v1": 22 | config = "configs/inference/swinir.yaml" 23 | weight = MODELS["swinir_general"] 24 | elif self.args.version == "v2": 25 | config = "configs/inference/bsrnet.yaml" 26 | weight = MODELS["bsrnet"] 27 | else: 28 | config = "configs/inference/swinir.yaml" 29 | weight = MODELS["swinir_realesrgan"] 30 | self.cleaner: RRDBNet | SwinIR = instantiate_from_config(OmegaConf.load(config)) 31 | model_weight = load_model_from_url(weight) 32 | self.cleaner.load_state_dict(model_weight, strict=True) 33 | self.cleaner.eval().to(self.args.device) 34 | 35 | def load_pipeline(self) -> None: 36 | if self.args.version == "v1" or self.args.version == "v2.1": 37 | self.pipeline = SwinIRPipeline( 38 | self.cleaner, 39 | self.cldm, 40 | self.diffusion, 41 | self.cond_fn, 42 | self.args.device, 43 | ) 44 | else: 45 | self.pipeline = BSRNetPipeline( 46 | self.cleaner, 47 | self.cldm, 48 | self.diffusion, 49 | self.cond_fn, 50 | self.args.device, 51 | self.args.upscale, 52 | ) 53 | 54 | def after_load_lq(self, lq: Image.Image) -> np.ndarray: 55 | if self.args.version == "v1" or self.args.version == "v2.1": 56 | lq = lq.resize( 57 | tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC 58 | ) 59 | return super().after_load_lq(lq) 60 | -------------------------------------------------------------------------------- /diffbir/inference/custom_loop.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from omegaconf import OmegaConf 6 | import torch 7 | 8 | from .loop import InferenceLoop 9 | from ..utils.common import ( 10 | instantiate_from_config, 11 | VRAMPeakMonitor, 12 | ) 13 | from ..pipeline import ( 14 | SwinIRPipeline, 15 | Pipeline, 16 | ) 17 | from ..model import SwinIR, ControlLDM, Diffusion 18 | 19 | 20 | class CustomInferenceLoop(InferenceLoop): 21 | 22 | def __init__(self, args: Namespace) -> "InferenceLoop": 23 | self.args = args 24 | self.train_cfg = OmegaConf.load(args.train_cfg) 25 | self.loop_ctx = {} 26 | self.pipeline: Pipeline = None 27 | with VRAMPeakMonitor("loading cleaner model"): 28 | self.load_cleaner() 29 | with VRAMPeakMonitor("loading cldm model"): 30 | self.load_cldm() 31 | self.load_cond_fn() 32 | self.load_pipeline() 33 | with VRAMPeakMonitor("loading captioner"): 34 | self.load_captioner() 35 | 36 | def load_cldm(self) -> None: 37 | self.cldm: ControlLDM = instantiate_from_config(self.train_cfg.model.cldm) 38 | 39 | # load pre-trained SD weight 40 | sd_weight = torch.load(self.train_cfg.train.sd_path, map_location="cpu") 41 | sd_weight = sd_weight["state_dict"] 42 | unused, missing = self.cldm.load_pretrained_sd(sd_weight) 43 | print( 44 | f"load pretrained stable diffusion, " 45 | f"unused weights: {unused}, missing weights: {missing}" 46 | ) 47 | # load controlnet weight 48 | control_weight = torch.load(self.args.ckpt, map_location="cpu") 49 | self.cldm.load_controlnet_from_ckpt(control_weight) 50 | print(f"load controlnet weight") 51 | self.cldm.eval().to(self.args.device) 52 | cast_type = { 53 | "fp32": torch.float32, 54 | "fp16": torch.float16, 55 | "bf16": torch.bfloat16, 56 | }[self.args.precision] 57 | self.cldm.cast_dtype(cast_type) 58 | 59 | # load diffusion 60 | self.diffusion: Diffusion = instantiate_from_config( 61 | self.train_cfg.model.diffusion 62 | ) 63 | self.diffusion.to(self.args.device) 64 | 65 | def load_cleaner(self) -> None: 66 | # NOTE: Use SwinIR as stage-1 model. Change it if you want. 67 | self.cleaner: SwinIR = instantiate_from_config(self.train_cfg.model.swinir) 68 | weight = torch.load(self.train_cfg.train.swinir_path, map_location="cpu") 69 | if "state_dict" in weight: 70 | weight = weight["state_dict"] 71 | weight = { 72 | (k[len("module.") :] if k.startswith("module.") else k): v 73 | for k, v in weight.items() 74 | } 75 | self.cleaner.load_state_dict(weight, strict=True) 76 | self.cleaner.eval().to(self.args.device) 77 | 78 | def load_pipeline(self) -> None: 79 | # NOTE: Choose the correct pipeline if SwinIR is not your stage-1 model. 80 | self.pipeline = SwinIRPipeline( 81 | self.cleaner, 82 | self.cldm, 83 | self.diffusion, 84 | self.cond_fn, 85 | self.args.device, 86 | ) 87 | 88 | def after_load_lq(self, lq: Image.Image) -> np.ndarray: 89 | # For SwinIRPipeline, upscaling is achieved by resizing input LQ. 90 | lq = lq.resize( 91 | tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC 92 | ) 93 | return super().after_load_lq(lq) 94 | -------------------------------------------------------------------------------- /diffbir/inference/pretrained_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | All models used in inference: 3 | - DiffBIR-v1 4 | All tasks share the same pre-trained stable diffusion v2.1 (sd_v2.1). 5 | -- BSR task 6 | stage-1 model (swinir_general): SwinIR trained on ImageNet-1k with Real-ESRGAN degradation. 7 | stage-2 model (v1_general): IRControlNet trained on ImageNet-1k. 8 | -- BFR task 9 | stage-1 model (swinir_face): SwinIR pre-trained on FFHQ, borrowed from DifFace (https://github.com/zsyOAOA/DifFace.git) 10 | stage-2 model (v1_face): IRControlNet trained on FFHQ. 11 | -- BID task 12 | The same as BSR task. 13 | 14 | - DiffBIR-v2 15 | All tasks share the same pre-trained stable diffusion v2.1 (sd_v2.1). 16 | All tasks share the same stage-2 model (v2). 17 | -- BSR task 18 | stage-1 model (bsrnet): BSRNet borrowed from BSRGAN (https://github.com/cszn/BSRGAN.git). 19 | -- BFR task 20 | stage-1 model (swinir_face): SwinIR pre-trained on FFHQ, borrowed from DifFace (https://github.com/zsyOAOA/DifFace.git) 21 | -- BID task 22 | stage-1 model (scunet_psnr): SCUNet-PSNR borrowed from SCUNet (https://github.com/cszn/SCUNet.git) 23 | 24 | - DiffBIR-v2.1 25 | All tasks share the same pre-trained stable diffusion v2.1-zsnr (sd_v2.1_zsnr). 26 | All tasks share the same stage-2 model (v2.1). 27 | -- BSR task 28 | stage-1 model (swinir_realesrgan): SwinIR trained on ImageNet-1k with Real-ESRGAN degradation. 29 | -- BFR task 30 | stage-1 model (swinir_face): SwinIR pre-trained on FFHQ, borrowed from DifFace (https://github.com/zsyOAOA/DifFace.git) 31 | -- BID task 32 | The same as BSR task. 33 | """ 34 | MODELS = { 35 | # --------------- stage-1 model weights --------------- 36 | "bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth", 37 | # the following checkpoint is up-to-date, we use the old version in our paper 38 | # "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth", 39 | "swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt", 40 | "scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth", 41 | "swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt", 42 | "swinir_realesrgan": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/realesrgan_s4_swinir_100k.pth", 43 | # --------------- pre-trained stable diffusion weights --------------- 44 | "sd_v2.1": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt", 45 | "sd_v2.1_zsnr": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/sd2.1-base-zsnr-laionaes5.ckpt", 46 | # --------------- IRControlNet weights --------------- 47 | "v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth", 48 | "v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth", 49 | "v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth", 50 | "v2.1": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/DiffBIR_v2.1.pt", 51 | } 52 | -------------------------------------------------------------------------------- /diffbir/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config 2 | 3 | from .controlnet import ControlledUnetModel, ControlNet 4 | from .vae import AutoencoderKL 5 | from .clip import FrozenOpenCLIPEmbedder 6 | 7 | from .cldm import ControlLDM 8 | from .gaussian_diffusion import Diffusion 9 | 10 | from .swinir import SwinIR 11 | from .bsrnet import RRDBNet 12 | from .scunet import SCUNet 13 | -------------------------------------------------------------------------------- /diffbir/model/bsrnet.py: -------------------------------------------------------------------------------- 1 | # From BSRGAN 2 | import functools 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | 9 | def initialize_weights(net_l, scale=1): 10 | if not isinstance(net_l, list): 11 | net_l = [net_l] 12 | for net in net_l: 13 | for m in net.modules(): 14 | if isinstance(m, nn.Conv2d): 15 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 16 | m.weight.data *= scale # for residual block 17 | if m.bias is not None: 18 | m.bias.data.zero_() 19 | elif isinstance(m, nn.Linear): 20 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 21 | m.weight.data *= scale 22 | if m.bias is not None: 23 | m.bias.data.zero_() 24 | elif isinstance(m, nn.BatchNorm2d): 25 | init.constant_(m.weight, 1) 26 | init.constant_(m.bias.data, 0.0) 27 | 28 | 29 | def make_layer(block, n_layers): 30 | layers = [] 31 | for _ in range(n_layers): 32 | layers.append(block()) 33 | return nn.Sequential(*layers) 34 | 35 | 36 | class ResidualDenseBlock_5C(nn.Module): 37 | def __init__(self, nf=64, gc=32, bias=True): 38 | super(ResidualDenseBlock_5C, self).__init__() 39 | # gc: growth channel, i.e. intermediate channels 40 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 41 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 42 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 43 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 44 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 46 | 47 | # initialization 48 | initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 49 | 50 | def forward(self, x): 51 | x1 = self.lrelu(self.conv1(x)) 52 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 53 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 54 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 55 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 56 | return x5 * 0.2 + x 57 | 58 | 59 | class RRDB(nn.Module): 60 | '''Residual in Residual Dense Block''' 61 | 62 | def __init__(self, nf, gc=32): 63 | super(RRDB, self).__init__() 64 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 65 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 66 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 67 | 68 | def forward(self, x): 69 | out = self.RDB1(x) 70 | out = self.RDB2(out) 71 | out = self.RDB3(out) 72 | return out * 0.2 + x 73 | 74 | 75 | class RRDBNet(nn.Module): 76 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): 77 | super(RRDBNet, self).__init__() 78 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 79 | self.sf = sf 80 | print([in_nc, out_nc, nf, nb, gc, sf]) 81 | 82 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 83 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 84 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 85 | #### upsampling 86 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 87 | if self.sf==4: 88 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 89 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 90 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 91 | 92 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 93 | 94 | def forward(self, x): 95 | fea = self.conv_first(x) 96 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 97 | fea = fea + trunk 98 | 99 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 100 | if self.sf==4: 101 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 102 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 103 | 104 | return out 105 | -------------------------------------------------------------------------------- /diffbir/model/clip.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.checkpoint import checkpoint 5 | from .open_clip import CLIP, tokenize 6 | 7 | 8 | class FrozenOpenCLIPEmbedder(nn.Module): 9 | """ 10 | Uses the OpenCLIP transformer encoder for text 11 | """ 12 | LAYERS = [ 13 | #"pooled", 14 | "last", 15 | "penultimate" 16 | ] 17 | def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"): 18 | super().__init__() 19 | assert layer in self.LAYERS 20 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 21 | model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg)) 22 | del model.visual 23 | self.model = model 24 | 25 | self.layer = layer 26 | if self.layer == "last": 27 | self.layer_idx = 0 28 | elif self.layer == "penultimate": 29 | self.layer_idx = 1 30 | else: 31 | raise NotImplementedError() 32 | 33 | def forward(self, tokens): 34 | z = self.encode_with_transformer(tokens) 35 | return z 36 | 37 | def encode_with_transformer(self, text): 38 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 39 | x = x + self.model.positional_embedding 40 | x = x.permute(1, 0, 2) # NLD -> LND 41 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 42 | x = x.permute(1, 0, 2) # LND -> NLD 43 | x = self.model.ln_final(x) 44 | return x 45 | 46 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 47 | for i, r in enumerate(self.model.transformer.resblocks): 48 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 49 | break 50 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 51 | x = checkpoint(r, x, attn_mask) 52 | else: 53 | x = r(x, attn_mask=attn_mask) 54 | return x 55 | 56 | def encode(self, text: List[str]) -> torch.Tensor: 57 | # convert a batch of text to tensor 58 | tokens = tokenize(text) 59 | # move tensor to model device 60 | tokens = tokens.to(next(self.model.parameters()).device) 61 | return self(tokens) 62 | -------------------------------------------------------------------------------- /diffbir/model/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Literal 3 | from types import ModuleType 4 | import enum 5 | from packaging import version 6 | 7 | import torch 8 | 9 | # collect system information 10 | if version.parse(torch.__version__) >= version.parse("2.0.0"): 11 | SDP_IS_AVAILABLE = True 12 | else: 13 | SDP_IS_AVAILABLE = False 14 | 15 | try: 16 | import xformers 17 | import xformers.ops 18 | XFORMERS_IS_AVAILBLE = True 19 | except: 20 | XFORMERS_IS_AVAILBLE = False 21 | 22 | 23 | class AttnMode(enum.Enum): 24 | SDP = 0 25 | XFORMERS = 1 26 | VANILLA = 2 27 | 28 | 29 | class Config: 30 | xformers: Optional[ModuleType] = None 31 | attn_mode: AttnMode = AttnMode.VANILLA 32 | 33 | 34 | # initialize attention mode 35 | if XFORMERS_IS_AVAILBLE: 36 | Config.attn_mode = AttnMode.XFORMERS 37 | print(f"use xformers attention as default") 38 | elif SDP_IS_AVAILABLE: 39 | Config.attn_mode = AttnMode.SDP 40 | print(f"use sdp attention as default") 41 | else: 42 | print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default") 43 | 44 | if XFORMERS_IS_AVAILBLE: 45 | Config.xformers = xformers 46 | 47 | 48 | # user-specified attention mode 49 | ATTN_MODE = os.environ.get("ATTN_MODE", None) 50 | if ATTN_MODE is not None: 51 | assert ATTN_MODE in ["vanilla", "sdp", "xformers"] 52 | if ATTN_MODE == "sdp": 53 | assert SDP_IS_AVAILABLE 54 | Config.attn_mode = AttnMode.SDP 55 | elif ATTN_MODE == "xformers": 56 | assert XFORMERS_IS_AVAILBLE 57 | Config.attn_mode = AttnMode.XFORMERS 58 | else: 59 | Config.attn_mode = AttnMode.VANILLA 60 | print(f"set attention mode to {ATTN_MODE}") 61 | else: 62 | print("keep default attention mode") 63 | -------------------------------------------------------------------------------- /diffbir/model/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /diffbir/model/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CLIP 2 | from .tokenizer import tokenize 3 | 4 | __all__ = ["CLIP", "tokenize"] 5 | -------------------------------------------------------------------------------- /diffbir/model/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/diffbir/model/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /diffbir/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .spaced_sampler import SpacedSampler 2 | from .ddim_sampler import DDIMSampler 3 | from .dpms_sampler import DPMSolverSampler 4 | from .edm_sampler import EDMSampler 5 | -------------------------------------------------------------------------------- /diffbir/sampler/dpms_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Literal 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from .sampler import Sampler 7 | from .dpm_solver_pytorch import ( 8 | NoiseScheduleVP, 9 | model_wrapper, 10 | DPM_Solver, 11 | ) 12 | from ..utils.cond_fn import Guidance 13 | from ..model.cldm import ControlLDM 14 | from ..utils.common import make_tiled_fn, trace_vram_usage 15 | 16 | 17 | class DPMSolverSampler(Sampler): 18 | 19 | def __init__( 20 | self, 21 | betas: np.ndarray, 22 | parameterization: Literal["eps", "v"], 23 | rescale_cfg: bool, 24 | model_spec: str, 25 | ) -> "DPMSolverSampler": 26 | super().__init__(betas, parameterization, rescale_cfg) 27 | if parameterization == "eps": 28 | self.model_type = "noise" 29 | elif parameterization == "v": 30 | self.model_type = "v" 31 | else: 32 | raise ValueError(parameterization) 33 | # parse samping args from string 34 | # e.g. dpm++_s2 => solver_type=dpmsolver++, method=singlestep, order=2 35 | solver_type, (method, order) = model_spec.split("_") 36 | self.solver_type = {"dpm": "dpmsolver", "dpm++": "dpmsolver++"}[solver_type] 37 | self.method = {"s": "singlestep", "m": "multistep"}[method] 38 | self.order = {"1": 1, "2": 2, "3": 3}[order] 39 | self.register("betas", betas) 40 | 41 | @torch.no_grad() 42 | def sample( 43 | self, 44 | model: ControlLDM, 45 | device: str, 46 | steps: int, 47 | x_size: Tuple[int], 48 | cond: Dict[str, torch.Tensor], 49 | uncond: Dict[str, torch.Tensor], 50 | cfg_scale: float, 51 | tiled: bool = False, 52 | tile_size: int = -1, 53 | tile_stride: int = -1, 54 | x_T: torch.Tensor | None = None, 55 | progress: bool = True, 56 | ) -> torch.Tensor: 57 | if tiled: 58 | forward = model.forward 59 | model.forward = make_tiled_fn( 60 | lambda x_tile, t, cond, hi, hi_end, wi, wi_end: ( 61 | forward( 62 | x_tile, 63 | t, 64 | { 65 | "c_txt": cond["c_txt"], 66 | "c_img": cond["c_img"][..., hi:hi_end, wi:wi_end], 67 | }, 68 | ) 69 | ), 70 | tile_size, 71 | tile_stride, 72 | ) 73 | if x_T is None: 74 | x_T = torch.randn(x_size, device=device, dtype=torch.float32) 75 | x = x_T 76 | 77 | noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas) 78 | model_fn = model_wrapper( 79 | lambda x, t, c: model(x, t, c), 80 | noise_schedule, 81 | model_type=self.model_type, 82 | guidance_type="classifier-free", 83 | condition=cond, 84 | unconditional_condition=uncond, 85 | guidance_scale=cfg_scale, 86 | cfg_rescale=self.rescale_cfg, 87 | ) 88 | dpm_solver = DPM_Solver( 89 | model_fn, noise_schedule, algorithm_type=self.solver_type 90 | ) 91 | x = dpm_solver.sample( 92 | x_T, 93 | steps=steps, 94 | skip_type="time_uniform", 95 | method=self.method, 96 | order=self.order, 97 | return_intermediate=False, 98 | ) 99 | if tiled: 100 | model.forward = forward 101 | return x 102 | -------------------------------------------------------------------------------- /diffbir/sampler/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, overload, Dict, Optional, Tuple 2 | import math 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | 7 | from ..model.cldm import ControlLDM 8 | 9 | 10 | class Sampler(nn.Module): 11 | 12 | def __init__( 13 | self, 14 | betas: np.ndarray, 15 | parameterization: Literal["eps", "v"], 16 | rescale_cfg: bool, 17 | ) -> "Sampler": 18 | super().__init__() 19 | self.num_timesteps = len(betas) 20 | self.training_betas = betas 21 | self.training_alphas_cumprod = np.cumprod(1.0 - betas, axis=0) 22 | self.context = {} 23 | self.parameterization = parameterization 24 | self.rescale_cfg = rescale_cfg 25 | 26 | def register( 27 | self, name: str, value: np.ndarray, dtype: torch.dtype = torch.float32 28 | ) -> None: 29 | self.register_buffer(name, torch.tensor(value, dtype=dtype)) 30 | 31 | def get_cfg_scale(self, default_cfg_scale: float, model_t: int) -> float: 32 | if self.rescale_cfg and default_cfg_scale > 1: 33 | cfg_scale = 1 + default_cfg_scale * ( 34 | (1 - math.cos(math.pi * ((1000 - model_t) / 1000) ** 5.0)) / 2 35 | ) 36 | else: 37 | cfg_scale = default_cfg_scale 38 | return cfg_scale 39 | 40 | @overload 41 | def sample( 42 | self, 43 | model: ControlLDM, 44 | device: str, 45 | steps: int, 46 | x_size: Tuple[int], 47 | cond: Dict[str, torch.Tensor], 48 | uncond: Dict[str, torch.Tensor], 49 | cfg_scale: float, 50 | tiled: bool = False, 51 | tile_size: int = -1, 52 | tile_stride: int = -1, 53 | x_T: Optional[torch.Tensor] = None, 54 | progress: bool = True, 55 | ) -> torch.Tensor: ... 56 | -------------------------------------------------------------------------------- /diffbir/utils/cond_fn.py: -------------------------------------------------------------------------------- 1 | from typing import overload, Tuple 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class Guidance: 7 | 8 | def __init__( 9 | self, scale: float, t_start: int, t_stop: int, space: str, repeat: int 10 | ) -> "Guidance": 11 | """ 12 | Initialize restoration guidance. 13 | 14 | Args: 15 | scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale, 16 | the closer the final result will be to the output of the first stage model. 17 | t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling 18 | process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`. 19 | space (str): The data space for computing loss function (rgb or latent). 20 | 21 | Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior). 22 | Thanks for their work! 23 | """ 24 | self.scale = scale * 3000 25 | self.t_start = t_start 26 | self.t_stop = t_stop 27 | self.target = None 28 | self.space = space 29 | self.repeat = repeat 30 | 31 | def load_target(self, target: torch.Tensor) -> None: 32 | self.target = target 33 | 34 | def __call__( 35 | self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int 36 | ) -> Tuple[torch.Tensor, float]: 37 | # avoid propagating gradient out of this scope 38 | pred_x0 = pred_x0.detach().clone() 39 | target_x0 = target_x0.detach().clone() 40 | return self._forward(target_x0, pred_x0, t) 41 | 42 | @overload 43 | def _forward( 44 | self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int 45 | ) -> Tuple[torch.Tensor, float]: ... 46 | 47 | 48 | class MSEGuidance(Guidance): 49 | 50 | def _forward( 51 | self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int 52 | ) -> Tuple[torch.Tensor, float]: 53 | # inputs: [-1, 1], nchw, rgb 54 | with torch.enable_grad(): 55 | pred_x0.requires_grad_(True) 56 | loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum() 57 | scale = self.scale 58 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale 59 | return g, loss.item() 60 | 61 | 62 | class WeightedMSEGuidance(Guidance): 63 | 64 | def _get_weight(self, target: torch.Tensor) -> torch.Tensor: 65 | # convert RGB to G 66 | rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1) 67 | target = torch.sum( 68 | target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True 69 | ) 70 | # initialize sobel kernel in x and y axis 71 | G_x = [[1, 0, -1], [2, 0, -2], [1, 0, -1]] 72 | G_y = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]] 73 | G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None] 74 | G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None] 75 | G = torch.stack((G_x, G_y)) 76 | 77 | target = F.pad(target, (1, 1, 1, 1), mode="replicate") # padding = 1 78 | grad = F.conv2d(target, G, stride=1) 79 | mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt() 80 | 81 | n, c, h, w = mag.size() 82 | block_size = 2 83 | blocks = ( 84 | mag.view(n, c, h // block_size, block_size, w // block_size, block_size) 85 | .permute(0, 1, 2, 4, 3, 5) 86 | .contiguous() 87 | ) 88 | block_mean = ( 89 | blocks.sum(dim=(-2, -1), keepdim=True) 90 | .tanh() 91 | .repeat(1, 1, 1, 1, block_size, block_size) 92 | .permute(0, 1, 2, 4, 3, 5) 93 | .contiguous() 94 | ) 95 | block_mean = block_mean.view(n, c, h, w) 96 | weight_map = 1 - block_mean 97 | 98 | return weight_map 99 | 100 | def _forward( 101 | self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int 102 | ) -> Tuple[torch.Tensor, float]: 103 | # inputs: [-1, 1], nchw, rgb 104 | with torch.no_grad(): 105 | w = self._get_weight((target_x0 + 1) / 2) 106 | with torch.enable_grad(): 107 | pred_x0.requires_grad_(True) 108 | loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum() 109 | scale = self.scale 110 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale 111 | return g, loss.item() 112 | -------------------------------------------------------------------------------- /diffbir/utils/tilevae/__init__.py: -------------------------------------------------------------------------------- 1 | from .tilevae import VAEHook 2 | -------------------------------------------------------------------------------- /diffbir/utils/tilevae/attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from the sd_hijack_optimizations.py to remove the residual and norm part, 3 | So that the Tiled VAE can support other types of attention. 4 | ''' 5 | import torch 6 | from torch.nn import functional as F 7 | from einops import rearrange 8 | 9 | from ...model.config import Config, AttnMode 10 | 11 | 12 | def get_attn_func(): 13 | return { 14 | AttnMode.VANILLA: forward, 15 | AttnMode.XFORMERS: xformers_forward, 16 | AttnMode.SDP: sdp_forward, 17 | }[Config.attn_mode] 18 | 19 | # The following functions are all copied from modules.sd_hijack_optimizations 20 | # However, the residual & normalization are removed and computed separately. 21 | 22 | def forward(self, x): 23 | h_ = x 24 | # h_ = self.norm(h_) 25 | q = self.q(h_) 26 | k = self.k(h_) 27 | v = self.v(h_) 28 | 29 | # compute attention 30 | b, c, h, w = q.shape 31 | q = q.reshape(b, c, h * w) 32 | q = q.permute(0, 2, 1) # b,hw,c 33 | k = k.reshape(b, c, h * w) # b,c,hw 34 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 35 | w_ = w_ * (int(c) ** (-0.5)) 36 | w_ = torch.nn.functional.softmax(w_, dim=2) 37 | 38 | # attend to values 39 | v = v.reshape(b, c, h * w) 40 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 41 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 42 | h_ = h_.reshape(b, c, h, w) 43 | 44 | h_ = self.proj_out(h_) 45 | 46 | # return x + h_ 47 | return h_ 48 | 49 | 50 | def xformers_forward(self, x): 51 | h_ = x 52 | # h_ = self.norm(h_) 53 | q = self.q(h_) 54 | k = self.k(h_) 55 | v = self.v(h_) 56 | 57 | # compute attention 58 | B, C, H, W = q.shape 59 | q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) 60 | 61 | q, k, v = map( 62 | lambda t: t.unsqueeze(3) 63 | .reshape(B, t.shape[1], 1, C) 64 | .permute(0, 2, 1, 3) 65 | .reshape(B * 1, t.shape[1], C) 66 | .contiguous(), 67 | (q, k, v), 68 | ) 69 | out = Config.xformers.ops.memory_efficient_attention( 70 | q, k, v, attn_bias=None, op=self.attention_op 71 | ) 72 | 73 | out = ( 74 | out.unsqueeze(0) 75 | .reshape(B, 1, out.shape[1], C) 76 | .permute(0, 2, 1, 3) 77 | .reshape(B, out.shape[1], C) 78 | ) 79 | out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) 80 | out = self.proj_out(out) 81 | # return x + out 82 | return out 83 | 84 | 85 | def sdp_forward(self, x): 86 | h_ = x 87 | # h_ = self.norm(h_) 88 | q = self.q(h_) 89 | k = self.k(h_) 90 | v = self.v(h_) 91 | 92 | # compute attention 93 | B, C, H, W = q.shape 94 | q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) 95 | 96 | q, k, v = map( 97 | lambda t: t.unsqueeze(3) 98 | .reshape(B, t.shape[1], 1, C) 99 | .permute(0, 2, 1, 3) 100 | .reshape(B * 1, t.shape[1], C) 101 | .contiguous(), 102 | (q, k, v), 103 | ) 104 | out = F.scaled_dot_product_attention(q, k, v) 105 | 106 | out = ( 107 | out.unsqueeze(0) 108 | .reshape(B, 1, out.shape[1], C) 109 | .permute(0, 2, 1, 3) 110 | .reshape(B, out.shape[1], C) 111 | ) 112 | out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) 113 | out = self.proj_out(out) 114 | # return x + out 115 | return out 116 | -------------------------------------------------------------------------------- /inputs/demo/bfr/aligned/0229.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/0229.png -------------------------------------------------------------------------------- /inputs/demo/bfr/aligned/0427.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/0427.png -------------------------------------------------------------------------------- /inputs/demo/bfr/aligned/0722.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/0722.png -------------------------------------------------------------------------------- /inputs/demo/bfr/aligned/hermione.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/hermione.jpg -------------------------------------------------------------------------------- /inputs/demo/bfr/whole_img/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/01.jpg -------------------------------------------------------------------------------- /inputs/demo/bfr/whole_img/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/02.png -------------------------------------------------------------------------------- /inputs/demo/bfr/whole_img/Audrey_Hepburn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Audrey_Hepburn.jpg -------------------------------------------------------------------------------- /inputs/demo/bfr/whole_img/Blake_Lively.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Blake_Lively.jpg -------------------------------------------------------------------------------- /inputs/demo/bfr/whole_img/Harry_Potter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Harry_Potter.jpg -------------------------------------------------------------------------------- /inputs/demo/bfr/whole_img/Queen.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Queen.jpeg -------------------------------------------------------------------------------- /inputs/demo/bfr/whole_img/real47_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/real47_1.jpg -------------------------------------------------------------------------------- /inputs/demo/bid/Audrey_Hepburn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Audrey_Hepburn.jpg -------------------------------------------------------------------------------- /inputs/demo/bid/Bears.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Bears.png -------------------------------------------------------------------------------- /inputs/demo/bid/Flowers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Flowers.png -------------------------------------------------------------------------------- /inputs/demo/bid/Movie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Movie.png -------------------------------------------------------------------------------- /inputs/demo/bid/Postcards.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Postcards.png -------------------------------------------------------------------------------- /inputs/demo/bid/cty_fnb_0047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/cty_fnb_0047.png -------------------------------------------------------------------------------- /inputs/demo/bid/kf_fnb_0058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/kf_fnb_0058.png -------------------------------------------------------------------------------- /inputs/demo/bid/palace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/palace.png -------------------------------------------------------------------------------- /inputs/demo/bsr/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/14.jpg -------------------------------------------------------------------------------- /inputs/demo/bsr/29.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/29.jpg -------------------------------------------------------------------------------- /inputs/demo/bsr/49.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/49.jpg -------------------------------------------------------------------------------- /inputs/demo/bsr/53.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/53.jpeg -------------------------------------------------------------------------------- /inputs/demo/bsr/comic3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/comic3.png -------------------------------------------------------------------------------- /inputs/real47/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/1.jpg -------------------------------------------------------------------------------- /inputs/real47/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/11.jpg -------------------------------------------------------------------------------- /inputs/real47/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/12.jpg -------------------------------------------------------------------------------- /inputs/real47/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/13.jpg -------------------------------------------------------------------------------- /inputs/real47/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/14.jpg -------------------------------------------------------------------------------- /inputs/real47/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/15.jpg -------------------------------------------------------------------------------- /inputs/real47/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/16.jpg -------------------------------------------------------------------------------- /inputs/real47/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/17.jpg -------------------------------------------------------------------------------- /inputs/real47/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/19.jpg -------------------------------------------------------------------------------- /inputs/real47/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/2.jpg -------------------------------------------------------------------------------- /inputs/real47/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/20.jpg -------------------------------------------------------------------------------- /inputs/real47/21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/21.jpg -------------------------------------------------------------------------------- /inputs/real47/22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/22.jpg -------------------------------------------------------------------------------- /inputs/real47/23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/23.jpg -------------------------------------------------------------------------------- /inputs/real47/24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/24.jpg -------------------------------------------------------------------------------- /inputs/real47/26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/26.jpg -------------------------------------------------------------------------------- /inputs/real47/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/27.jpg -------------------------------------------------------------------------------- /inputs/real47/29.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/29.jpg -------------------------------------------------------------------------------- /inputs/real47/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/3.jpg -------------------------------------------------------------------------------- /inputs/real47/32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/32.jpg -------------------------------------------------------------------------------- /inputs/real47/33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/33.jpg -------------------------------------------------------------------------------- /inputs/real47/34.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/34.jpg -------------------------------------------------------------------------------- /inputs/real47/35.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/35.jpg -------------------------------------------------------------------------------- /inputs/real47/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/36.png -------------------------------------------------------------------------------- /inputs/real47/38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/38.jpg -------------------------------------------------------------------------------- /inputs/real47/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/4.jpg -------------------------------------------------------------------------------- /inputs/real47/40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/40.jpg -------------------------------------------------------------------------------- /inputs/real47/41.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/41.jpg -------------------------------------------------------------------------------- /inputs/real47/42.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/42.jpg -------------------------------------------------------------------------------- /inputs/real47/43.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/43.jpg -------------------------------------------------------------------------------- /inputs/real47/44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/44.jpg -------------------------------------------------------------------------------- /inputs/real47/45.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/45.jpg -------------------------------------------------------------------------------- /inputs/real47/46.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/46.jpg -------------------------------------------------------------------------------- /inputs/real47/47.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/47.jpg -------------------------------------------------------------------------------- /inputs/real47/48.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/48.jpg -------------------------------------------------------------------------------- /inputs/real47/49.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/49.jpg -------------------------------------------------------------------------------- /inputs/real47/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/5.jpg -------------------------------------------------------------------------------- /inputs/real47/50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/50.jpg -------------------------------------------------------------------------------- /inputs/real47/51.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/51.jpg -------------------------------------------------------------------------------- /inputs/real47/52.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/52.jpg -------------------------------------------------------------------------------- /inputs/real47/53.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/53.jpeg -------------------------------------------------------------------------------- /inputs/real47/54.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/54.jpeg -------------------------------------------------------------------------------- /inputs/real47/55.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/55.jpg -------------------------------------------------------------------------------- /inputs/real47/56.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/56.jpg -------------------------------------------------------------------------------- /inputs/real47/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/6.jpg -------------------------------------------------------------------------------- /inputs/real47/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/7.jpg -------------------------------------------------------------------------------- /inputs/real47/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/9.jpg -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "<image>" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" 11 | DEFAULT_IM_START_TOKEN = "<im_start>" 12 | DEFAULT_IM_END_TOKEN = "<im_end>" 13 | IMAGE_PLACEHOLDER = "<image-placeholder>" 14 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '<image>' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | @torch.inference_mode() 14 | def eval_model(model_name, questions_file, answers_file): 15 | # Model 16 | disable_torch_init() 17 | model_name = os.path.expanduser(model_name) 18 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 19 | model = AutoModelForCausalLM.from_pretrained(model_name, 20 | torch_dtype=torch.float16).cuda() 21 | 22 | 23 | ques_file = open(os.path.expanduser(questions_file), "r") 24 | ans_file = open(os.path.expanduser(answers_file), "w") 25 | for i, line in enumerate(tqdm(ques_file)): 26 | idx = json.loads(line)["question_id"] 27 | qs = json.loads(line)["text"] 28 | cat = json.loads(line)["category"] 29 | conv = default_conversation.copy() 30 | conv.append_message(conv.roles[0], qs) 31 | prompt = conv.get_prompt() 32 | inputs = tokenizer([prompt]) 33 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 34 | output_ids = model.generate( 35 | input_ids, 36 | do_sample=True, 37 | use_cache=True, 38 | temperature=0.7, 39 | max_new_tokens=1024,) 40 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 41 | try: 42 | index = outputs.index(conv.sep, len(prompt)) 43 | except ValueError: 44 | outputs += conv.sep 45 | index = outputs.index(conv.sep, len(prompt)) 46 | 47 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 48 | ans_id = shortuuid.uuid() 49 | ans_file.write(json.dumps({"question_id": idx, 50 | "text": outputs, 51 | "answer_id": ans_id, 52 | "model_id": model_name, 53 | "metadata": {}}) + "\n") 54 | ans_file.flush() 55 | ans_file.close() 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 60 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 61 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 62 | args = parser.parse_args() 63 | 64 | eval_model(args.model_name, args.question_file, args.answers_file) 65 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') 59 | image_tensor = process_images([image], image_processor, model.config)[0] 60 | 61 | with torch.inference_mode(): 62 | output_ids = model.generate( 63 | input_ids, 64 | images=image_tensor.unsqueeze(0).half().cuda(), 65 | image_sizes=[image.size], 66 | do_sample=True if args.temperature > 0 else False, 67 | temperature=args.temperature, 68 | top_p=args.top_p, 69 | num_beams=args.num_beams, 70 | # no_repeat_ngram_size=3, 71 | max_new_tokens=1024, 72 | use_cache=True) 73 | 74 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 75 | 76 | ans_id = shortuuid.uuid() 77 | ans_file.write(json.dumps({"question_id": idx, 78 | "prompt": cur_prompt, 79 | "text": outputs, 80 | "answer_id": ans_id, 81 | "model_id": model_name, 82 | "metadata": {}}) + "\n") 83 | ans_file.flush() 84 | ans_file.close() 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 89 | parser.add_argument("--model-base", type=str, default=None) 90 | parser.add_argument("--image-folder", type=str, default="") 91 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 92 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 93 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 94 | parser.add_argument("--num-chunks", type=int, default=1) 95 | parser.add_argument("--chunk-idx", type=int, default=0) 96 | parser.add_argument("--temperature", type=float, default=0.2) 97 | parser.add_argument("--top_p", type=float, default=None) 98 | parser.add_argument("--num_beams", type=int, default=1) 99 | args = parser.parse_args() 100 | 101 | eval_model(args) 102 | -------------------------------------------------------------------------------- /llava/eval/model_vqa_science.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for i, line in enumerate(tqdm(questions)): 42 | idx = line["id"] 43 | question = line['conversations'][0] 44 | qs = question['value'].replace('<image>', '').strip() 45 | cur_prompt = qs 46 | 47 | if 'image' in line: 48 | image_file = line["image"] 49 | image = Image.open(os.path.join(args.image_folder, image_file)) 50 | image_tensor = process_images([image], image_processor, model.config)[0] 51 | images = image_tensor.unsqueeze(0).half().cuda() 52 | image_sizes = [image.size] 53 | if getattr(model.config, 'mm_use_im_start_end', False): 54 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 55 | else: 56 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 57 | cur_prompt = '<image>' + '\n' + cur_prompt 58 | else: 59 | images = None 60 | image_sizes = None 61 | 62 | if args.single_pred_prompt: 63 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 64 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly." 65 | 66 | conv = conv_templates[args.conv_mode].copy() 67 | conv.append_message(conv.roles[0], qs) 68 | conv.append_message(conv.roles[1], None) 69 | prompt = conv.get_prompt() 70 | 71 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 72 | 73 | with torch.inference_mode(): 74 | output_ids = model.generate( 75 | input_ids, 76 | images=images, 77 | image_sizes=image_sizes, 78 | do_sample=True if args.temperature > 0 else False, 79 | temperature=args.temperature, 80 | max_new_tokens=1024, 81 | use_cache=True, 82 | ) 83 | 84 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 85 | 86 | ans_id = shortuuid.uuid() 87 | ans_file.write(json.dumps({"question_id": idx, 88 | "prompt": cur_prompt, 89 | "text": outputs, 90 | "answer_id": ans_id, 91 | "model_id": model_name, 92 | "metadata": {}}) + "\n") 93 | ans_file.flush() 94 | ans_file.close() 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 99 | parser.add_argument("--model-base", type=str, default=None) 100 | parser.add_argument("--image-folder", type=str, default="") 101 | parser.add_argument("--question-file", type=str, default="tables/question.json") 102 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 103 | parser.add_argument("--conv-mode", type=str, default="llava_v0") 104 | parser.add_argument("--num-chunks", type=int, default=1) 105 | parser.add_argument("--chunk-idx", type=int, default=0) 106 | parser.add_argument("--temperature", type=float, default=0.2) 107 | parser.add_argument("--answer-prompter", action="store_true") 108 | parser.add_argument("--single-pred-prompt", action="store_true") 109 | args = parser.parse_args() 110 | 111 | eval_model(args) 112 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import ( 5 | IMAGE_TOKEN_INDEX, 6 | DEFAULT_IMAGE_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | IMAGE_PLACEHOLDER, 10 | ) 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import ( 15 | process_images, 16 | tokenizer_image_token, 17 | get_model_name_from_path, 18 | ) 19 | 20 | from PIL import Image 21 | 22 | import requests 23 | from PIL import Image 24 | from io import BytesIO 25 | import re 26 | 27 | 28 | def image_parser(args): 29 | out = args.image_file.split(args.sep) 30 | return out 31 | 32 | 33 | def load_image(image_file): 34 | if image_file.startswith("http") or image_file.startswith("https"): 35 | response = requests.get(image_file) 36 | image = Image.open(BytesIO(response.content)).convert("RGB") 37 | else: 38 | image = Image.open(image_file).convert("RGB") 39 | return image 40 | 41 | 42 | def load_images(image_files): 43 | out = [] 44 | for image_file in image_files: 45 | image = load_image(image_file) 46 | out.append(image) 47 | return out 48 | 49 | 50 | def eval_model(args): 51 | # Model 52 | disable_torch_init() 53 | 54 | model_name = get_model_name_from_path(args.model_path) 55 | tokenizer, model, image_processor, context_len = load_pretrained_model( 56 | args.model_path, args.model_base, model_name 57 | ) 58 | 59 | qs = args.query 60 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 61 | if IMAGE_PLACEHOLDER in qs: 62 | if model.config.mm_use_im_start_end: 63 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 64 | else: 65 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 66 | else: 67 | if model.config.mm_use_im_start_end: 68 | qs = image_token_se + "\n" + qs 69 | else: 70 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 71 | 72 | if "llama-2" in model_name.lower(): 73 | conv_mode = "llava_llama_2" 74 | elif "mistral" in model_name.lower(): 75 | conv_mode = "mistral_instruct" 76 | elif "v1.6-34b" in model_name.lower(): 77 | conv_mode = "chatml_direct" 78 | elif "v1" in model_name.lower(): 79 | conv_mode = "llava_v1" 80 | elif "mpt" in model_name.lower(): 81 | conv_mode = "mpt" 82 | else: 83 | conv_mode = "llava_v0" 84 | 85 | if args.conv_mode is not None and conv_mode != args.conv_mode: 86 | print( 87 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 88 | conv_mode, args.conv_mode, args.conv_mode 89 | ) 90 | ) 91 | else: 92 | args.conv_mode = conv_mode 93 | 94 | conv = conv_templates[args.conv_mode].copy() 95 | conv.append_message(conv.roles[0], qs) 96 | conv.append_message(conv.roles[1], None) 97 | prompt = conv.get_prompt() 98 | 99 | image_files = image_parser(args) 100 | images = load_images(image_files) 101 | image_sizes = [x.size for x in images] 102 | images_tensor = process_images( 103 | images, 104 | image_processor, 105 | model.config 106 | ).to(model.device, dtype=torch.float16) 107 | 108 | input_ids = ( 109 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 110 | .unsqueeze(0) 111 | .cuda() 112 | ) 113 | 114 | with torch.inference_mode(): 115 | output_ids = model.generate( 116 | input_ids, 117 | images=images_tensor, 118 | image_sizes=image_sizes, 119 | do_sample=True if args.temperature > 0 else False, 120 | temperature=args.temperature, 121 | top_p=args.top_p, 122 | num_beams=args.num_beams, 123 | max_new_tokens=args.max_new_tokens, 124 | use_cache=True, 125 | ) 126 | 127 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 128 | print(outputs) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 134 | parser.add_argument("--model-base", type=str, default=None) 135 | parser.add_argument("--image-file", type=str, required=True) 136 | parser.add_argument("--query", type=str, required=True) 137 | parser.add_argument("--conv-mode", type=str, default=None) 138 | parser.add_argument("--sep", type=str, default=",") 139 | parser.add_argument("--temperature", type=float, default=0.2) 140 | parser.add_argument("--top_p", type=float, default=None) 141 | parser.add_argument("--num_beams", type=int, default=1) 142 | parser.add_argument("--max_new_tokens", type=int, default=512) 143 | args = parser.parse_args() 144 | 145 | eval_model(args) 146 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 2406 2406"><path d="M1 578.4C1 259.5 259.5 1 578.4 1h1249.1c319 0 577.5 258.5 577.5 577.4V2406H578.4C259.5 2406 1 2147.5 1 1828.6V578.4z" fill="#74aa9c"/><path d="M1107.3 299.1c-198 0-373.9 127.3-435.2 315.3C544.8 640.6 434.9 720.2 370.5 833c-99.3 171.4-76.6 386.9 56.4 533.8-41.1 123.1-27 257.7 38.6 369.2 98.7 172 297.3 260.2 491.6 219.2 86.1 97 209.8 152.3 339.6 151.8 198 0 373.9-127.3 435.3-315.3 127.5-26.3 237.2-105.9 301-218.5 99.9-171.4 77.2-386.9-55.8-533.9v-.6c41.1-123.1 27-257.8-38.6-369.8-98.7-171.4-297.3-259.6-491-218.6-86.6-96.8-210.5-151.8-340.3-151.2zm0 117.5-.6.6c79.7 0 156.3 27.5 217.6 78.4-2.5 1.2-7.4 4.3-11 6.1L952.8 709.3c-18.4 10.4-29.4 30-29.4 51.4V1248l-155.1-89.4V755.8c-.1-187.1 151.6-338.9 339-339.2zm434.2 141.9c121.6-.2 234 64.5 294.7 169.8 39.2 68.6 53.9 148.8 40.4 226.5-2.5-1.8-7.3-4.3-10.4-6.1l-360.4-208.2c-18.4-10.4-41-10.4-59.4 0L1024 984.2V805.4L1372.7 604c51.3-29.7 109.5-45.4 168.8-45.5zM650 743.5v427.9c0 21.4 11 40.4 29.4 51.4l421.7 243-155.7 90L597.2 1355c-162-93.8-217.4-300.9-123.8-462.8C513.1 823.6 575.5 771 650 743.5zm807.9 106 348.8 200.8c162.5 93.7 217.6 300.6 123.8 462.8l.6.6c-39.8 68.6-102.4 121.2-176.5 148.2v-428c0-21.4-11-41-29.4-51.4l-422.3-243.7 155-89.3zM1201.7 997l177.8 102.8v205.1l-177.8 102.8-177.8-102.8v-205.1L1201.7 997zm279.5 161.6 155.1 89.4v402.2c0 187.3-152 339.2-339 339.2v-.6c-79.1 0-156.3-27.6-217-78.4 2.5-1.2 8-4.3 11-6.1l360.4-207.5c18.4-10.4 30-30 29.4-51.4l.1-486.8zM1380 1421.9v178.8l-348.8 200.8c-162.5 93.1-369.6 38-463.4-123.7h.6c-39.8-68-54-148.8-40.5-226.5 2.5 1.8 7.4 4.3 10.4 6.1l360.4 208.2c18.4 10.4 41 10.4 59.4 0l421.9-243.7z" fill="white"/></svg> -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | <svg xmlns="http://www.w3.org/2000/svg" height="48" viewBox="0 96 960 960" width="48"><path d="m762.846 947.614-124.77-124.769-88 88-30.306-30.692q-16.616-16.231-16.616-40.077 0-23.846 16.616-40.461L708 611.385q16.23-16.231 40.076-16.231t40.462 16.231l30.307 30.691-88 88 124.154 124.77q8.615 8.615 8.615 20.23 0 11.616-8.615 20.231l-51.692 52.307q-8.615 9-20.231 9-11.615 0-20.23-9Zm97.153-624.076L412.768 771.153l27.847 28.077q16.231 16.616 16.231 40.462 0 23.846-16.231 40.077l-30.691 30.691-88-88-124.77 124.769q-8.615 9-20.23 9-11.616 0-20.231-9l-52.307-52.307q-9-8.615-9-20.23 0-11.616 9-20.231l124.769-124.769-88-88L171.847 611q16.231-16.23 40.077-16.23 23.846 0 40.461 16.23l28.462 28.232 447.615-447.231h131.537v131.537ZM323.846 483.769l33.769-34.154 34.154-34.153-34.154 34.153-33.769 34.154Zm-31.999 31.999-191.846-192.23V192.001h131.537l191.461 191.846-31.23 31.615-179.077-178.077h-67.307v67.307l178.461 179.077-31.999 31.999Zm87.691 222.77 435.077-433.846v-67.307h-67.307L312.231 670.846l67.307 67.692Zm0 0L346.385 704l-34.154-33.154L346.385 704l33.153 34.538Z"/></svg> -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | except: 6 | pass 7 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | use_s2 = getattr(vision_tower_cfg, 's2', False) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 10 | if use_s2: 11 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 12 | else: 13 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 14 | 15 | raise ValueError(f'Unknown vision tower: {vision_tower}') 16 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu#39;, projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "mistral" in model_name.lower(): 37 | conv_mode = "mistral_instruct" 38 | elif "v1.6-34b" in model_name.lower(): 39 | conv_mode = "chatml_direct" 40 | elif "v1" in model_name.lower(): 41 | conv_mode = "llava_v1" 42 | elif "mpt" in model_name.lower(): 43 | conv_mode = "mpt" 44 | else: 45 | conv_mode = "llava_v0" 46 | 47 | if args.conv_mode is not None and conv_mode != args.conv_mode: 48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 49 | else: 50 | args.conv_mode = conv_mode 51 | 52 | conv = conv_templates[args.conv_mode].copy() 53 | if "mpt" in model_name.lower(): 54 | roles = ('user', 'assistant') 55 | else: 56 | roles = conv.roles 57 | 58 | image = load_image(args.image_file) 59 | image_size = image.size 60 | # Similar operation in model_worker.py 61 | image_tensor = process_images([image], image_processor, model.config) 62 | if type(image_tensor) is list: 63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 64 | else: 65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 66 | 67 | while True: 68 | try: 69 | inp = input(f"{roles[0]}: ") 70 | except EOFError: 71 | inp = "" 72 | if not inp: 73 | print("exit...") 74 | break 75 | 76 | print(f"{roles[1]}: ", end="") 77 | 78 | if image is not None: 79 | # first message 80 | if model.config.mm_use_im_start_end: 81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 82 | else: 83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 84 | image = None 85 | 86 | conv.append_message(conv.roles[0], inp) 87 | conv.append_message(conv.roles[1], None) 88 | prompt = conv.get_prompt() 89 | 90 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 92 | keywords = [stop_str] 93 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 94 | 95 | with torch.inference_mode(): 96 | output_ids = model.generate( 97 | input_ids, 98 | images=image_tensor, 99 | image_sizes=[image_size], 100 | do_sample=True if args.temperature > 0 else False, 101 | temperature=args.temperature, 102 | max_new_tokens=args.max_new_tokens, 103 | streamer=streamer, 104 | use_cache=True) 105 | 106 | outputs = tokenizer.decode(output_ids[0]).strip() 107 | conv.messages[-1][-1] = outputs 108 | 109 | if args.debug: 110 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 116 | parser.add_argument("--model-base", type=str, default=None) 117 | parser.add_argument("--image-file", type=str, required=True) 118 | parser.add_argument("--device", type=str, default="cuda") 119 | parser.add_argument("--conv-mode", type=str, default=None) 120 | parser.add_argument("--temperature", type=float, default=0.2) 121 | parser.add_argument("--max-new-tokens", type=int, default=512) 122 | parser.add_argument("--load-8bit", action="store_true") 123 | parser.add_argument("--load-4bit", action="store_true") 124 | parser.add_argument("--debug", action="store_true") 125 | args = parser.parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train(attn_implementation="flash_attention_2") 5 | -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /ram/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import inference_tag2text, inference_ram, inference_ram_openset 2 | from .transform import get_transform 3 | -------------------------------------------------------------------------------- /ram/configs/finetune.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | ] 4 | image_path_root: "" 5 | 6 | # size of vit model; base or large 7 | vit: 'swin_l' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 384 12 | batch_size: 26 13 | 14 | # optimizer 15 | weight_decay: 0.05 16 | init_lr: 5e-06 17 | min_lr: 0 18 | max_epoch: 2 19 | warmup_steps: 3000 20 | 21 | class_num: 4585 22 | 23 | -------------------------------------------------------------------------------- /ram/configs/finetune_tag2text.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | ] 4 | image_path_root: "" 5 | 6 | # size of vit model; base or large 7 | vit: 'swin_b' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 384 12 | batch_size: 36 13 | 14 | # optimizer 15 | weight_decay: 0.05 16 | init_lr: 5e-06 17 | min_lr: 0 18 | max_epoch: 2 19 | warmup_steps: 3000 20 | 21 | class_num: 4585 22 | 23 | -------------------------------------------------------------------------------- /ram/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } -------------------------------------------------------------------------------- /ram/configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | 'datasets/train/vg_ram.json', 4 | 'datasets/train/sbu_ram.json', 5 | 'datasets/train/cc3m_train_ram.json', 6 | 'datasets/train/cc3m_val_ram.json', 7 | 'datasets/train/cc12m_ram.json', 8 | ] 9 | image_path_root: "" 10 | 11 | # size of vit model; base or large 12 | vit: 'swin_l' 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | 16 | image_size: 224 17 | batch_size: 52 18 | 19 | # optimizer 20 | weight_decay: 0.05 21 | init_lr: 1e-4 22 | min_lr: 5e-7 23 | warmup_lr: 5e-7 24 | lr_decay_rate: 0.9 25 | max_epoch: 5 26 | warmup_steps: 3000 27 | 28 | class_num: 4585 29 | 30 | -------------------------------------------------------------------------------- /ram/configs/pretrain_tag2text.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | 'datasets/train/vg_ram.json', 4 | 'datasets/train/sbu_ram.json', 5 | 'datasets/train/cc3m_train_ram.json', 6 | 'datasets/train/cc3m_val_ram.json', 7 | 'datasets/train/cc12m_ram.json', 8 | ] 9 | image_path_root: "" 10 | 11 | # size of vit model; base or large 12 | vit: 'swin_b' 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | 16 | image_size: 224 17 | batch_size: 80 18 | 19 | # optimizer 20 | weight_decay: 0.05 21 | init_lr: 1e-4 22 | min_lr: 5e-7 23 | warmup_lr: 5e-7 24 | lr_decay_rate: 0.9 25 | max_epoch: 5 26 | warmup_steps: 3000 27 | 28 | class_num: 4585 29 | 30 | -------------------------------------------------------------------------------- /ram/configs/q2l_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 4, 15 | "num_hidden_layers": 2, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true, 21 | "add_tag_cross_attention": false 22 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinB_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 224, 5 | "window_size": 7, 6 | "embed_dim": 128, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 4, 8, 16, 32 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinB_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 128, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 4, 8, 16, 32 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window7_224_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 224, 5 | "window_size": 7, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } 10 | -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } -------------------------------------------------------------------------------- /ram/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms.functional import InterpolationMode 5 | 6 | from .dataset import pretrain_dataset, finetune_dataset 7 | from .randaugment import RandomAugment 8 | 9 | def create_dataset(dataset, config, min_scale=0.5): 10 | 11 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 12 | 13 | transform_train = transforms.Compose([ 14 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 15 | transforms.RandomHorizontalFlip(), 16 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 17 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 18 | transforms.ToTensor(), 19 | normalize, 20 | ]) 21 | 22 | transform_inputsize_224 = transforms.Compose([ 23 | transforms.RandomResizedCrop(224,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 24 | transforms.RandomHorizontalFlip(), 25 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 26 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 27 | transforms.ToTensor(), 28 | normalize, 29 | ]) 30 | 31 | if dataset=='pretrain': 32 | dataset = pretrain_dataset(config['train_file'], transform_train, class_num=config['class_num'], root=config['image_path_root']) 33 | return dataset 34 | 35 | elif dataset=='finetune': 36 | dataset = finetune_dataset(config['train_file'], transform_train, transform_inputsize_224, class_num=config['class_num'], root=config['image_path_root']) 37 | return dataset 38 | 39 | 40 | 41 | 42 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 43 | samplers = [] 44 | for dataset,shuffle in zip(datasets,shuffles): 45 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 46 | samplers.append(sampler) 47 | return samplers 48 | 49 | 50 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 51 | loaders = [] 52 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 53 | if is_train: 54 | shuffle = (sampler is None) 55 | drop_last = True 56 | else: 57 | shuffle = False 58 | drop_last = False 59 | loader = DataLoader( 60 | dataset, 61 | batch_size=bs, 62 | num_workers=n_worker, 63 | pin_memory=True, 64 | sampler=sampler, 65 | shuffle=shuffle, 66 | collate_fn=collate_fn, 67 | drop_last=drop_last, 68 | ) 69 | loaders.append(loader) 70 | return loaders 71 | 72 | -------------------------------------------------------------------------------- /ram/data/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from .utils import pre_caption 13 | import os,glob 14 | 15 | import torch 16 | import numpy as np 17 | 18 | class pretrain_dataset(Dataset): 19 | def __init__(self, ann_file, transform, class_num = 4585, root = ''): 20 | 21 | self.ann = [] 22 | for f in ann_file: 23 | print('loading '+f) 24 | ann = json.load(open(f,'r')) 25 | self.ann += ann 26 | 27 | self.transform = transform 28 | self.class_num = class_num 29 | self.root = root 30 | 31 | 32 | def __len__(self): 33 | return len(self.ann) 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.ann[index] 38 | 39 | image_path_use = os.path.join(self.root, ann['image_path']) 40 | image = Image.open(image_path_use).convert('RGB') 41 | image = self.transform(image) 42 | 43 | # required for tag2text support 44 | if ann.get('union_label_id') is not None: 45 | num = ann['union_label_id'] 46 | image_tag = np.zeros([self.class_num]) 47 | image_tag[num] = 1 48 | image_tag = torch.tensor(image_tag, dtype = torch.long) 49 | else: 50 | image_tag = None 51 | 52 | caption_index = np.random.randint(0, len(ann['caption'])) 53 | 54 | caption = pre_caption(ann['caption'][caption_index],30) 55 | 56 | num = ann['parse_label_id'][caption_index] 57 | parse_tag = np.zeros([self.class_num]) 58 | parse_tag[num] = 1 59 | parse_tag = torch.tensor(parse_tag, dtype = torch.long) 60 | 61 | return image, caption, image_tag, parse_tag 62 | 63 | 64 | class finetune_dataset(Dataset): 65 | def __init__(self, ann_file, transform, transform_224, class_num = 4585, root = ''): 66 | 67 | self.ann = [] 68 | for f in ann_file: 69 | print('loading '+f) 70 | ann = json.load(open(f,'r')) 71 | self.ann += ann 72 | 73 | self.transform = transform 74 | self.transform_224 = transform_224 75 | self.class_num = class_num 76 | self.root = root 77 | 78 | 79 | def __len__(self): 80 | return len(self.ann) 81 | 82 | def __getitem__(self, index): 83 | 84 | ann = self.ann[index] 85 | 86 | image_path_use = os.path.join(self.root, ann['image_path']) 87 | image = Image.open(image_path_use).convert('RGB') 88 | image = self.transform(image) 89 | 90 | image_224 = Image.open(image_path_use).convert('RGB') 91 | image_224 = self.transform_224(image_224) 92 | 93 | # required for tag2text support 94 | if ann.get('union_label_id') is not None: 95 | num = ann['union_label_id'] 96 | image_tag = np.zeros([self.class_num]) 97 | image_tag[num] = 1 98 | image_tag = torch.tensor(image_tag, dtype = torch.long) 99 | else: 100 | image_tag = None 101 | 102 | caption_index = np.random.randint(0, len(ann['caption'])) 103 | 104 | caption = pre_caption(ann['caption'][caption_index],30) 105 | 106 | num = ann['parse_label_id'][caption_index] 107 | parse_tag = np.zeros([self.class_num]) 108 | parse_tag[num] = 1 109 | parse_tag = torch.tensor(parse_tag, dtype = torch.long) 110 | 111 | return image, image_224, caption, image_tag, parse_tag 112 | 113 | -------------------------------------------------------------------------------- /ram/data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils 9 | 10 | def pre_caption(caption,max_words=50): 11 | caption = re.sub( 12 | r"([.!\"()*#:;~])", 13 | ' ', 14 | caption.lower(), 15 | ) 16 | caption = re.sub( 17 | r"\s{2,}", 18 | ' ', 19 | caption, 20 | ) 21 | caption = caption.rstrip('\n') 22 | caption = caption.strip(' ') 23 | 24 | #truncate caption 25 | caption_words = caption.split(' ') 26 | if len(caption_words)>max_words: 27 | caption = ' '.join(caption_words[:max_words]) 28 | 29 | return caption 30 | 31 | def pre_question(question,max_ques_words=50): 32 | question = re.sub( 33 | r"([.!\"()*#:;~])", 34 | '', 35 | question.lower(), 36 | ) 37 | question = question.rstrip(' ') 38 | 39 | #truncate question 40 | question_words = question.split(' ') 41 | if len(question_words)>max_ques_words: 42 | question = ' '.join(question_words[:max_ques_words]) 43 | 44 | return question 45 | 46 | 47 | def save_result(result, result_dir, filename, remove_duplicate=''): 48 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 49 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 50 | 51 | json.dump(result,open(result_file,'w')) 52 | 53 | dist.barrier() 54 | 55 | if utils.is_main_process(): 56 | # combine results from all processes 57 | result = [] 58 | 59 | for rank in range(utils.get_world_size()): 60 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 61 | res = json.load(open(result_file,'r')) 62 | result += res 63 | 64 | if remove_duplicate: 65 | result_new = [] 66 | id_list = [] 67 | for res in result: 68 | if res[remove_duplicate] not in id_list: 69 | id_list.append(res[remove_duplicate]) 70 | result_new.append(res) 71 | result = result_new 72 | 73 | json.dump(result,open(final_result_file,'w')) 74 | print('result file saved to %s'%final_result_file) 75 | 76 | return final_result_file 77 | 78 | 79 | 80 | from pycocotools.coco import COCO 81 | from pycocoevalcap.eval import COCOEvalCap 82 | from torchvision.datasets.utils import download_url 83 | 84 | def coco_caption_eval(coco_gt_root, results_file, split): 85 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 86 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 87 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 88 | 89 | download_url(urls[split],coco_gt_root) 90 | annotation_file = os.path.join(coco_gt_root,filenames[split]) 91 | 92 | # create coco object and coco_result object 93 | coco = COCO(annotation_file) 94 | coco_result = coco.loadRes(results_file) 95 | 96 | # create coco_eval object by taking coco and coco_result 97 | coco_eval = COCOEvalCap(coco, coco_result) 98 | 99 | # evaluate on a subset of images by setting 100 | # coco_eval.params['image_id'] = coco_result.getImgIds() 101 | # please remove this line when evaluating the full validation set 102 | # coco_eval.params['image_id'] = coco_result.getImgIds() 103 | 104 | # evaluate results 105 | # SPICE will take a few minutes the first time, but speeds up due to caching 106 | coco_eval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in coco_eval.eval.items(): 110 | print(f'{metric}: {score:.3f}') 111 | 112 | return coco_eval -------------------------------------------------------------------------------- /ram/inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Inference of RAM and Tag2Text Models 3 | * Written by Xinyu Huang 4 | ''' 5 | import torch 6 | 7 | 8 | def inference_tag2text(image, model, input_tag="None"): 9 | 10 | with torch.no_grad(): 11 | caption, tag_predict = model.generate(image, 12 | tag_input=None, 13 | max_length=50, 14 | return_tag_predict=True) 15 | 16 | if input_tag == '' or input_tag == 'none' or input_tag == 'None': 17 | return tag_predict[0], None, caption[0] 18 | 19 | # If user input specified tags: 20 | else: 21 | input_tag_list = [] 22 | input_tag_list.append(input_tag.replace(',', ' | ')) 23 | 24 | with torch.no_grad(): 25 | caption, input_tag = model.generate(image, 26 | tag_input=input_tag_list, 27 | max_length=50, 28 | return_tag_predict=True) 29 | 30 | return tag_predict[0], input_tag[0], caption[0] 31 | 32 | 33 | def inference_ram(image, model): 34 | 35 | with torch.no_grad(): 36 | tags, tags_chinese = model.generate_tag(image) 37 | 38 | return tags[0],tags_chinese[0] 39 | 40 | 41 | def inference_ram_openset(image, model): 42 | 43 | with torch.no_grad(): 44 | tags = model.generate_tag_openset(image) 45 | 46 | return tags[0] 47 | -------------------------------------------------------------------------------- /ram/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ram_plus import ram_plus 2 | from .ram import ram 3 | from .tag2text import tag2text 4 | -------------------------------------------------------------------------------- /ram/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, Resize, ToTensor 2 | 3 | 4 | def convert_to_rgb(image): 5 | return image.convert("RGB") 6 | 7 | def get_transform(image_size=384): 8 | return Compose([ 9 | convert_to_rgb, 10 | Resize((image_size, image_size)), 11 | ToTensor(), 12 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 13 | ]) 14 | -------------------------------------------------------------------------------- /ram/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import get_mAP, get_PR 2 | from .openset_utils import build_openset_label_embedding, build_openset_llm_label_embedding 3 | -------------------------------------------------------------------------------- /ram/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from numpy import ndarray 5 | 6 | 7 | def get_mAP( 8 | preds: ndarray, 9 | gt_file: str, 10 | taglist: List[str] 11 | ) -> Tuple[float, ndarray]: 12 | assert preds.shape[1] == len(taglist) 13 | 14 | # When mapping categories from test datasets to our system, there might be 15 | # multiple vs one situation due to different semantic definitions of tags. 16 | # So there can be duplicate tags in `taglist`. This special case is taken 17 | # into account. 18 | tag2idxs = {} 19 | for idx, tag in enumerate(taglist): 20 | if tag not in tag2idxs: 21 | tag2idxs[tag] = [] 22 | tag2idxs[tag].append(idx) 23 | 24 | # build targets 25 | targets = np.zeros_like(preds) 26 | with open(gt_file, "r") as f: 27 | lines = [line.strip("\n").split(",") for line in f.readlines()] 28 | assert len(lines) == targets.shape[0] 29 | for i, line in enumerate(lines): 30 | for tag in line[1:]: 31 | targets[i, tag2idxs[tag]] = 1.0 32 | 33 | # compute average precision for each class 34 | APs = np.zeros(preds.shape[1]) 35 | for k in range(preds.shape[1]): 36 | APs[k] = _average_precision(preds[:, k], targets[:, k]) 37 | 38 | return APs.mean(), APs 39 | 40 | 41 | def _average_precision(output: ndarray, target: ndarray) -> float: 42 | epsilon = 1e-8 43 | 44 | # sort examples 45 | indices = output.argsort()[::-1] 46 | # Computes prec@i 47 | total_count_ = np.cumsum(np.ones((len(output), 1))) 48 | 49 | target_ = target[indices] 50 | ind = target_ == 1 51 | pos_count_ = np.cumsum(ind) 52 | total = pos_count_[-1] 53 | pos_count_[np.logical_not(ind)] = 0 54 | pp = pos_count_ / total_count_ 55 | precision_at_i_ = np.sum(pp) 56 | precision_at_i = precision_at_i_ / (total + epsilon) 57 | 58 | return precision_at_i 59 | 60 | 61 | def get_PR( 62 | pred_file: str, 63 | gt_file: str, 64 | taglist: List[str] 65 | ) -> Tuple[float, float, ndarray, ndarray]: 66 | # When mapping categories from test datasets to our system, there might be 67 | # multiple vs one situation due to different semantic definitions of tags. 68 | # So there can be duplicate tags in `taglist`. This special case is taken 69 | # into account. 70 | tag2idxs = {} 71 | for idx, tag in enumerate(taglist): 72 | if tag not in tag2idxs: 73 | tag2idxs[tag] = [] 74 | tag2idxs[tag].append(idx) 75 | 76 | # build preds 77 | with open(pred_file, "r", encoding="utf-8") as f: 78 | lines = [line.strip().split(",") for line in f.readlines()] 79 | preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 80 | for i, line in enumerate(lines): 81 | for tag in line[1:]: 82 | preds[i, tag2idxs[tag]] = True 83 | 84 | # build targets 85 | with open(gt_file, "r", encoding="utf-8") as f: 86 | lines = [line.strip().split(",") for line in f.readlines()] 87 | targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 88 | for i, line in enumerate(lines): 89 | for tag in line[1:]: 90 | targets[i, tag2idxs[tag]] = True 91 | 92 | assert preds.shape == targets.shape 93 | 94 | # calculate P and R 95 | TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222 96 | FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222 97 | FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222 98 | eps = 1.e-9 99 | Ps = TPs / (TPs + FPs + eps) 100 | Rs = TPs / (TPs + FNs + eps) 101 | 102 | return Ps.mean(), Rs.mean(), Ps, Rs 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | torch==2.2.2+cu118 3 | torchvision==0.17.2+cu118 4 | torchaudio==2.2.2+cu118 5 | xformers==0.0.25.post1+cu118 6 | omegaconf==2.3.0 7 | accelerate==0.28.0 8 | einops==0.7.0 9 | opencv_python==4.9.0.80 10 | scipy==1.12.0 11 | ftfy==6.2.0 12 | regex==2023.12.25 13 | python-dateutil==2.9.0.post0 14 | timm==0.9.16 15 | pytorch-lightning==2.2.1 # only for loading pretrained sd weight 16 | tensorboard==2.16.2 # for tensorboard event visualization 17 | protobuf==4.25.3 # for tensorboard 18 | lpips==0.1.4 19 | facexlib==0.3.0 20 | gradio==4.43.0 21 | polars==1.12.0 22 | torchsde==0.2.6 23 | bitsandbytes==0.44.1 24 | 25 | # requirements for llava 26 | transformers==4.37.2 27 | tokenizers==0.15.1 28 | sentencepiece==0.1.99 29 | 30 | # requirements for ram 31 | fairscale==0.4.4 32 | --------------------------------------------------------------------------------