The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── LICENSE
├── README.md
├── assets
    ├── docs
    │   ├── installation_xOS.md
    │   └── v1_models.md
    ├── gradio.png
    ├── gradio_error_img.png
    ├── logo.png
    ├── pipeline.png
    ├── teaser.png
    └── visual_results
    │   ├── bfr1.png
    │   ├── bfr2.png
    │   ├── bfr4.png
    │   ├── bid1.png
    │   ├── bid2.png
    │   ├── bid3.png
    │   ├── bsr1.png
    │   ├── bsr2.png
    │   ├── bsr3.png
    │   ├── bsr4.png
    │   ├── bsr5.png
    │   ├── bsr6.png
    │   ├── bsr7.png
    │   ├── tiled_sampling.png
    │   ├── whole_image1.png
    │   └── whole_image2.png
├── configs
    ├── inference
    │   ├── bsrnet.yaml
    │   ├── cldm.yaml
    │   ├── diffusion.yaml
    │   ├── diffusion_v2.1.yaml
    │   ├── scunet.yaml
    │   └── swinir.yaml
    └── train
    │   ├── train_stage1.yaml
    │   ├── train_stage2.yaml
    │   └── train_stage2_v2.1.yaml
├── diffbir
    ├── dataset
    │   ├── batch_transform.py
    │   ├── codeformer.py
    │   ├── degradation.py
    │   ├── diffjpeg.py
    │   ├── file_backend.py
    │   ├── realesrgan.py
    │   └── utils.py
    ├── inference
    │   ├── __init__.py
    │   ├── bfr_loop.py
    │   ├── bid_loop.py
    │   ├── bsr_loop.py
    │   ├── custom_loop.py
    │   ├── loop.py
    │   ├── pretrained_models.py
    │   └── unaligned_bfr_loop.py
    ├── model
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── bsrnet.py
    │   ├── cldm.py
    │   ├── clip.py
    │   ├── config.py
    │   ├── controlnet.py
    │   ├── distributions.py
    │   ├── gaussian_diffusion.py
    │   ├── open_clip
    │   │   ├── __init__.py
    │   │   ├── bpe_simple_vocab_16e6.txt.gz
    │   │   ├── model.py
    │   │   ├── tokenizer.py
    │   │   └── transformer.py
    │   ├── scunet.py
    │   ├── swinir.py
    │   ├── unet.py
    │   ├── util.py
    │   └── vae.py
    ├── pipeline.py
    ├── sampler
    │   ├── __init__.py
    │   ├── ddim_sampler.py
    │   ├── dpm_solver_pytorch.py
    │   ├── dpms_sampler.py
    │   ├── edm_sampler.py
    │   ├── k_diffusion.py
    │   ├── sampler.py
    │   └── spaced_sampler.py
    └── utils
    │   ├── caption.py
    │   ├── common.py
    │   ├── cond_fn.py
    │   ├── face.py
    │   └── tilevae
    │       ├── __init__.py
    │       ├── attn.py
    │       └── tilevae.py
├── inference.py
├── inputs
    ├── demo
    │   ├── bfr
    │   │   ├── aligned
    │   │   │   ├── 0229.png
    │   │   │   ├── 0427.png
    │   │   │   ├── 0722.png
    │   │   │   └── hermione.jpg
    │   │   └── whole_img
    │   │   │   ├── 01.jpg
    │   │   │   ├── 02.png
    │   │   │   ├── Audrey_Hepburn.jpg
    │   │   │   ├── Blake_Lively.jpg
    │   │   │   ├── Harry_Potter.jpg
    │   │   │   ├── Queen.jpeg
    │   │   │   └── real47_1.jpg
    │   ├── bid
    │   │   ├── Audrey_Hepburn.jpg
    │   │   ├── Bears.png
    │   │   ├── Flowers.png
    │   │   ├── Movie.png
    │   │   ├── Postcards.png
    │   │   ├── cty_fnb_0047.png
    │   │   ├── kf_fnb_0058.png
    │   │   └── palace.png
    │   └── bsr
    │   │   ├── 14.jpg
    │   │   ├── 29.jpg
    │   │   ├── 49.jpg
    │   │   ├── 53.jpeg
    │   │   └── comic3.png
    └── real47
    │   ├── 1.jpg
    │   ├── 11.jpg
    │   ├── 12.jpg
    │   ├── 13.jpg
    │   ├── 14.jpg
    │   ├── 15.jpg
    │   ├── 16.jpg
    │   ├── 17.jpg
    │   ├── 19.jpg
    │   ├── 2.jpg
    │   ├── 20.jpg
    │   ├── 21.jpg
    │   ├── 22.jpg
    │   ├── 23.jpg
    │   ├── 24.jpg
    │   ├── 26.jpg
    │   ├── 27.jpg
    │   ├── 29.jpg
    │   ├── 3.jpg
    │   ├── 32.jpg
    │   ├── 33.jpg
    │   ├── 34.jpg
    │   ├── 35.jpg
    │   ├── 36.png
    │   ├── 38.jpg
    │   ├── 4.jpg
    │   ├── 40.jpg
    │   ├── 41.jpg
    │   ├── 42.jpg
    │   ├── 43.jpg
    │   ├── 44.jpg
    │   ├── 45.jpg
    │   ├── 46.jpg
    │   ├── 47.jpg
    │   ├── 48.jpg
    │   ├── 49.jpg
    │   ├── 5.jpg
    │   ├── 50.jpg
    │   ├── 51.jpg
    │   ├── 52.jpg
    │   ├── 53.jpeg
    │   ├── 54.jpeg
    │   ├── 55.jpg
    │   ├── 56.jpg
    │   ├── 6.jpg
    │   ├── 7.jpg
    │   └── 9.jpg
├── llava
    ├── __init__.py
    ├── constants.py
    ├── conversation.py
    ├── eval
    │   ├── eval_gpt_review.py
    │   ├── eval_gpt_review_bench.py
    │   ├── eval_gpt_review_visual.py
    │   ├── eval_pope.py
    │   ├── eval_science_qa.py
    │   ├── eval_science_qa_gpt4.py
    │   ├── eval_science_qa_gpt4_requery.py
    │   ├── eval_textvqa.py
    │   ├── generate_webpage_data_from_table.py
    │   ├── m4c_evaluator.py
    │   ├── model_qa.py
    │   ├── model_vqa.py
    │   ├── model_vqa_loader.py
    │   ├── model_vqa_mmbench.py
    │   ├── model_vqa_science.py
    │   ├── qa_baseline_gpt35.py
    │   ├── run_llava.py
    │   ├── summarize_gpt_review.py
    │   ├── table
    │   │   ├── answer
    │   │   │   ├── answer_alpaca-13b.jsonl
    │   │   │   ├── answer_bard.jsonl
    │   │   │   ├── answer_gpt35.jsonl
    │   │   │   ├── answer_llama-13b.jsonl
    │   │   │   └── answer_vicuna-13b.jsonl
    │   │   ├── caps_boxes_coco2014_val_80.jsonl
    │   │   ├── model.jsonl
    │   │   ├── prompt.jsonl
    │   │   ├── question.jsonl
    │   │   ├── results
    │   │   │   ├── test_sqa_llava_13b_v0.json
    │   │   │   └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
    │   │   ├── review
    │   │   │   ├── review_alpaca-13b_vicuna-13b.jsonl
    │   │   │   ├── review_bard_vicuna-13b.jsonl
    │   │   │   ├── review_gpt35_vicuna-13b.jsonl
    │   │   │   └── review_llama-13b_vicuna-13b.jsonl
    │   │   ├── reviewer.jsonl
    │   │   └── rule.json
    │   └── webpage
    │   │   ├── figures
    │   │       ├── alpaca.png
    │   │       ├── bard.jpg
    │   │       ├── chatgpt.svg
    │   │       ├── llama.jpg
    │   │       ├── swords_FILL0_wght300_GRAD0_opsz48.svg
    │   │       └── vicuna.jpeg
    │   │   ├── index.html
    │   │   ├── script.js
    │   │   └── styles.css
    ├── mm_utils.py
    ├── model
    │   ├── __init__.py
    │   ├── apply_delta.py
    │   ├── builder.py
    │   ├── consolidate.py
    │   ├── language_model
    │   │   ├── llava_llama.py
    │   │   ├── llava_mistral.py
    │   │   └── llava_mpt.py
    │   ├── llava_arch.py
    │   ├── make_delta.py
    │   ├── multimodal_encoder
    │   │   ├── builder.py
    │   │   └── clip_encoder.py
    │   ├── multimodal_projector
    │   │   └── builder.py
    │   └── utils.py
    ├── serve
    │   ├── __init__.py
    │   ├── cli.py
    │   ├── controller.py
    │   ├── examples
    │   │   ├── extreme_ironing.jpg
    │   │   └── waterview.jpg
    │   ├── gradio_web_server.py
    │   ├── model_worker.py
    │   ├── register_worker.py
    │   ├── sglang_worker.py
    │   └── test_message.py
    ├── train
    │   ├── llama_flash_attn_monkey_patch.py
    │   ├── llama_xformers_attn_monkey_patch.py
    │   ├── llava_trainer.py
    │   ├── train.py
    │   ├── train_mem.py
    │   └── train_xformers.py
    └── utils.py
├── ram
    ├── __init__.py
    ├── configs
    │   ├── finetune.yaml
    │   ├── finetune_tag2text.yaml
    │   ├── med_config.json
    │   ├── pretrain.yaml
    │   ├── pretrain_tag2text.yaml
    │   ├── q2l_config.json
    │   └── swin
    │   │   ├── config_swinB_224.json
    │   │   ├── config_swinB_384.json
    │   │   ├── config_swinL_224.json
    │   │   └── config_swinL_384.json
    ├── data
    │   ├── __init__.py
    │   ├── dataset.py
    │   ├── ram_tag_list.txt
    │   ├── ram_tag_list_chinese.txt
    │   ├── ram_tag_list_threshold.txt
    │   ├── randaugment.py
    │   ├── tag2text_ori_tag_list.txt
    │   ├── tag_list.txt
    │   └── utils.py
    ├── inference.py
    ├── models
    │   ├── __init__.py
    │   ├── bert.py
    │   ├── ram.py
    │   ├── ram_plus.py
    │   ├── swin_transformer.py
    │   ├── tag2text.py
    │   ├── utils.py
    │   └── vit.py
    ├── transform.py
    └── utils
    │   ├── __init__.py
    │   ├── metrics.py
    │   └── openset_utils.py
├── requirements.txt
├── run_gradio.py
├── scripts
    └── convert_diffusers_to_sd.py
├── train_stage1.py
└── train_stage2.py


/.gitignore:
--------------------------------------------------------------------------------
 1 | __pycache__
 2 | *.ckpt
 3 | *.pth
 4 | /data
 5 | /exps
 6 | *.sh
 7 | !install_env.sh
 8 | /weights
 9 | /temp
10 | /results
11 | .ipynb_checkpoints/
12 | /TODO.txt
13 | /deprecated
14 | /temp_scripts
15 | /.vscode
16 | /runs
17 | /tests
18 | /meta_files
19 | 


--------------------------------------------------------------------------------
/assets/docs/installation_xOS.md:
--------------------------------------------------------------------------------
 1 | # Linux
 2 | Please follow the primary README.md of this repo.
 3 | 
 4 | # Windows
 5 | Windows users may stumble when installing the package `triton`. 
 6 | You can choose to run on **CPU** without `xformers` and `triton` installed.
 7 | 
 8 | To use **CUDA**, please refer to [issue#24](https://github.com/XPixelGroup/DiffBIR/issues/24) to try solve the problem of `triton` installation.
 9 | 
10 | # MacOS
11 | <!-- Currenly only CPU device is supported to run DiffBIR on Apple Silicon since most GPU acceleration packages are compatible with CUDA only. 
12 | 
13 | We are still trying to support MPS device. Stay tuned for our progress! -->
14 | 
15 | You can try to set up according to the following steps to use CPU or MPS device.
16 | 
17 | 1. Install **torch (Preview/Nighly version)**.
18 | 
19 |     ```bash
20 |     # MPS acceleration is available on MacOS 12.3+
21 |     pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
22 |     ```
23 |     Check more details in [official document](https://pytorch.org/get-started/locally/).
24 | 
25 | 2. Package `triton` and `xformers` is not needed since they work with CUDA. Remove the related packages. 
26 | 
27 |     Your requirements.txt should look like:
28 |     ```bash
29 |     # requirements.txt
30 |     pytorch_lightning==1.4.2
31 |     einops
32 |     open-clip-torch
33 |     omegaconf
34 |     torchmetrics==0.6.0
35 |     opencv-python-headless
36 |     scipy
37 |     matplotlib
38 |     lpips
39 |     gradio
40 |     chardet
41 |     transformers
42 |     facexlib
43 |     ```
44 | 
45 |     ```bash
46 |     pip install -r requirements.txt
47 |     ```
48 | 
49 | 3. [Run the inference script](https://github.com/XPixelGroup/DiffBIR#general_image_inference) and specify `--device cpu` or `--device mps`. Using MPS can accelarate your inference.
50 | 
51 |     You can specify `--tiled` and related arguments to avoid OOM. 


--------------------------------------------------------------------------------
/assets/docs/v1_models.md:
--------------------------------------------------------------------------------
1 | | Model Name | Description | HuggingFace | BaiduNetdisk | OpenXLab |
2 | | :--------- | :---------- | :---------- | :---------- | :---------- |
3 | | general_swinir_v1.ckpt | Stage1 model (SwinIR) for general image restoration. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt) | [download](https://pan.baidu.com/s/1uvSvJgcoL_Knj0h22-9TvA?pwd=v3v6) (pwd: v3v6) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_general_swinir_v1) |
4 | | general_full_v1.ckpt | Full model for general image restoration. "Full" means it contains both the stage1 and stage2 model. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt) | [download](https://pan.baidu.com/s/1gLvW1nvkJStdVAKROqaYaA?pwd=86zi) (pwd: 86zi) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_general_full_v1) |
5 | | face_swinir_v1.ckpt | Stage1 model (SwinIR) for face restoration. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt) | [download](https://pan.baidu.com/s/1cnBBC8437BJiM3q6suaK8g?pwd=xk5u) (pwd: xk5u) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_face_swinir_v1) |
6 | | face_full_v1.ckpt | Full model for face restoration. | [download](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) | [download](https://pan.baidu.com/s/1pc04xvQybkynRfzK5Y8K0Q?pwd=ov8i) (pwd: ov8i) | [download](https://download.openxlab.org.cn/models/linxinqi/DiffBIR/weight//diffbir_face_full_v1) |


--------------------------------------------------------------------------------
/assets/gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/gradio.png


--------------------------------------------------------------------------------
/assets/gradio_error_img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/gradio_error_img.png


--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/logo.png


--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/pipeline.png


--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/teaser.png


--------------------------------------------------------------------------------
/assets/visual_results/bfr1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bfr1.png


--------------------------------------------------------------------------------
/assets/visual_results/bfr2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bfr2.png


--------------------------------------------------------------------------------
/assets/visual_results/bfr4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bfr4.png


--------------------------------------------------------------------------------
/assets/visual_results/bid1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bid1.png


--------------------------------------------------------------------------------
/assets/visual_results/bid2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bid2.png


--------------------------------------------------------------------------------
/assets/visual_results/bid3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bid3.png


--------------------------------------------------------------------------------
/assets/visual_results/bsr1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr1.png


--------------------------------------------------------------------------------
/assets/visual_results/bsr2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr2.png


--------------------------------------------------------------------------------
/assets/visual_results/bsr3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr3.png


--------------------------------------------------------------------------------
/assets/visual_results/bsr4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr4.png


--------------------------------------------------------------------------------
/assets/visual_results/bsr5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr5.png


--------------------------------------------------------------------------------
/assets/visual_results/bsr6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr6.png


--------------------------------------------------------------------------------
/assets/visual_results/bsr7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/bsr7.png


--------------------------------------------------------------------------------
/assets/visual_results/tiled_sampling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/tiled_sampling.png


--------------------------------------------------------------------------------
/assets/visual_results/whole_image1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/whole_image1.png


--------------------------------------------------------------------------------
/assets/visual_results/whole_image2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/assets/visual_results/whole_image2.png


--------------------------------------------------------------------------------
/configs/inference/bsrnet.yaml:
--------------------------------------------------------------------------------
1 | target: diffbir.model.RRDBNet
2 | params:
3 |   in_nc: 3
4 |   out_nc: 3
5 |   nf: 64
6 |   nb: 23
7 |   gc: 32
8 |   sf: 4
9 | 


--------------------------------------------------------------------------------
/configs/inference/cldm.yaml:
--------------------------------------------------------------------------------
 1 | target: diffbir.model.ControlLDM
 2 | params:
 3 |   latent_scale_factor: 0.18215
 4 |   unet_cfg:
 5 |     use_checkpoint: True
 6 |     image_size: 32 # unused
 7 |     in_channels: 4
 8 |     out_channels: 4
 9 |     model_channels: 320
10 |     attention_resolutions: [ 4, 2, 1 ]
11 |     num_res_blocks: 2
12 |     channel_mult: [ 1, 2, 4, 4 ]
13 |     num_head_channels: 64 # need to fix for flash-attn
14 |     use_spatial_transformer: True
15 |     use_linear_in_transformer: True
16 |     transformer_depth: 1
17 |     context_dim: 1024
18 |     legacy: False
19 |   vae_cfg:
20 |     embed_dim: 4
21 |     ddconfig:
22 |       double_z: true
23 |       z_channels: 4
24 |       resolution: 256
25 |       in_channels: 3
26 |       out_ch: 3
27 |       ch: 128
28 |       ch_mult:
29 |       - 1
30 |       - 2
31 |       - 4
32 |       - 4
33 |       num_res_blocks: 2
34 |       attn_resolutions: []
35 |       dropout: 0.0
36 |   clip_cfg:
37 |     embed_dim: 1024
38 |     vision_cfg:
39 |       image_size: 224
40 |       layers: 32
41 |       width: 1280
42 |       head_width: 80
43 |       patch_size: 14
44 |     text_cfg:
45 |       context_length: 77
46 |       vocab_size: 49408
47 |       width: 1024
48 |       heads: 16
49 |       layers: 24
50 |     layer: "penultimate"
51 |   controlnet_cfg:
52 |     use_checkpoint: True
53 |     image_size: 32 # unused
54 |     in_channels: 4
55 |     hint_channels: 4
56 |     model_channels: 320
57 |     attention_resolutions: [ 4, 2, 1 ]
58 |     num_res_blocks: 2
59 |     channel_mult: [ 1, 2, 4, 4 ]
60 |     num_head_channels: 64 # need to fix for flash-attn
61 |     use_spatial_transformer: True
62 |     use_linear_in_transformer: True
63 |     transformer_depth: 1
64 |     context_dim: 1024
65 |     legacy: False
66 | 


--------------------------------------------------------------------------------
/configs/inference/diffusion.yaml:
--------------------------------------------------------------------------------
1 | target: diffbir.model.Diffusion
2 | params:
3 |   linear_start: 0.00085
4 |   linear_end: 0.0120
5 |   timesteps: 1000
6 | 


--------------------------------------------------------------------------------
/configs/inference/diffusion_v2.1.yaml:
--------------------------------------------------------------------------------
1 | target: diffbir.model.Diffusion
2 | params:
3 |   linear_start: 0.00085
4 |   linear_end: 0.0120
5 |   timesteps: 1000
6 |   zero_snr: True
7 |   parameterization: v
8 | 


--------------------------------------------------------------------------------
/configs/inference/scunet.yaml:
--------------------------------------------------------------------------------
1 | target: diffbir.model.SCUNet
2 | params:
3 |   in_nc: 3
4 |   config: [4,4,4,4,4,4,4]
5 |   dim: 64
6 | 


--------------------------------------------------------------------------------
/configs/inference/swinir.yaml:
--------------------------------------------------------------------------------
 1 | target: diffbir.model.SwinIR
 2 | params:
 3 |   img_size: 64
 4 |   patch_size: 1
 5 |   in_chans: 3
 6 |   embed_dim: 180
 7 |   depths: [6, 6, 6, 6, 6, 6, 6, 6]
 8 |   num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
 9 |   window_size: 8
10 |   mlp_ratio: 2
11 |   sf: 8
12 |   img_range: 1.0
13 |   upsampler: "nearest+conv"
14 |   resi_connection: "1conv"
15 |   unshuffle: True
16 |   unshuffle_scale: 8
17 | 


--------------------------------------------------------------------------------
/configs/train/train_stage1.yaml:
--------------------------------------------------------------------------------
 1 | model:
 2 |   swinir:
 3 |     target: diffbir.model.swinir.SwinIR
 4 |     params:
 5 |       img_size: 64
 6 |       patch_size: 1
 7 |       in_chans: 3
 8 |       embed_dim: 180
 9 |       depths: [6, 6, 6, 6, 6, 6, 6, 6]
10 |       num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
11 |       window_size: 8
12 |       mlp_ratio: 2
13 |       sf: 8
14 |       img_range: 1.0
15 |       upsampler: "nearest+conv"
16 |       resi_connection: "1conv"
17 |       unshuffle: True
18 |       unshuffle_scale: 8
19 | 
20 | dataset:
21 |   train:
22 |     target: diffbir.dataset.codeformer.CodeformerDataset
23 |     params:
24 |       # training file list path
25 |       file_list: 
26 |       file_backend_cfg:
27 |         target: diffbir.dataset.file_backend.HardDiskBackend
28 |       out_size: 512
29 |       crop_type: center
30 |       blur_kernel_size: 41
31 |       kernel_list: ['iso', 'aniso']
32 |       kernel_prob: [0.5, 0.5]
33 |       blur_sigma: [0.1, 12]
34 |       downsample_range: [1, 12]
35 |       noise_range: [0, 15]
36 |       jpeg_range: [30, 100]
37 |   val:
38 |     target: diffbir.dataset.codeformer.CodeformerDataset
39 |     params:
40 |       # validation file list path
41 |       file_list: 
42 |       file_backend_cfg:
43 |         target: diffbir.dataset.file_backend.HardDiskBackend
44 |       out_size: 512
45 |       crop_type: center
46 |       blur_kernel_size: 41
47 |       kernel_list: ['iso', 'aniso']
48 |       kernel_prob: [0.5, 0.5]
49 |       blur_sigma: [0.1, 12]
50 |       downsample_range: [1, 12]
51 |       noise_range: [0, 15]
52 |       jpeg_range: [30, 100]
53 | 
54 | batch_transform:
55 |   target: diffbir.dataset.batch_transform.IdentityBatchTransform
56 | 
57 | train:
58 |   # experiment directory path
59 |   exp_dir: 
60 |   learning_rate: 1e-4
61 |   # total batch size
62 |   batch_size: 96
63 |   num_workers: 
64 |   train_steps: 150000
65 |   log_every: 50
66 |   ckpt_every: 10000
67 |   image_every: 1000
68 |   val_every: 1000
69 |   resume: ~
70 | 


--------------------------------------------------------------------------------
/configs/train/train_stage2.yaml:
--------------------------------------------------------------------------------
  1 | model:
  2 |   cldm:
  3 |     target: diffbir.model.cldm.ControlLDM
  4 |     params:
  5 |       latent_scale_factor: 0.18215
  6 |       unet_cfg:
  7 |         use_checkpoint: True
  8 |         image_size: 32 # unused
  9 |         in_channels: 4
 10 |         out_channels: 4
 11 |         model_channels: 320
 12 |         attention_resolutions: [ 4, 2, 1 ]
 13 |         num_res_blocks: 2
 14 |         channel_mult: [ 1, 2, 4, 4 ]
 15 |         num_head_channels: 64 # need to fix for flash-attn
 16 |         use_spatial_transformer: True
 17 |         use_linear_in_transformer: True
 18 |         transformer_depth: 1
 19 |         context_dim: 1024
 20 |         legacy: False
 21 |       vae_cfg:
 22 |         embed_dim: 4
 23 |         ddconfig:
 24 |           double_z: true
 25 |           z_channels: 4
 26 |           resolution: 256
 27 |           in_channels: 3
 28 |           out_ch: 3
 29 |           ch: 128
 30 |           ch_mult:
 31 |           - 1
 32 |           - 2
 33 |           - 4
 34 |           - 4
 35 |           num_res_blocks: 2
 36 |           attn_resolutions: []
 37 |           dropout: 0.0
 38 |       clip_cfg:
 39 |         embed_dim: 1024
 40 |         vision_cfg:
 41 |           image_size: 224
 42 |           layers: 32
 43 |           width: 1280
 44 |           head_width: 80
 45 |           patch_size: 14
 46 |         text_cfg:
 47 |           context_length: 77
 48 |           vocab_size: 49408
 49 |           width: 1024
 50 |           heads: 16
 51 |           layers: 24
 52 |         layer: "penultimate"
 53 |       controlnet_cfg:
 54 |         use_checkpoint: True
 55 |         image_size: 32 # unused
 56 |         in_channels: 4
 57 |         hint_channels: 4
 58 |         model_channels: 320
 59 |         attention_resolutions: [ 4, 2, 1 ]
 60 |         num_res_blocks: 2
 61 |         channel_mult: [ 1, 2, 4, 4 ]
 62 |         num_head_channels: 64 # need to fix for flash-attn
 63 |         use_spatial_transformer: True
 64 |         use_linear_in_transformer: True
 65 |         transformer_depth: 1
 66 |         context_dim: 1024
 67 |         legacy: False
 68 | 
 69 |   swinir:
 70 |     target: diffbir.model.swinir.SwinIR
 71 |     params:
 72 |       img_size: 64
 73 |       patch_size: 1
 74 |       in_chans: 3
 75 |       embed_dim: 180
 76 |       depths: [6, 6, 6, 6, 6, 6, 6, 6]
 77 |       num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
 78 |       window_size: 8
 79 |       mlp_ratio: 2
 80 |       sf: 8
 81 |       img_range: 1.0
 82 |       upsampler: "nearest+conv"
 83 |       resi_connection: "1conv"
 84 |       unshuffle: True
 85 |       unshuffle_scale: 8
 86 | 
 87 |   diffusion:
 88 |     target: diffbir.model.gaussian_diffusion.Diffusion
 89 |     params:
 90 |       linear_start: 0.00085
 91 |       linear_end: 0.0120
 92 |       timesteps: 1000
 93 |       zero_snr: False
 94 |       parameterization: eps
 95 | 
 96 | dataset:
 97 |   train:
 98 |     target: diffbir.dataset.codeformer.CodeformerDataset
 99 |     params:
100 |       # training file list path
101 |       file_list: 
102 |       file_backend_cfg:
103 |         target: diffbir.dataset.file_backend.HardDiskBackend
104 |       out_size: 512
105 |       crop_type: center
106 |       blur_kernel_size: 41
107 |       kernel_list: ['iso', 'aniso']
108 |       kernel_prob: [0.5, 0.5]
109 |       blur_sigma: [0.1, 12]
110 |       downsample_range: [1, 12]
111 |       noise_range: [0, 15]
112 |       jpeg_range: [30, 100]
113 | 
114 | batch_transform:
115 |   target: diffbir.dataset.batch_transform.IdentityBatchTransform
116 | 
117 | train:
118 |   # pretrained sd v2.1 path
119 |   sd_path: 
120 |   # experiment directory path
121 |   exp_dir: 
122 |   # stage 1 swinir path.
123 |   # In our paper, we use SwinIR trained on ImageNet-1k with codeformer degradation.
124 |   swinir_path: 
125 |   learning_rate: 1e-4
126 |   # ImageNet 1k (1.3M images)
127 |   # batch size = 192, lr = 1e-4, total training steps = 25k
128 |   # Our filtered laion2b-en (15M images)
129 |   # batch size = 256, lr = 1e-4 (first 30k), 1e-5 (next 50k), total training steps = 80k
130 |   batch_size: 256
131 |   num_workers: 
132 |   train_steps: 30000
133 |   log_every: 50
134 |   ckpt_every: 10000
135 |   image_every: 1000
136 |   resume: ~
137 |   noise_aug_timestep: 0
138 | 


--------------------------------------------------------------------------------
/configs/train/train_stage2_v2.1.yaml:
--------------------------------------------------------------------------------
  1 | model:
  2 |   cldm:
  3 |     target: diffbir.model.cldm.ControlLDM
  4 |     params:
  5 |       latent_scale_factor: 0.18215
  6 |       unet_cfg:
  7 |         use_checkpoint: True
  8 |         image_size: 32 # unused
  9 |         in_channels: 4
 10 |         out_channels: 4
 11 |         model_channels: 320
 12 |         attention_resolutions: [ 4, 2, 1 ]
 13 |         num_res_blocks: 2
 14 |         channel_mult: [ 1, 2, 4, 4 ]
 15 |         num_head_channels: 64 # need to fix for flash-attn
 16 |         use_spatial_transformer: True
 17 |         use_linear_in_transformer: True
 18 |         transformer_depth: 1
 19 |         context_dim: 1024
 20 |         legacy: False
 21 |       vae_cfg:
 22 |         embed_dim: 4
 23 |         ddconfig:
 24 |           double_z: true
 25 |           z_channels: 4
 26 |           resolution: 256
 27 |           in_channels: 3
 28 |           out_ch: 3
 29 |           ch: 128
 30 |           ch_mult:
 31 |           - 1
 32 |           - 2
 33 |           - 4
 34 |           - 4
 35 |           num_res_blocks: 2
 36 |           attn_resolutions: []
 37 |           dropout: 0.0
 38 |       clip_cfg:
 39 |         embed_dim: 1024
 40 |         vision_cfg:
 41 |           image_size: 224
 42 |           layers: 32
 43 |           width: 1280
 44 |           head_width: 80
 45 |           patch_size: 14
 46 |         text_cfg:
 47 |           context_length: 77
 48 |           vocab_size: 49408
 49 |           width: 1024
 50 |           heads: 16
 51 |           layers: 24
 52 |         layer: "penultimate"
 53 |       controlnet_cfg:
 54 |         use_checkpoint: True
 55 |         image_size: 32 # unused
 56 |         in_channels: 4
 57 |         hint_channels: 4
 58 |         model_channels: 320
 59 |         attention_resolutions: [ 4, 2, 1 ]
 60 |         num_res_blocks: 2
 61 |         channel_mult: [ 1, 2, 4, 4 ]
 62 |         num_head_channels: 64 # need to fix for flash-attn
 63 |         use_spatial_transformer: True
 64 |         use_linear_in_transformer: True
 65 |         transformer_depth: 1
 66 |         context_dim: 1024
 67 |         legacy: False
 68 | 
 69 |   swinir:
 70 |     target: diffbir.model.swinir.SwinIR
 71 |     params:
 72 |       img_size: 64
 73 |       patch_size: 1
 74 |       in_chans: 3
 75 |       embed_dim: 180
 76 |       depths: [6, 6, 6, 6, 6, 6, 6, 6]
 77 |       num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
 78 |       window_size: 8
 79 |       mlp_ratio: 2
 80 |       sf: 8
 81 |       img_range: 1.0
 82 |       upsampler: "nearest+conv"
 83 |       resi_connection: "1conv"
 84 |       unshuffle: True
 85 |       unshuffle_scale: 8
 86 | 
 87 |   diffusion:
 88 |     target: diffbir.model.gaussian_diffusion.Diffusion
 89 |     params:
 90 |       linear_start: 0.00085
 91 |       linear_end: 0.0120
 92 |       timesteps: 1000
 93 |       zero_snr: True
 94 |       parameterization: v
 95 | 
 96 | dataset:
 97 |   train:
 98 |     target: diffbir.dataset.realesrgan.RealESRGANDataset
 99 |     params:
100 |       # Path to the file list.
101 |       file_metas:
102 |         # The training set is formatted as a parquet file.
103 |         # Each row contains file path, long caption and short caption of a high-quality image.
104 |         - file_list: 
105 |           image_path_key: image_path
106 |           short_prompt_key: llava_short
107 |           long_prompt_key: llava_long
108 | 
109 |       p_long_prompt: 0.2
110 | 
111 |       file_backend_cfg:
112 |         target: diffbir.dataset.file_backend.HardDiskBackend
113 | 
114 |       out_size: 512
115 |       crop_type: none
116 | 
117 |       use_hflip: false
118 |       use_rot: false
119 | 
120 |       blur_kernel_size: 21
121 |       kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
122 |       kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
123 |       sinc_prob: 0.1
124 |       blur_sigma: [0.2, 3]
125 |       betag_range: [0.5, 4]
126 |       betap_range: [1, 2]
127 | 
128 |       blur_kernel_size2: 21
129 |       kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
130 |       kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
131 |       sinc_prob2: 0.1
132 |       blur_sigma2: [0.2, 1.5]
133 |       betag_range2: [0.5, 4]
134 |       betap_range2: [1, 2]
135 | 
136 |       final_sinc_prob: 0.8
137 | 
138 |       p_empty_prompt: 0.2
139 | 
140 | batch_transform:
141 |   target: diffbir.dataset.batch_transform.RealESRGANBatchTransform
142 |   params:
143 |     use_sharpener: true
144 |     # Queue size of training pool, this should be multiples of batch_size (per GPU).
145 |     queue_size: 256
146 |     # the first degradation process
147 |     resize_prob: [0.2, 0.7, 0.1] # up, down, keep
148 |     resize_range: [0.15, 1.5]
149 |     gaussian_noise_prob: 0.5
150 |     noise_range: [1, 30]
151 |     poisson_scale_range: [0.05, 3]
152 |     gray_noise_prob: 0.4
153 |     jpeg_range: [30, 95]
154 | 
155 |     # the second degradation process
156 |     stage2_scale: 4
157 |     second_blur_prob: 0.8
158 |     resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
159 |     resize_range2: [0.3, 1.2]
160 |     gaussian_noise_prob2: 0.5
161 |     noise_range2: [1, 25]
162 |     poisson_scale_range2: [0.05, 2.5]
163 |     gray_noise_prob2: 0.4
164 |     jpeg_range2: [30, 95]
165 | 
166 | train:
167 |   # pretrained sd v2.1-zsnr path
168 |   sd_path: 
169 |   # experiment directory path
170 |   exp_dir: 
171 |   # stage 1 swinir path.
172 |   # For DiffBIR v2.1, we use SwinIR trained on ImageNet-1k with RealESRGAN degradation.
173 |   swinir_path: 
174 |   learning_rate: 1e-5
175 |   batch_size: 512
176 |   num_workers: 16
177 |   train_steps: 1000000
178 |   log_every: 100
179 |   ckpt_every: 10000
180 |   image_every: 1000
181 |   resume: 
182 |   noise_aug_timestep: 200
183 | 


--------------------------------------------------------------------------------
/diffbir/dataset/codeformer.py:
--------------------------------------------------------------------------------
  1 | from typing import Sequence, Dict, Union, List, Mapping, Any, Optional
  2 | import math
  3 | import time
  4 | import io
  5 | import random
  6 | 
  7 | import numpy as np
  8 | import cv2
  9 | from PIL import Image
 10 | import torch.utils.data as data
 11 | 
 12 | from .degradation import (
 13 |     random_mixed_kernels,
 14 |     random_add_gaussian_noise,
 15 |     random_add_jpg_compression,
 16 | )
 17 | from .utils import load_file_list, center_crop_arr, random_crop_arr
 18 | from ..utils.common import instantiate_from_config
 19 | 
 20 | 
 21 | class CodeformerDataset(data.Dataset):
 22 | 
 23 |     def __init__(
 24 |         self,
 25 |         file_list: str,
 26 |         file_backend_cfg: Mapping[str, Any],
 27 |         out_size: int,
 28 |         crop_type: str,
 29 |         blur_kernel_size: int,
 30 |         kernel_list: Sequence[str],
 31 |         kernel_prob: Sequence[float],
 32 |         blur_sigma: Sequence[float],
 33 |         downsample_range: Sequence[float],
 34 |         noise_range: Sequence[float],
 35 |         jpeg_range: Sequence[int],
 36 |     ) -> "CodeformerDataset":
 37 |         super(CodeformerDataset, self).__init__()
 38 |         self.file_list = file_list
 39 |         self.image_files = load_file_list(file_list)
 40 |         self.file_backend = instantiate_from_config(file_backend_cfg)
 41 |         self.out_size = out_size
 42 |         self.crop_type = crop_type
 43 |         assert self.crop_type in ["none", "center", "random"]
 44 |         # degradation configurations
 45 |         self.blur_kernel_size = blur_kernel_size
 46 |         self.kernel_list = kernel_list
 47 |         self.kernel_prob = kernel_prob
 48 |         self.blur_sigma = blur_sigma
 49 |         self.downsample_range = downsample_range
 50 |         self.noise_range = noise_range
 51 |         self.jpeg_range = jpeg_range
 52 | 
 53 |     def load_gt_image(
 54 |         self, image_path: str, max_retry: int = 5
 55 |     ) -> Optional[np.ndarray]:
 56 |         image_bytes = None
 57 |         while image_bytes is None:
 58 |             if max_retry == 0:
 59 |                 return None
 60 |             image_bytes = self.file_backend.get(image_path)
 61 |             max_retry -= 1
 62 |             if image_bytes is None:
 63 |                 time.sleep(0.5)
 64 |         image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 65 |         if self.crop_type != "none":
 66 |             if image.height == self.out_size and image.width == self.out_size:
 67 |                 image = np.array(image)
 68 |             else:
 69 |                 if self.crop_type == "center":
 70 |                     image = center_crop_arr(image, self.out_size)
 71 |                 elif self.crop_type == "random":
 72 |                     image = random_crop_arr(image, self.out_size, min_crop_frac=0.7)
 73 |         else:
 74 |             assert image.height == self.out_size and image.width == self.out_size
 75 |             image = np.array(image)
 76 |         # hwc, rgb, 0,255, uint8
 77 |         return image
 78 | 
 79 |     def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]:
 80 |         # load gt image
 81 |         img_gt = None
 82 |         while img_gt is None:
 83 |             # load meta file
 84 |             image_file = self.image_files[index]
 85 |             gt_path = image_file["image_path"]
 86 |             prompt = image_file["prompt"]
 87 |             img_gt = self.load_gt_image(gt_path)
 88 |             if img_gt is None:
 89 |                 print(f"filed to load {gt_path}, try another image")
 90 |                 index = random.randint(0, len(self) - 1)
 91 | 
 92 |         # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
 93 |         img_gt = (img_gt[..., ::-1] / 255.0).astype(np.float32)
 94 |         h, w, _ = img_gt.shape
 95 |         if np.random.uniform() < 0.5:
 96 |             prompt = ""
 97 | 
 98 |         # ------------------------ generate lq image ------------------------ #
 99 |         # blur
100 |         kernel = random_mixed_kernels(
101 |             self.kernel_list,
102 |             self.kernel_prob,
103 |             self.blur_kernel_size,
104 |             self.blur_sigma,
105 |             self.blur_sigma,
106 |             [-math.pi, math.pi],
107 |             noise_range=None,
108 |         )
109 |         img_lq = cv2.filter2D(img_gt, -1, kernel)
110 |         # downsample
111 |         scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
112 |         img_lq = cv2.resize(
113 |             img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR
114 |         )
115 |         # noise
116 |         if self.noise_range is not None:
117 |             img_lq = random_add_gaussian_noise(img_lq, self.noise_range)
118 |         # jpeg compression
119 |         if self.jpeg_range is not None:
120 |             img_lq = random_add_jpg_compression(img_lq, self.jpeg_range)
121 | 
122 |         # resize to original size
123 |         img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
124 | 
125 |         # BGR to RGB, [-1, 1]
126 |         gt = (img_gt[..., ::-1] * 2 - 1).astype(np.float32)
127 |         # BGR to RGB, [0, 1]
128 |         lq = img_lq[..., ::-1].astype(np.float32)
129 | 
130 |         return gt, lq, prompt
131 | 
132 |     def __len__(self) -> int:
133 |         return len(self.image_files)
134 | 


--------------------------------------------------------------------------------
/diffbir/dataset/file_backend.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) OpenMMLab. All rights reserved.
  2 | # https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
  3 | import re
  4 | from abc import ABCMeta, abstractmethod
  5 | from pathlib import Path
  6 | from typing import Optional, Union
  7 | 
  8 | 
  9 | class BaseStorageBackend(metaclass=ABCMeta):
 10 |     """Abstract class of storage backends.
 11 | 
 12 |     All backends need to implement two apis: ``get()`` and ``get_text()``.
 13 |     ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
 14 |     as texts.
 15 |     """
 16 | 
 17 |     @property
 18 |     def name(self) -> str:
 19 |         return self.__class__.__name__
 20 | 
 21 |     @abstractmethod
 22 |     def get(self, filepath: str) -> bytes:
 23 |         pass
 24 | 
 25 | 
 26 | class PetrelBackend(BaseStorageBackend):
 27 |     """Petrel storage backend (for internal use).
 28 | 
 29 |     PetrelBackend supports reading and writing data to multiple clusters.
 30 |     If the file path contains the cluster name, PetrelBackend will read data
 31 |     from specified cluster or write data to it. Otherwise, PetrelBackend will
 32 |     access the default cluster.
 33 | 
 34 |     Args:
 35 |         path_mapping (dict, optional): Path mapping dict from local path to
 36 |             Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
 37 |             ``filepath`` will be replaced by ``dst``. Default: None.
 38 |         enable_mc (bool, optional): Whether to enable memcached support.
 39 |             Default: True.
 40 |         conf_path (str, optional): Config path of Petrel client. Default: None.
 41 |             `New in version 1.7.1`.
 42 | 
 43 |     Examples:
 44 |         >>> filepath1 = 's3://path/of/file'
 45 |         >>> filepath2 = 'cluster-name:s3://path/of/file'
 46 |         >>> client = PetrelBackend()
 47 |         >>> client.get(filepath1)  # get data from default cluster
 48 |         >>> client.get(filepath2)  # get data from 'cluster-name' cluster
 49 |     """
 50 | 
 51 |     def __init__(self,
 52 |                  path_mapping: Optional[dict] = None,
 53 |                  enable_mc: bool = False,
 54 |                  conf_path: str = None):
 55 |         try:
 56 |             from petrel_client import client
 57 |         except ImportError:
 58 |             raise ImportError('Please install petrel_client to enable '
 59 |                               'PetrelBackend.')
 60 | 
 61 |         self._client = client.Client(conf_path=conf_path, enable_mc=enable_mc)
 62 |         assert isinstance(path_mapping, dict) or path_mapping is None
 63 |         self.path_mapping = path_mapping
 64 | 
 65 |     def _map_path(self, filepath: Union[str, Path]) -> str:
 66 |         """Map ``filepath`` to a string path whose prefix will be replaced by
 67 |         :attr:`self.path_mapping`.
 68 | 
 69 |         Args:
 70 |             filepath (str): Path to be mapped.
 71 |         """
 72 |         filepath = str(filepath)
 73 |         if self.path_mapping is not None:
 74 |             for k, v in self.path_mapping.items():
 75 |                 filepath = filepath.replace(k, v, 1)
 76 |         return filepath
 77 | 
 78 |     def _format_path(self, filepath: str) -> str:
 79 |         """Convert a ``filepath`` to standard format of petrel oss.
 80 | 
 81 |         If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
 82 |         environment, the ``filepath`` will be the format of
 83 |         's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
 84 |         above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
 85 | 
 86 |         Args:
 87 |             filepath (str): Path to be formatted.
 88 |         """
 89 |         return re.sub(r'\\+', '/', filepath)
 90 | 
 91 |     def get(self, filepath: Union[str, Path]) -> bytes:
 92 |         """Read data from a given ``filepath`` with 'rb' mode.
 93 | 
 94 |         Args:
 95 |             filepath (str or Path): Path to read data.
 96 | 
 97 |         Returns:
 98 |             bytes: The loaded bytes.
 99 |         """
100 |         filepath = self._map_path(filepath)
101 |         filepath = self._format_path(filepath)
102 |         value = self._client.Get(filepath)
103 |         return value
104 | 
105 | 
106 | class HardDiskBackend(BaseStorageBackend):
107 |     """Raw hard disks storage backend."""
108 | 
109 |     def get(self, filepath: Union[str, Path]) -> bytes:
110 |         """Read data from a given ``filepath`` with 'rb' mode.
111 | 
112 |         Args:
113 |             filepath (str or Path): Path to read data.
114 | 
115 |         Returns:
116 |             bytes: Expected bytes object.
117 |         """
118 |         with open(filepath, 'rb') as f:
119 |             value_buf = f.read()
120 |         return value_buf
121 | 


--------------------------------------------------------------------------------
/diffbir/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from .bsr_loop import BSRInferenceLoop
2 | from .bfr_loop import BFRInferenceLoop
3 | from .bid_loop import BIDInferenceLoop
4 | from .unaligned_bfr_loop import UnAlignedBFRInferenceLoop
5 | from .custom_loop import CustomInferenceLoop
6 | 


--------------------------------------------------------------------------------
/diffbir/inference/bfr_loop.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | from PIL import Image
 3 | from omegaconf import OmegaConf
 4 | 
 5 | from .loop import InferenceLoop, MODELS
 6 | from ..utils.common import (
 7 |     instantiate_from_config,
 8 |     load_model_from_url,
 9 |     trace_vram_usage,
10 | )
11 | from ..pipeline import SwinIRPipeline
12 | from ..model import SwinIR
13 | 
14 | 
15 | class BFRInferenceLoop(InferenceLoop):
16 | 
17 |     def load_cleaner(self) -> None:
18 |         self.cleaner: SwinIR = instantiate_from_config(
19 |             OmegaConf.load("configs/inference/swinir.yaml")
20 |         )
21 |         weight = load_model_from_url(MODELS["swinir_face"])
22 |         self.cleaner.load_state_dict(weight, strict=True)
23 |         self.cleaner.eval().to(self.args.device)
24 | 
25 |     def load_pipeline(self) -> None:
26 |         self.pipeline = SwinIRPipeline(
27 |             self.cleaner, self.cldm, self.diffusion, self.cond_fn, self.args.device
28 |         )
29 | 
30 |     def after_load_lq(self, lq: Image.Image) -> np.ndarray:
31 |         lq = lq.resize(
32 |             tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC
33 |         )
34 |         return super().after_load_lq(lq)
35 | 


--------------------------------------------------------------------------------
/diffbir/inference/bid_loop.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | from PIL import Image
 3 | from omegaconf import OmegaConf
 4 | 
 5 | from .loop import InferenceLoop, MODELS
 6 | from ..utils.common import (
 7 |     instantiate_from_config,
 8 |     load_model_from_url,
 9 |     trace_vram_usage,
10 | )
11 | from ..pipeline import (
12 |     SwinIRPipeline,
13 |     SCUNetPipeline,
14 | )
15 | from ..model import SwinIR, SCUNet
16 | 
17 | 
18 | class BIDInferenceLoop(InferenceLoop):
19 | 
20 |     def load_cleaner(self) -> None:
21 |         if self.args.version == "v1":
22 |             config = "configs/inference/swinir.yaml"
23 |             weight = MODELS["swinir_general"]
24 |         elif self.args.version == "v2":
25 |             config = "configs/inference/scunet.yaml"
26 |             weight = MODELS["scunet_psnr"]
27 |         else:
28 |             config = "configs/inference/swinir.yaml"
29 |             weight = MODELS["swinir_realesrgan"]
30 |         self.cleaner: SCUNet | SwinIR = instantiate_from_config(OmegaConf.load(config))
31 |         model_weight = load_model_from_url(weight)
32 |         self.cleaner.load_state_dict(model_weight, strict=True)
33 |         self.cleaner.eval().to(self.args.device)
34 | 
35 |     def load_pipeline(self) -> None:
36 |         if self.args.version == "v1" or self.args.version == "v2.1":
37 |             pipeline_class = SwinIRPipeline
38 |         else:
39 |             pipeline_class = SCUNetPipeline
40 |         self.pipeline = pipeline_class(
41 |             self.cleaner,
42 |             self.cldm,
43 |             self.diffusion,
44 |             self.cond_fn,
45 |             self.args.device,
46 |         )
47 | 
48 |     def after_load_lq(self, lq: Image.Image) -> np.ndarray:
49 |         lq = lq.resize(
50 |             tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC
51 |         )
52 |         return super().after_load_lq(lq)
53 | 


--------------------------------------------------------------------------------
/diffbir/inference/bsr_loop.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | from PIL import Image
 3 | from omegaconf import OmegaConf
 4 | 
 5 | from .loop import InferenceLoop, MODELS
 6 | from ..utils.common import (
 7 |     instantiate_from_config,
 8 |     load_model_from_url,
 9 |     trace_vram_usage,
10 | )
11 | from ..pipeline import (
12 |     BSRNetPipeline,
13 |     SwinIRPipeline,
14 | )
15 | from ..model import RRDBNet, SwinIR
16 | 
17 | 
18 | class BSRInferenceLoop(InferenceLoop):
19 | 
20 |     def load_cleaner(self) -> None:
21 |         if self.args.version == "v1":
22 |             config = "configs/inference/swinir.yaml"
23 |             weight = MODELS["swinir_general"]
24 |         elif self.args.version == "v2":
25 |             config = "configs/inference/bsrnet.yaml"
26 |             weight = MODELS["bsrnet"]
27 |         else:
28 |             config = "configs/inference/swinir.yaml"
29 |             weight = MODELS["swinir_realesrgan"]
30 |         self.cleaner: RRDBNet | SwinIR = instantiate_from_config(OmegaConf.load(config))
31 |         model_weight = load_model_from_url(weight)
32 |         self.cleaner.load_state_dict(model_weight, strict=True)
33 |         self.cleaner.eval().to(self.args.device)
34 | 
35 |     def load_pipeline(self) -> None:
36 |         if self.args.version == "v1" or self.args.version == "v2.1":
37 |             self.pipeline = SwinIRPipeline(
38 |                 self.cleaner,
39 |                 self.cldm,
40 |                 self.diffusion,
41 |                 self.cond_fn,
42 |                 self.args.device,
43 |             )
44 |         else:
45 |             self.pipeline = BSRNetPipeline(
46 |                 self.cleaner,
47 |                 self.cldm,
48 |                 self.diffusion,
49 |                 self.cond_fn,
50 |                 self.args.device,
51 |                 self.args.upscale,
52 |             )
53 | 
54 |     def after_load_lq(self, lq: Image.Image) -> np.ndarray:
55 |         if self.args.version == "v1" or self.args.version == "v2.1":
56 |             lq = lq.resize(
57 |                 tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC
58 |             )
59 |         return super().after_load_lq(lq)
60 | 


--------------------------------------------------------------------------------
/diffbir/inference/custom_loop.py:
--------------------------------------------------------------------------------
 1 | from argparse import Namespace
 2 | 
 3 | import numpy as np
 4 | from PIL import Image
 5 | from omegaconf import OmegaConf
 6 | import torch
 7 | 
 8 | from .loop import InferenceLoop
 9 | from ..utils.common import (
10 |     instantiate_from_config,
11 |     VRAMPeakMonitor,
12 | )
13 | from ..pipeline import (
14 |     SwinIRPipeline,
15 |     Pipeline,
16 | )
17 | from ..model import SwinIR, ControlLDM, Diffusion
18 | 
19 | 
20 | class CustomInferenceLoop(InferenceLoop):
21 | 
22 |     def __init__(self, args: Namespace) -> "InferenceLoop":
23 |         self.args = args
24 |         self.train_cfg = OmegaConf.load(args.train_cfg)
25 |         self.loop_ctx = {}
26 |         self.pipeline: Pipeline = None
27 |         with VRAMPeakMonitor("loading cleaner model"):
28 |             self.load_cleaner()
29 |         with VRAMPeakMonitor("loading cldm model"):
30 |             self.load_cldm()
31 |         self.load_cond_fn()
32 |         self.load_pipeline()
33 |         with VRAMPeakMonitor("loading captioner"):
34 |             self.load_captioner()
35 | 
36 |     def load_cldm(self) -> None:
37 |         self.cldm: ControlLDM = instantiate_from_config(self.train_cfg.model.cldm)
38 | 
39 |         # load pre-trained SD weight
40 |         sd_weight = torch.load(self.train_cfg.train.sd_path, map_location="cpu")
41 |         sd_weight = sd_weight["state_dict"]
42 |         unused, missing = self.cldm.load_pretrained_sd(sd_weight)
43 |         print(
44 |             f"load pretrained stable diffusion, "
45 |             f"unused weights: {unused}, missing weights: {missing}"
46 |         )
47 |         # load controlnet weight
48 |         control_weight = torch.load(self.args.ckpt, map_location="cpu")
49 |         self.cldm.load_controlnet_from_ckpt(control_weight)
50 |         print(f"load controlnet weight")
51 |         self.cldm.eval().to(self.args.device)
52 |         cast_type = {
53 |             "fp32": torch.float32,
54 |             "fp16": torch.float16,
55 |             "bf16": torch.bfloat16,
56 |         }[self.args.precision]
57 |         self.cldm.cast_dtype(cast_type)
58 | 
59 |         # load diffusion
60 |         self.diffusion: Diffusion = instantiate_from_config(
61 |             self.train_cfg.model.diffusion
62 |         )
63 |         self.diffusion.to(self.args.device)
64 | 
65 |     def load_cleaner(self) -> None:
66 |         # NOTE: Use SwinIR as stage-1 model. Change it if you want.
67 |         self.cleaner: SwinIR = instantiate_from_config(self.train_cfg.model.swinir)
68 |         weight = torch.load(self.train_cfg.train.swinir_path, map_location="cpu")
69 |         if "state_dict" in weight:
70 |             weight = weight["state_dict"]
71 |         weight = {
72 |             (k[len("module.") :] if k.startswith("module.") else k): v
73 |             for k, v in weight.items()
74 |         }
75 |         self.cleaner.load_state_dict(weight, strict=True)
76 |         self.cleaner.eval().to(self.args.device)
77 | 
78 |     def load_pipeline(self) -> None:
79 |         # NOTE: Choose the correct pipeline if SwinIR is not your stage-1 model.
80 |         self.pipeline = SwinIRPipeline(
81 |             self.cleaner,
82 |             self.cldm,
83 |             self.diffusion,
84 |             self.cond_fn,
85 |             self.args.device,
86 |         )
87 | 
88 |     def after_load_lq(self, lq: Image.Image) -> np.ndarray:
89 |         # For SwinIRPipeline, upscaling is achieved by resizing input LQ.
90 |         lq = lq.resize(
91 |             tuple(int(x * self.args.upscale) for x in lq.size), Image.BICUBIC
92 |         )
93 |         return super().after_load_lq(lq)
94 | 


--------------------------------------------------------------------------------
/diffbir/inference/pretrained_models.py:
--------------------------------------------------------------------------------
 1 | """
 2 | All models used in inference:
 3 | - DiffBIR-v1
 4 |   All tasks share the same pre-trained stable diffusion v2.1 (sd_v2.1).
 5 | -- BSR task
 6 |     stage-1 model (swinir_general): SwinIR trained on ImageNet-1k with Real-ESRGAN degradation.
 7 |     stage-2 model (v1_general): IRControlNet trained on ImageNet-1k.
 8 | -- BFR task
 9 |     stage-1 model (swinir_face): SwinIR pre-trained on FFHQ, borrowed from DifFace (https://github.com/zsyOAOA/DifFace.git)
10 |     stage-2 model (v1_face): IRControlNet trained on FFHQ.
11 | -- BID task
12 |     The same as BSR task.
13 | 
14 | - DiffBIR-v2
15 |   All tasks share the same pre-trained stable diffusion v2.1 (sd_v2.1).
16 |   All tasks share the same stage-2 model (v2).
17 | -- BSR task
18 |     stage-1 model (bsrnet): BSRNet borrowed from BSRGAN (https://github.com/cszn/BSRGAN.git).
19 | -- BFR task
20 |     stage-1 model (swinir_face): SwinIR pre-trained on FFHQ, borrowed from DifFace (https://github.com/zsyOAOA/DifFace.git)
21 | -- BID task
22 |     stage-1 model (scunet_psnr): SCUNet-PSNR borrowed from SCUNet (https://github.com/cszn/SCUNet.git)
23 | 
24 | - DiffBIR-v2.1
25 |   All tasks share the same pre-trained stable diffusion v2.1-zsnr (sd_v2.1_zsnr).
26 |   All tasks share the same stage-2 model (v2.1).
27 | -- BSR task
28 |     stage-1 model (swinir_realesrgan): SwinIR trained on ImageNet-1k with Real-ESRGAN degradation.
29 | -- BFR task
30 |     stage-1 model (swinir_face): SwinIR pre-trained on FFHQ, borrowed from DifFace (https://github.com/zsyOAOA/DifFace.git)
31 | -- BID task
32 |     The same as BSR task.
33 | """
34 | MODELS = {
35 |     # --------------- stage-1 model weights ---------------
36 |     "bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth",
37 |     # the following checkpoint is up-to-date, we use the old version in our paper
38 |     # "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth",
39 |     "swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt",
40 |     "scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth",
41 |     "swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt",
42 |     "swinir_realesrgan": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/realesrgan_s4_swinir_100k.pth",
43 |     # --------------- pre-trained stable diffusion weights ---------------
44 |     "sd_v2.1": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt",
45 |     "sd_v2.1_zsnr": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/sd2.1-base-zsnr-laionaes5.ckpt",
46 |     # --------------- IRControlNet weights ---------------
47 |     "v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth",
48 |     "v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth",
49 |     "v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth",
50 |     "v2.1": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/DiffBIR_v2.1.pt",
51 | }
52 | 


--------------------------------------------------------------------------------
/diffbir/model/__init__.py:
--------------------------------------------------------------------------------
 1 | from . import config
 2 | 
 3 | from .controlnet import ControlledUnetModel, ControlNet
 4 | from .vae import AutoencoderKL
 5 | from .clip import FrozenOpenCLIPEmbedder
 6 | 
 7 | from .cldm import ControlLDM
 8 | from .gaussian_diffusion import Diffusion
 9 | 
10 | from .swinir import SwinIR
11 | from .bsrnet import RRDBNet
12 | from .scunet import SCUNet
13 | 


--------------------------------------------------------------------------------
/diffbir/model/bsrnet.py:
--------------------------------------------------------------------------------
  1 | # From BSRGAN
  2 | import functools
  3 | import torch
  4 | import torch.nn as nn
  5 | import torch.nn.functional as F
  6 | import torch.nn.init as init
  7 | 
  8 | 
  9 | def initialize_weights(net_l, scale=1):
 10 |     if not isinstance(net_l, list):
 11 |         net_l = [net_l]
 12 |     for net in net_l:
 13 |         for m in net.modules():
 14 |             if isinstance(m, nn.Conv2d):
 15 |                 init.kaiming_normal_(m.weight, a=0, mode='fan_in')
 16 |                 m.weight.data *= scale  # for residual block
 17 |                 if m.bias is not None:
 18 |                     m.bias.data.zero_()
 19 |             elif isinstance(m, nn.Linear):
 20 |                 init.kaiming_normal_(m.weight, a=0, mode='fan_in')
 21 |                 m.weight.data *= scale
 22 |                 if m.bias is not None:
 23 |                     m.bias.data.zero_()
 24 |             elif isinstance(m, nn.BatchNorm2d):
 25 |                 init.constant_(m.weight, 1)
 26 |                 init.constant_(m.bias.data, 0.0)
 27 | 
 28 | 
 29 | def make_layer(block, n_layers):
 30 |     layers = []
 31 |     for _ in range(n_layers):
 32 |         layers.append(block())
 33 |     return nn.Sequential(*layers)
 34 | 
 35 | 
 36 | class ResidualDenseBlock_5C(nn.Module):
 37 |     def __init__(self, nf=64, gc=32, bias=True):
 38 |         super(ResidualDenseBlock_5C, self).__init__()
 39 |         # gc: growth channel, i.e. intermediate channels
 40 |         self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
 41 |         self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
 42 |         self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
 43 |         self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
 44 |         self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
 45 |         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
 46 | 
 47 |         # initialization
 48 |         initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
 49 | 
 50 |     def forward(self, x):
 51 |         x1 = self.lrelu(self.conv1(x))
 52 |         x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
 53 |         x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
 54 |         x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
 55 |         x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
 56 |         return x5 * 0.2 + x
 57 | 
 58 | 
 59 | class RRDB(nn.Module):
 60 |     '''Residual in Residual Dense Block'''
 61 | 
 62 |     def __init__(self, nf, gc=32):
 63 |         super(RRDB, self).__init__()
 64 |         self.RDB1 = ResidualDenseBlock_5C(nf, gc)
 65 |         self.RDB2 = ResidualDenseBlock_5C(nf, gc)
 66 |         self.RDB3 = ResidualDenseBlock_5C(nf, gc)
 67 | 
 68 |     def forward(self, x):
 69 |         out = self.RDB1(x)
 70 |         out = self.RDB2(out)
 71 |         out = self.RDB3(out)
 72 |         return out * 0.2 + x
 73 | 
 74 | 
 75 | class RRDBNet(nn.Module):
 76 |     def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
 77 |         super(RRDBNet, self).__init__()
 78 |         RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
 79 |         self.sf = sf
 80 |         print([in_nc, out_nc, nf, nb, gc, sf])
 81 | 
 82 |         self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
 83 |         self.RRDB_trunk = make_layer(RRDB_block_f, nb)
 84 |         self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
 85 |         #### upsampling
 86 |         self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
 87 |         if self.sf==4:
 88 |             self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
 89 |         self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
 90 |         self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
 91 | 
 92 |         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
 93 | 
 94 |     def forward(self, x):
 95 |         fea = self.conv_first(x)
 96 |         trunk = self.trunk_conv(self.RRDB_trunk(fea))
 97 |         fea = fea + trunk
 98 | 
 99 |         fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
100 |         if self.sf==4:
101 |             fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
102 |         out = self.conv_last(self.lrelu(self.HRconv(fea)))
103 | 
104 |         return out
105 | 


--------------------------------------------------------------------------------
/diffbir/model/clip.py:
--------------------------------------------------------------------------------
 1 | from typing import List
 2 | import torch
 3 | import torch.nn as nn
 4 | from torch.utils.checkpoint import checkpoint
 5 | from .open_clip import CLIP, tokenize
 6 | 
 7 | 
 8 | class FrozenOpenCLIPEmbedder(nn.Module):
 9 |     """
10 |     Uses the OpenCLIP transformer encoder for text
11 |     """
12 |     LAYERS = [
13 |         #"pooled",
14 |         "last",
15 |         "penultimate"
16 |     ]
17 |     def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"):
18 |         super().__init__()
19 |         assert layer in self.LAYERS
20 |         # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
21 |         model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg))
22 |         del model.visual
23 |         self.model = model
24 |         
25 |         self.layer = layer
26 |         if self.layer == "last":
27 |             self.layer_idx = 0
28 |         elif self.layer == "penultimate":
29 |             self.layer_idx = 1
30 |         else:
31 |             raise NotImplementedError()
32 | 
33 |     def forward(self, tokens):
34 |         z = self.encode_with_transformer(tokens)
35 |         return z
36 | 
37 |     def encode_with_transformer(self, text):
38 |         x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]
39 |         x = x + self.model.positional_embedding
40 |         x = x.permute(1, 0, 2)  # NLD -> LND
41 |         x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
42 |         x = x.permute(1, 0, 2)  # LND -> NLD
43 |         x = self.model.ln_final(x)
44 |         return x
45 | 
46 |     def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
47 |         for i, r in enumerate(self.model.transformer.resblocks):
48 |             if i == len(self.model.transformer.resblocks) - self.layer_idx:
49 |                 break
50 |             if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
51 |                 x = checkpoint(r, x, attn_mask)
52 |             else:
53 |                 x = r(x, attn_mask=attn_mask)
54 |         return x
55 | 
56 |     def encode(self, text: List[str]) -> torch.Tensor:
57 |         # convert a batch of text to tensor
58 |         tokens = tokenize(text)
59 |         # move tensor to model device
60 |         tokens = tokens.to(next(self.model.parameters()).device)
61 |         return self(tokens)
62 | 


--------------------------------------------------------------------------------
/diffbir/model/config.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | from typing import Optional, Literal
 3 | from types import ModuleType
 4 | import enum
 5 | from packaging import version
 6 | 
 7 | import torch
 8 | 
 9 | # collect system information
10 | if version.parse(torch.__version__) >= version.parse("2.0.0"):
11 |     SDP_IS_AVAILABLE = True
12 | else:
13 |     SDP_IS_AVAILABLE = False
14 | 
15 | try:
16 |     import xformers
17 |     import xformers.ops
18 |     XFORMERS_IS_AVAILBLE = True
19 | except:
20 |     XFORMERS_IS_AVAILBLE = False
21 | 
22 | 
23 | class AttnMode(enum.Enum):
24 |     SDP = 0
25 |     XFORMERS = 1
26 |     VANILLA = 2
27 | 
28 | 
29 | class Config:
30 |     xformers: Optional[ModuleType] = None
31 |     attn_mode: AttnMode = AttnMode.VANILLA
32 | 
33 | 
34 | # initialize attention mode
35 | if XFORMERS_IS_AVAILBLE:
36 |     Config.attn_mode = AttnMode.XFORMERS
37 |     print(f"use xformers attention as default")
38 | elif SDP_IS_AVAILABLE:
39 |     Config.attn_mode = AttnMode.SDP
40 |     print(f"use sdp attention as default")
41 | else:
42 |     print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")
43 | 
44 | if XFORMERS_IS_AVAILBLE:
45 |     Config.xformers = xformers
46 | 
47 | 
48 | # user-specified attention mode
49 | ATTN_MODE = os.environ.get("ATTN_MODE", None)
50 | if ATTN_MODE is not None:
51 |     assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
52 |     if ATTN_MODE == "sdp":
53 |         assert SDP_IS_AVAILABLE
54 |         Config.attn_mode = AttnMode.SDP
55 |     elif ATTN_MODE == "xformers":
56 |         assert XFORMERS_IS_AVAILBLE
57 |         Config.attn_mode = AttnMode.XFORMERS
58 |     else:
59 |         Config.attn_mode = AttnMode.VANILLA
60 |     print(f"set attention mode to {ATTN_MODE}")
61 | else:
62 |     print("keep default attention mode")
63 | 


--------------------------------------------------------------------------------
/diffbir/model/distributions.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import numpy as np
 3 | 
 4 | 
 5 | class AbstractDistribution:
 6 |     def sample(self):
 7 |         raise NotImplementedError()
 8 | 
 9 |     def mode(self):
10 |         raise NotImplementedError()
11 | 
12 | 
13 | class DiracDistribution(AbstractDistribution):
14 |     def __init__(self, value):
15 |         self.value = value
16 | 
17 |     def sample(self):
18 |         return self.value
19 | 
20 |     def mode(self):
21 |         return self.value
22 | 
23 | 
24 | class DiagonalGaussianDistribution(object):
25 |     def __init__(self, parameters, deterministic=False):
26 |         self.parameters = parameters
27 |         self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 |         self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 |         self.deterministic = deterministic
30 |         self.std = torch.exp(0.5 * self.logvar)
31 |         self.var = torch.exp(self.logvar)
32 |         if self.deterministic:
33 |             self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 | 
35 |     def sample(self):
36 |         x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 |         return x
38 | 
39 |     def kl(self, other=None):
40 |         if self.deterministic:
41 |             return torch.Tensor([0.])
42 |         else:
43 |             if other is None:
44 |                 return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 |                                        + self.var - 1.0 - self.logvar,
46 |                                        dim=[1, 2, 3])
47 |             else:
48 |                 return 0.5 * torch.sum(
49 |                     torch.pow(self.mean - other.mean, 2) / other.var
50 |                     + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 |                     dim=[1, 2, 3])
52 | 
53 |     def nll(self, sample, dims=[1,2,3]):
54 |         if self.deterministic:
55 |             return torch.Tensor([0.])
56 |         logtwopi = np.log(2.0 * np.pi)
57 |         return 0.5 * torch.sum(
58 |             logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 |             dim=dims)
60 | 
61 |     def mode(self):
62 |         return self.mean
63 | 
64 | 
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 |     """
67 |     source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 |     Compute the KL divergence between two gaussians.
69 |     Shapes are automatically broadcasted, so batches can be compared to
70 |     scalars, among other use cases.
71 |     """
72 |     tensor = None
73 |     for obj in (mean1, logvar1, mean2, logvar2):
74 |         if isinstance(obj, torch.Tensor):
75 |             tensor = obj
76 |             break
77 |     assert tensor is not None, "at least one argument must be a Tensor"
78 | 
79 |     # Force variances to be Tensors. Broadcasting helps convert scalars to
80 |     # Tensors, but it does not work for torch.exp().
81 |     logvar1, logvar2 = [
82 |         x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 |         for x in (logvar1, logvar2)
84 |     ]
85 | 
86 |     return 0.5 * (
87 |         -1.0
88 |         + logvar2
89 |         - logvar1
90 |         + torch.exp(logvar1 - logvar2)
91 |         + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 |     )
93 | 


--------------------------------------------------------------------------------
/diffbir/model/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import CLIP
2 | from .tokenizer import tokenize
3 | 
4 | __all__ = ["CLIP", "tokenize"]
5 | 


--------------------------------------------------------------------------------
/diffbir/model/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/diffbir/model/open_clip/bpe_simple_vocab_16e6.txt.gz


--------------------------------------------------------------------------------
/diffbir/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | from .spaced_sampler import SpacedSampler
2 | from .ddim_sampler import DDIMSampler
3 | from .dpms_sampler import DPMSolverSampler
4 | from .edm_sampler import EDMSampler
5 | 


--------------------------------------------------------------------------------
/diffbir/sampler/dpms_sampler.py:
--------------------------------------------------------------------------------
  1 | from typing import Tuple, Dict, Literal
  2 | 
  3 | import torch
  4 | import numpy as np
  5 | 
  6 | from .sampler import Sampler
  7 | from .dpm_solver_pytorch import (
  8 |     NoiseScheduleVP,
  9 |     model_wrapper,
 10 |     DPM_Solver,
 11 | )
 12 | from ..utils.cond_fn import Guidance
 13 | from ..model.cldm import ControlLDM
 14 | from ..utils.common import make_tiled_fn, trace_vram_usage
 15 | 
 16 | 
 17 | class DPMSolverSampler(Sampler):
 18 | 
 19 |     def __init__(
 20 |         self,
 21 |         betas: np.ndarray,
 22 |         parameterization: Literal["eps", "v"],
 23 |         rescale_cfg: bool,
 24 |         model_spec: str,
 25 |     ) -> "DPMSolverSampler":
 26 |         super().__init__(betas, parameterization, rescale_cfg)
 27 |         if parameterization == "eps":
 28 |             self.model_type = "noise"
 29 |         elif parameterization == "v":
 30 |             self.model_type = "v"
 31 |         else:
 32 |             raise ValueError(parameterization)
 33 |         # parse samping args from string
 34 |         # e.g. dpm++_s2 => solver_type=dpmsolver++, method=singlestep, order=2
 35 |         solver_type, (method, order) = model_spec.split("_")
 36 |         self.solver_type = {"dpm": "dpmsolver", "dpm++": "dpmsolver++"}[solver_type]
 37 |         self.method = {"s": "singlestep", "m": "multistep"}[method]
 38 |         self.order = {"1": 1, "2": 2, "3": 3}[order]
 39 |         self.register("betas", betas)
 40 | 
 41 |     @torch.no_grad()
 42 |     def sample(
 43 |         self,
 44 |         model: ControlLDM,
 45 |         device: str,
 46 |         steps: int,
 47 |         x_size: Tuple[int],
 48 |         cond: Dict[str, torch.Tensor],
 49 |         uncond: Dict[str, torch.Tensor],
 50 |         cfg_scale: float,
 51 |         tiled: bool = False,
 52 |         tile_size: int = -1,
 53 |         tile_stride: int = -1,
 54 |         x_T: torch.Tensor | None = None,
 55 |         progress: bool = True,
 56 |     ) -> torch.Tensor:
 57 |         if tiled:
 58 |             forward = model.forward
 59 |             model.forward = make_tiled_fn(
 60 |                 lambda x_tile, t, cond, hi, hi_end, wi, wi_end: (
 61 |                     forward(
 62 |                         x_tile,
 63 |                         t,
 64 |                         {
 65 |                             "c_txt": cond["c_txt"],
 66 |                             "c_img": cond["c_img"][..., hi:hi_end, wi:wi_end],
 67 |                         },
 68 |                     )
 69 |                 ),
 70 |                 tile_size,
 71 |                 tile_stride,
 72 |             )
 73 |         if x_T is None:
 74 |             x_T = torch.randn(x_size, device=device, dtype=torch.float32)
 75 |         x = x_T
 76 | 
 77 |         noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas)
 78 |         model_fn = model_wrapper(
 79 |             lambda x, t, c: model(x, t, c),
 80 |             noise_schedule,
 81 |             model_type=self.model_type,
 82 |             guidance_type="classifier-free",
 83 |             condition=cond,
 84 |             unconditional_condition=uncond,
 85 |             guidance_scale=cfg_scale,
 86 |             cfg_rescale=self.rescale_cfg,
 87 |         )
 88 |         dpm_solver = DPM_Solver(
 89 |             model_fn, noise_schedule, algorithm_type=self.solver_type
 90 |         )
 91 |         x = dpm_solver.sample(
 92 |             x_T,
 93 |             steps=steps,
 94 |             skip_type="time_uniform",
 95 |             method=self.method,
 96 |             order=self.order,
 97 |             return_intermediate=False,
 98 |         )
 99 |         if tiled:
100 |             model.forward = forward
101 |         return x
102 | 


--------------------------------------------------------------------------------
/diffbir/sampler/sampler.py:
--------------------------------------------------------------------------------
 1 | from typing import Literal, overload, Dict, Optional, Tuple
 2 | import math
 3 | import torch
 4 | from torch import nn
 5 | import numpy as np
 6 | 
 7 | from ..model.cldm import ControlLDM
 8 | 
 9 | 
10 | class Sampler(nn.Module):
11 | 
12 |     def __init__(
13 |         self,
14 |         betas: np.ndarray,
15 |         parameterization: Literal["eps", "v"],
16 |         rescale_cfg: bool,
17 |     ) -> "Sampler":
18 |         super().__init__()
19 |         self.num_timesteps = len(betas)
20 |         self.training_betas = betas
21 |         self.training_alphas_cumprod = np.cumprod(1.0 - betas, axis=0)
22 |         self.context = {}
23 |         self.parameterization = parameterization
24 |         self.rescale_cfg = rescale_cfg
25 | 
26 |     def register(
27 |         self, name: str, value: np.ndarray, dtype: torch.dtype = torch.float32
28 |     ) -> None:
29 |         self.register_buffer(name, torch.tensor(value, dtype=dtype))
30 | 
31 |     def get_cfg_scale(self, default_cfg_scale: float, model_t: int) -> float:
32 |         if self.rescale_cfg and default_cfg_scale > 1:
33 |             cfg_scale = 1 + default_cfg_scale * (
34 |                 (1 - math.cos(math.pi * ((1000 - model_t) / 1000) ** 5.0)) / 2
35 |             )
36 |         else:
37 |             cfg_scale = default_cfg_scale
38 |         return cfg_scale
39 | 
40 |     @overload
41 |     def sample(
42 |         self,
43 |         model: ControlLDM,
44 |         device: str,
45 |         steps: int,
46 |         x_size: Tuple[int],
47 |         cond: Dict[str, torch.Tensor],
48 |         uncond: Dict[str, torch.Tensor],
49 |         cfg_scale: float,
50 |         tiled: bool = False,
51 |         tile_size: int = -1,
52 |         tile_stride: int = -1,
53 |         x_T: Optional[torch.Tensor] = None,
54 |         progress: bool = True,
55 |     ) -> torch.Tensor: ...
56 | 


--------------------------------------------------------------------------------
/diffbir/utils/cond_fn.py:
--------------------------------------------------------------------------------
  1 | from typing import overload, Tuple
  2 | import torch
  3 | from torch.nn import functional as F
  4 | 
  5 | 
  6 | class Guidance:
  7 | 
  8 |     def __init__(
  9 |         self, scale: float, t_start: int, t_stop: int, space: str, repeat: int
 10 |     ) -> "Guidance":
 11 |         """
 12 |         Initialize restoration guidance.
 13 | 
 14 |         Args:
 15 |             scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale,
 16 |                 the closer the final result will be to the output of the first stage model.
 17 |             t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling
 18 |                 process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
 19 |             space (str): The data space for computing loss function (rgb or latent).
 20 | 
 21 |         Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
 22 |         Thanks for their work!
 23 |         """
 24 |         self.scale = scale * 3000
 25 |         self.t_start = t_start
 26 |         self.t_stop = t_stop
 27 |         self.target = None
 28 |         self.space = space
 29 |         self.repeat = repeat
 30 | 
 31 |     def load_target(self, target: torch.Tensor) -> None:
 32 |         self.target = target
 33 | 
 34 |     def __call__(
 35 |         self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int
 36 |     ) -> Tuple[torch.Tensor, float]:
 37 |         # avoid propagating gradient out of this scope
 38 |         pred_x0 = pred_x0.detach().clone()
 39 |         target_x0 = target_x0.detach().clone()
 40 |         return self._forward(target_x0, pred_x0, t)
 41 | 
 42 |     @overload
 43 |     def _forward(
 44 |         self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int
 45 |     ) -> Tuple[torch.Tensor, float]: ...
 46 | 
 47 | 
 48 | class MSEGuidance(Guidance):
 49 | 
 50 |     def _forward(
 51 |         self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int
 52 |     ) -> Tuple[torch.Tensor, float]:
 53 |         # inputs: [-1, 1], nchw, rgb
 54 |         with torch.enable_grad():
 55 |             pred_x0.requires_grad_(True)
 56 |             loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
 57 |         scale = self.scale
 58 |         g = -torch.autograd.grad(loss, pred_x0)[0] * scale
 59 |         return g, loss.item()
 60 | 
 61 | 
 62 | class WeightedMSEGuidance(Guidance):
 63 | 
 64 |     def _get_weight(self, target: torch.Tensor) -> torch.Tensor:
 65 |         # convert RGB to G
 66 |         rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1)
 67 |         target = torch.sum(
 68 |             target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True
 69 |         )
 70 |         # initialize sobel kernel in x and y axis
 71 |         G_x = [[1, 0, -1], [2, 0, -2], [1, 0, -1]]
 72 |         G_y = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
 73 |         G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None]
 74 |         G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None]
 75 |         G = torch.stack((G_x, G_y))
 76 | 
 77 |         target = F.pad(target, (1, 1, 1, 1), mode="replicate")  # padding = 1
 78 |         grad = F.conv2d(target, G, stride=1)
 79 |         mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt()
 80 | 
 81 |         n, c, h, w = mag.size()
 82 |         block_size = 2
 83 |         blocks = (
 84 |             mag.view(n, c, h // block_size, block_size, w // block_size, block_size)
 85 |             .permute(0, 1, 2, 4, 3, 5)
 86 |             .contiguous()
 87 |         )
 88 |         block_mean = (
 89 |             blocks.sum(dim=(-2, -1), keepdim=True)
 90 |             .tanh()
 91 |             .repeat(1, 1, 1, 1, block_size, block_size)
 92 |             .permute(0, 1, 2, 4, 3, 5)
 93 |             .contiguous()
 94 |         )
 95 |         block_mean = block_mean.view(n, c, h, w)
 96 |         weight_map = 1 - block_mean
 97 | 
 98 |         return weight_map
 99 | 
100 |     def _forward(
101 |         self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int
102 |     ) -> Tuple[torch.Tensor, float]:
103 |         # inputs: [-1, 1], nchw, rgb
104 |         with torch.no_grad():
105 |             w = self._get_weight((target_x0 + 1) / 2)
106 |         with torch.enable_grad():
107 |             pred_x0.requires_grad_(True)
108 |             loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum()
109 |         scale = self.scale
110 |         g = -torch.autograd.grad(loss, pred_x0)[0] * scale
111 |         return g, loss.item()
112 | 


--------------------------------------------------------------------------------
/diffbir/utils/tilevae/__init__.py:
--------------------------------------------------------------------------------
1 | from .tilevae import VAEHook
2 | 


--------------------------------------------------------------------------------
/diffbir/utils/tilevae/attn.py:
--------------------------------------------------------------------------------
  1 | '''
  2 |     This file is modified from the sd_hijack_optimizations.py to remove the residual and norm part,
  3 |     So that the Tiled VAE can support other types of attention.
  4 | '''
  5 | import torch
  6 | from torch.nn import functional as F
  7 | from einops import rearrange
  8 | 
  9 | from ...model.config import Config, AttnMode
 10 | 
 11 | 
 12 | def get_attn_func():
 13 |     return {
 14 |         AttnMode.VANILLA: forward,
 15 |         AttnMode.XFORMERS: xformers_forward,
 16 |         AttnMode.SDP: sdp_forward,
 17 |     }[Config.attn_mode]
 18 | 
 19 | # The following functions are all copied from modules.sd_hijack_optimizations
 20 | # However, the residual & normalization are removed and computed separately.
 21 | 
 22 | def forward(self, x):
 23 |     h_ = x
 24 |     # h_ = self.norm(h_)
 25 |     q = self.q(h_)
 26 |     k = self.k(h_)
 27 |     v = self.v(h_)
 28 | 
 29 |     # compute attention
 30 |     b, c, h, w = q.shape
 31 |     q = q.reshape(b, c, h * w)
 32 |     q = q.permute(0, 2, 1)  # b,hw,c
 33 |     k = k.reshape(b, c, h * w)  # b,c,hw
 34 |     w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
 35 |     w_ = w_ * (int(c) ** (-0.5))
 36 |     w_ = torch.nn.functional.softmax(w_, dim=2)
 37 | 
 38 |     # attend to values
 39 |     v = v.reshape(b, c, h * w)
 40 |     w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
 41 |     h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
 42 |     h_ = h_.reshape(b, c, h, w)
 43 | 
 44 |     h_ = self.proj_out(h_)
 45 | 
 46 |     # return x + h_
 47 |     return h_
 48 | 
 49 | 
 50 | def xformers_forward(self, x):
 51 |     h_ = x
 52 |     # h_ = self.norm(h_)
 53 |     q = self.q(h_)
 54 |     k = self.k(h_)
 55 |     v = self.v(h_)
 56 | 
 57 |     # compute attention
 58 |     B, C, H, W = q.shape
 59 |     q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
 60 | 
 61 |     q, k, v = map(
 62 |         lambda t: t.unsqueeze(3)
 63 |         .reshape(B, t.shape[1], 1, C)
 64 |         .permute(0, 2, 1, 3)
 65 |         .reshape(B * 1, t.shape[1], C)
 66 |         .contiguous(),
 67 |         (q, k, v),
 68 |     )
 69 |     out = Config.xformers.ops.memory_efficient_attention(
 70 |         q, k, v, attn_bias=None, op=self.attention_op
 71 |     )
 72 | 
 73 |     out = (
 74 |         out.unsqueeze(0)
 75 |         .reshape(B, 1, out.shape[1], C)
 76 |         .permute(0, 2, 1, 3)
 77 |         .reshape(B, out.shape[1], C)
 78 |     )
 79 |     out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
 80 |     out = self.proj_out(out)
 81 |     # return x + out
 82 |     return out
 83 | 
 84 | 
 85 | def sdp_forward(self, x):
 86 |     h_ = x
 87 |     # h_ = self.norm(h_)
 88 |     q = self.q(h_)
 89 |     k = self.k(h_)
 90 |     v = self.v(h_)
 91 | 
 92 |     # compute attention
 93 |     B, C, H, W = q.shape
 94 |     q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
 95 | 
 96 |     q, k, v = map(
 97 |         lambda t: t.unsqueeze(3)
 98 |         .reshape(B, t.shape[1], 1, C)
 99 |         .permute(0, 2, 1, 3)
100 |         .reshape(B * 1, t.shape[1], C)
101 |         .contiguous(),
102 |         (q, k, v),
103 |     )
104 |     out = F.scaled_dot_product_attention(q, k, v)
105 | 
106 |     out = (
107 |         out.unsqueeze(0)
108 |         .reshape(B, 1, out.shape[1], C)
109 |         .permute(0, 2, 1, 3)
110 |         .reshape(B, out.shape[1], C)
111 |     )
112 |     out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
113 |     out = self.proj_out(out)
114 |     # return x + out
115 |     return out
116 | 


--------------------------------------------------------------------------------
/inputs/demo/bfr/aligned/0229.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/0229.png


--------------------------------------------------------------------------------
/inputs/demo/bfr/aligned/0427.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/0427.png


--------------------------------------------------------------------------------
/inputs/demo/bfr/aligned/0722.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/0722.png


--------------------------------------------------------------------------------
/inputs/demo/bfr/aligned/hermione.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/aligned/hermione.jpg


--------------------------------------------------------------------------------
/inputs/demo/bfr/whole_img/01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/01.jpg


--------------------------------------------------------------------------------
/inputs/demo/bfr/whole_img/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/02.png


--------------------------------------------------------------------------------
/inputs/demo/bfr/whole_img/Audrey_Hepburn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Audrey_Hepburn.jpg


--------------------------------------------------------------------------------
/inputs/demo/bfr/whole_img/Blake_Lively.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Blake_Lively.jpg


--------------------------------------------------------------------------------
/inputs/demo/bfr/whole_img/Harry_Potter.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Harry_Potter.jpg


--------------------------------------------------------------------------------
/inputs/demo/bfr/whole_img/Queen.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/Queen.jpeg


--------------------------------------------------------------------------------
/inputs/demo/bfr/whole_img/real47_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bfr/whole_img/real47_1.jpg


--------------------------------------------------------------------------------
/inputs/demo/bid/Audrey_Hepburn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Audrey_Hepburn.jpg


--------------------------------------------------------------------------------
/inputs/demo/bid/Bears.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Bears.png


--------------------------------------------------------------------------------
/inputs/demo/bid/Flowers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Flowers.png


--------------------------------------------------------------------------------
/inputs/demo/bid/Movie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Movie.png


--------------------------------------------------------------------------------
/inputs/demo/bid/Postcards.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/Postcards.png


--------------------------------------------------------------------------------
/inputs/demo/bid/cty_fnb_0047.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/cty_fnb_0047.png


--------------------------------------------------------------------------------
/inputs/demo/bid/kf_fnb_0058.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/kf_fnb_0058.png


--------------------------------------------------------------------------------
/inputs/demo/bid/palace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bid/palace.png


--------------------------------------------------------------------------------
/inputs/demo/bsr/14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/14.jpg


--------------------------------------------------------------------------------
/inputs/demo/bsr/29.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/29.jpg


--------------------------------------------------------------------------------
/inputs/demo/bsr/49.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/49.jpg


--------------------------------------------------------------------------------
/inputs/demo/bsr/53.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/53.jpeg


--------------------------------------------------------------------------------
/inputs/demo/bsr/comic3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/demo/bsr/comic3.png


--------------------------------------------------------------------------------
/inputs/real47/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/1.jpg


--------------------------------------------------------------------------------
/inputs/real47/11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/11.jpg


--------------------------------------------------------------------------------
/inputs/real47/12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/12.jpg


--------------------------------------------------------------------------------
/inputs/real47/13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/13.jpg


--------------------------------------------------------------------------------
/inputs/real47/14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/14.jpg


--------------------------------------------------------------------------------
/inputs/real47/15.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/15.jpg


--------------------------------------------------------------------------------
/inputs/real47/16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/16.jpg


--------------------------------------------------------------------------------
/inputs/real47/17.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/17.jpg


--------------------------------------------------------------------------------
/inputs/real47/19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/19.jpg


--------------------------------------------------------------------------------
/inputs/real47/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/2.jpg


--------------------------------------------------------------------------------
/inputs/real47/20.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/20.jpg


--------------------------------------------------------------------------------
/inputs/real47/21.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/21.jpg


--------------------------------------------------------------------------------
/inputs/real47/22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/22.jpg


--------------------------------------------------------------------------------
/inputs/real47/23.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/23.jpg


--------------------------------------------------------------------------------
/inputs/real47/24.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/24.jpg


--------------------------------------------------------------------------------
/inputs/real47/26.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/26.jpg


--------------------------------------------------------------------------------
/inputs/real47/27.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/27.jpg


--------------------------------------------------------------------------------
/inputs/real47/29.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/29.jpg


--------------------------------------------------------------------------------
/inputs/real47/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/3.jpg


--------------------------------------------------------------------------------
/inputs/real47/32.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/32.jpg


--------------------------------------------------------------------------------
/inputs/real47/33.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/33.jpg


--------------------------------------------------------------------------------
/inputs/real47/34.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/34.jpg


--------------------------------------------------------------------------------
/inputs/real47/35.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/35.jpg


--------------------------------------------------------------------------------
/inputs/real47/36.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/36.png


--------------------------------------------------------------------------------
/inputs/real47/38.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/38.jpg


--------------------------------------------------------------------------------
/inputs/real47/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/4.jpg


--------------------------------------------------------------------------------
/inputs/real47/40.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/40.jpg


--------------------------------------------------------------------------------
/inputs/real47/41.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/41.jpg


--------------------------------------------------------------------------------
/inputs/real47/42.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/42.jpg


--------------------------------------------------------------------------------
/inputs/real47/43.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/43.jpg


--------------------------------------------------------------------------------
/inputs/real47/44.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/44.jpg


--------------------------------------------------------------------------------
/inputs/real47/45.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/45.jpg


--------------------------------------------------------------------------------
/inputs/real47/46.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/46.jpg


--------------------------------------------------------------------------------
/inputs/real47/47.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/47.jpg


--------------------------------------------------------------------------------
/inputs/real47/48.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/48.jpg


--------------------------------------------------------------------------------
/inputs/real47/49.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/49.jpg


--------------------------------------------------------------------------------
/inputs/real47/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/5.jpg


--------------------------------------------------------------------------------
/inputs/real47/50.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/50.jpg


--------------------------------------------------------------------------------
/inputs/real47/51.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/51.jpg


--------------------------------------------------------------------------------
/inputs/real47/52.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/52.jpg


--------------------------------------------------------------------------------
/inputs/real47/53.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/53.jpeg


--------------------------------------------------------------------------------
/inputs/real47/54.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/54.jpeg


--------------------------------------------------------------------------------
/inputs/real47/55.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/55.jpg


--------------------------------------------------------------------------------
/inputs/real47/56.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/56.jpg


--------------------------------------------------------------------------------
/inputs/real47/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/6.jpg


--------------------------------------------------------------------------------
/inputs/real47/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/7.jpg


--------------------------------------------------------------------------------
/inputs/real47/9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/inputs/real47/9.jpg


--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 | 


--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
 2 | WORKER_HEART_BEAT_INTERVAL = 15
 3 | 
 4 | LOGDIR = "."
 5 | 
 6 | # Model Constants
 7 | IGNORE_INDEX = -100
 8 | IMAGE_TOKEN_INDEX = -200
 9 | DEFAULT_IMAGE_TOKEN = "<image>"
10 | DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11 | DEFAULT_IM_START_TOKEN = "<im_start>"
12 | DEFAULT_IM_END_TOKEN = "<im_end>"
13 | IMAGE_PLACEHOLDER = "<image-placeholder>"
14 | 


--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import json
  3 | import os
  4 | 
  5 | import openai
  6 | import tqdm
  7 | import ray
  8 | import time
  9 | 
 10 | NUM_SECONDS_TO_SLEEP = 3
 11 | 
 12 | @ray.remote(num_cpus=4)
 13 | def get_eval(content: str, max_tokens: int):
 14 |     while True:
 15 |         try:
 16 |             response = openai.ChatCompletion.create(
 17 |                 model='gpt-4',
 18 |                 messages=[{
 19 |                     'role': 'system',
 20 |                     'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
 21 |                 }, {
 22 |                     'role': 'user',
 23 |                     'content': content,
 24 |                 }],
 25 |                 temperature=0.2,  # TODO: figure out which temperature is best for evaluation
 26 |                 max_tokens=max_tokens,
 27 |             )
 28 |             break
 29 |         except openai.error.RateLimitError:
 30 |             pass
 31 |         except Exception as e:
 32 |             print(e)
 33 |         time.sleep(NUM_SECONDS_TO_SLEEP)
 34 | 
 35 |     print('success!')
 36 |     return response['choices'][0]['message']['content']
 37 | 
 38 | 
 39 | def parse_score(review):
 40 |     try:
 41 |         score_pair = review.split('\n')[0]
 42 |         score_pair = score_pair.replace(',', ' ')
 43 |         sp = score_pair.split(' ')
 44 |         if len(sp) == 2:
 45 |             return [float(sp[0]), float(sp[1])]
 46 |         else:
 47 |             print('error', review)
 48 |             return [-1, -1]
 49 |     except Exception as e:
 50 |         print(e)
 51 |         print('error', review)
 52 |         return [-1, -1]
 53 | 
 54 | 
 55 | if __name__ == '__main__':
 56 |     parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
 57 |     parser.add_argument('-q', '--question')
 58 |     # parser.add_argument('-a', '--answer')
 59 |     parser.add_argument('-a', '--answer-list', nargs='+', default=[])
 60 |     parser.add_argument('-r', '--rule')
 61 |     parser.add_argument('-o', '--output')
 62 |     parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
 63 |     args = parser.parse_args()
 64 | 
 65 |     ray.init()
 66 | 
 67 |     f_q = open(os.path.expanduser(args.question))
 68 |     f_ans1 = open(os.path.expanduser(args.answer_list[0]))
 69 |     f_ans2 = open(os.path.expanduser(args.answer_list[1]))
 70 |     rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
 71 | 
 72 |     review_file = open(f'{args.output}', 'w')
 73 | 
 74 |     js_list = []
 75 |     handles = []
 76 |     idx = 0
 77 |     for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
 78 |         # if idx == 1:
 79 |         #     break
 80 | 
 81 |         ques = json.loads(ques_js)
 82 |         ans1 = json.loads(ans1_js)
 83 |         ans2 = json.loads(ans2_js)
 84 | 
 85 |         category = json.loads(ques_js)['category']
 86 |         if category in rule_dict:
 87 |             rule = rule_dict[category]
 88 |         else:
 89 |             rule = rule_dict['default']
 90 |         prompt = rule['prompt']
 91 |         role = rule['role']
 92 |         content = (f'[Question]\n{ques["text"]}\n\n'
 93 |                    f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
 94 |                    f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
 95 |                    f'[System]\n{prompt}\n\n')
 96 |         js_list.append({
 97 |             'id': idx+1,
 98 |             'question_id': ques['question_id'],
 99 |             'answer1_id': ans1['answer_id'],
100 |             'answer2_id': ans2['answer_id'],
101 |             'category': category})
102 |         idx += 1
103 |         handles.append(get_eval.remote(content, args.max_tokens))
104 |         # To avoid the rate limit set by OpenAI
105 |         time.sleep(NUM_SECONDS_TO_SLEEP)
106 | 
107 |     reviews = ray.get(handles)
108 |     for idx, review in enumerate(reviews):
109 |         scores = parse_score(review)
110 |         js_list[idx]['content'] = review
111 |         js_list[idx]['tuple'] = scores
112 |         review_file.write(json.dumps(js_list[idx]) + '\n')
113 |     review_file.close()
114 | 


--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_bench.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import json
  3 | import os
  4 | 
  5 | import openai
  6 | import time
  7 | 
  8 | NUM_SECONDS_TO_SLEEP = 0.5
  9 | 
 10 | 
 11 | def get_eval(content: str, max_tokens: int):
 12 |     while True:
 13 |         try:
 14 |             response = openai.ChatCompletion.create(
 15 |                 model='gpt-4-0314',
 16 |                 messages=[{
 17 |                     'role': 'system',
 18 |                     'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
 19 |                 }, {
 20 |                     'role': 'user',
 21 |                     'content': content,
 22 |                 }],
 23 |                 temperature=0.2,  # TODO: figure out which temperature is best for evaluation
 24 |                 max_tokens=max_tokens,
 25 |             )
 26 |             break
 27 |         except openai.error.RateLimitError:
 28 |             pass
 29 |         except Exception as e:
 30 |             print(e)
 31 |         time.sleep(NUM_SECONDS_TO_SLEEP)
 32 | 
 33 |     return response['choices'][0]['message']['content']
 34 | 
 35 | 
 36 | def parse_score(review):
 37 |     try:
 38 |         score_pair = review.split('\n')[0]
 39 |         score_pair = score_pair.replace(',', ' ')
 40 |         sp = score_pair.split(' ')
 41 |         if len(sp) == 2:
 42 |             return [float(sp[0]), float(sp[1])]
 43 |         else:
 44 |             print('error', review)
 45 |             return [-1, -1]
 46 |     except Exception as e:
 47 |         print(e)
 48 |         print('error', review)
 49 |         return [-1, -1]
 50 | 
 51 | 
 52 | if __name__ == '__main__':
 53 |     parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
 54 |     parser.add_argument('-q', '--question')
 55 |     parser.add_argument('-c', '--context')
 56 |     parser.add_argument('-a', '--answer-list', nargs='+', default=[])
 57 |     parser.add_argument('-r', '--rule')
 58 |     parser.add_argument('-o', '--output')
 59 |     parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
 60 |     args = parser.parse_args()
 61 | 
 62 |     f_q = open(os.path.expanduser(args.question))
 63 |     f_ans1 = open(os.path.expanduser(args.answer_list[0]))
 64 |     f_ans2 = open(os.path.expanduser(args.answer_list[1]))
 65 |     rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
 66 | 
 67 |     if os.path.isfile(os.path.expanduser(args.output)):
 68 |         cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
 69 |     else:
 70 |         cur_reviews = []
 71 | 
 72 |     review_file = open(f'{args.output}', 'a')
 73 | 
 74 |     context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
 75 |     image_to_context = {context['image']: context for context in context_list}
 76 | 
 77 |     handles = []
 78 |     idx = 0
 79 |     for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
 80 |         ques = json.loads(ques_js)
 81 |         ans1 = json.loads(ans1_js)
 82 |         ans2 = json.loads(ans2_js)
 83 | 
 84 |         inst = image_to_context[ques['image']]
 85 | 
 86 |         if isinstance(inst['caption'], list):
 87 |             cap_str = '\n'.join(inst['caption'])
 88 |         else:
 89 |             cap_str = inst['caption']
 90 | 
 91 |         category = 'llava_bench_' + json.loads(ques_js)['category']
 92 |         if category in rule_dict:
 93 |             rule = rule_dict[category]
 94 |         else:
 95 |             assert False, f"Visual QA category not found in rule file: {category}."
 96 |         prompt = rule['prompt']
 97 |         role = rule['role']
 98 |         content = (f'[Context]\n{cap_str}\n\n'
 99 |                    f'[Question]\n{ques["text"]}\n\n'
100 |                    f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101 |                    f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102 |                    f'[System]\n{prompt}\n\n')
103 |         cur_js = {
104 |             'id': idx+1,
105 |             'question_id': ques['question_id'],
106 |             'answer1_id': ans1.get('answer_id', ans1['question_id']),
107 |             'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108 |             'category': category
109 |         }
110 |         if idx >= len(cur_reviews):
111 |             review = get_eval(content, args.max_tokens)
112 |             scores = parse_score(review)
113 |             cur_js['content'] = review
114 |             cur_js['tuple'] = scores
115 |             review_file.write(json.dumps(cur_js) + '\n')
116 |             review_file.flush()
117 |         else:
118 |             print(f'Skipping {idx} as we already have it.')
119 |         idx += 1
120 |         print(idx)
121 |     review_file.close()
122 | 


--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_visual.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import json
  3 | import os
  4 | 
  5 | import openai
  6 | import time
  7 | 
  8 | NUM_SECONDS_TO_SLEEP = 0.5
  9 | 
 10 | 
 11 | def get_eval(content: str, max_tokens: int):
 12 |     while True:
 13 |         try:
 14 |             response = openai.ChatCompletion.create(
 15 |                 model='gpt-4-0314',
 16 |                 messages=[{
 17 |                     'role': 'system',
 18 |                     'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
 19 |                 }, {
 20 |                     'role': 'user',
 21 |                     'content': content,
 22 |                 }],
 23 |                 temperature=0.2,  # TODO: figure out which temperature is best for evaluation
 24 |                 max_tokens=max_tokens,
 25 |             )
 26 |             break
 27 |         except openai.error.RateLimitError:
 28 |             pass
 29 |         except Exception as e:
 30 |             print(e)
 31 |         time.sleep(NUM_SECONDS_TO_SLEEP)
 32 | 
 33 |     return response['choices'][0]['message']['content']
 34 | 
 35 | 
 36 | def parse_score(review):
 37 |     try:
 38 |         score_pair = review.split('\n')[0]
 39 |         score_pair = score_pair.replace(',', ' ')
 40 |         sp = score_pair.split(' ')
 41 |         if len(sp) == 2:
 42 |             return [float(sp[0]), float(sp[1])]
 43 |         else:
 44 |             print('error', review)
 45 |             return [-1, -1]
 46 |     except Exception as e:
 47 |         print(e)
 48 |         print('error', review)
 49 |         return [-1, -1]
 50 | 
 51 | 
 52 | if __name__ == '__main__':
 53 |     parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
 54 |     parser.add_argument('-q', '--question')
 55 |     parser.add_argument('-c', '--context')
 56 |     parser.add_argument('-a', '--answer-list', nargs='+', default=[])
 57 |     parser.add_argument('-r', '--rule')
 58 |     parser.add_argument('-o', '--output')
 59 |     parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
 60 |     args = parser.parse_args()
 61 | 
 62 |     f_q = open(os.path.expanduser(args.question))
 63 |     f_ans1 = open(os.path.expanduser(args.answer_list[0]))
 64 |     f_ans2 = open(os.path.expanduser(args.answer_list[1]))
 65 |     rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
 66 | 
 67 |     if os.path.isfile(os.path.expanduser(args.output)):
 68 |         cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
 69 |     else:
 70 |         cur_reviews = []
 71 | 
 72 |     review_file = open(f'{args.output}', 'a')
 73 | 
 74 |     context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
 75 |     image_to_context = {context['image']: context for context in context_list}
 76 | 
 77 |     handles = []
 78 |     idx = 0
 79 |     for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
 80 |         ques = json.loads(ques_js)
 81 |         ans1 = json.loads(ans1_js)
 82 |         ans2 = json.loads(ans2_js)
 83 | 
 84 |         inst = image_to_context[ques['image']]
 85 |         cap_str = '\n'.join(inst['captions'])
 86 |         box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
 87 | 
 88 |         category = json.loads(ques_js)['category']
 89 |         if category in rule_dict:
 90 |             rule = rule_dict[category]
 91 |         else:
 92 |             assert False, f"Visual QA category not found in rule file: {category}."
 93 |         prompt = rule['prompt']
 94 |         role = rule['role']
 95 |         content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
 96 |                    f'[Question]\n{ques["text"]}\n\n'
 97 |                    f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
 98 |                    f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
 99 |                    f'[System]\n{prompt}\n\n')
100 |         cur_js = {
101 |             'id': idx+1,
102 |             'question_id': ques['question_id'],
103 |             'answer1_id': ans1.get('answer_id', ans1['question_id']),
104 |             'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105 |             'category': category
106 |         }
107 |         if idx >= len(cur_reviews):
108 |             review = get_eval(content, args.max_tokens)
109 |             scores = parse_score(review)
110 |             cur_js['content'] = review
111 |             cur_js['tuple'] = scores
112 |             review_file.write(json.dumps(cur_js) + '\n')
113 |             review_file.flush()
114 |         else:
115 |             print(f'Skipping {idx} as we already have it.')
116 |         idx += 1
117 |         print(idx)
118 |     review_file.close()
119 | 


--------------------------------------------------------------------------------
/llava/eval/eval_pope.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import json
 3 | import argparse
 4 | 
 5 | def eval_pope(answers, label_file):
 6 |     label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
 7 | 
 8 |     for answer in answers:
 9 |         text = answer['text']
10 | 
11 |         # Only keep the first sentence
12 |         if text.find('.') != -1:
13 |             text = text.split('.')[0]
14 | 
15 |         text = text.replace(',', '')
16 |         words = text.split(' ')
17 |         if 'No' in words or 'not' in words or 'no' in words:
18 |             answer['text'] = 'no'
19 |         else:
20 |             answer['text'] = 'yes'
21 | 
22 |     for i in range(len(label_list)):
23 |         if label_list[i] == 'no':
24 |             label_list[i] = 0
25 |         else:
26 |             label_list[i] = 1
27 | 
28 |     pred_list = []
29 |     for answer in answers:
30 |         if answer['text'] == 'no':
31 |             pred_list.append(0)
32 |         else:
33 |             pred_list.append(1)
34 | 
35 |     pos = 1
36 |     neg = 0
37 |     yes_ratio = pred_list.count(1) / len(pred_list)
38 | 
39 |     TP, TN, FP, FN = 0, 0, 0, 0
40 |     for pred, label in zip(pred_list, label_list):
41 |         if pred == pos and label == pos:
42 |             TP += 1
43 |         elif pred == pos and label == neg:
44 |             FP += 1
45 |         elif pred == neg and label == neg:
46 |             TN += 1
47 |         elif pred == neg and label == pos:
48 |             FN += 1
49 | 
50 |     print('TP\tFP\tTN\tFN\t')
51 |     print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52 | 
53 |     precision = float(TP) / float(TP + FP)
54 |     recall = float(TP) / float(TP + FN)
55 |     f1 = 2*precision*recall / (precision + recall)
56 |     acc = (TP + TN) / (TP + TN + FP + FN)
57 |     print('Accuracy: {}'.format(acc))
58 |     print('Precision: {}'.format(precision))
59 |     print('Recall: {}'.format(recall))
60 |     print('F1 score: {}'.format(f1))
61 |     print('Yes ratio: {}'.format(yes_ratio))
62 |     print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63 | 
64 | if __name__ == "__main__":
65 |     parser = argparse.ArgumentParser()
66 |     parser.add_argument("--annotation-dir", type=str)
67 |     parser.add_argument("--question-file", type=str)
68 |     parser.add_argument("--result-file", type=str)
69 |     args = parser.parse_args()
70 | 
71 |     questions = [json.loads(line) for line in open(args.question_file)]
72 |     questions = {question['question_id']: question for question in questions}
73 |     answers = [json.loads(q) for q in open(args.result_file)]
74 |     for file in os.listdir(args.annotation_dir):
75 |         assert file.startswith('coco_pope_')
76 |         assert file.endswith('.json')
77 |         category = file[10:-5]
78 |         cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79 |         print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80 |         eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81 |         print("====================================")
82 | 


--------------------------------------------------------------------------------
/llava/eval/eval_science_qa.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import json
  3 | import os
  4 | import re
  5 | import random
  6 | 
  7 | 
  8 | def get_args():
  9 |     parser = argparse.ArgumentParser()
 10 |     parser.add_argument('--base-dir', type=str)
 11 |     parser.add_argument('--result-file', type=str)
 12 |     parser.add_argument('--output-file', type=str)
 13 |     parser.add_argument('--output-result', type=str)
 14 |     parser.add_argument('--split', type=str, default='test')
 15 |     parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
 16 |     return parser.parse_args()
 17 | 
 18 | 
 19 | def convert_caps(results):
 20 |     fakecaps = []
 21 |     for result in results:
 22 |         image_id = result['question_id']
 23 |         caption = result['text']
 24 |         fakecaps.append({"image_id": int(image_id), "caption": caption})
 25 |     return fakecaps
 26 | 
 27 | 
 28 | def get_pred_idx(prediction, choices, options):
 29 |     """
 30 |     Get the index (e.g. 2) from the prediction (e.g. 'C')
 31 |     """
 32 |     if prediction in options[:len(choices)]:
 33 |         return options.index(prediction)
 34 |     else:
 35 |         return -1
 36 |         return random.choice(range(len(choices)))
 37 | 
 38 | 
 39 | if __name__ == "__main__":
 40 |     args = get_args()
 41 | 
 42 |     base_dir = args.base_dir
 43 |     split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
 44 |     problems = json.load(open(os.path.join(base_dir, "problems.json")))
 45 |     predictions = [json.loads(line) for line in open(args.result_file)]
 46 |     predictions = {pred['question_id']: pred for pred in predictions}
 47 |     split_problems = {idx: problems[idx] for idx in split_indices}
 48 | 
 49 |     results = {'correct': [], 'incorrect': []}
 50 |     sqa_results = {}
 51 |     sqa_results['acc'] = None
 52 |     sqa_results['correct'] = None
 53 |     sqa_results['count'] = None
 54 |     sqa_results['results'] = {}
 55 |     sqa_results['outputs'] = {}
 56 | 
 57 |     for prob_id, prob in split_problems.items():
 58 |         if prob_id not in predictions:
 59 |             pred = {'text': 'FAILED', 'prompt': 'Unknown'}
 60 |             pred_text = 'FAILED'
 61 |         else:
 62 |             pred = predictions[prob_id]
 63 |             pred_text = pred['text']
 64 | 
 65 |         if pred_text in args.options:
 66 |             answer = pred_text
 67 |         elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
 68 |             answer = pred_text[0]
 69 |         else:
 70 |             pattern = re.compile(r'The answer is ([A-Z]).')
 71 |             res = pattern.findall(pred_text)
 72 |             if len(res) == 1:
 73 |                 answer = res[0]  # 'A', 'B', ...
 74 |             else:
 75 |                 answer = "FAILED"
 76 | 
 77 |         pred_idx = get_pred_idx(answer, prob['choices'], args.options)
 78 | 
 79 |         analysis = {
 80 |             'question_id': prob_id,
 81 |             'parsed_ans': answer,
 82 |             'ground_truth': args.options[prob['answer']],
 83 |             'question': pred['prompt'],
 84 |             'pred': pred_text,
 85 |             'is_multimodal': '<image>' in pred['prompt'],
 86 |         }
 87 | 
 88 |         sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
 89 |         sqa_results['outputs'][prob_id] = pred_text
 90 | 
 91 |         if pred_idx == prob['answer']:
 92 |             results['correct'].append(analysis)
 93 |         else:
 94 |             results['incorrect'].append(analysis)
 95 | 
 96 |     correct = len(results['correct'])
 97 |     total = len(results['correct']) + len(results['incorrect'])
 98 | 
 99 |     ###### IMG ######
100 |     multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101 |     multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102 |     multimodal_total = multimodal_correct + multimodal_incorrect
103 |     ###### IMG ######
104 | 
105 |     print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106 | 
107 |     sqa_results['acc'] = correct / total * 100
108 |     sqa_results['correct'] = correct
109 |     sqa_results['count'] = total
110 | 
111 |     with open(args.output_file, 'w') as f:
112 |         json.dump(results, f, indent=2)
113 |     with open(args.output_result, 'w') as f:
114 |         json.dump(sqa_results, f, indent=2)
115 | 


--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import json
  3 | import os
  4 | import re
  5 | import random
  6 | from collections import defaultdict
  7 | 
  8 | 
  9 | def get_args():
 10 |     parser = argparse.ArgumentParser()
 11 |     parser.add_argument('--base-dir', type=str)
 12 |     parser.add_argument('--gpt4-result', type=str)
 13 |     parser.add_argument('--our-result', type=str)
 14 |     parser.add_argument('--split', type=str, default='test')
 15 |     parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
 16 |     return parser.parse_args()
 17 | 
 18 | 
 19 | def convert_caps(results):
 20 |     fakecaps = []
 21 |     for result in results:
 22 |         image_id = result['question_id']
 23 |         caption = result['text']
 24 |         fakecaps.append({"image_id": int(image_id), "caption": caption})
 25 |     return fakecaps
 26 | 
 27 | 
 28 | def get_pred_idx(prediction, choices, options):
 29 |     """
 30 |     Get the index (e.g. 2) from the prediction (e.g. 'C')
 31 |     """
 32 |     if prediction in options[:len(choices)]:
 33 |         return options.index(prediction)
 34 |     else:
 35 |         return random.choice(range(len(choices)))
 36 | 
 37 | 
 38 | if __name__ == "__main__":
 39 |     args = get_args()
 40 | 
 41 |     base_dir = args.base_dir
 42 |     split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
 43 |     problems = json.load(open(os.path.join(base_dir, "problems.json")))
 44 |     our_predictions = [json.loads(line) for line in open(args.our_result)]
 45 |     our_predictions = {pred['question_id']: pred for pred in our_predictions}
 46 |     split_problems = {idx: problems[idx] for idx in split_indices}
 47 | 
 48 |     gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
 49 | 
 50 |     results = defaultdict(lambda: 0)
 51 | 
 52 |     for prob_id, prob in split_problems.items():
 53 |         if prob_id not in our_predictions:
 54 |             continue
 55 |         if prob_id not in gpt4_predictions:
 56 |             continue
 57 |         our_pred = our_predictions[prob_id]['text']
 58 |         gpt4_pred = gpt4_predictions[prob_id]
 59 | 
 60 |         pattern = re.compile(r'The answer is ([A-Z]).')
 61 |         our_res = pattern.findall(our_pred)
 62 |         if len(our_res) == 1:
 63 |             our_answer = our_res[0]  # 'A', 'B', ...
 64 |         else:
 65 |             our_answer = "FAILED"
 66 |         gpt4_res = pattern.findall(gpt4_pred)
 67 |         if len(gpt4_res) == 1:
 68 |             gpt4_answer = gpt4_res[0]  # 'A', 'B', ...
 69 |         else:
 70 |             gpt4_answer = "FAILED"
 71 | 
 72 |         our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
 73 |         gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
 74 | 
 75 |         if gpt4_answer == 'FAILED':
 76 |             results['gpt4_failed'] += 1
 77 |             # continue
 78 |             gpt4_pred_idx = our_pred_idx
 79 |             # if our_pred_idx != prob['answer']:
 80 |             #     print(our_predictions[prob_id]['prompt'])
 81 |             #     print('-----------------')
 82 |             #     print(f'LECTURE: {prob["lecture"]}')
 83 |             #     print(f'SOLUTION: {prob["solution"]}')
 84 |             #     print('=====================')
 85 |         else:
 86 |             # continue
 87 |             pass
 88 |         # gpt4_pred_idx = our_pred_idx
 89 | 
 90 |         if gpt4_pred_idx == prob['answer']:
 91 |             results['correct'] += 1
 92 |         else:
 93 |             results['incorrect'] += 1
 94 | 
 95 | 
 96 |         if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
 97 |             results['correct_upperbound'] += 1
 98 | 
 99 |     correct = results['correct']
100 |     total = results['correct'] + results['incorrect']
101 |     print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 |     print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 |     print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 | 
105 | 


--------------------------------------------------------------------------------
/llava/eval/eval_textvqa.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import argparse
 3 | import json
 4 | import re
 5 | 
 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
 7 | 
 8 | 
 9 | def get_args():
10 |     parser = argparse.ArgumentParser()
11 |     parser.add_argument('--annotation-file', type=str)
12 |     parser.add_argument('--result-file', type=str)
13 |     parser.add_argument('--result-dir', type=str)
14 |     return parser.parse_args()
15 | 
16 | 
17 | def prompt_processor(prompt):
18 |     if prompt.startswith('OCR tokens: '):
19 |         pattern = r"Question: (.*?) Short answer:"
20 |         match = re.search(pattern, prompt, re.DOTALL)
21 |         question = match.group(1)
22 |     elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23 |         if prompt.startswith('Reference OCR token:'):
24 |             question = prompt.split('\n')[1]
25 |         else:
26 |             question = prompt.split('\n')[0]
27 |     elif len(prompt.split('\n')) == 2:
28 |         question = prompt.split('\n')[0]
29 |     else:
30 |         assert False
31 | 
32 |     return question.lower()
33 | 
34 | 
35 | def eval_single(annotation_file, result_file):
36 |     experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37 |     print(experiment_name)
38 |     annotations = json.load(open(annotation_file))['data']
39 |     annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40 |     results = [json.loads(line) for line in open(result_file)]
41 | 
42 |     pred_list = []
43 |     for result in results:
44 |         annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45 |         pred_list.append({
46 |             "pred_answer": result['text'],
47 |             "gt_answers": annotation['answers'],
48 |         })
49 | 
50 |     evaluator = TextVQAAccuracyEvaluator()
51 |     print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52 | 
53 | 
54 | if __name__ == "__main__":
55 |     args = get_args()
56 | 
57 |     if args.result_file is not None:
58 |         eval_single(args.annotation_file, args.result_file)
59 | 
60 |     if args.result_dir is not None:
61 |         for result_file in sorted(os.listdir(args.result_dir)):
62 |             if not result_file.endswith('.jsonl'):
63 |                 print(f'Skipping {result_file}')
64 |                 continue
65 |             eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
66 | 


--------------------------------------------------------------------------------
/llava/eval/generate_webpage_data_from_table.py:
--------------------------------------------------------------------------------
  1 | """Generate json file for webpage."""
  2 | import json
  3 | import os
  4 | import re
  5 | 
  6 | # models = ['llama', 'alpaca', 'gpt35', 'bard']
  7 | models = ['vicuna']
  8 | 
  9 | 
 10 | def read_jsonl(path: str, key: str=None):
 11 |     data = []
 12 |     with open(os.path.expanduser(path)) as f:
 13 |         for line in f:
 14 |             if not line:
 15 |                 continue
 16 |             data.append(json.loads(line))
 17 |     if key is not None:
 18 |         data.sort(key=lambda x: x[key])
 19 |         data = {item[key]: item for item in data}
 20 |     return data
 21 | 
 22 | 
 23 | def trim_hanging_lines(s: str, n: int) -> str:
 24 |     s = s.strip()
 25 |     for _ in range(n):
 26 |         s = s.split('\n', 1)[1].strip()
 27 |     return s
 28 | 
 29 | 
 30 | if __name__ == '__main__':
 31 |     questions = read_jsonl('table/question.jsonl', key='question_id')
 32 | 
 33 |     # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
 34 |     # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
 35 |     # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
 36 |     # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
 37 |     vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
 38 |     ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
 39 | 
 40 |     review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
 41 |     # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
 42 |     # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
 43 |     # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
 44 |     # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
 45 | 
 46 |     records = []
 47 |     for qid in questions.keys():
 48 |         r = {
 49 |             'id': qid,
 50 |             'category': questions[qid]['category'],
 51 |             'question': questions[qid]['text'],
 52 |             'answers': {
 53 |                 # 'alpaca': alpaca_answers[qid]['text'],
 54 |                 # 'llama': llama_answers[qid]['text'],
 55 |                 # 'bard': bard_answers[qid]['text'],
 56 |                 # 'gpt35': gpt35_answers[qid]['text'],
 57 |                 'vicuna': vicuna_answers[qid]['text'],
 58 |                 'ours': ours_answers[qid]['text'],
 59 |             },
 60 |             'evaluations': {
 61 |                 # 'alpaca': review_alpaca[qid]['text'],
 62 |                 # 'llama': review_llama[qid]['text'],
 63 |                 # 'bard': review_bard[qid]['text'],
 64 |                 'vicuna': review_vicuna[qid]['content'],
 65 |                 # 'gpt35': review_gpt35[qid]['text'],
 66 |             },
 67 |             'scores': {
 68 |                 'vicuna': review_vicuna[qid]['tuple'],
 69 |                 # 'alpaca': review_alpaca[qid]['score'],
 70 |                 # 'llama': review_llama[qid]['score'],
 71 |                 # 'bard': review_bard[qid]['score'],
 72 |                 # 'gpt35': review_gpt35[qid]['score'],
 73 |             },
 74 |         }
 75 | 
 76 |         # cleanup data
 77 |         cleaned_evals = {}
 78 |         for k, v in r['evaluations'].items():
 79 |             v = v.strip()
 80 |             lines = v.split('\n')
 81 |             # trim the first line if it's a pair of numbers
 82 |             if re.match(r'\d+[, ]+\d+', lines[0]):
 83 |                 lines = lines[1:]
 84 |             v = '\n'.join(lines)
 85 |             cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
 86 | 
 87 |         r['evaluations'] = cleaned_evals
 88 |         records.append(r)
 89 | 
 90 |     # Reorder the records, this is optional
 91 |     for r in records:
 92 |         if r['id'] <= 20:
 93 |             r['id'] += 60
 94 |         else:
 95 |             r['id'] -= 20
 96 |     for r in records:
 97 |         if r['id'] <= 50:
 98 |             r['id'] += 10
 99 |         elif 50 < r['id'] <= 60:
100 |             r['id'] -= 50
101 |     for r in records:
102 |         if r['id'] == 7:
103 |             r['id'] = 1
104 |         elif r['id'] < 7:
105 |             r['id'] += 1 
106 | 
107 |     records.sort(key=lambda x: x['id'])
108 | 
109 |     # Write to file
110 |     with open('webpage/data.json', 'w') as f:
111 |         json.dump({'questions': records, 'models': models}, f, indent=2)
112 | 


--------------------------------------------------------------------------------
/llava/eval/model_qa.py:
--------------------------------------------------------------------------------
 1 | import argparse
 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
 3 | import torch
 4 | import os
 5 | import json
 6 | from tqdm import tqdm
 7 | import shortuuid
 8 | 
 9 | from llava.conversation import default_conversation
10 | from llava.utils import disable_torch_init
11 | 
12 | 
13 | @torch.inference_mode()
14 | def eval_model(model_name, questions_file, answers_file):
15 |     # Model
16 |     disable_torch_init()
17 |     model_name = os.path.expanduser(model_name)
18 |     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19 |     model = AutoModelForCausalLM.from_pretrained(model_name,
20 |         torch_dtype=torch.float16).cuda()
21 | 
22 | 
23 |     ques_file = open(os.path.expanduser(questions_file), "r")
24 |     ans_file = open(os.path.expanduser(answers_file), "w")
25 |     for i, line in enumerate(tqdm(ques_file)):
26 |         idx = json.loads(line)["question_id"]
27 |         qs = json.loads(line)["text"]
28 |         cat = json.loads(line)["category"]
29 |         conv = default_conversation.copy()
30 |         conv.append_message(conv.roles[0], qs)
31 |         prompt = conv.get_prompt()
32 |         inputs = tokenizer([prompt])
33 |         input_ids = torch.as_tensor(inputs.input_ids).cuda()
34 |         output_ids = model.generate(
35 |             input_ids,
36 |             do_sample=True,
37 |             use_cache=True,
38 |             temperature=0.7,
39 |             max_new_tokens=1024,)
40 |         outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
41 |         try:
42 |             index = outputs.index(conv.sep, len(prompt))
43 |         except ValueError:
44 |             outputs += conv.sep
45 |             index = outputs.index(conv.sep, len(prompt))
46 | 
47 |         outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
48 |         ans_id = shortuuid.uuid()
49 |         ans_file.write(json.dumps({"question_id": idx,
50 |                                    "text": outputs,
51 |                                    "answer_id": ans_id,
52 |                                    "model_id": model_name,
53 |                                    "metadata": {}}) + "\n")
54 |         ans_file.flush()
55 |     ans_file.close()
56 | 
57 | if __name__ == "__main__":
58 |     parser = argparse.ArgumentParser()
59 |     parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
60 |     parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
61 |     parser.add_argument("--answers-file", type=str, default="answer.jsonl")
62 |     args = parser.parse_args()
63 | 
64 |     eval_model(args.model_name, args.question_file, args.answers_file)
65 | 


--------------------------------------------------------------------------------
/llava/eval/model_vqa.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import torch
  3 | import os
  4 | import json
  5 | from tqdm import tqdm
  6 | import shortuuid
  7 | 
  8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
  9 | from llava.conversation import conv_templates, SeparatorStyle
 10 | from llava.model.builder import load_pretrained_model
 11 | from llava.utils import disable_torch_init
 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
 13 | 
 14 | from PIL import Image
 15 | import math
 16 | 
 17 | 
 18 | def split_list(lst, n):
 19 |     """Split a list into n (roughly) equal-sized chunks"""
 20 |     chunk_size = math.ceil(len(lst) / n)  # integer division
 21 |     return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
 22 | 
 23 | 
 24 | def get_chunk(lst, n, k):
 25 |     chunks = split_list(lst, n)
 26 |     return chunks[k]
 27 | 
 28 | 
 29 | def eval_model(args):
 30 |     # Model
 31 |     disable_torch_init()
 32 |     model_path = os.path.expanduser(args.model_path)
 33 |     model_name = get_model_name_from_path(model_path)
 34 |     tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
 35 | 
 36 |     questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
 37 |     questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
 38 |     answers_file = os.path.expanduser(args.answers_file)
 39 |     os.makedirs(os.path.dirname(answers_file), exist_ok=True)
 40 |     ans_file = open(answers_file, "w")
 41 |     for line in tqdm(questions):
 42 |         idx = line["question_id"]
 43 |         image_file = line["image"]
 44 |         qs = line["text"]
 45 |         cur_prompt = qs
 46 |         if model.config.mm_use_im_start_end:
 47 |             qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
 48 |         else:
 49 |             qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
 50 | 
 51 |         conv = conv_templates[args.conv_mode].copy()
 52 |         conv.append_message(conv.roles[0], qs)
 53 |         conv.append_message(conv.roles[1], None)
 54 |         prompt = conv.get_prompt()
 55 | 
 56 |         input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
 57 | 
 58 |         image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
 59 |         image_tensor = process_images([image], image_processor, model.config)[0]
 60 | 
 61 |         with torch.inference_mode():
 62 |             output_ids = model.generate(
 63 |                 input_ids,
 64 |                 images=image_tensor.unsqueeze(0).half().cuda(),
 65 |                 image_sizes=[image.size],
 66 |                 do_sample=True if args.temperature > 0 else False,
 67 |                 temperature=args.temperature,
 68 |                 top_p=args.top_p,
 69 |                 num_beams=args.num_beams,
 70 |                 # no_repeat_ngram_size=3,
 71 |                 max_new_tokens=1024,
 72 |                 use_cache=True)
 73 | 
 74 |         outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
 75 | 
 76 |         ans_id = shortuuid.uuid()
 77 |         ans_file.write(json.dumps({"question_id": idx,
 78 |                                    "prompt": cur_prompt,
 79 |                                    "text": outputs,
 80 |                                    "answer_id": ans_id,
 81 |                                    "model_id": model_name,
 82 |                                    "metadata": {}}) + "\n")
 83 |         ans_file.flush()
 84 |     ans_file.close()
 85 | 
 86 | if __name__ == "__main__":
 87 |     parser = argparse.ArgumentParser()
 88 |     parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
 89 |     parser.add_argument("--model-base", type=str, default=None)
 90 |     parser.add_argument("--image-folder", type=str, default="")
 91 |     parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
 92 |     parser.add_argument("--answers-file", type=str, default="answer.jsonl")
 93 |     parser.add_argument("--conv-mode", type=str, default="llava_v1")
 94 |     parser.add_argument("--num-chunks", type=int, default=1)
 95 |     parser.add_argument("--chunk-idx", type=int, default=0)
 96 |     parser.add_argument("--temperature", type=float, default=0.2)
 97 |     parser.add_argument("--top_p", type=float, default=None)
 98 |     parser.add_argument("--num_beams", type=int, default=1)
 99 |     args = parser.parse_args()
100 | 
101 |     eval_model(args)
102 | 


--------------------------------------------------------------------------------
/llava/eval/model_vqa_science.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import torch
  3 | import os
  4 | import json
  5 | from tqdm import tqdm
  6 | import shortuuid
  7 | 
  8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
  9 | from llava.conversation import conv_templates, SeparatorStyle
 10 | from llava.model.builder import load_pretrained_model
 11 | from llava.utils import disable_torch_init
 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
 13 | 
 14 | from PIL import Image
 15 | import math
 16 | 
 17 | 
 18 | def split_list(lst, n):
 19 |     """Split a list into n (roughly) equal-sized chunks"""
 20 |     chunk_size = math.ceil(len(lst) / n)  # integer division
 21 |     return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
 22 | 
 23 | 
 24 | def get_chunk(lst, n, k):
 25 |     chunks = split_list(lst, n)
 26 |     return chunks[k]
 27 | 
 28 | 
 29 | def eval_model(args):
 30 |     # Model
 31 |     disable_torch_init()
 32 |     model_path = os.path.expanduser(args.model_path)
 33 |     model_name = get_model_name_from_path(model_path)
 34 |     tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
 35 | 
 36 |     questions = json.load(open(os.path.expanduser(args.question_file), "r"))
 37 |     questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
 38 |     answers_file = os.path.expanduser(args.answers_file)
 39 |     os.makedirs(os.path.dirname(answers_file), exist_ok=True)
 40 |     ans_file = open(answers_file, "w")
 41 |     for i, line in enumerate(tqdm(questions)):
 42 |         idx = line["id"]
 43 |         question = line['conversations'][0]
 44 |         qs = question['value'].replace('<image>', '').strip()
 45 |         cur_prompt = qs
 46 | 
 47 |         if 'image' in line:
 48 |             image_file = line["image"]
 49 |             image = Image.open(os.path.join(args.image_folder, image_file))
 50 |             image_tensor = process_images([image], image_processor, model.config)[0]
 51 |             images = image_tensor.unsqueeze(0).half().cuda()
 52 |             image_sizes = [image.size]
 53 |             if getattr(model.config, 'mm_use_im_start_end', False):
 54 |                 qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
 55 |             else:
 56 |                 qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
 57 |             cur_prompt = '<image>' + '\n' + cur_prompt
 58 |         else:
 59 |             images = None
 60 |             image_sizes = None
 61 | 
 62 |         if args.single_pred_prompt:
 63 |             qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
 64 |             cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
 65 | 
 66 |         conv = conv_templates[args.conv_mode].copy()
 67 |         conv.append_message(conv.roles[0], qs)
 68 |         conv.append_message(conv.roles[1], None)
 69 |         prompt = conv.get_prompt()
 70 | 
 71 |         input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
 72 | 
 73 |         with torch.inference_mode():
 74 |             output_ids = model.generate(
 75 |                 input_ids,
 76 |                 images=images,
 77 |                 image_sizes=image_sizes,
 78 |                 do_sample=True if args.temperature > 0 else False,
 79 |                 temperature=args.temperature,
 80 |                 max_new_tokens=1024,
 81 |                 use_cache=True,
 82 |             )
 83 | 
 84 |         outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
 85 | 
 86 |         ans_id = shortuuid.uuid()
 87 |         ans_file.write(json.dumps({"question_id": idx,
 88 |                                    "prompt": cur_prompt,
 89 |                                    "text": outputs,
 90 |                                    "answer_id": ans_id,
 91 |                                    "model_id": model_name,
 92 |                                    "metadata": {}}) + "\n")
 93 |         ans_file.flush()
 94 |     ans_file.close()
 95 | 
 96 | if __name__ == "__main__":
 97 |     parser = argparse.ArgumentParser()
 98 |     parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
 99 |     parser.add_argument("--model-base", type=str, default=None)
100 |     parser.add_argument("--image-folder", type=str, default="")
101 |     parser.add_argument("--question-file", type=str, default="tables/question.json")
102 |     parser.add_argument("--answers-file", type=str, default="answer.jsonl")
103 |     parser.add_argument("--conv-mode", type=str, default="llava_v0")
104 |     parser.add_argument("--num-chunks", type=int, default=1)
105 |     parser.add_argument("--chunk-idx", type=int, default=0)
106 |     parser.add_argument("--temperature", type=float, default=0.2)
107 |     parser.add_argument("--answer-prompter", action="store_true")
108 |     parser.add_argument("--single-pred-prompt", action="store_true")
109 |     args = parser.parse_args()
110 | 
111 |     eval_model(args)
112 | 


--------------------------------------------------------------------------------
/llava/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
 1 | """Generate answers with GPT-3.5"""
 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
 3 | import argparse
 4 | import json
 5 | import os
 6 | import time
 7 | import concurrent.futures
 8 | 
 9 | import openai
10 | import tqdm
11 | import shortuuid
12 | 
13 | MODEL = 'gpt-3.5-turbo'
14 | MODEL_ID = 'gpt-3.5-turbo:20230327'
15 | 
16 | def get_answer(question_id: int, question: str, max_tokens: int):
17 |     ans = {
18 |         'answer_id': shortuuid.uuid(),
19 |         'question_id': question_id,
20 |         'model_id': MODEL_ID,
21 |     }
22 |     for _ in range(3):
23 |         try:
24 |             response = openai.ChatCompletion.create(
25 |                 model=MODEL,
26 |                 messages=[{
27 |                     'role': 'system',
28 |                     'content': 'You are a helpful assistant.'
29 |                 }, {
30 |                     'role': 'user',
31 |                     'content': question,
32 |                 }],
33 |                 max_tokens=max_tokens,
34 |             )
35 |             ans['text'] = response['choices'][0]['message']['content']
36 |             return ans
37 |         except Exception as e:
38 |             print('[ERROR]', e)
39 |             ans['text'] = '#ERROR#'
40 |             time.sleep(1)
41 |     return ans
42 | 
43 | 
44 | if __name__ == '__main__':
45 |     parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46 |     parser.add_argument('-q', '--question')
47 |     parser.add_argument('-o', '--output')
48 |     parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49 |     args = parser.parse_args()
50 | 
51 |     questions_dict = {}
52 |     with open(os.path.expanduser(args.question)) as f:
53 |         for line in f:
54 |             if not line:
55 |                 continue
56 |             q = json.loads(line)
57 |             questions_dict[q['question_id']] = q['text']
58 | 
59 |     answers = []
60 | 
61 |     with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62 |         futures = []
63 |         for qid, question in questions_dict.items():
64 |             future = executor.submit(get_answer, qid, question, args.max_tokens)
65 |             futures.append(future)
66 | 
67 |         for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68 |             answers.append(future.result())
69 | 
70 |     answers.sort(key=lambda x: x['question_id'])
71 | 
72 |     with open(os.path.expanduser(args.output), 'w') as f:
73 |         table = [json.dumps(ans) for ans in answers]
74 |         f.write('\n'.join(table))
75 | 


--------------------------------------------------------------------------------
/llava/eval/run_llava.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import torch
  3 | 
  4 | from llava.constants import (
  5 |     IMAGE_TOKEN_INDEX,
  6 |     DEFAULT_IMAGE_TOKEN,
  7 |     DEFAULT_IM_START_TOKEN,
  8 |     DEFAULT_IM_END_TOKEN,
  9 |     IMAGE_PLACEHOLDER,
 10 | )
 11 | from llava.conversation import conv_templates, SeparatorStyle
 12 | from llava.model.builder import load_pretrained_model
 13 | from llava.utils import disable_torch_init
 14 | from llava.mm_utils import (
 15 |     process_images,
 16 |     tokenizer_image_token,
 17 |     get_model_name_from_path,
 18 | )
 19 | 
 20 | from PIL import Image
 21 | 
 22 | import requests
 23 | from PIL import Image
 24 | from io import BytesIO
 25 | import re
 26 | 
 27 | 
 28 | def image_parser(args):
 29 |     out = args.image_file.split(args.sep)
 30 |     return out
 31 | 
 32 | 
 33 | def load_image(image_file):
 34 |     if image_file.startswith("http") or image_file.startswith("https"):
 35 |         response = requests.get(image_file)
 36 |         image = Image.open(BytesIO(response.content)).convert("RGB")
 37 |     else:
 38 |         image = Image.open(image_file).convert("RGB")
 39 |     return image
 40 | 
 41 | 
 42 | def load_images(image_files):
 43 |     out = []
 44 |     for image_file in image_files:
 45 |         image = load_image(image_file)
 46 |         out.append(image)
 47 |     return out
 48 | 
 49 | 
 50 | def eval_model(args):
 51 |     # Model
 52 |     disable_torch_init()
 53 | 
 54 |     model_name = get_model_name_from_path(args.model_path)
 55 |     tokenizer, model, image_processor, context_len = load_pretrained_model(
 56 |         args.model_path, args.model_base, model_name
 57 |     )
 58 | 
 59 |     qs = args.query
 60 |     image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
 61 |     if IMAGE_PLACEHOLDER in qs:
 62 |         if model.config.mm_use_im_start_end:
 63 |             qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
 64 |         else:
 65 |             qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
 66 |     else:
 67 |         if model.config.mm_use_im_start_end:
 68 |             qs = image_token_se + "\n" + qs
 69 |         else:
 70 |             qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
 71 | 
 72 |     if "llama-2" in model_name.lower():
 73 |         conv_mode = "llava_llama_2"
 74 |     elif "mistral" in model_name.lower():
 75 |         conv_mode = "mistral_instruct"
 76 |     elif "v1.6-34b" in model_name.lower():
 77 |         conv_mode = "chatml_direct"
 78 |     elif "v1" in model_name.lower():
 79 |         conv_mode = "llava_v1"
 80 |     elif "mpt" in model_name.lower():
 81 |         conv_mode = "mpt"
 82 |     else:
 83 |         conv_mode = "llava_v0"
 84 | 
 85 |     if args.conv_mode is not None and conv_mode != args.conv_mode:
 86 |         print(
 87 |             "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
 88 |                 conv_mode, args.conv_mode, args.conv_mode
 89 |             )
 90 |         )
 91 |     else:
 92 |         args.conv_mode = conv_mode
 93 | 
 94 |     conv = conv_templates[args.conv_mode].copy()
 95 |     conv.append_message(conv.roles[0], qs)
 96 |     conv.append_message(conv.roles[1], None)
 97 |     prompt = conv.get_prompt()
 98 | 
 99 |     image_files = image_parser(args)
100 |     images = load_images(image_files)
101 |     image_sizes = [x.size for x in images]
102 |     images_tensor = process_images(
103 |         images,
104 |         image_processor,
105 |         model.config
106 |     ).to(model.device, dtype=torch.float16)
107 | 
108 |     input_ids = (
109 |         tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
110 |         .unsqueeze(0)
111 |         .cuda()
112 |     )
113 | 
114 |     with torch.inference_mode():
115 |         output_ids = model.generate(
116 |             input_ids,
117 |             images=images_tensor,
118 |             image_sizes=image_sizes,
119 |             do_sample=True if args.temperature > 0 else False,
120 |             temperature=args.temperature,
121 |             top_p=args.top_p,
122 |             num_beams=args.num_beams,
123 |             max_new_tokens=args.max_new_tokens,
124 |             use_cache=True,
125 |         )
126 | 
127 |     outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
128 |     print(outputs)
129 | 
130 | 
131 | if __name__ == "__main__":
132 |     parser = argparse.ArgumentParser()
133 |     parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
134 |     parser.add_argument("--model-base", type=str, default=None)
135 |     parser.add_argument("--image-file", type=str, required=True)
136 |     parser.add_argument("--query", type=str, required=True)
137 |     parser.add_argument("--conv-mode", type=str, default=None)
138 |     parser.add_argument("--sep", type=str, default=",")
139 |     parser.add_argument("--temperature", type=float, default=0.2)
140 |     parser.add_argument("--top_p", type=float, default=None)
141 |     parser.add_argument("--num_beams", type=int, default=1)
142 |     parser.add_argument("--max_new_tokens", type=int, default=512)
143 |     args = parser.parse_args()
144 | 
145 |     eval_model(args)
146 | 


--------------------------------------------------------------------------------
/llava/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
 1 | import json
 2 | import os
 3 | from collections import defaultdict
 4 | 
 5 | import numpy as np
 6 | 
 7 | import argparse
 8 | 
 9 | def parse_args():
10 |     parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11 |     parser.add_argument('-d', '--dir', default=None)
12 |     parser.add_argument('-v', '--version', default=None)
13 |     parser.add_argument('-s', '--select', nargs='*', default=None)
14 |     parser.add_argument('-f', '--files', nargs='*', default=[])
15 |     parser.add_argument('-i', '--ignore', nargs='*', default=[])
16 |     return parser.parse_args()
17 | 
18 | 
19 | if __name__ == '__main__':
20 |     args = parse_args()
21 | 
22 |     if args.ignore is not None:
23 |         args.ignore = [int(x) for x in args.ignore]
24 | 
25 |     if len(args.files) > 0:
26 |         review_files = args.files
27 |     else:
28 |         review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29 | 
30 |     for review_file in sorted(review_files):
31 |         config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32 |         if args.select is not None and any(x not in config for x in args.select):
33 |             continue
34 |         if '0613' in config:
35 |             version = '0613'
36 |         else:
37 |             version = '0314'
38 |         if args.version is not None and args.version != version:
39 |             continue
40 |         scores = defaultdict(list)
41 |         print(config)
42 |         with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43 |             for review_str in f:
44 |                 review = json.loads(review_str)
45 |                 if review['question_id'] in args.ignore:
46 |                     continue
47 |                 if 'category' in review:
48 |                     scores[review['category']].append(review['tuple'])
49 |                     scores['all'].append(review['tuple'])
50 |                 else:
51 |                     if 'tuple' in review:
52 |                         scores['all'].append(review['tuple'])
53 |                     else:
54 |                         scores['all'].append(review['score'])
55 |         for k, v in sorted(scores.items()):
56 |             stats = np.asarray(v).mean(0).tolist()
57 |             stats = [round(x, 3) for x in stats]
58 |             # print(k, stats, round(stats[1]/stats[0]*100, 1))
59 |             print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60 |         print('=================================')
61 | 


--------------------------------------------------------------------------------
/llava/eval/table/model.jsonl:
--------------------------------------------------------------------------------
1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
6 | 


--------------------------------------------------------------------------------
/llava/eval/table/prompt.jsonl:
--------------------------------------------------------------------------------
1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"}
2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"}
3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"}
4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"}
5 | 


--------------------------------------------------------------------------------
/llava/eval/table/reviewer.jsonl:
--------------------------------------------------------------------------------
1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
5 | 


--------------------------------------------------------------------------------
/llava/eval/webpage/figures/alpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/alpaca.png


--------------------------------------------------------------------------------
/llava/eval/webpage/figures/bard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/bard.jpg


--------------------------------------------------------------------------------
/llava/eval/webpage/figures/chatgpt.svg:
--------------------------------------------------------------------------------
1 | <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 2406 2406"><path d="M1 578.4C1 259.5 259.5 1 578.4 1h1249.1c319 0 577.5 258.5 577.5 577.4V2406H578.4C259.5 2406 1 2147.5 1 1828.6V578.4z" fill="#74aa9c"/><path d="M1107.3 299.1c-198 0-373.9 127.3-435.2 315.3C544.8 640.6 434.9 720.2 370.5 833c-99.3 171.4-76.6 386.9 56.4 533.8-41.1 123.1-27 257.7 38.6 369.2 98.7 172 297.3 260.2 491.6 219.2 86.1 97 209.8 152.3 339.6 151.8 198 0 373.9-127.3 435.3-315.3 127.5-26.3 237.2-105.9 301-218.5 99.9-171.4 77.2-386.9-55.8-533.9v-.6c41.1-123.1 27-257.8-38.6-369.8-98.7-171.4-297.3-259.6-491-218.6-86.6-96.8-210.5-151.8-340.3-151.2zm0 117.5-.6.6c79.7 0 156.3 27.5 217.6 78.4-2.5 1.2-7.4 4.3-11 6.1L952.8 709.3c-18.4 10.4-29.4 30-29.4 51.4V1248l-155.1-89.4V755.8c-.1-187.1 151.6-338.9 339-339.2zm434.2 141.9c121.6-.2 234 64.5 294.7 169.8 39.2 68.6 53.9 148.8 40.4 226.5-2.5-1.8-7.3-4.3-10.4-6.1l-360.4-208.2c-18.4-10.4-41-10.4-59.4 0L1024 984.2V805.4L1372.7 604c51.3-29.7 109.5-45.4 168.8-45.5zM650 743.5v427.9c0 21.4 11 40.4 29.4 51.4l421.7 243-155.7 90L597.2 1355c-162-93.8-217.4-300.9-123.8-462.8C513.1 823.6 575.5 771 650 743.5zm807.9 106 348.8 200.8c162.5 93.7 217.6 300.6 123.8 462.8l.6.6c-39.8 68.6-102.4 121.2-176.5 148.2v-428c0-21.4-11-41-29.4-51.4l-422.3-243.7 155-89.3zM1201.7 997l177.8 102.8v205.1l-177.8 102.8-177.8-102.8v-205.1L1201.7 997zm279.5 161.6 155.1 89.4v402.2c0 187.3-152 339.2-339 339.2v-.6c-79.1 0-156.3-27.6-217-78.4 2.5-1.2 8-4.3 11-6.1l360.4-207.5c18.4-10.4 30-30 29.4-51.4l.1-486.8zM1380 1421.9v178.8l-348.8 200.8c-162.5 93.1-369.6 38-463.4-123.7h.6c-39.8-68-54-148.8-40.5-226.5 2.5 1.8 7.4 4.3 10.4 6.1l360.4 208.2c18.4 10.4 41 10.4 59.4 0l421.9-243.7z" fill="white"/></svg>


--------------------------------------------------------------------------------
/llava/eval/webpage/figures/llama.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/llama.jpg


--------------------------------------------------------------------------------
/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg:
--------------------------------------------------------------------------------
1 | <svg xmlns="http://www.w3.org/2000/svg" height="48" viewBox="0 96 960 960" width="48"><path d="m762.846 947.614-124.77-124.769-88 88-30.306-30.692q-16.616-16.231-16.616-40.077 0-23.846 16.616-40.461L708 611.385q16.23-16.231 40.076-16.231t40.462 16.231l30.307 30.691-88 88 124.154 124.77q8.615 8.615 8.615 20.23 0 11.616-8.615 20.231l-51.692 52.307q-8.615 9-20.231 9-11.615 0-20.23-9Zm97.153-624.076L412.768 771.153l27.847 28.077q16.231 16.616 16.231 40.462 0 23.846-16.231 40.077l-30.691 30.691-88-88-124.77 124.769q-8.615 9-20.23 9-11.616 0-20.231-9l-52.307-52.307q-9-8.615-9-20.23 0-11.616 9-20.231l124.769-124.769-88-88L171.847 611q16.231-16.23 40.077-16.23 23.846 0 40.461 16.23l28.462 28.232 447.615-447.231h131.537v131.537ZM323.846 483.769l33.769-34.154 34.154-34.153-34.154 34.153-33.769 34.154Zm-31.999 31.999-191.846-192.23V192.001h131.537l191.461 191.846-31.23 31.615-179.077-178.077h-67.307v67.307l178.461 179.077-31.999 31.999Zm87.691 222.77 435.077-433.846v-67.307h-67.307L312.231 670.846l67.307 67.692Zm0 0L346.385 704l-34.154-33.154L346.385 704l33.153 34.538Z"/></svg>


--------------------------------------------------------------------------------
/llava/eval/webpage/figures/vicuna.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/eval/webpage/figures/vicuna.jpeg


--------------------------------------------------------------------------------
/llava/eval/webpage/styles.css:
--------------------------------------------------------------------------------
  1 | body {
  2 |     font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
  3 |     background-color: #f8f9fa;
  4 | }
  5 | 
  6 | .navbar-dark .navbar-nav .nav-link {
  7 |     color: #f1cf68;
  8 |     font-size: 1.1rem;
  9 |     padding: 0.5rem 0.6rem;
 10 | }
 11 | 
 12 | .card-header {
 13 |     font-weight: bold;
 14 | }
 15 | 
 16 | .card {
 17 |     box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
 18 |     transition: 0.3s;
 19 | }
 20 | 
 21 | .card:hover {
 22 |     box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
 23 | }
 24 | 
 25 | button {
 26 |     transition: background-color 0.3s;
 27 | }
 28 | 
 29 | button:hover {
 30 |     background-color: #007bff;
 31 | }
 32 | 
 33 | @media (max-width: 767px) {
 34 |     .form-row .form-group {
 35 |         margin-bottom: 10px;
 36 |     }
 37 | }
 38 | 
 39 | /* Extra styles */
 40 | 
 41 | .expandable-card .card-text-container {
 42 |     max-height: 200px;
 43 |     overflow-y: hidden;
 44 |     position: relative;
 45 | }
 46 | 
 47 | .expandable-card.expanded .card-text-container {
 48 |     max-height: none;
 49 | }
 50 | 
 51 | .expand-btn {
 52 |     position: relative;
 53 |     display: none;
 54 |     background-color: rgba(255, 255, 255, 0.8);
 55 |     color: #510c75;
 56 |     border-color: transparent;
 57 | }
 58 | 
 59 | .expand-btn:hover {
 60 |     background-color: rgba(200, 200, 200, 0.8);
 61 |     text-decoration: none;
 62 |     border-color: transparent;
 63 |     color: #510c75;
 64 | }
 65 | 
 66 | .expand-btn:focus {
 67 |     outline: none;
 68 |     text-decoration: none;
 69 | }
 70 | 
 71 | .expandable-card:not(.expanded) .card-text-container:after {
 72 |     content: "";
 73 |     position: absolute;
 74 |     bottom: 0;
 75 |     left: 0;
 76 |     width: 100%;
 77 |     height: 90px;
 78 |     background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
 79 | }
 80 | 
 81 | .expandable-card:not(.expanded) .expand-btn {
 82 |     margin-top: -40px;
 83 | }
 84 | 
 85 | .card-body {
 86 |     padding-bottom: 5px;
 87 | }
 88 | 
 89 | .vertical-flex-layout {
 90 |     justify-content: center;
 91 |     align-items: center;
 92 |     height: 100%;
 93 |     display: flex;
 94 |     flex-direction: column;
 95 |     gap: 5px;
 96 | }
 97 | 
 98 | .figure-img {
 99 |     max-width: 100%;
100 |     height: auto;
101 | }
102 | 
103 | .adjustable-font-size {
104 |     font-size: calc(0.5rem + 2vw);
105 | }
106 | 


--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 |     from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3 |     from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4 |     from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5 | except:
6 |     pass
7 | 


--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Usage:
 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
 4 | """
 5 | import argparse
 6 | 
 7 | import torch
 8 | from tqdm import tqdm
 9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 | 
12 | 
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 |     print("Loading base model")
15 |     base = AutoModelForCausalLM.from_pretrained(
16 |         base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | 
18 |     print("Loading delta")
19 |     delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 |     delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 | 
22 |     print("Applying delta")
23 |     for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 |         if name not in base.state_dict():
25 |             assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 |             continue
27 |         if param.data.shape == base.state_dict()[name].shape:
28 |             param.data += base.state_dict()[name]
29 |         else:
30 |             assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 |                 f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 |             bparam = base.state_dict()[name]
33 |             param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 | 
35 |     print("Saving target model")
36 |     delta.save_pretrained(target_model_path)
37 |     delta_tokenizer.save_pretrained(target_model_path)
38 | 
39 | 
40 | if __name__ == "__main__":
41 |     parser = argparse.ArgumentParser()
42 |     parser.add_argument("--base-model-path", type=str, required=True)
43 |     parser.add_argument("--target-model-path", type=str, required=True)
44 |     parser.add_argument("--delta-path", type=str, required=True)
45 | 
46 |     args = parser.parse_args()
47 | 
48 |     apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 | 


--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Usage:
 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
 4 | """
 5 | import argparse
 6 | 
 7 | import torch
 8 | from transformers import AutoTokenizer, AutoModelForCausalLM
 9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 | 
12 | 
13 | def consolidate_ckpt(src_path, dst_path):
14 |     print("Loading model")
15 |     auto_upgrade(src_path)
16 |     src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |     src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 |     src_model.save_pretrained(dst_path)
19 |     src_tokenizer.save_pretrained(dst_path)
20 | 
21 | 
22 | if __name__ == "__main__":
23 |     parser = argparse.ArgumentParser()
24 |     parser.add_argument("--src", type=str, required=True)
25 |     parser.add_argument("--dst", type=str, required=True)
26 | 
27 |     args = parser.parse_args()
28 | 
29 |     consolidate_ckpt(args.src, args.dst)
30 | 


--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
 1 | #    Copyright 2023 Haotian Liu
 2 | #
 3 | #    Licensed under the Apache License, Version 2.0 (the "License");
 4 | #    you may not use this file except in compliance with the License.
 5 | #    You may obtain a copy of the License at
 6 | #
 7 | #        http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | #    Unless required by applicable law or agreed to in writing, software
10 | #    distributed under the License is distributed on an "AS IS" BASIS,
11 | #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | #    See the License for the specific language governing permissions and
13 | #    limitations under the License.
14 | 
15 | 
16 | from typing import Optional, Tuple
17 | 
18 | import torch
19 | 
20 | from transformers import AutoConfig, AutoModelForCausalLM, \
21 |                          MptConfig, MptForCausalLM, MptModel
22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23 | 
24 | 
25 | class LlavaMptConfig(MptConfig):
26 |     model_type = "llava_mpt"
27 | 
28 | 
29 | class LlavaMptModel(LlavaMetaModel, MptModel):
30 |     config_class = LlavaMptConfig
31 | 
32 |     def __init__(self, config: MptConfig):
33 |         config.hidden_size = config.d_model
34 |         super(LlavaMptModel, self).__init__(config)
35 |     
36 |     def embed_tokens(self, x):
37 |         return self.wte(x)
38 | 
39 | 
40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41 |     config_class = LlavaMptConfig
42 |     supports_gradient_checkpointing = True
43 | 
44 |     def __init__(self, config):
45 |         super(MptForCausalLM, self).__init__(config)
46 | 
47 |         self.transformer = LlavaMptModel(config)
48 |         self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 | 
50 |         # Initialize weights and apply final processing
51 |         self.post_init()
52 | 
53 |     def get_model(self):
54 |         return self.transformer
55 | 
56 |     def _set_gradient_checkpointing(self, module, value=False):
57 |         if isinstance(module, LlavaMptModel):
58 |             module.gradient_checkpointing = value
59 | 
60 |     def forward(
61 |         self,
62 |         input_ids: Optional[torch.LongTensor] = None,
63 |         past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64 |         attention_mask: Optional[torch.Tensor] = None,
65 |         inputs_embeds: Optional[torch.Tensor] = None,
66 |         labels: Optional[torch.Tensor] = None,
67 |         use_cache: Optional[bool] = None,
68 |         output_attentions: Optional[bool] = None,
69 |         output_hidden_states: Optional[bool] = None,
70 |         return_dict: Optional[bool] = None,
71 |         images=None):
72 | 
73 |         input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74 |         
75 |         return super().forward(
76 |             input_ids,
77 |             past_key_values=past_key_values,
78 |             attention_mask=attention_mask,
79 |             inputs_embeds=inputs_embeds,
80 |             labels=labels,
81 |             use_cache=use_cache,
82 |             output_attentions=output_attentions,
83 |             output_hidden_states=output_hidden_states,
84 |             return_dict=return_dict,
85 |         )
86 | 
87 |     def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88 |         images = kwargs.pop("images", None)
89 |         _inputs = super().prepare_inputs_for_generation(
90 |             input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91 |         )
92 |         _inputs['images'] = images
93 |         return _inputs
94 | 
95 | 
96 | AutoConfig.register("llava_mpt", LlavaMptConfig)
97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
98 | 


--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Usage:
 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
 4 | """
 5 | import argparse
 6 | 
 7 | import torch
 8 | from tqdm import tqdm
 9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 | 
12 | 
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 |     print("Loading base model")
15 |     base = AutoModelForCausalLM.from_pretrained(
16 |         base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | 
18 |     print("Loading target model")
19 |     auto_upgrade(target_model_path)
20 |     target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 | 
22 |     print("Calculating delta")
23 |     for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 |         if name not in base.state_dict():
25 |             assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 |             continue
27 |         if param.data.shape == base.state_dict()[name].shape:
28 |             param.data -= base.state_dict()[name]
29 |         else:
30 |             assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 |             bparam = base.state_dict()[name]
32 |             param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 | 
34 |     print("Saving delta")
35 |     if hub_repo_id:
36 |         kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 |     else:
38 |         kwargs = {}
39 |     target.save_pretrained(delta_path, **kwargs)
40 |     target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 |     target_tokenizer.save_pretrained(delta_path, **kwargs)
42 | 
43 | 
44 | if __name__ == "__main__":
45 |     parser = argparse.ArgumentParser()
46 |     parser.add_argument("--base-model-path", type=str, required=True)
47 |     parser.add_argument("--target-model-path", type=str, required=True)
48 |     parser.add_argument("--delta-path", type=str, required=True)
49 |     parser.add_argument("--hub-repo-id", type=str, default=None)
50 |     args = parser.parse_args()
51 | 
52 |     make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 | 


--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
 3 | 
 4 | 
 5 | def build_vision_tower(vision_tower_cfg, **kwargs):
 6 |     vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
 7 |     is_absolute_path_exists = os.path.exists(vision_tower)
 8 |     use_s2 = getattr(vision_tower_cfg, 's2', False)
 9 |     if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
10 |         if use_s2:
11 |             return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
12 |         else:
13 |             return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
14 | 
15 |     raise ValueError(f'Unknown vision tower: {vision_tower}')
16 | 


--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | import re
 4 | 
 5 | 
 6 | class IdentityMap(nn.Module):
 7 |     def __init__(self):
 8 |         super().__init__()
 9 | 
10 |     def forward(self, x, *args, **kwargs):
11 |         return x
12 | 
13 |     @property
14 |     def config(self):
15 |         return {"mm_projector_type": 'identity'}
16 | 
17 | 
18 | class SimpleResBlock(nn.Module):
19 |     def __init__(self, channels):
20 |         super().__init__()
21 |         self.pre_norm = nn.LayerNorm(channels)
22 | 
23 |         self.proj = nn.Sequential(
24 |             nn.Linear(channels, channels),
25 |             nn.GELU(),
26 |             nn.Linear(channels, channels)
27 |         )
28 |     def forward(self, x):
29 |         x = self.pre_norm(x)
30 |         return x + self.proj(x)
31 | 
32 | 
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 |     projector_type = getattr(config, 'mm_projector_type', 'linear')
35 | 
36 |     if projector_type == 'linear':
37 |         return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 | 
39 |     mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu
#39;, projector_type)
40 |     if mlp_gelu_match:
41 |         mlp_depth = int(mlp_gelu_match.group(1))
42 |         modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 |         for _ in range(1, mlp_depth):
44 |             modules.append(nn.GELU())
45 |             modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 |         return nn.Sequential(*modules)
47 | 
48 |     if projector_type == 'identity':
49 |         return IdentityMap()
50 | 
51 |     raise ValueError(f'Unknown projector type: {projector_type}')
52 | 


--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
 1 | from transformers import AutoConfig
 2 | 
 3 | 
 4 | def auto_upgrade(config):
 5 |     cfg = AutoConfig.from_pretrained(config)
 6 |     if 'llava' in config and 'llava' not in cfg.model_type:
 7 |         assert cfg.model_type == 'llama'
 8 |         print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
 9 |         print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 |         confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 |         if confirm.lower() in ["y", "yes"]:
12 |             print("Upgrading checkpoint...")
13 |             assert len(cfg.architectures) == 1
14 |             setattr(cfg.__class__, "model_type", "llava")
15 |             cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 |             cfg.save_pretrained(config)
17 |             print("Checkpoint upgraded.")
18 |         else:
19 |             print("Checkpoint upgrade aborted.")
20 |             exit(1)
21 | 


--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/serve/__init__.py


--------------------------------------------------------------------------------
/llava/serve/cli.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import torch
  3 | 
  4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
  5 | from llava.conversation import conv_templates, SeparatorStyle
  6 | from llava.model.builder import load_pretrained_model
  7 | from llava.utils import disable_torch_init
  8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
  9 | 
 10 | from PIL import Image
 11 | 
 12 | import requests
 13 | from PIL import Image
 14 | from io import BytesIO
 15 | from transformers import TextStreamer
 16 | 
 17 | 
 18 | def load_image(image_file):
 19 |     if image_file.startswith('http://') or image_file.startswith('https://'):
 20 |         response = requests.get(image_file)
 21 |         image = Image.open(BytesIO(response.content)).convert('RGB')
 22 |     else:
 23 |         image = Image.open(image_file).convert('RGB')
 24 |     return image
 25 | 
 26 | 
 27 | def main(args):
 28 |     # Model
 29 |     disable_torch_init()
 30 | 
 31 |     model_name = get_model_name_from_path(args.model_path)
 32 |     tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
 33 | 
 34 |     if "llama-2" in model_name.lower():
 35 |         conv_mode = "llava_llama_2"
 36 |     elif "mistral" in model_name.lower():
 37 |         conv_mode = "mistral_instruct"
 38 |     elif "v1.6-34b" in model_name.lower():
 39 |         conv_mode = "chatml_direct"
 40 |     elif "v1" in model_name.lower():
 41 |         conv_mode = "llava_v1"
 42 |     elif "mpt" in model_name.lower():
 43 |         conv_mode = "mpt"
 44 |     else:
 45 |         conv_mode = "llava_v0"
 46 | 
 47 |     if args.conv_mode is not None and conv_mode != args.conv_mode:
 48 |         print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
 49 |     else:
 50 |         args.conv_mode = conv_mode
 51 | 
 52 |     conv = conv_templates[args.conv_mode].copy()
 53 |     if "mpt" in model_name.lower():
 54 |         roles = ('user', 'assistant')
 55 |     else:
 56 |         roles = conv.roles
 57 | 
 58 |     image = load_image(args.image_file)
 59 |     image_size = image.size
 60 |     # Similar operation in model_worker.py
 61 |     image_tensor = process_images([image], image_processor, model.config)
 62 |     if type(image_tensor) is list:
 63 |         image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
 64 |     else:
 65 |         image_tensor = image_tensor.to(model.device, dtype=torch.float16)
 66 | 
 67 |     while True:
 68 |         try:
 69 |             inp = input(f"{roles[0]}: ")
 70 |         except EOFError:
 71 |             inp = ""
 72 |         if not inp:
 73 |             print("exit...")
 74 |             break
 75 | 
 76 |         print(f"{roles[1]}: ", end="")
 77 | 
 78 |         if image is not None:
 79 |             # first message
 80 |             if model.config.mm_use_im_start_end:
 81 |                 inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
 82 |             else:
 83 |                 inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
 84 |             image = None
 85 |         
 86 |         conv.append_message(conv.roles[0], inp)
 87 |         conv.append_message(conv.roles[1], None)
 88 |         prompt = conv.get_prompt()
 89 | 
 90 |         input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
 91 |         stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
 92 |         keywords = [stop_str]
 93 |         streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 94 | 
 95 |         with torch.inference_mode():
 96 |             output_ids = model.generate(
 97 |                 input_ids,
 98 |                 images=image_tensor,
 99 |                 image_sizes=[image_size],
100 |                 do_sample=True if args.temperature > 0 else False,
101 |                 temperature=args.temperature,
102 |                 max_new_tokens=args.max_new_tokens,
103 |                 streamer=streamer,
104 |                 use_cache=True)
105 | 
106 |         outputs = tokenizer.decode(output_ids[0]).strip()
107 |         conv.messages[-1][-1] = outputs
108 | 
109 |         if args.debug:
110 |             print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
111 | 
112 | 
113 | if __name__ == "__main__":
114 |     parser = argparse.ArgumentParser()
115 |     parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
116 |     parser.add_argument("--model-base", type=str, default=None)
117 |     parser.add_argument("--image-file", type=str, required=True)
118 |     parser.add_argument("--device", type=str, default="cuda")
119 |     parser.add_argument("--conv-mode", type=str, default=None)
120 |     parser.add_argument("--temperature", type=float, default=0.2)
121 |     parser.add_argument("--max-new-tokens", type=int, default=512)
122 |     parser.add_argument("--load-8bit", action="store_true")
123 |     parser.add_argument("--load-4bit", action="store_true")
124 |     parser.add_argument("--debug", action="store_true")
125 |     args = parser.parse_args()
126 |     main(args)
127 | 


--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/serve/examples/extreme_ironing.jpg


--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/DiffBIR/3f8b5304245ba0ffd6f1995e51c27c3c72d7618f/llava/serve/examples/waterview.jpg


--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Manually register workers.
 3 | 
 4 | Usage:
 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
 6 | """
 7 | 
 8 | import argparse
 9 | 
10 | import requests
11 | 
12 | if __name__ == "__main__":
13 |     parser = argparse.ArgumentParser()
14 |     parser.add_argument("--controller-address", type=str)
15 |     parser.add_argument("--worker-name", type=str)
16 |     parser.add_argument("--check-heart-beat", action="store_true")
17 |     args = parser.parse_args()
18 | 
19 |     url = args.controller_address + "/register_worker"
20 |     data = {
21 |         "worker_name": args.worker_name,
22 |         "check_heart_beat": args.check_heart_beat,
23 |         "worker_status": None,
24 |     }
25 |     r = requests.post(url, json=data)
26 |     assert r.status_code == 200
27 | 


--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
 1 | import argparse
 2 | import json
 3 | 
 4 | import requests
 5 | 
 6 | from llava.conversation import default_conversation
 7 | 
 8 | 
 9 | def main():
10 |     if args.worker_address:
11 |         worker_addr = args.worker_address
12 |     else:
13 |         controller_addr = args.controller_address
14 |         ret = requests.post(controller_addr + "/refresh_all_workers")
15 |         ret = requests.post(controller_addr + "/list_models")
16 |         models = ret.json()["models"]
17 |         models.sort()
18 |         print(f"Models: {models}")
19 | 
20 |         ret = requests.post(controller_addr + "/get_worker_address",
21 |             json={"model": args.model_name})
22 |         worker_addr = ret.json()["address"]
23 |         print(f"worker_addr: {worker_addr}")
24 | 
25 |     if worker_addr == "":
26 |         return
27 | 
28 |     conv = default_conversation.copy()
29 |     conv.append_message(conv.roles[0], args.message)
30 |     prompt = conv.get_prompt()
31 | 
32 |     headers = {"User-Agent": "LLaVA Client"}
33 |     pload = {
34 |         "model": args.model_name,
35 |         "prompt": prompt,
36 |         "max_new_tokens": args.max_new_tokens,
37 |         "temperature": 0.7,
38 |         "stop": conv.sep,
39 |     }
40 |     response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 |             json=pload, stream=True)
42 | 
43 |     print(prompt.replace(conv.sep, "\n"), end="")
44 |     for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 |         if chunk:
46 |             data = json.loads(chunk.decode("utf-8"))
47 |             output = data["text"].split(conv.sep)[-1]
48 |             print(output, end="\r")
49 |     print("")
50 | 
51 | 
52 | if __name__ == "__main__":
53 |     parser = argparse.ArgumentParser()
54 |     parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 |     parser.add_argument("--worker-address", type=str)
56 |     parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 |     parser.add_argument("--max-new-tokens", type=int, default=32)
58 |     parser.add_argument("--message", type=str, default=
59 |         "Tell me a story with more than 1000 words.")
60 |     args = parser.parse_args()
61 | 
62 |     main()
63 | 


--------------------------------------------------------------------------------
/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
  1 | from typing import Optional, Tuple
  2 | import warnings
  3 | 
  4 | import torch
  5 | 
  6 | import transformers
  7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
  8 | 
  9 | try:
 10 |     from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
 11 | except ImportError:
 12 |     from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
 13 | from flash_attn.bert_padding import unpad_input, pad_input
 14 | 
 15 | 
 16 | def forward(
 17 |     self,
 18 |     hidden_states: torch.Tensor,
 19 |     attention_mask: Optional[torch.Tensor] = None,
 20 |     position_ids: Optional[torch.Tensor] = None,
 21 |     past_key_value: Optional[Tuple[torch.Tensor]] = None,
 22 |     output_attentions: bool = False,
 23 |     use_cache: bool = False,
 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 25 |     if output_attentions:
 26 |         warnings.warn(
 27 |             "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
 28 |         )
 29 | 
 30 |     bsz, q_len, _ = hidden_states.size()
 31 | 
 32 |     query_states = (
 33 |         self.q_proj(hidden_states)
 34 |         .view(bsz, q_len, self.num_heads, self.head_dim)
 35 |         .transpose(1, 2)
 36 |     )
 37 |     key_states = (
 38 |         self.k_proj(hidden_states)
 39 |         .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
 40 |         .transpose(1, 2)
 41 |     )
 42 |     value_states = (
 43 |         self.v_proj(hidden_states)
 44 |         .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
 45 |         .transpose(1, 2)
 46 |     )  # shape: (b, num_heads, s, head_dim)
 47 | 
 48 |     kv_seq_len = key_states.shape[-2]
 49 |     if past_key_value is not None:
 50 |         kv_seq_len += past_key_value[0].shape[-2]
 51 | 
 52 |     cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 53 |     query_states, key_states = apply_rotary_pos_emb(
 54 |         query_states, key_states, cos, sin, position_ids
 55 |     )
 56 | 
 57 |     if past_key_value is not None:
 58 |         # reuse k, v
 59 |         key_states = torch.cat([past_key_value[0], key_states], dim=2)
 60 |         value_states = torch.cat([past_key_value[1], value_states], dim=2)
 61 | 
 62 |     past_key_value = (key_states, value_states) if use_cache else None
 63 | 
 64 |     # repeat k/v heads if n_kv_heads < n_heads
 65 |     key_states = repeat_kv(key_states, self.num_key_value_groups)
 66 |     value_states = repeat_kv(value_states, self.num_key_value_groups)
 67 | 
 68 |     # Transform the data into the format required by flash attention
 69 |     qkv = torch.stack([query_states, key_states, value_states], dim=2)
 70 |     qkv = qkv.transpose(1, 3)  # shape: [b, s, 3, num_heads, head_dim]
 71 |     key_padding_mask = attention_mask
 72 | 
 73 |     if key_padding_mask is None:
 74 |         qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
 75 |         cu_q_lens = torch.arange(
 76 |             0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
 77 |         )
 78 |         max_s = q_len
 79 |         output = flash_attn_unpadded_qkvpacked_func(
 80 |             qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
 81 |         )
 82 |         output = output.view(bsz, q_len, -1)
 83 |     else:
 84 |         qkv = qkv.reshape(bsz, q_len, -1)
 85 |         qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
 86 |         qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
 87 |         output_unpad = flash_attn_unpadded_qkvpacked_func(
 88 |             qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
 89 |         )
 90 |         output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
 91 |         output = pad_input(output_unpad, indices, bsz, q_len)
 92 | 
 93 |     return self.o_proj(output), None, past_key_value
 94 | 
 95 | 
 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
 97 | # requires the attention mask to be the same as the key_padding_mask
 98 | def _prepare_decoder_attention_mask(
 99 |     self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 |     # [bsz, seq_len]
102 |     return attention_mask
103 | 
104 | 
105 | def replace_llama_attn_with_flash_attn():
106 |     cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 |     if cuda_major < 8:
108 |         warnings.warn(
109 |             "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 |             "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 |         )
112 |     transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 |         _prepare_decoder_attention_mask
114 |     )
115 |     transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 | 


--------------------------------------------------------------------------------
/llava/train/llama_xformers_attn_monkey_patch.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
  3 | """
  4 | 
  5 | import logging
  6 | import math
  7 | from typing import Optional, Tuple
  8 | 
  9 | import torch
 10 | import transformers.models.llama.modeling_llama
 11 | from torch import nn
 12 | 
 13 | try:
 14 |     import xformers.ops
 15 | except ImportError:
 16 |     logging.error("xformers not found! Please install it before trying to use it.")
 17 | 
 18 | 
 19 | def replace_llama_attn_with_xformers_attn():
 20 |     transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
 21 | 
 22 | 
 23 | def xformers_forward(
 24 |     self,
 25 |     hidden_states: torch.Tensor,
 26 |     attention_mask: Optional[torch.Tensor] = None,
 27 |     position_ids: Optional[torch.LongTensor] = None,
 28 |     past_key_value: Optional[Tuple[torch.Tensor]] = None,
 29 |     output_attentions: bool = False,
 30 |     use_cache: bool = False,
 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 32 |     # pylint: disable=duplicate-code
 33 |     bsz, q_len, _ = hidden_states.size()
 34 | 
 35 |     query_states = (
 36 |         self.q_proj(hidden_states)
 37 |         .view(bsz, q_len, self.num_heads, self.head_dim)
 38 |         .transpose(1, 2)
 39 |     )
 40 |     key_states = (
 41 |         self.k_proj(hidden_states)
 42 |         .view(bsz, q_len, self.num_heads, self.head_dim)
 43 |         .transpose(1, 2)
 44 |     )
 45 |     value_states = (
 46 |         self.v_proj(hidden_states)
 47 |         .view(bsz, q_len, self.num_heads, self.head_dim)
 48 |         .transpose(1, 2)
 49 |     )
 50 | 
 51 |     kv_seq_len = key_states.shape[-2]
 52 |     if past_key_value is not None:
 53 |         kv_seq_len += past_key_value[0].shape[-2]
 54 |     cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 55 |     (
 56 |         query_states,
 57 |         key_states,
 58 |     ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
 59 |         query_states, key_states, cos, sin, position_ids
 60 |     )
 61 |     # [bsz, nh, t, hd]
 62 | 
 63 |     if past_key_value is not None:
 64 |         # reuse k, v, self_attention
 65 |         key_states = torch.cat([past_key_value[0], key_states], dim=2)
 66 |         value_states = torch.cat([past_key_value[1], value_states], dim=2)
 67 | 
 68 |     past_key_value = (key_states, value_states) if use_cache else None
 69 | 
 70 |     # We only apply xformers optimizations if we don't need to output the whole attention matrix
 71 |     if not output_attentions:
 72 |         query_states = query_states.transpose(1, 2)
 73 |         key_states = key_states.transpose(1, 2)
 74 |         value_states = value_states.transpose(1, 2)
 75 | 
 76 |         # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
 77 |         # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
 78 |         if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
 79 |             # input and output should be of form (bsz, q_len, num_heads, head_dim)
 80 |             attn_output = xformers.ops.memory_efficient_attention(
 81 |                 query_states, key_states, value_states, attn_bias=None
 82 |             )
 83 |         else:
 84 |             # input and output should be of form (bsz, q_len, num_heads, head_dim)
 85 |             attn_output = xformers.ops.memory_efficient_attention(
 86 |                 query_states,
 87 |                 key_states,
 88 |                 value_states,
 89 |                 attn_bias=xformers.ops.LowerTriangularMask(),
 90 |             )
 91 |         attn_weights = None
 92 |     else:
 93 |         attn_weights = torch.matmul(
 94 |             query_states, key_states.transpose(2, 3)
 95 |         ) / math.sqrt(self.head_dim)
 96 | 
 97 |         if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 98 |             raise ValueError(
 99 |                 f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100 |                 f" {attn_weights.size()}"
101 |             )
102 | 
103 |         if attention_mask is not None:
104 |             if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105 |                 raise ValueError(
106 |                     f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107 |                 )
108 |             attn_weights = attn_weights + attention_mask
109 |             attn_weights = torch.max(
110 |                 attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111 |             )
112 | 
113 |         # upcast attention to fp32
114 |         attn_weights = nn.functional.softmax(
115 |             attn_weights, dim=-1, dtype=torch.float32
116 |         ).to(query_states.dtype)
117 |         attn_output = torch.matmul(attn_weights, value_states)
118 | 
119 |         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120 |             raise ValueError(
121 |                 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122 |                 f" {attn_output.size()}"
123 |             )
124 | 
125 |         attn_output = attn_output.transpose(1, 2)
126 | 
127 |     attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128 |     attn_output = self.o_proj(attn_output)
129 |     return attn_output, attn_weights, past_key_value
130 | 


--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | from llava.train.train import train
2 | 
3 | if __name__ == "__main__":
4 |     train(attn_implementation="flash_attention_2")
5 | 


--------------------------------------------------------------------------------
/llava/train/train_xformers.py:
--------------------------------------------------------------------------------
 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
 2 | 
 3 | # Need to call this before importing transformers.
 4 | from llava.train.llama_xformers_attn_monkey_patch import (
 5 |     replace_llama_attn_with_xformers_attn,
 6 | )
 7 | 
 8 | replace_llama_attn_with_xformers_attn()
 9 | 
10 | from llava.train.train import train
11 | 
12 | if __name__ == "__main__":
13 |     train()
14 | 


--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
  1 | import datetime
  2 | import logging
  3 | import logging.handlers
  4 | import os
  5 | import sys
  6 | 
  7 | import requests
  8 | 
  9 | from llava.constants import LOGDIR
 10 | 
 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
 13 | 
 14 | handler = None
 15 | 
 16 | 
 17 | def build_logger(logger_name, logger_filename):
 18 |     global handler
 19 | 
 20 |     formatter = logging.Formatter(
 21 |         fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
 22 |         datefmt="%Y-%m-%d %H:%M:%S",
 23 |     )
 24 | 
 25 |     # Set the format of root handlers
 26 |     if not logging.getLogger().handlers:
 27 |         logging.basicConfig(level=logging.INFO)
 28 |     logging.getLogger().handlers[0].setFormatter(formatter)
 29 | 
 30 |     # Redirect stdout and stderr to loggers
 31 |     stdout_logger = logging.getLogger("stdout")
 32 |     stdout_logger.setLevel(logging.INFO)
 33 |     sl = StreamToLogger(stdout_logger, logging.INFO)
 34 |     sys.stdout = sl
 35 | 
 36 |     stderr_logger = logging.getLogger("stderr")
 37 |     stderr_logger.setLevel(logging.ERROR)
 38 |     sl = StreamToLogger(stderr_logger, logging.ERROR)
 39 |     sys.stderr = sl
 40 | 
 41 |     # Get logger
 42 |     logger = logging.getLogger(logger_name)
 43 |     logger.setLevel(logging.INFO)
 44 | 
 45 |     # Add a file handler for all loggers
 46 |     if handler is None:
 47 |         os.makedirs(LOGDIR, exist_ok=True)
 48 |         filename = os.path.join(LOGDIR, logger_filename)
 49 |         handler = logging.handlers.TimedRotatingFileHandler(
 50 |             filename, when='D', utc=True, encoding='UTF-8')
 51 |         handler.setFormatter(formatter)
 52 | 
 53 |         for name, item in logging.root.manager.loggerDict.items():
 54 |             if isinstance(item, logging.Logger):
 55 |                 item.addHandler(handler)
 56 | 
 57 |     return logger
 58 | 
 59 | 
 60 | class StreamToLogger(object):
 61 |     """
 62 |     Fake file-like stream object that redirects writes to a logger instance.
 63 |     """
 64 |     def __init__(self, logger, log_level=logging.INFO):
 65 |         self.terminal = sys.stdout
 66 |         self.logger = logger
 67 |         self.log_level = log_level
 68 |         self.linebuf = ''
 69 | 
 70 |     def __getattr__(self, attr):
 71 |         return getattr(self.terminal, attr)
 72 | 
 73 |     def write(self, buf):
 74 |         temp_linebuf = self.linebuf + buf
 75 |         self.linebuf = ''
 76 |         for line in temp_linebuf.splitlines(True):
 77 |             # From the io.TextIOWrapper docs:
 78 |             #   On output, if newline is None, any '\n' characters written
 79 |             #   are translated to the system default line separator.
 80 |             # By default sys.stdout.write() expects '\n' newlines and then
 81 |             # translates them so this is still cross platform.
 82 |             if line[-1] == '\n':
 83 |                 self.logger.log(self.log_level, line.rstrip())
 84 |             else:
 85 |                 self.linebuf += line
 86 | 
 87 |     def flush(self):
 88 |         if self.linebuf != '':
 89 |             self.logger.log(self.log_level, self.linebuf.rstrip())
 90 |         self.linebuf = ''
 91 | 
 92 | 
 93 | def disable_torch_init():
 94 |     """
 95 |     Disable the redundant torch default initialization to accelerate model creation.
 96 |     """
 97 |     import torch
 98 |     setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
 99 |     setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 | 
101 | 
102 | def violates_moderation(text):
103 |     """
104 |     Check whether the text violates OpenAI moderation API.
105 |     """
106 |     url = "https://api.openai.com/v1/moderations"
107 |     headers = {"Content-Type": "application/json",
108 |                "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 |     text = text.replace("\n", "")
110 |     data = "{" + '"input": ' + f'"{text}"' + "}"
111 |     data = data.encode("utf-8")
112 |     try:
113 |         ret = requests.post(url, headers=headers, data=data, timeout=5)
114 |         flagged = ret.json()["results"][0]["flagged"]
115 |     except requests.exceptions.RequestException as e:
116 |         flagged = False
117 |     except KeyError as e:
118 |         flagged = False
119 | 
120 |     return flagged
121 | 
122 | 
123 | def pretty_print_semaphore(semaphore):
124 |     if semaphore is None:
125 |         return "None"
126 |     return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 | 


--------------------------------------------------------------------------------
/ram/__init__.py:
--------------------------------------------------------------------------------
1 | from .inference import inference_tag2text, inference_ram, inference_ram_openset
2 | from .transform import get_transform
3 | 


--------------------------------------------------------------------------------
/ram/configs/finetune.yaml:
--------------------------------------------------------------------------------
 1 | train_file: [
 2 |             'datasets/train/coco_train_rmcocodev_ram.json',
 3 |              ]
 4 | image_path_root: ""
 5 | 
 6 | # size of vit model; base or large
 7 | vit: 'swin_l'
 8 | vit_grad_ckpt: False
 9 | vit_ckpt_layer: 0
10 | 
11 | image_size: 384
12 | batch_size: 26
13 | 
14 | # optimizer
15 | weight_decay: 0.05
16 | init_lr: 5e-06
17 | min_lr: 0
18 | max_epoch: 2
19 | warmup_steps: 3000
20 | 
21 | class_num: 4585
22 | 
23 | 


--------------------------------------------------------------------------------
/ram/configs/finetune_tag2text.yaml:
--------------------------------------------------------------------------------
 1 | train_file: [
 2 |             'datasets/train/coco_train_rmcocodev_ram.json',
 3 |              ]
 4 | image_path_root: ""
 5 | 
 6 | # size of vit model; base or large
 7 | vit: 'swin_b'
 8 | vit_grad_ckpt: False
 9 | vit_ckpt_layer: 0
10 | 
11 | image_size: 384
12 | batch_size: 36
13 | 
14 | # optimizer
15 | weight_decay: 0.05
16 | init_lr: 5e-06
17 | min_lr: 0
18 | max_epoch: 2
19 | warmup_steps: 3000
20 | 
21 | class_num: 4585
22 | 
23 | 


--------------------------------------------------------------------------------
/ram/configs/med_config.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "architectures": [
 3 |       "BertModel"
 4 |     ],
 5 |     "attention_probs_dropout_prob": 0.1,
 6 |     "hidden_act": "gelu",
 7 |     "hidden_dropout_prob": 0.1,
 8 |     "hidden_size": 768,
 9 |     "initializer_range": 0.02,
10 |     "intermediate_size": 3072,
11 |     "layer_norm_eps": 1e-12,
12 |     "max_position_embeddings": 512,
13 |     "model_type": "bert",
14 |     "num_attention_heads": 12,
15 |     "num_hidden_layers": 12,
16 |     "pad_token_id": 0,
17 |     "type_vocab_size": 2,
18 |     "vocab_size": 30524,
19 |     "encoder_width": 768,
20 |     "add_cross_attention": true   
21 |   }


--------------------------------------------------------------------------------
/ram/configs/pretrain.yaml:
--------------------------------------------------------------------------------
 1 | train_file: [
 2 |             'datasets/train/coco_train_rmcocodev_ram.json',
 3 |             'datasets/train/vg_ram.json',
 4 |             'datasets/train/sbu_ram.json',
 5 |             'datasets/train/cc3m_train_ram.json',
 6 |             'datasets/train/cc3m_val_ram.json',
 7 |             'datasets/train/cc12m_ram.json',
 8 |              ]
 9 | image_path_root: ""
10 | 
11 | # size of vit model; base or large
12 | vit: 'swin_l'
13 | vit_grad_ckpt: False
14 | vit_ckpt_layer: 0
15 | 
16 | image_size: 224
17 | batch_size: 52
18 | 
19 | # optimizer
20 | weight_decay: 0.05
21 | init_lr: 1e-4
22 | min_lr: 5e-7
23 | warmup_lr: 5e-7
24 | lr_decay_rate: 0.9
25 | max_epoch: 5
26 | warmup_steps: 3000
27 | 
28 | class_num: 4585
29 | 
30 | 


--------------------------------------------------------------------------------
/ram/configs/pretrain_tag2text.yaml:
--------------------------------------------------------------------------------
 1 | train_file: [
 2 |             'datasets/train/coco_train_rmcocodev_ram.json',
 3 |             'datasets/train/vg_ram.json',
 4 |             'datasets/train/sbu_ram.json',
 5 |             'datasets/train/cc3m_train_ram.json',
 6 |             'datasets/train/cc3m_val_ram.json',
 7 |             'datasets/train/cc12m_ram.json',
 8 |              ]
 9 | image_path_root: ""
10 | 
11 | # size of vit model; base or large
12 | vit: 'swin_b'
13 | vit_grad_ckpt: False
14 | vit_ckpt_layer: 0
15 | 
16 | image_size: 224
17 | batch_size: 80
18 | 
19 | # optimizer
20 | weight_decay: 0.05
21 | init_lr: 1e-4
22 | min_lr: 5e-7
23 | warmup_lr: 5e-7
24 | lr_decay_rate: 0.9
25 | max_epoch: 5
26 | warmup_steps: 3000
27 | 
28 | class_num: 4585
29 | 
30 | 


--------------------------------------------------------------------------------
/ram/configs/q2l_config.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "architectures": [
 3 |       "BertModel"
 4 |     ],
 5 |     "attention_probs_dropout_prob": 0.1,
 6 |     "hidden_act": "gelu",
 7 |     "hidden_dropout_prob": 0.1,
 8 |     "hidden_size": 768,
 9 |     "initializer_range": 0.02,
10 |     "intermediate_size": 3072,
11 |     "layer_norm_eps": 1e-12,
12 |     "max_position_embeddings": 512,
13 |     "model_type": "bert",
14 |     "num_attention_heads": 4,
15 |     "num_hidden_layers": 2,
16 |     "pad_token_id": 0,
17 |     "type_vocab_size": 2,
18 |     "vocab_size": 30522,
19 |     "encoder_width": 768,
20 |     "add_cross_attention": true,
21 |     "add_tag_cross_attention": false
22 |   }


--------------------------------------------------------------------------------
/ram/configs/swin/config_swinB_224.json:
--------------------------------------------------------------------------------
1 | {
2 |     "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3 |     "vision_width": 1024,
4 |     "image_res": 224,
5 |     "window_size": 7,
6 |     "embed_dim": 128,
7 |     "depths": [ 2, 2, 18, 2 ],
8 |     "num_heads": [ 4, 8, 16, 32 ]
9 |   }


--------------------------------------------------------------------------------
/ram/configs/swin/config_swinB_384.json:
--------------------------------------------------------------------------------
1 | {
2 |     "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3 |     "vision_width": 1024,
4 |     "image_res": 384,
5 |     "window_size": 12,
6 |     "embed_dim": 128,
7 |     "depths": [ 2, 2, 18, 2 ],
8 |     "num_heads": [ 4, 8, 16, 32 ]
9 |   }


--------------------------------------------------------------------------------
/ram/configs/swin/config_swinL_224.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "ckpt": "pretrain_model/swin_large_patch4_window7_224_22k.pth",
 3 |     "vision_width": 1536,
 4 |     "image_res": 224,
 5 |     "window_size": 7,
 6 |     "embed_dim": 192,
 7 |     "depths": [ 2, 2, 18, 2 ],
 8 |     "num_heads": [ 6, 12, 24, 48 ]
 9 |   }
10 | 


--------------------------------------------------------------------------------
/ram/configs/swin/config_swinL_384.json:
--------------------------------------------------------------------------------
1 | {
2 |     "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth",
3 |     "vision_width": 1536,
4 |     "image_res": 384,
5 |     "window_size": 12,
6 |     "embed_dim": 192,
7 |     "depths": [ 2, 2, 18, 2 ],
8 |     "num_heads": [ 6, 12, 24, 48 ]
9 |   }


--------------------------------------------------------------------------------
/ram/data/__init__.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | from torch.utils.data import DataLoader
 3 | from torchvision import transforms
 4 | from torchvision.transforms.functional import InterpolationMode
 5 | 
 6 | from .dataset import pretrain_dataset, finetune_dataset
 7 | from .randaugment import RandomAugment
 8 | 
 9 | def create_dataset(dataset, config, min_scale=0.5):
10 |     
11 |     normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
12 | 
13 |     transform_train = transforms.Compose([                        
14 |             transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
15 |             transforms.RandomHorizontalFlip(),
16 |             RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
17 |                                               'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
18 |             transforms.ToTensor(),
19 |             normalize,
20 |         ])        
21 | 
22 |     transform_inputsize_224 = transforms.Compose([                        
23 |             transforms.RandomResizedCrop(224,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
24 |             transforms.RandomHorizontalFlip(),
25 |             RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
26 |                                               'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
27 |             transforms.ToTensor(),
28 |             normalize,
29 |         ])     
30 | 
31 |     if dataset=='pretrain':
32 |         dataset = pretrain_dataset(config['train_file'], transform_train, class_num=config['class_num'], root=config['image_path_root'])              
33 |         return dataset  
34 | 
35 |     elif dataset=='finetune':   
36 |         dataset = finetune_dataset(config['train_file'], transform_train, transform_inputsize_224, class_num=config['class_num'], root=config['image_path_root'])              
37 |         return dataset  
38 | 
39 | 
40 |     
41 |     
42 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
43 |     samplers = []
44 |     for dataset,shuffle in zip(datasets,shuffles):
45 |         sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
46 |         samplers.append(sampler)
47 |     return samplers     
48 | 
49 | 
50 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
51 |     loaders = []
52 |     for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
53 |         if is_train:
54 |             shuffle = (sampler is None)
55 |             drop_last = True
56 |         else:
57 |             shuffle = False
58 |             drop_last = False
59 |         loader = DataLoader(
60 |             dataset,
61 |             batch_size=bs,
62 |             num_workers=n_worker,
63 |             pin_memory=True,
64 |             sampler=sampler,
65 |             shuffle=shuffle,
66 |             collate_fn=collate_fn,
67 |             drop_last=drop_last,
68 |         )              
69 |         loaders.append(loader)
70 |     return loaders    
71 | 
72 | 


--------------------------------------------------------------------------------
/ram/data/dataset.py:
--------------------------------------------------------------------------------
  1 | import json
  2 | import os
  3 | import random
  4 | 
  5 | from torch.utils.data import Dataset
  6 | 
  7 | from PIL import Image
  8 | from PIL import ImageFile
  9 | ImageFile.LOAD_TRUNCATED_IMAGES = True
 10 | Image.MAX_IMAGE_PIXELS = None
 11 | 
 12 | from .utils import pre_caption
 13 | import os,glob
 14 | 
 15 | import torch
 16 | import numpy as np
 17 | 
 18 | class pretrain_dataset(Dataset):
 19 |     def __init__(self, ann_file, transform, class_num = 4585, root = ''): 
 20 | 
 21 |         self.ann = []
 22 |         for f in ann_file:
 23 |             print('loading '+f)
 24 |             ann = json.load(open(f,'r'))
 25 |             self.ann += ann
 26 |             
 27 |         self.transform = transform
 28 |         self.class_num = class_num
 29 |         self.root = root
 30 | 
 31 |     
 32 |     def __len__(self):
 33 |         return len(self.ann)
 34 |     
 35 |     def __getitem__(self, index):    
 36 |         
 37 |         ann = self.ann[index]   
 38 | 
 39 |         image_path_use = os.path.join(self.root, ann['image_path'])
 40 |         image = Image.open(image_path_use).convert('RGB')   
 41 |         image = self.transform(image)
 42 | 
 43 |         # required for tag2text support
 44 |         if ann.get('union_label_id') is not None:
 45 |             num = ann['union_label_id'] 
 46 |             image_tag = np.zeros([self.class_num]) 
 47 |             image_tag[num] = 1 
 48 |             image_tag = torch.tensor(image_tag, dtype = torch.long)
 49 |         else:
 50 |             image_tag = None
 51 | 
 52 |         caption_index = np.random.randint(0, len(ann['caption']))
 53 | 
 54 |         caption = pre_caption(ann['caption'][caption_index],30)
 55 | 
 56 |         num = ann['parse_label_id'][caption_index]
 57 |         parse_tag = np.zeros([self.class_num])
 58 |         parse_tag[num] = 1
 59 |         parse_tag = torch.tensor(parse_tag, dtype = torch.long)
 60 | 
 61 |         return image, caption, image_tag, parse_tag
 62 |     
 63 | 
 64 | class finetune_dataset(Dataset):
 65 |     def __init__(self, ann_file, transform, transform_224, class_num = 4585, root = ''): 
 66 | 
 67 |         self.ann = []
 68 |         for f in ann_file:
 69 |             print('loading '+f)
 70 |             ann = json.load(open(f,'r'))
 71 |             self.ann += ann
 72 |             
 73 |         self.transform = transform
 74 |         self.transform_224 = transform_224
 75 |         self.class_num = class_num
 76 |         self.root = root
 77 | 
 78 |     
 79 |     def __len__(self):
 80 |         return len(self.ann)
 81 |     
 82 |     def __getitem__(self, index):    
 83 |         
 84 |         ann = self.ann[index]   
 85 | 
 86 |         image_path_use = os.path.join(self.root, ann['image_path'])
 87 |         image = Image.open(image_path_use).convert('RGB')   
 88 |         image = self.transform(image)
 89 | 
 90 |         image_224 = Image.open(image_path_use).convert('RGB')  
 91 |         image_224 = self.transform_224(image_224)
 92 | 
 93 |         # required for tag2text support
 94 |         if ann.get('union_label_id') is not None:
 95 |             num = ann['union_label_id'] 
 96 |             image_tag = np.zeros([self.class_num]) 
 97 |             image_tag[num] = 1 
 98 |             image_tag = torch.tensor(image_tag, dtype = torch.long)
 99 |         else:
100 |             image_tag = None
101 | 
102 |         caption_index = np.random.randint(0, len(ann['caption']))
103 | 
104 |         caption = pre_caption(ann['caption'][caption_index],30)
105 | 
106 |         num = ann['parse_label_id'][caption_index]
107 |         parse_tag = np.zeros([self.class_num])
108 |         parse_tag[num] = 1
109 |         parse_tag = torch.tensor(parse_tag, dtype = torch.long)
110 | 
111 |         return image, image_224, caption, image_tag, parse_tag
112 |     
113 | 


--------------------------------------------------------------------------------
/ram/data/utils.py:
--------------------------------------------------------------------------------
  1 | import re
  2 | import json
  3 | import os
  4 | 
  5 | import torch
  6 | import torch.distributed as dist
  7 | 
  8 | import utils
  9 | 
 10 | def pre_caption(caption,max_words=50):
 11 |     caption = re.sub(
 12 |         r"([.!\"()*#:;~])",       
 13 |         ' ',
 14 |         caption.lower(),
 15 |     )
 16 |     caption = re.sub(
 17 |         r"\s{2,}",
 18 |         ' ',
 19 |         caption,
 20 |     )
 21 |     caption = caption.rstrip('\n') 
 22 |     caption = caption.strip(' ')
 23 | 
 24 |     #truncate caption
 25 |     caption_words = caption.split(' ')
 26 |     if len(caption_words)>max_words:
 27 |         caption = ' '.join(caption_words[:max_words])
 28 |             
 29 |     return caption
 30 | 
 31 | def pre_question(question,max_ques_words=50):
 32 |     question = re.sub(
 33 |         r"([.!\"()*#:;~])",
 34 |         '',
 35 |         question.lower(),
 36 |     ) 
 37 |     question = question.rstrip(' ')
 38 |     
 39 |     #truncate question
 40 |     question_words = question.split(' ')
 41 |     if len(question_words)>max_ques_words:
 42 |         question = ' '.join(question_words[:max_ques_words])
 43 |             
 44 |     return question
 45 | 
 46 | 
 47 | def save_result(result, result_dir, filename, remove_duplicate=''):
 48 |     result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
 49 |     final_result_file = os.path.join(result_dir, '%s.json'%filename)
 50 |     
 51 |     json.dump(result,open(result_file,'w'))
 52 | 
 53 |     dist.barrier()
 54 | 
 55 |     if utils.is_main_process():   
 56 |         # combine results from all processes
 57 |         result = []
 58 | 
 59 |         for rank in range(utils.get_world_size()):
 60 |             result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
 61 |             res = json.load(open(result_file,'r'))
 62 |             result += res
 63 | 
 64 |         if remove_duplicate:
 65 |             result_new = []
 66 |             id_list = []    
 67 |             for res in result:
 68 |                 if res[remove_duplicate] not in id_list:
 69 |                     id_list.append(res[remove_duplicate])
 70 |                     result_new.append(res)
 71 |             result = result_new             
 72 |                 
 73 |         json.dump(result,open(final_result_file,'w'))            
 74 |         print('result file saved to %s'%final_result_file)
 75 | 
 76 |     return final_result_file
 77 | 
 78 | 
 79 | 
 80 | from pycocotools.coco import COCO
 81 | from pycocoevalcap.eval import COCOEvalCap
 82 | from torchvision.datasets.utils import download_url
 83 | 
 84 | def coco_caption_eval(coco_gt_root, results_file, split):
 85 |     urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
 86 |             'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
 87 |     filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}    
 88 |     
 89 |     download_url(urls[split],coco_gt_root)
 90 |     annotation_file = os.path.join(coco_gt_root,filenames[split])
 91 |     
 92 |     # create coco object and coco_result object
 93 |     coco = COCO(annotation_file)
 94 |     coco_result = coco.loadRes(results_file)
 95 | 
 96 |     # create coco_eval object by taking coco and coco_result
 97 |     coco_eval = COCOEvalCap(coco, coco_result)
 98 | 
 99 |     # evaluate on a subset of images by setting
100 |     # coco_eval.params['image_id'] = coco_result.getImgIds()
101 |     # please remove this line when evaluating the full validation set
102 |     # coco_eval.params['image_id'] = coco_result.getImgIds()
103 | 
104 |     # evaluate results
105 |     # SPICE will take a few minutes the first time, but speeds up due to caching
106 |     coco_eval.evaluate()
107 | 
108 |     # print output evaluation scores
109 |     for metric, score in coco_eval.eval.items():
110 |         print(f'{metric}: {score:.3f}')
111 |     
112 |     return coco_eval


--------------------------------------------------------------------------------
/ram/inference.py:
--------------------------------------------------------------------------------
 1 | '''
 2 |  * The Inference of RAM and Tag2Text Models
 3 |  * Written by Xinyu Huang
 4 | '''
 5 | import torch
 6 | 
 7 | 
 8 | def inference_tag2text(image, model, input_tag="None"):
 9 | 
10 |     with torch.no_grad():
11 |         caption, tag_predict = model.generate(image,
12 |                                               tag_input=None,
13 |                                               max_length=50,
14 |                                               return_tag_predict=True)
15 | 
16 |     if input_tag == '' or input_tag == 'none' or input_tag == 'None':
17 |         return tag_predict[0], None, caption[0]
18 | 
19 |     # If user input specified tags:
20 |     else:
21 |         input_tag_list = []
22 |         input_tag_list.append(input_tag.replace(',', ' | '))
23 | 
24 |         with torch.no_grad():
25 |             caption, input_tag = model.generate(image,
26 |                                                 tag_input=input_tag_list,
27 |                                                 max_length=50,
28 |                                                 return_tag_predict=True)
29 | 
30 |         return tag_predict[0], input_tag[0], caption[0]
31 | 
32 | 
33 | def inference_ram(image, model):
34 | 
35 |     with torch.no_grad():
36 |         tags, tags_chinese = model.generate_tag(image)
37 | 
38 |     return tags[0],tags_chinese[0]
39 | 
40 | 
41 | def inference_ram_openset(image, model):
42 | 
43 |     with torch.no_grad():
44 |         tags = model.generate_tag_openset(image)
45 | 
46 |     return tags[0]
47 | 


--------------------------------------------------------------------------------
/ram/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ram_plus import ram_plus
2 | from .ram import ram
3 | from .tag2text import tag2text
4 | 


--------------------------------------------------------------------------------
/ram/transform.py:
--------------------------------------------------------------------------------
 1 | from torchvision.transforms import Normalize, Compose, Resize, ToTensor
 2 | 
 3 | 
 4 | def convert_to_rgb(image):
 5 |     return image.convert("RGB")
 6 | 
 7 | def get_transform(image_size=384):
 8 |     return Compose([
 9 |         convert_to_rgb,
10 |         Resize((image_size, image_size)),
11 |         ToTensor(),
12 |         Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13 |     ])
14 | 


--------------------------------------------------------------------------------
/ram/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import get_mAP, get_PR
2 | from .openset_utils import build_openset_label_embedding, build_openset_llm_label_embedding
3 | 


--------------------------------------------------------------------------------
/ram/utils/metrics.py:
--------------------------------------------------------------------------------
  1 | from typing import List, Tuple
  2 | 
  3 | import numpy as np
  4 | from numpy import ndarray
  5 | 
  6 | 
  7 | def get_mAP(
  8 |     preds: ndarray,
  9 |     gt_file: str,
 10 |     taglist: List[str]
 11 | ) -> Tuple[float, ndarray]:
 12 |     assert preds.shape[1] == len(taglist)
 13 | 
 14 |     # When mapping categories from test datasets to our system, there might be
 15 |     # multiple vs one situation due to different semantic definitions of tags.
 16 |     # So there can be duplicate tags in `taglist`. This special case is taken
 17 |     # into account.
 18 |     tag2idxs = {}
 19 |     for idx, tag in enumerate(taglist):
 20 |         if tag not in tag2idxs:
 21 |             tag2idxs[tag] = []
 22 |         tag2idxs[tag].append(idx)
 23 | 
 24 |     # build targets
 25 |     targets = np.zeros_like(preds)
 26 |     with open(gt_file, "r") as f:
 27 |         lines = [line.strip("\n").split(",") for line in f.readlines()]
 28 |     assert len(lines) == targets.shape[0]
 29 |     for i, line in enumerate(lines):
 30 |         for tag in line[1:]:
 31 |             targets[i, tag2idxs[tag]] = 1.0
 32 | 
 33 |     # compute average precision for each class
 34 |     APs = np.zeros(preds.shape[1])
 35 |     for k in range(preds.shape[1]):
 36 |         APs[k] = _average_precision(preds[:, k], targets[:, k])
 37 | 
 38 |     return APs.mean(), APs
 39 | 
 40 | 
 41 | def _average_precision(output: ndarray, target: ndarray) -> float:
 42 |     epsilon = 1e-8
 43 | 
 44 |     # sort examples
 45 |     indices = output.argsort()[::-1]
 46 |     # Computes prec@i
 47 |     total_count_ = np.cumsum(np.ones((len(output), 1)))
 48 | 
 49 |     target_ = target[indices]
 50 |     ind = target_ == 1
 51 |     pos_count_ = np.cumsum(ind)
 52 |     total = pos_count_[-1]
 53 |     pos_count_[np.logical_not(ind)] = 0
 54 |     pp = pos_count_ / total_count_
 55 |     precision_at_i_ = np.sum(pp)
 56 |     precision_at_i = precision_at_i_ / (total + epsilon)
 57 | 
 58 |     return precision_at_i
 59 | 
 60 | 
 61 | def get_PR(
 62 |     pred_file: str,
 63 |     gt_file: str,
 64 |     taglist: List[str]
 65 | ) -> Tuple[float, float, ndarray, ndarray]:
 66 |     # When mapping categories from test datasets to our system, there might be
 67 |     # multiple vs one situation due to different semantic definitions of tags.
 68 |     # So there can be duplicate tags in `taglist`. This special case is taken
 69 |     # into account.
 70 |     tag2idxs = {}
 71 |     for idx, tag in enumerate(taglist):
 72 |         if tag not in tag2idxs:
 73 |             tag2idxs[tag] = []
 74 |         tag2idxs[tag].append(idx)
 75 | 
 76 |     # build preds
 77 |     with open(pred_file, "r", encoding="utf-8") as f:
 78 |         lines = [line.strip().split(",") for line in f.readlines()]
 79 |     preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
 80 |     for i, line in enumerate(lines):
 81 |         for tag in line[1:]:
 82 |             preds[i, tag2idxs[tag]] = True
 83 | 
 84 |     # build targets
 85 |     with open(gt_file, "r", encoding="utf-8") as f:
 86 |         lines = [line.strip().split(",") for line in f.readlines()]
 87 |     targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
 88 |     for i, line in enumerate(lines):
 89 |         for tag in line[1:]:
 90 |             targets[i, tag2idxs[tag]] = True
 91 | 
 92 |     assert preds.shape == targets.shape
 93 | 
 94 |     # calculate P and R
 95 |     TPs = ( preds &  targets).sum(axis=0)  # noqa: E201, E222
 96 |     FPs = ( preds & ~targets).sum(axis=0)  # noqa: E201, E222
 97 |     FNs = (~preds &  targets).sum(axis=0)  # noqa: E201, E222
 98 |     eps = 1.e-9
 99 |     Ps = TPs / (TPs + FPs + eps)
100 |     Rs = TPs / (TPs + FNs + eps)
101 | 
102 |     return Ps.mean(), Rs.mean(), Ps, Rs
103 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | --extra-index-url https://download.pytorch.org/whl/cu118
 2 | torch==2.2.2+cu118
 3 | torchvision==0.17.2+cu118
 4 | torchaudio==2.2.2+cu118
 5 | xformers==0.0.25.post1+cu118
 6 | omegaconf==2.3.0
 7 | accelerate==0.28.0
 8 | einops==0.7.0
 9 | opencv_python==4.9.0.80
10 | scipy==1.12.0
11 | ftfy==6.2.0
12 | regex==2023.12.25
13 | python-dateutil==2.9.0.post0
14 | timm==0.9.16
15 | pytorch-lightning==2.2.1 # only for loading pretrained sd weight
16 | tensorboard==2.16.2 # for tensorboard event visualization
17 | protobuf==4.25.3 # for tensorboard
18 | lpips==0.1.4
19 | facexlib==0.3.0
20 | gradio==4.43.0
21 | polars==1.12.0
22 | torchsde==0.2.6
23 | bitsandbytes==0.44.1
24 | 
25 | # requirements for llava
26 | transformers==4.37.2
27 | tokenizers==0.15.1
28 | sentencepiece==0.1.99
29 | 
30 | # requirements for ram
31 | fairscale==0.4.4
32 | 


--------------------------------------------------------------------------------