├── .gitignore
├── LICENSE
├── README.md
├── assets
├── JiuTian.pdf
├── LION-6Examples.jpg
├── LION-CapVQA.jpg
├── LION-Examples.jpg
├── LION-Image-level.jpg
├── LION-Introduction.jpg
├── LION-MMBench.jpg
├── LION-Method.jpg
├── LION-POPE.jpg
├── LION-REC.jpg
├── LION-Region-level.jpg
├── LION-Score.jpg
├── LION_logo.png
└── model.jpg
├── common
└── registry.py
├── configs
└── models
│ ├── lion_flant5xl.yaml
│ └── lion_flant5xxl.yaml
├── images
├── COCO_train2014_000000024935.jpg
└── COCO_train2014_000000533220.jpg
├── models
├── Qformer.py
├── __init__.py
├── base_model.py
├── eva_vit.py
├── lion_adapters.py
├── lion_t5.py
└── modeling_t5.py
├── playground.ipynb
├── preprocessors
└── lion_preprocessors.py
├── ram
├── __init__.py
├── configs
│ ├── finetune.yaml
│ ├── finetune_tag2text.yaml
│ ├── med_config.json
│ ├── pretrain.yaml
│ ├── pretrain_tag2text.yaml
│ ├── q2l_config.json
│ └── swin
│ │ ├── config_swinB_224.json
│ │ ├── config_swinB_384.json
│ │ ├── config_swinL_224.json
│ │ └── config_swinL_384.json
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── ram_tag_list.txt
│ ├── ram_tag_list_chinese.txt
│ ├── ram_tag_list_threshold.txt
│ ├── randaugment.py
│ ├── tag2text_ori_tag_list.txt
│ ├── tag_list.txt
│ └── utils.py
├── inference.py
├── models
│ ├── __init__.py
│ ├── bert.py
│ ├── ram.py
│ ├── ram_plus.py
│ ├── swin_transformer.py
│ ├── tag2text.py
│ ├── utils.py
│ └── vit.py
├── transform.py
└── utils
│ ├── __init__.py
│ ├── metrics.py
│ └── openset_utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112 | .pdm.toml
113 | .pdm-python
114 | .pdm-build/
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Rui Shao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
LION: Empowering Multimodal Large Language Model with Dual-Level Visual Knowledge
5 |
6 |
13 |
14 | School of Computer Science and Technology, Harbin Institute of Technology, Shenzhen
15 | *Corresponding author
16 |
17 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2024
18 |
19 | [[Paper]](https://arxiv.org/abs/2311.11860) [[Project Page]](https://rshaojimmy.github.io/Projects/JiuTian-LION) [[Video(YouTube)]](https://www.youtube.com/watch?v=YzJ5MZFS5RA) [[Video(bilibili)]](https://www.bilibili.com/video/BV1kH4y1y7UR/)
20 |
21 | :fire: Details will be released. Stay tuned :beers: :+1:
22 |
23 | [](https://hits.seeyoufarm.com)
24 |
25 |
26 |
27 |
28 |

29 |
30 |
31 |
32 | ## If you find this work useful for your research, please kindly cite our paper and star our repo.
33 |
34 | ## Updates
35 | - [07/2024] Code and checkpoints are released.
36 | - [02/2024] LION has been accepted by CVPR 2024.
37 | - [11/2023] [Arxiv paper](https://arxiv.org/abs/2311.11860) released.
38 | - [11/2023] [Project page](https://rshaojimmy.github.io/Projects/JiuTian-LION) released.
39 |
40 | ## Introduction
41 |
42 | This is the github repository of *LION : Empowering Multimodal Large Language Model with Dual-Level Visual Knowledge*. In this work, we enhance MLLMs by integrating fine-grained spatial-aware visual knowledge and high-level semantic visual evidence, boosting capabilities and alleviating hallucinations.
43 |
44 | The framework of the proposed LION model:
45 |
46 |
47 |

48 |
49 |
50 | ## Installation
51 |
52 | ### Download
53 | ```bash
54 | git clone https://github.com/JiuTian-VL/JiuTian-LION.git
55 | cd JiuTian-LION
56 | ```
57 |
58 | ### Environment
59 |
60 | ```bash
61 | conda create -n LION python=3.12
62 | conda activate LION
63 | conda install pip
64 | pip install -r requirements.txt
65 | ```
66 |
67 | ## Checkpoints
68 |
69 | | Version | Checkpoint |
70 | | --- | --- |
71 | | LION-FlanT5-XL| [daybreaksly/LION-FlanT5-XL](https://huggingface.co/daybreaksly/LION-FlanT5-XL) |
72 | | LION-FlanT5-XXL| [daybreaksly/LION-FlanT5-XXL](https://huggingface.co/daybreaksly/LION-FlanT5-XXL) |
73 |
74 | ## Usage
75 |
76 | ### Prepare models
77 |
78 | 1. Download the pre-trained vit model [eva_vit_g](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth).
79 | 2. Download the pre-trained RAM model [ram_swin_large_14m](https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text/blob/main/ram_swin_large_14m.pth).
80 | 3. Download the pre-trained FlanT5 model [FlanT5-XL](https://huggingface.co/google/flan-t5-xl).
81 | 4. Download the pre-trained BERT model [bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)
82 | 5. Fill in the paths to these models into the corresponding locations in the config file `configs\models\lion_flant5xl.yaml`
83 |
84 | ### Inference
85 |
86 | We provide inference examples for **Image-Level** and **Region-Level** tasks in `playground.ipynb`.
87 |
88 | ## Evaluation results
89 |
90 | For image-level tasks, we focus on image captioning and Visual Question Answering (VQA). For region-level tasks, we evaluate LION on three REC datasets including RefCOCO, RefCOCO+ and RefCOCOg. The results, detailed in Table 1~2, highlight LION's superior performance compared to baseline models.
91 |
92 | 
93 |
94 | 
95 | 
96 |
97 | We further evaluate LION on a object hallucination benchmark([POPE](https://github.com/AoiDragon/POPE)) and the most popular MLLM benchmark ([MMBench](https://mmbench.opencompass.org.cn/home)). The results in Table 1~2 show that LION has strong performances across various skills and also demonstrates a strong resistance to hallucinations, particularly in popular and adversarial settings in POPE.
98 |
99 | 
100 | 
101 |
102 | ## Qualitative Comparison
103 |
104 | 
105 | 
106 | 
107 |
108 | ## More Examples
109 | 
110 |
111 | ## Citation
112 |
113 | If you find this work useful for your research, please kindly cite our paper:
114 | ```
115 | @inproceedings{chen2024lion,
116 | title={LION: Empowering Multimodal Large Language Model with Dual-Level Visual Knowledge},
117 | author={Chen, Gongwei and Shen, Leyang and Shao, Rui and Deng, Xiang and Nie, Liqiang},
118 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
119 | year={2024}
120 | }
121 | ```
122 |
--------------------------------------------------------------------------------
/assets/JiuTian.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/JiuTian.pdf
--------------------------------------------------------------------------------
/assets/LION-6Examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-6Examples.jpg
--------------------------------------------------------------------------------
/assets/LION-CapVQA.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-CapVQA.jpg
--------------------------------------------------------------------------------
/assets/LION-Examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Examples.jpg
--------------------------------------------------------------------------------
/assets/LION-Image-level.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Image-level.jpg
--------------------------------------------------------------------------------
/assets/LION-Introduction.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Introduction.jpg
--------------------------------------------------------------------------------
/assets/LION-MMBench.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-MMBench.jpg
--------------------------------------------------------------------------------
/assets/LION-Method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Method.jpg
--------------------------------------------------------------------------------
/assets/LION-POPE.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-POPE.jpg
--------------------------------------------------------------------------------
/assets/LION-REC.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-REC.jpg
--------------------------------------------------------------------------------
/assets/LION-Region-level.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Region-level.jpg
--------------------------------------------------------------------------------
/assets/LION-Score.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Score.jpg
--------------------------------------------------------------------------------
/assets/LION_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION_logo.png
--------------------------------------------------------------------------------
/assets/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/model.jpg
--------------------------------------------------------------------------------
/common/registry.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | class Registry:
4 | root = os.path.expanduser("~")
5 | mapping = {
6 | "builder_name_mapping": {},
7 | "processor_name_mapping": {},
8 | "model_name_mapping": {},
9 | "evaluator_name_mapping": {}
10 | }
11 |
12 | @classmethod
13 | def register_builder(cls, name):
14 | r"""Register a dataset builder to registry with key 'name'
15 |
16 | Args:
17 | name: Key with which the dataset builder will be registered.
18 | """
19 |
20 | def wrap(builder_func):
21 | if name in cls.mapping["builder_name_mapping"]:
22 | raise KeyError(
23 | "Name '{}' already registered for {}.".format(
24 | name, cls.mapping["builder_name_mapping"][name]
25 | )
26 | )
27 | cls.mapping["builder_name_mapping"][name] = builder_func
28 | return builder_func
29 |
30 | return wrap
31 |
32 | @classmethod
33 | def register_evaluator(cls, name):
34 | r"""Register a task evaluator to registry with key 'name'
35 |
36 | Args:
37 | name: Key with which the task evaluator will be registered.
38 | """
39 |
40 | def wrap(eval_func):
41 | if name in cls.mapping["evaluator_name_mapping"]:
42 | raise KeyError(
43 | "Name '{}' already registered for {}.".format(
44 | name, cls.mapping["evaluator_name_mapping"][name]
45 | )
46 | )
47 | cls.mapping["evaluator_name_mapping"][name] = eval_func
48 | return eval_func
49 |
50 | return wrap
51 |
52 | @classmethod
53 | def register_model(cls, name):
54 | r"""Register a model to registry with key 'name'
55 |
56 | Args:
57 | name: Key with which the model will be registered.
58 | """
59 |
60 | def wrap(model_cls):
61 | from models import BaseModel
62 |
63 | assert issubclass(
64 | model_cls, BaseModel
65 | ), "All models must inherit BaseModel class"
66 | if name in cls.mapping["model_name_mapping"]:
67 | raise KeyError(
68 | "Name '{}' already registered for {}.".format(
69 | name, cls.mapping["model_name_mapping"][name]
70 | )
71 | )
72 | cls.mapping["model_name_mapping"][name] = model_cls
73 | return model_cls
74 |
75 | return wrap
76 |
77 | @classmethod
78 | def register_processor(cls, name):
79 | r"""Register a processor to registry with key 'name'
80 |
81 | Args:
82 | name: Key with which the processor will be registered.
83 | """
84 |
85 | def wrap(processor_cls):
86 | if name in cls.mapping["processor_name_mapping"]:
87 | raise KeyError(
88 | "Name '{}' already registered for {}.".format(
89 | name, cls.mapping["processor_name_mapping"][name]
90 | )
91 | )
92 | cls.mapping["processor_name_mapping"][name] = processor_cls
93 | return processor_cls
94 |
95 | return wrap
96 |
97 | @classmethod
98 | def get_builder_func(cls, name):
99 | return cls.mapping["builder_name_mapping"].get(name, None)
100 |
101 | @classmethod
102 | def get_evaluator_func(cls, name):
103 | return cls.mapping["evaluator_name_mapping"].get(name, None)
104 |
105 | @classmethod
106 | def get_model_class(cls, name):
107 | return cls.mapping["model_name_mapping"].get(name, None)
108 |
109 | @classmethod
110 | def get_processor_class(cls, name):
111 | return cls.mapping["processor_name_mapping"].get(name, None)
112 |
113 | @classmethod
114 | def list_models(cls):
115 | return sorted(cls.mapping["model_name_mapping"].keys())
116 |
117 | @classmethod
118 | def list_processors(cls):
119 | return sorted(cls.mapping["processor_name_mapping"].keys())
120 |
121 | @classmethod
122 | def list_datasets(cls):
123 | return sorted(cls.mapping["builder_name_mapping"].keys())
124 |
125 | @classmethod
126 | def list_evaluators(cls):
127 | return sorted(cls.mapping["evaluator_name_mapping"].keys())
128 |
129 | registry = Registry()
--------------------------------------------------------------------------------
/configs/models/lion_flant5xl.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | bert_model: "/path/to/bert-base-uncased/"
3 | vit_model: "/path/to/eva_vit_g.pth"
4 | llm_model: "/path/to/flan-t5-xl/"
5 | ram_model: "/path/to/ram_swin_large_14m.pth"
6 | visual_input: "ALL"
7 | enable_semantic_tags: True
8 |
9 | load_pretrained: True
10 | pretrained: /path/to/LION-FlanT5-XL.pth
11 |
--------------------------------------------------------------------------------
/configs/models/lion_flant5xxl.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | bert_model: "/path/to/bert-base-uncased/"
3 | vit_model: "/path/to/eva_vit_g.pth"
4 | llm_model: "/path/to/flan-t5-xxl/"
5 | ram_model: "/path/to/ram_swin_large_14m.pth"
6 | visual_input: "ALL"
7 | enable_semantic_tags: True
8 |
9 | load_pretrained: True
10 | pretrained: /path/to/LION-FlanT5-XXL.pth
11 |
--------------------------------------------------------------------------------
/images/COCO_train2014_000000024935.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/images/COCO_train2014_000000024935.jpg
--------------------------------------------------------------------------------
/images/COCO_train2014_000000533220.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/images/COCO_train2014_000000533220.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from common.registry import registry
2 | from models.base_model import BaseModel
3 | from models.lion_t5 import LIONT5InstructAdapter
4 |
5 | __all__ = [
6 | "LIONT5InstructAdapter"
7 | ]
8 |
9 | def load_model(name, model_type, is_eval=False, device="cpu"):
10 | """
11 | Load supported models.
12 |
13 | Args:
14 | name (str): name of the model.
15 | model_type (str): type of the model.
16 | is_eval (bool): whether the model is in eval mode. Default: False.
17 | device (str): device to use. Default: "cpu".
18 |
19 | Returns:
20 | model (torch.nn.Module): model.
21 | """
22 |
23 | model = registry.get_model_class(name).from_pretrained(model_type=model_type)
24 |
25 | if is_eval:
26 | model.eval()
27 |
28 | if device == "cpu":
29 | model = model.float()
30 |
31 | return model.to(device)
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import logging
5 | from omegaconf import OmegaConf
6 |
7 | class BaseModel(nn.Module):
8 | """Base class for models."""
9 |
10 | def __init__(self):
11 | super().__init__()
12 |
13 | @property
14 | def device(self):
15 | return list(self.parameters())[0].device
16 |
17 | def load_checkpoint_from_config(self, cfg, **kwargs):
18 | """
19 | Load checkpoint as specified in the config file.
20 |
21 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
22 | When loading the pretrained model, each task-specific architecture may define their
23 | own load_from_pretrained() method.
24 | """
25 | load_pretrained = cfg.get("load_pretrained", True)
26 | if load_pretrained:
27 | # load pre-trained weights
28 | pretrain_path = cfg.get("pretrained", None)
29 | assert pretrain_path, "pretrained path is not specified in the config file"
30 | print("loading pretrain: ", pretrain_path)
31 | msg = self.load_checkpoint(filename=pretrain_path, **kwargs)
32 | print("unexpected_keys: ", msg.unexpected_keys)
33 |
34 | def load_checkpoint(self, filename):
35 | """
36 | Load from a finetuned checkpoint.
37 |
38 | This should expect no mismatch in the model keys and the checkpoint keys.
39 | """
40 |
41 | if os.path.isfile(filename):
42 | checkpoint = torch.load(filename, map_location="cpu")
43 | else:
44 | raise RuntimeError("checkpoint url or path is invalid")
45 |
46 | if "model" in checkpoint.keys():
47 | state_dict = checkpoint["model"]
48 | else:
49 | state_dict = checkpoint
50 |
51 | msg = self.load_state_dict(state_dict, strict=False)
52 |
53 | # logging.info("Missing keys {}".format(msg.missing_keys))
54 | logging.info(f"Unexpected keys: {msg.unexpected_keys}")
55 | logging.info(f"load checkpoint from: {filename}")
56 |
57 | return msg
58 |
59 | @classmethod
60 | def default_config_path(cls, model_type):
61 | assert (
62 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
63 | ), "Unknown model type {}".format(model_type)
64 | return cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]
65 |
66 | def counting_training_parameters(self):
67 | total = 0.
68 | trainable_names = []
69 | all = 0.
70 | for name, param in self.named_parameters():
71 | if param.requires_grad:
72 | total += param.nelement()
73 | trainable_names.append(name)
74 | all += param.nelement()
75 | logging.info(trainable_names)
76 | logging.info(' + Number of trainable params: %.2fM' % (total / 1e6))
77 | logging.info('Number of all params: %.2fM' % (all / 1e6))
78 | return total
79 |
80 | @classmethod
81 | def from_config(cls, cfg):
82 | raise NotImplementedError()
83 |
84 | @classmethod
85 | def from_pretrained(cls, model_type):
86 | """
87 | Build a pretrained model from default configuration file, specified by model_type.
88 |
89 | Args:
90 | - model_type (str): model type, specifying architecture and checkpoints.
91 |
92 | Returns:
93 | - model (nn.Module): pretrained or finetuned model, depending on the configuration.
94 | """
95 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
96 | model = cls.from_config(model_cfg)
97 |
98 | return model
99 |
100 |
101 | def get_optimizer_params(self, weight_decay, lr_scale=1):
102 | p_wd, p_non_wd = [], []
103 | for n, p in self.named_parameters():
104 | if not p.requires_grad:
105 | continue # frozen weights
106 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
107 | p_non_wd.append(p)
108 | else:
109 | p_wd.append(p)
110 | optim_params = [
111 | {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale},
112 | {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale},
113 | ]
114 | return optim_params
115 |
--------------------------------------------------------------------------------
/models/eva_vit.py:
--------------------------------------------------------------------------------
1 | # Based on EVA, BEIT, timm and DeiT code bases
2 | # https://github.com/baaivision/EVA
3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4 | # https://github.com/microsoft/unilm/tree/master/beit
5 | # https://github.com/facebookresearch/deit/
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | import math
9 | from functools import partial
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.utils.checkpoint as checkpoint
15 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16 | from timm.models.registry import register_model
17 |
18 |
19 | def _cfg(url='', **kwargs):
20 | return {
21 | 'url': url,
22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
23 | 'crop_pct': .9, 'interpolation': 'bicubic',
24 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
25 | **kwargs
26 | }
27 |
28 |
29 | class DropPath(nn.Module):
30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
31 | """
32 | def __init__(self, drop_prob=None):
33 | super(DropPath, self).__init__()
34 | self.drop_prob = drop_prob
35 |
36 | def forward(self, x):
37 | return drop_path(x, self.drop_prob, self.training)
38 |
39 | def extra_repr(self) -> str:
40 | return 'p={}'.format(self.drop_prob)
41 |
42 |
43 | class Mlp(nn.Module):
44 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
45 | super().__init__()
46 | out_features = out_features or in_features
47 | hidden_features = hidden_features or in_features
48 | self.fc1 = nn.Linear(in_features, hidden_features)
49 | self.act = act_layer()
50 | self.fc2 = nn.Linear(hidden_features, out_features)
51 | self.drop = nn.Dropout(drop)
52 |
53 | def forward(self, x):
54 | x = self.fc1(x)
55 | x = self.act(x)
56 | # x = self.drop(x)
57 | # commit this for the orignal BERT implement
58 | x = self.fc2(x)
59 | x = self.drop(x)
60 | return x
61 |
62 |
63 | class Attention(nn.Module):
64 | def __init__(
65 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
66 | proj_drop=0., window_size=None, attn_head_dim=None):
67 | super().__init__()
68 | self.num_heads = num_heads
69 | head_dim = dim // num_heads
70 | if attn_head_dim is not None:
71 | head_dim = attn_head_dim
72 | all_head_dim = head_dim * self.num_heads
73 | self.scale = qk_scale or head_dim ** -0.5
74 |
75 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
76 | if qkv_bias:
77 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
78 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
79 | else:
80 | self.q_bias = None
81 | self.v_bias = None
82 |
83 | if window_size:
84 | self.window_size = window_size
85 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
86 | self.relative_position_bias_table = nn.Parameter(
87 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88 | # cls to token & token 2 cls & cls to cls
89 |
90 | # get pair-wise relative position index for each token inside the window
91 | coords_h = torch.arange(window_size[0])
92 | coords_w = torch.arange(window_size[1])
93 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
94 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
95 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
96 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
97 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
98 | relative_coords[:, :, 1] += window_size[1] - 1
99 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
100 | relative_position_index = \
101 | torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
102 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103 | relative_position_index[0, 0:] = self.num_relative_distance - 3
104 | relative_position_index[0:, 0] = self.num_relative_distance - 2
105 | relative_position_index[0, 0] = self.num_relative_distance - 1
106 |
107 | self.register_buffer("relative_position_index", relative_position_index)
108 | else:
109 | self.window_size = None
110 | self.relative_position_bias_table = None
111 | self.relative_position_index = None
112 |
113 | self.attn_drop = nn.Dropout(attn_drop)
114 | self.proj = nn.Linear(all_head_dim, dim)
115 | self.proj_drop = nn.Dropout(proj_drop)
116 |
117 | def forward(self, x, rel_pos_bias=None):
118 | B, N, C = x.shape
119 | qkv_bias = None
120 | if self.q_bias is not None:
121 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
122 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
124 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
125 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
126 |
127 | q = q * self.scale
128 | attn = (q @ k.transpose(-2, -1))
129 |
130 | if self.relative_position_bias_table is not None:
131 | relative_position_bias = \
132 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
133 | self.window_size[0] * self.window_size[1] + 1,
134 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
135 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
136 | attn = attn + relative_position_bias.unsqueeze(0)
137 |
138 | if rel_pos_bias is not None:
139 | attn = attn + rel_pos_bias
140 |
141 | attn = attn.softmax(dim=-1)
142 | attn = self.attn_drop(attn)
143 |
144 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
145 | x = self.proj(x)
146 | x = self.proj_drop(x)
147 | return x
148 |
149 |
150 | class Block(nn.Module):
151 |
152 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
153 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
154 | window_size=None, attn_head_dim=None):
155 | super().__init__()
156 | self.norm1 = norm_layer(dim)
157 | self.attn = Attention(
158 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
159 | attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
160 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
161 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
162 | self.norm2 = norm_layer(dim)
163 | mlp_hidden_dim = int(dim * mlp_ratio)
164 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
165 |
166 | if init_values is not None and init_values > 0:
167 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
168 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169 | else:
170 | self.gamma_1, self.gamma_2 = None, None
171 |
172 | def forward(self, x, rel_pos_bias=None):
173 | if self.gamma_1 is None:
174 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
175 | x = x + self.drop_path(self.mlp(self.norm2(x)))
176 | else:
177 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
178 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
179 | return x
180 |
181 |
182 | class PatchEmbed(nn.Module):
183 | """ Image to Patch Embedding
184 | """
185 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
186 | super().__init__()
187 | img_size = to_2tuple(img_size)
188 | patch_size = to_2tuple(patch_size)
189 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
190 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
191 | self.img_size = img_size
192 | self.patch_size = patch_size
193 | self.num_patches = num_patches
194 |
195 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
196 |
197 | def forward(self, x, **kwargs):
198 | B, C, H, W = x.shape
199 | # FIXME look at relaxing size constraints
200 | assert H == self.img_size[0] and W == self.img_size[1], \
201 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
202 | x = self.proj(x).flatten(2).transpose(1, 2)
203 | return x
204 |
205 |
206 | class RelativePositionBias(nn.Module):
207 |
208 | def __init__(self, window_size, num_heads):
209 | super().__init__()
210 | self.window_size = window_size
211 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
212 | self.relative_position_bias_table = nn.Parameter(
213 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
214 | # cls to token & token 2 cls & cls to cls
215 |
216 | # get pair-wise relative position index for each token inside the window
217 | coords_h = torch.arange(window_size[0])
218 | coords_w = torch.arange(window_size[1])
219 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
220 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
221 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
222 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
223 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
224 | relative_coords[:, :, 1] += window_size[1] - 1
225 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
226 | relative_position_index = \
227 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
228 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
229 | relative_position_index[0, 0:] = self.num_relative_distance - 3
230 | relative_position_index[0:, 0] = self.num_relative_distance - 2
231 | relative_position_index[0, 0] = self.num_relative_distance - 1
232 |
233 | self.register_buffer("relative_position_index", relative_position_index)
234 |
235 | # trunc_normal_(self.relative_position_bias_table, std=.02)
236 |
237 | def forward(self):
238 | relative_position_bias = \
239 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
240 | self.window_size[0] * self.window_size[1] + 1,
241 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
242 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
243 |
244 |
245 | class VisionTransformer(nn.Module):
246 | """ Vision Transformer with support for patch or hybrid CNN input stage
247 | """
248 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
250 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
251 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
252 | use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
253 | super().__init__()
254 | self.image_size = img_size
255 | self.num_classes = num_classes
256 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
257 |
258 | self.patch_embed = PatchEmbed(
259 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
260 | num_patches = self.patch_embed.num_patches
261 |
262 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
263 | if use_abs_pos_emb:
264 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
265 | else:
266 | self.pos_embed = None
267 | self.pos_drop = nn.Dropout(p=drop_rate)
268 |
269 | if use_shared_rel_pos_bias:
270 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
271 | else:
272 | self.rel_pos_bias = None
273 | self.use_checkpoint = use_checkpoint
274 |
275 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
276 | self.use_rel_pos_bias = use_rel_pos_bias
277 | self.blocks = nn.ModuleList([
278 | Block(
279 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
280 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
281 | init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
282 | for i in range(depth)])
283 | # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
284 | # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
285 | # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
286 |
287 | if self.pos_embed is not None:
288 | trunc_normal_(self.pos_embed, std=.02)
289 | trunc_normal_(self.cls_token, std=.02)
290 | # trunc_normal_(self.mask_token, std=.02)
291 | # if isinstance(self.head, nn.Linear):
292 | # trunc_normal_(self.head.weight, std=.02)
293 | self.apply(self._init_weights)
294 | self.fix_init_weight()
295 | # if isinstance(self.head, nn.Linear):
296 | # self.head.weight.data.mul_(init_scale)
297 | # self.head.bias.data.mul_(init_scale)
298 |
299 | def fix_init_weight(self):
300 | def rescale(param, layer_id):
301 | param.div_(math.sqrt(2.0 * layer_id))
302 |
303 | for layer_id, layer in enumerate(self.blocks):
304 | rescale(layer.attn.proj.weight.data, layer_id + 1)
305 | rescale(layer.mlp.fc2.weight.data, layer_id + 1)
306 |
307 | def _init_weights(self, m):
308 | if isinstance(m, nn.Linear):
309 | trunc_normal_(m.weight, std=.02)
310 | if isinstance(m, nn.Linear) and m.bias is not None:
311 | nn.init.constant_(m.bias, 0)
312 | elif isinstance(m, nn.LayerNorm):
313 | nn.init.constant_(m.bias, 0)
314 | nn.init.constant_(m.weight, 1.0)
315 |
316 | def get_classifier(self):
317 | return self.head
318 |
319 | def reset_classifier(self, num_classes, global_pool=''):
320 | self.num_classes = num_classes
321 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
322 |
323 | def forward_features(self, x, return_intermediate=False):
324 | x = self.patch_embed(x)
325 | batch_size, seq_len, _ = x.size()
326 |
327 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
328 | x = torch.cat((cls_tokens, x), dim=1)
329 | if self.pos_embed is not None:
330 | x = x + self.pos_embed
331 | x = self.pos_drop(x)
332 |
333 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
334 | intermediate = []
335 | for idx,blk in enumerate(self.blocks):
336 | if self.use_checkpoint:
337 | x = checkpoint.checkpoint(blk, x, rel_pos_bias)
338 | else:
339 | x = blk(x, rel_pos_bias)
340 | if return_intermediate:
341 | intermediate.append(x)
342 | if return_intermediate:
343 | return x, intermediate
344 | return x
345 | # x = self.norm(x)
346 |
347 | # if self.fc_norm is not None:
348 | # t = x[:, 1:, :]
349 | # return self.fc_norm(t.mean(1))
350 | # else:
351 | # return x[:, 0]
352 |
353 | def forward(self, x, return_intermediate=False):
354 | x = self.forward_features(x, return_intermediate)
355 | # x = self.head(x)
356 | return x
357 |
358 | def get_intermediate_layers(self, x):
359 | x = self.patch_embed(x)
360 | batch_size, seq_len, _ = x.size()
361 |
362 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
363 | x = torch.cat((cls_tokens, x), dim=1)
364 | if self.pos_embed is not None:
365 | x = x + self.pos_embed
366 | x = self.pos_drop(x)
367 |
368 | features = []
369 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
370 | for blk in self.blocks:
371 | x = blk(x, rel_pos_bias)
372 | features.append(x)
373 |
374 | return features
375 |
376 | def get_num_layer(self, var_name=""):
377 | if var_name in ("cls_token", "mask_token", "pos_embed"):
378 | return 0
379 | elif var_name.startswith("patch_embed"):
380 | return 0
381 | elif var_name.startswith("rel_pos_bias"):
382 | return len(self.blocks) - 1
383 | elif var_name.startswith("blocks"):
384 | layer_id = int(var_name.split('.')[1])
385 | return layer_id + 1
386 | else:
387 | return len(self.blocks)
388 |
389 |
390 | def interpolate_pos_embed(model, checkpoint_model):
391 | if 'pos_embed' in checkpoint_model:
392 | pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
393 | embedding_size = pos_embed_checkpoint.shape[-1]
394 | num_patches = model.patch_embed.num_patches
395 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
396 | # height (== width) for the checkpoint position embedding
397 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
398 | # height (== width) for the new position embedding
399 | new_size = int(num_patches ** 0.5)
400 | # class_token and dist_token are kept unchanged
401 | if orig_size != new_size:
402 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
403 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
404 | # only the position tokens are interpolated
405 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
406 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
407 | pos_tokens = torch.nn.functional.interpolate(
408 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
409 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
410 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
411 | checkpoint_model['pos_embed'] = new_pos_embed
412 |
413 |
414 | def convert_weights_to_fp16(model: nn.Module):
415 | """Convert applicable model parameters to fp16"""
416 |
417 | def _convert_weights_to_fp16(l):
418 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
419 | l.weight.data = l.weight.data.half()
420 | if l.bias is not None:
421 | l.bias.data = l.bias.data.half()
422 |
423 | # if isinstance(l, (nn.MultiheadAttention, Attention)):
424 | # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
425 | # tensor = getattr(l, attr)
426 | # if tensor is not None:
427 | # tensor.data = tensor.data.half()
428 |
429 | model.apply(_convert_weights_to_fp16)
430 |
431 |
432 | def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16",path="/home/ubuntu/Models/eva_vit_g.pth",depth=39):
433 | print("ViT Depth:", depth)
434 | model = VisionTransformer(
435 | img_size=img_size,
436 | patch_size=14,
437 | use_mean_pooling=False,
438 | embed_dim=1408,
439 | depth=depth,
440 | num_heads=1408//88,
441 | mlp_ratio=4.3637,
442 | qkv_bias=True,
443 | drop_path_rate=drop_path_rate,
444 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
445 | use_checkpoint=use_checkpoint,
446 | )
447 | state_dict = torch.load(path, map_location="cpu")
448 | interpolate_pos_embed(model,state_dict)
449 |
450 | incompatible_keys = model.load_state_dict(state_dict, strict=False)
451 |
452 | if precision == "fp16":
453 | convert_weights_to_fp16(model)
454 | return model
--------------------------------------------------------------------------------
/models/lion_adapters.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from models.modeling_t5 import T5LayerFF
7 |
8 | import torch
9 | import torch.nn as nn
10 | from transformers import BertLayer, BertConfig
11 | from transformers.models.bert.modeling_bert import BertOutput, BertSelfOutput
12 |
13 | class FusionAdapter(nn.Module):
14 | def __init__(
15 | self,
16 | num_blocks: int = 2,
17 | dim: int = 1408,
18 | num_heads: int = 16,
19 | ):
20 | super().__init__()
21 | config = BertConfig(
22 | hidden_size=dim,
23 | num_attention_heads=num_heads
24 | )
25 | config.add_cross_attention = True
26 | config.is_decoder = True
27 | self.config = config
28 | self.blocks = nn.ModuleList([BertLayer(config) for _ in range(num_blocks)])
29 | self.apply(self._init_weights)
30 |
31 | def forward(self, hidden_states, encoder_hidden_states):
32 | if isinstance(encoder_hidden_states, list):
33 | assert len(encoder_hidden_states) == len(self.blocks)
34 | for idx,block in enumerate(self.blocks):
35 | hidden_states = block(
36 | hidden_states,
37 | encoder_hidden_states=encoder_hidden_states[idx],
38 | )[0]
39 | else:
40 | for idx,block in enumerate(self.blocks):
41 | hidden_states = block(
42 | hidden_states,
43 | encoder_hidden_states=encoder_hidden_states,
44 | )[0]
45 | return hidden_states
46 |
47 | def _init_weights(self, module):
48 | """ Initialize the weights """
49 | if isinstance(module, (nn.Linear, nn.Embedding)):
50 | # Slightly different from the TF version which uses truncated_normal for initialization
51 | # cf https://github.com/pytorch/pytorch/pull/5617
52 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
53 | elif isinstance(module, nn.LayerNorm):
54 | module.bias.data.zero_()
55 | module.weight.data.fill_(1.0)
56 | if isinstance(module, nn.Linear) and module.bias is not None:
57 | module.bias.data.zero_()
58 | if isinstance(module, BertSelfOutput) or isinstance(module, BertOutput):
59 | module.dense.weight.data.zero_()
60 | module.dense.bias.data.zero_()
61 |
62 | class Adapter(nn.Module):
63 | def __init__(self,
64 | d_model=None,
65 | bottleneck=64,
66 | dropout=0.0,
67 | init_option="lora",
68 | adapter_scalar="learnable_scalar",
69 | adapter_layernorm_option="none"):
70 | super().__init__()
71 | self.n_embd = d_model
72 | self.down_size = bottleneck
73 |
74 | #_before
75 | self.adapter_layernorm_option = adapter_layernorm_option
76 |
77 | self.adapter_layer_norm_before = None
78 | if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
79 | self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
80 |
81 | if adapter_scalar == "learnable_scalar":
82 | self.scale = nn.Parameter(torch.ones(1))
83 | else:
84 | self.scale = float(adapter_scalar)
85 |
86 | self.down_proj = nn.Linear(self.n_embd, self.down_size)
87 | self.non_linear_func = nn.ReLU()
88 | self.up_proj = nn.Linear(self.down_size, self.n_embd)
89 |
90 | self.dropout = dropout
91 | if init_option == "bert":
92 | raise NotImplementedError
93 | elif init_option == "lora":
94 | with torch.no_grad():
95 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
96 | nn.init.zeros_(self.up_proj.weight)
97 | nn.init.zeros_(self.down_proj.bias)
98 | nn.init.zeros_(self.up_proj.bias)
99 |
100 | def forward(self, x):
101 | if self.adapter_layernorm_option == 'in':
102 | x = self.adapter_layer_norm_before(x)
103 |
104 | down = self.down_proj(x)
105 | down = self.non_linear_func(down)
106 | down = nn.functional.dropout(down, p=self.dropout, training=self.training)
107 | up = self.up_proj(down)
108 |
109 | up = up * self.scale
110 |
111 | if self.adapter_layernorm_option == 'out':
112 | up = self.adapter_layer_norm_before(up)
113 |
114 | return up
115 |
116 | class AdapterRouter(nn.Module):
117 | def __init__(self,
118 | d_model,
119 | bottleneck=64,
120 | dropout=0.0,
121 | init_option="lora",
122 | adapter_scalar="learnable_scalar",
123 | adapter_layernorm_option="none",
124 | num_adapters = 2,
125 | ):
126 | super().__init__()
127 | self.adapters = nn.ModuleList([])
128 | for _ in range(num_adapters):
129 | self.adapters.append(Adapter(
130 | d_model=d_model,
131 | bottleneck=bottleneck,
132 | dropout=dropout,
133 | init_option=init_option,
134 | adapter_scalar=adapter_scalar,
135 | adapter_layernorm_option=adapter_layernorm_option
136 | ))
137 | if num_adapters > 1:
138 | self.router_ratio1 = nn.Parameter(
139 | torch.tensor([[1],[0]],dtype=torch.float32).repeat(1,d_model)
140 | )
141 | self.router_ratio2 = nn.Parameter(
142 | torch.tensor([[0],[1]],dtype=torch.float32).repeat(1,d_model)
143 | )
144 | self.num_adapters = num_adapters
145 | self.router_idx = None
146 |
147 | def forward(self, x):
148 | assert self.router_idx in [0,1]
149 | output1 = self.adapters[0](x)
150 | if self.num_adapters == 1:
151 | return output1
152 |
153 | output2 = self.adapters[1](x)
154 | ratio1 = self.router_ratio1[self.router_idx]
155 | ratio2 = self.router_ratio2[self.router_idx]
156 | return output1 * ratio1 + output2 * ratio2
157 |
158 | def forward_ffn_t5(self, hidden_states):
159 | adapt_hidden_states = self.adapter(hidden_states)
160 | forwarded_states = self.layer_norm(hidden_states)
161 | forwarded_states = self.DenseReluDense(forwarded_states)
162 | hidden_states = hidden_states + self.dropout(forwarded_states)
163 | return hidden_states + adapt_hidden_states
164 |
165 | def set_adapter_t5(model: nn.Module, d_model: int, n: int, bottleneck: int = 64):
166 | for c in model.children():
167 | if isinstance(c, T5LayerFF):
168 | c.adapter = AdapterRouter(d_model=d_model, bottleneck=bottleneck, num_adapters=n)
169 | bound_method = forward_ffn_t5.__get__(c, c.__class__)
170 | setattr(c, 'forward', bound_method)
171 | elif len(list(c.children())) > 0:
172 | set_adapter_t5(c, d_model, n, bottleneck)
173 |
174 | def set_router_idx(model: nn.Module, idx: int):
175 | for c in model.children():
176 | if isinstance(c, AdapterRouter):
177 | c.router_idx = idx
178 | elif len(list(c.children())) > 0:
179 | set_router_idx(c, idx)
--------------------------------------------------------------------------------
/models/lion_t5.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import logging
3 | import string
4 | from typing import Literal, Union, List
5 | from PIL.Image import Image
6 |
7 | import torch
8 | import torch.nn as nn
9 | from icecream import ic
10 | from torch.cuda.amp import autocast as autocast
11 | from transformers import BertTokenizer, T5TokenizerFast
12 | from transformers.modeling_outputs import BaseModelOutput
13 |
14 | from common.registry import registry
15 | from models.base_model import BaseModel
16 | from models.eva_vit import create_eva_vit_g
17 | from models.lion_adapters import FusionAdapter, set_adapter_t5, set_router_idx
18 | from models.modeling_t5 import T5Config, T5ForConditionalGeneration
19 | from models.Qformer import BertConfig, BertLMHeadModel
20 | from ram import get_transform
21 | from ram.models import ram
22 |
23 |
24 | class LayerNorm(nn.LayerNorm):
25 | """Subclass torch's LayerNorm to handle fp16."""
26 |
27 | def forward(self, x: torch.Tensor):
28 | orig_type = x.dtype
29 | ret = super().forward(x.type(torch.float32))
30 | return ret.type(orig_type)
31 |
32 | def disabled_train(self, mode=True):
33 | """Overwrite model.train with this function to make sure train/eval mode
34 | does not change anymore."""
35 | return self
36 |
37 | @registry.register_model("lion_t5")
38 | class LIONT5InstructAdapter(BaseModel):
39 | """
40 | LION T5 model.
41 | Supported model types:
42 | - flant5xl
43 | - flant5xxl
44 | Usage:
45 | >>> from models import load_model
46 | >>> model = load_model("lion_t5", "flant5xl")
47 | """
48 |
49 | PRETRAINED_MODEL_CONFIG_DICT = {
50 | "flant5xl": "configs/models/lion_flant5xl.yaml",
51 | "flant5xxl": "configs/models/lion_flant5xxl.yaml",
52 | }
53 |
54 | def __init__(
55 | self,
56 | bert_model,
57 | vit_model,
58 | llm_model,
59 | ram_model,
60 | max_txt_len=128,
61 | max_output_txt_len=128,
62 | visual_input: Literal["ALL", "QFORMER", "AGGREGATOR"] = "ALL",
63 | enable_semantic_tags=True,
64 | boost_lr_scale=1,
65 |
66 | img_size=224,
67 | drop_path_rate=0,
68 | use_grad_checkpoint=False,
69 | vit_precision="fp16",
70 | freeze_vit=True,
71 | num_query_token=32,
72 | ):
73 | super().__init__()
74 | assert bert_model is not None, "The path for bert model is not provided."
75 | assert vit_model is not None, "The path for vit model is not provided."
76 | assert llm_model is not None, "The path for llm model is not provided."
77 | assert visual_input in ["ALL", "QFORMER", "AGGREGATOR"], f"Invalid visual input type: {visual_input}."
78 | self.bert_model = bert_model
79 | self.visual_input = visual_input
80 | self.enable_semantic_tags = enable_semantic_tags
81 | self.max_txt_len = max_txt_len
82 | self.max_output_txt_len = max_output_txt_len
83 | self.boost_lr_scale = boost_lr_scale
84 | self.ram_path = ram_model
85 | self.ram_model = None
86 | logging.info(f"visual_input: {visual_input}")
87 |
88 | print("Loading VIT")
89 | self.visual_encoder = self._init_vision_encoder(
90 | vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
91 | )
92 | if freeze_vit:
93 | for name, param in self.visual_encoder.named_parameters():
94 | param.requires_grad = False
95 | self.visual_encoder = self.visual_encoder.eval()
96 | self.visual_encoder.train = disabled_train
97 | logging.info("freeze vision encoder")
98 | print("Loading VIT Done")
99 |
100 | self._init_llm(llm_model)
101 |
102 | if self.visual_input != "AGGREGATOR":
103 | print("Loading QFormer")
104 | self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model, truncation_side="left")
105 | self.bert_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
106 | self.Qformer, self.query_tokens = self._init_Qformer(
107 | bert_model, num_query_token, self.visual_encoder.num_features
108 | )
109 | self.Qformer.resize_token_embeddings(len(self.bert_tokenizer))
110 | self.Qformer.cls = None
111 | self.t5_proj = nn.Linear(
112 | self.Qformer.config.hidden_size, self.t5_model.config.hidden_size
113 | )
114 | self.ln_vision = LayerNorm(self.visual_encoder.num_features)
115 | print("Loading QFormer Done")
116 |
117 | if self.visual_input != "QFORMER":
118 | print("Loading Vision Aggregator")
119 | self.ln_adapter = LayerNorm(self.visual_encoder.num_features)
120 | self.adapter_proj = nn.Sequential(
121 | nn.Linear(self.visual_encoder.num_features, self.visual_encoder.num_features * 4),
122 | nn.GELU(),
123 | nn.Linear(self.visual_encoder.num_features * 4, self.t5_model.config.hidden_size),
124 | )
125 | self.fusion_adapter = FusionAdapter(num_blocks=2,dim=self.visual_encoder.num_features)
126 | print("Loading Vision Aggregator Done")
127 |
128 | if self.enable_semantic_tags:
129 | tag_sp_token = ""
130 | self.tag_softPrompt_id = self.t5_tokenizer.convert_tokens_to_ids(tag_sp_token)
131 | self.tag_prompt = "According to , you are allowed to use or partially use the following tags: [{}]. "
132 | self.soft_prompt_hint = nn.Parameter(torch.zeros(self.t5_model.config.hidden_size))
133 | self.soft_prompt_hint.data.normal_(mean=0.0, std=self.t5_model.config.initializer_factor)
134 | logging.info(f"boost_lr_scale:{boost_lr_scale}")
135 |
136 | def maybe_autocast(self, dtype=torch.bfloat16):
137 | # if on cpu, don't use autocast
138 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
139 | enable_autocast = self.device != torch.device("cpu")
140 |
141 | if enable_autocast:
142 | return torch.cuda.amp.autocast(dtype=dtype)
143 | else:
144 | return contextlib.nullcontext()
145 |
146 | def _init_vision_encoder(
147 | self, model_path, img_size, drop_path_rate, use_grad_checkpoint, precision
148 | ):
149 | print("Using normal vit")
150 | visual_encoder = create_eva_vit_g(
151 | img_size, drop_path_rate, use_grad_checkpoint, precision, model_path
152 | )
153 | return visual_encoder
154 |
155 | def _init_Qformer(self, bert_model, num_query_token, vision_width, cross_attention_freq=2):
156 | encoder_config = BertConfig.from_pretrained(bert_model)
157 | encoder_config.encoder_width = vision_width
158 | encoder_config.add_cross_attention = True
159 | encoder_config.cross_attention_freq = cross_attention_freq
160 | encoder_config.query_length = num_query_token
161 | Qformer = BertLMHeadModel(config=encoder_config)
162 | query_tokens = nn.Parameter(
163 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
164 | )
165 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
166 | return Qformer, query_tokens
167 |
168 | def _init_llm(self, llm_model):
169 | print("Loading LLM")
170 | self.t5_tokenizer = T5TokenizerFast.from_pretrained(llm_model, truncation_side='left')
171 | self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(llm_model, truncation_side='right')
172 |
173 | llm_config = T5Config.from_pretrained(llm_model)
174 | llm_config.dense_act_fn = "gelu"
175 | self.t5_model = T5ForConditionalGeneration.from_pretrained(
176 | llm_model, config=llm_config, torch_dtype=torch.bfloat16,
177 | )
178 | set_adapter_t5(self.t5_model, self.t5_model.config.d_model, n=2 if self.visual_input=="ALL" else 1, bottleneck=64)
179 |
180 | for name, param in self.t5_model.named_parameters():
181 | if "adapter" in name:
182 | if "router_ratio" in name and self.visual_input != "ALL":
183 | param.requires_grad = False
184 | else:
185 | param.requires_grad = True
186 | else:
187 | param.requires_grad = False
188 | print("Loading LLM Done")
189 |
190 | def _init_ram(self):
191 | if self.ram_model == None:
192 | print("Loading RAM Model For Tag Generation")
193 | self.ram_model = ram(pretrained=self.ram_path, image_size=384, vit="swin_l", text_encoder_type=self.bert_model).cuda()
194 | self.ram_processor = get_transform()
195 | print("Loading RAM Model Done")
196 |
197 | def generate_tags(self, images:Union[List[Image], Image]) -> List[str]:
198 | """
199 | Generate tags for provided images.
200 |
201 | Args:
202 | images (Image or List[Image])
203 | Returns:
204 | tags (List[str])
205 | """
206 |
207 | self._init_ram()
208 | if isinstance(images, Image):
209 | images = [images]
210 | images = torch.stack([self.ram_processor(img) for img in images]).to(self.device)
211 | tags = self.ram_model.generate_tag(images, threshold=0.85)[0]
212 | return [t.replace(" |",",") for t in tags]
213 |
214 | def _insert_tags(self, samples, prompt):
215 | if self.enable_semantic_tags:
216 | assert self.tag_prompt is not None, "Please provide Tags prompt."
217 | if "tags" not in samples:
218 | samples = self._generate_tags(samples)
219 | prompt = [self.tag_prompt.format(tags) + tin for tags, tin in zip(samples["tags"], prompt)]
220 | return prompt
221 |
222 | def _insert_softTagHint(self, samples, input_tokens, inputs_embeds):
223 | if self.enable_semantic_tags:
224 | bs = inputs_embeds.size(0)
225 | sp_embeds = self.soft_prompt_hint.expand(bs, -1).to(inputs_embeds.dtype)
226 | sp_index = (input_tokens.input_ids == self.tag_softPrompt_id).nonzero(as_tuple=True)
227 | inputs_embeds[sp_index] = sp_embeds
228 | return inputs_embeds
229 |
230 | def get_optimizer_params(self, weight_decay, lr_scale=1):
231 | p_wd, p_non_wd = [], []
232 | p_boost, p_boost_non_wd = [], []
233 | boost_name = []
234 | for n, p in self.named_parameters():
235 | if not p.requires_grad:
236 | continue # frozen weights
237 | if "fusion_adapter" in n:
238 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
239 | p_boost_non_wd.append(p)
240 | else:
241 | p_boost.append(p)
242 | boost_name.append(n)
243 | else:
244 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
245 | p_non_wd.append(p)
246 | else:
247 | p_wd.append(p)
248 | optim_params = [
249 | {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale},
250 | {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale},
251 | {"params": p_boost, "weight_decay": weight_decay, "lr_scale": lr_scale*self.boost_lr_scale},
252 | {"params": p_boost_non_wd, "weight_decay": 0, "lr_scale": lr_scale*self.boost_lr_scale},
253 | ]
254 | logging.info(f"boost params:{boost_name}")
255 | return optim_params
256 |
257 | def encode_img(self, image, question):
258 | with self.maybe_autocast(dtype=torch.bfloat16):
259 | image_embeds, intermediate = self.visual_encoder(image, return_intermediate=True)
260 | if self.visual_input != "QFORMER":
261 | adapter_embeds = self.ln_adapter(self.fusion_adapter(intermediate[38], [intermediate[28],intermediate[18]]))
262 | adapter_embeds = self.adapter_proj(adapter_embeds)
263 | if self.visual_input != "AGGREGATOR":
264 | image_embeds = self.ln_vision(image_embeds)
265 | if self.visual_input != "AGGREGATOR":
266 | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
267 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
268 |
269 | text_Qformer = self.bert_tokenizer(
270 | question,
271 | padding='longest',
272 | truncation=True,
273 | max_length=self.max_txt_len,
274 | return_tensors="pt",
275 | ).to(image.device)
276 | query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
277 | Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
278 |
279 | query_output = self.Qformer.bert(
280 | text_Qformer.input_ids,
281 | attention_mask=Qformer_atts,
282 | query_embeds=query_tokens,
283 | encoder_hidden_states=image_embeds,
284 | encoder_attention_mask=image_atts,
285 | return_dict=True,
286 | )
287 | input_qformer = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
288 |
289 | match self.visual_input:
290 | case "AGGREGATOR":
291 | inputs_t5 = adapter_embeds
292 | case "QFORMER":
293 | inputs_t5 = input_qformer
294 | case "ALL":
295 | inputs_t5 = torch.cat([input_qformer, adapter_embeds],dim=1)
296 | case _:
297 | raise NotImplementedError(f"Visual input type {self.visual_input} is not supported.")
298 |
299 | atts_t5 = torch.ones(inputs_t5.size()[:-1],dtype=torch.long).to(image.device)
300 |
301 | return inputs_t5, atts_t5
302 |
303 | def forward(self, samples):
304 | img_embeds, img_atts = self.encode_img(samples["image"], samples["question"])
305 | prompt = self._insert_tags(samples, samples["question"])
306 | with self.maybe_autocast(dtype=torch.bfloat16):
307 | input_tokens = self.t5_tokenizer(
308 | prompt,
309 | padding="longest",
310 | truncation=True,
311 | max_length=self.max_txt_len,
312 | return_tensors="pt",
313 | ).to(self.device)
314 | output_tokens = self.t5_output_tokenizer(
315 | samples["answer"],
316 | padding="longest",
317 | truncation=True,
318 | max_length=self.max_output_txt_len,
319 | return_tensors="pt",
320 | ).to(self.device)
321 |
322 | targets = output_tokens.input_ids.masked_fill(
323 | output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
324 | )
325 |
326 | text_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
327 | text_embeds = self._insert_softTagHint(samples, input_tokens, text_embeds)
328 | text_atts = input_tokens.attention_mask
329 |
330 | input_embeds = torch.cat([img_embeds, text_embeds], dim=1)
331 | encoder_atts = torch.cat([img_atts, text_atts], dim=1)
332 |
333 | set_router_idx(self.t5_model, int(samples["category"][0] != "region_level"))
334 | outputs = self.t5_model(
335 | inputs_embeds=input_embeds,
336 | attention_mask=encoder_atts,
337 | decoder_attention_mask=output_tokens.attention_mask,
338 | return_dict=True,
339 | labels=targets,
340 | )
341 | loss = outputs.loss
342 |
343 | return {"loss": loss}
344 |
345 | @torch.no_grad()
346 | def generate(
347 | self,
348 | samples,
349 | use_nucleus_sampling=False,
350 | num_beams=5,
351 | max_length=256,
352 | min_length=1,
353 | top_p=0.9,
354 | repetition_penalty=1.5,
355 | length_penalty=1.0,
356 | num_captions=1,
357 | temperature=1,
358 | ):
359 | img_embeds, img_atts = self.encode_img(samples["image"].to(self.device), samples["question"])
360 | prompt = self._insert_tags(samples, samples["question"])
361 | input_tokens = self.t5_tokenizer(
362 | prompt,
363 | padding="longest",
364 | return_tensors="pt"
365 | ).to(self.device)
366 |
367 | with self.maybe_autocast(dtype=torch.bfloat16):
368 | text_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
369 | text_embeds = self._insert_softTagHint(samples, input_tokens, text_embeds)
370 | text_atts = input_tokens.attention_mask
371 |
372 | inputs_embeds = torch.cat([img_embeds, text_embeds], dim=1)
373 | input_atts = torch.cat([img_atts, text_atts], dim=1)
374 | set_router_idx(self.t5_model, int(samples.get("category") != "region_level"))
375 | outputs = self.t5_model.generate(
376 | inputs_embeds=inputs_embeds,
377 | attention_mask=input_atts,
378 | do_sample=use_nucleus_sampling,
379 | top_p=top_p,
380 | temperature=temperature,
381 | num_beams=num_beams,
382 | max_new_tokens=max_length,
383 | min_length=min_length,
384 | repetition_penalty=repetition_penalty,
385 | length_penalty=length_penalty,
386 | num_return_sequences=num_captions,
387 | )
388 | output_text = self.t5_tokenizer.batch_decode(
389 | outputs, skip_special_tokens=True
390 | )
391 |
392 | return output_text
393 |
394 | @classmethod
395 | def from_config(cls, cfg):
396 | bert_model = cfg.get("bert_model")
397 | vit_model = cfg.get("vit_model")
398 | llm_model = cfg.get("llm_model")
399 | ram_model = cfg.get("ram_model")
400 |
401 | max_txt_len = cfg.get("max_txt_len", 128)
402 | max_output_txt_len = cfg.get("max_output_txt_len", 128)
403 | visual_input = cfg.get("visual_input", "ALL")
404 | enable_semantic_tags = cfg.get("enable_semantic_tags", True)
405 | boost_lr_scale = cfg.get("boost_lr_scale", 1.0)
406 |
407 | img_size = cfg.get("image_size", 224)
408 | drop_path_rate = cfg.get("drop_path_rate", 0)
409 | use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
410 | vit_precision = cfg.get("vit_precision", "fp16")
411 | freeze_vit = cfg.get("freeze_vit", True)
412 | num_query_token = cfg.get("num_query_token", 32)
413 |
414 | model = cls(
415 | bert_model=bert_model,
416 | vit_model=vit_model,
417 | llm_model=llm_model,
418 | ram_model=ram_model,
419 | max_txt_len=max_txt_len,
420 | max_output_txt_len=max_output_txt_len,
421 | visual_input=visual_input,
422 | enable_semantic_tags=enable_semantic_tags,
423 | boost_lr_scale=boost_lr_scale,
424 |
425 | img_size=img_size,
426 | drop_path_rate=drop_path_rate,
427 | use_grad_checkpoint=use_grad_checkpoint,
428 | vit_precision=vit_precision,
429 | freeze_vit=freeze_vit,
430 | num_query_token=num_query_token,
431 | )
432 |
433 | model.load_checkpoint_from_config(cfg)
434 |
435 | return model
436 |
--------------------------------------------------------------------------------
/preprocessors/lion_preprocessors.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from torchvision import transforms
3 | from torchvision.transforms.functional import InterpolationMode
4 |
5 | from common.registry import registry
6 |
7 |
8 | class BaseProcessor:
9 | def __init__(self, mean=None, std=None):
10 | if mean is None:
11 | mean = (0.48145466, 0.4578275, 0.40821073)
12 | if std is None:
13 | std = (0.26862954, 0.26130258, 0.27577711)
14 |
15 | self.normalize = transforms.Normalize(mean, std)
16 |
17 |
18 | @registry.register_processor("eval")
19 | class ImageEvalProcessor(BaseProcessor):
20 | def __init__(self, image_size=224, mean=None, std=None):
21 | super().__init__(mean=mean, std=std)
22 |
23 | self.transform = transforms.Compose(
24 | [
25 | transforms.Resize(
26 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
27 | ),
28 | transforms.ToTensor(),
29 | self.normalize,
30 | ]
31 | )
32 |
33 | def __call__(self, item):
34 | return self.transform(item)
35 |
36 | @classmethod
37 | def from_config(cls, cfg=None):
38 | if cfg is None:
39 | cfg = OmegaConf.create()
40 |
41 | image_size = cfg.get("image_size", 224)
42 |
43 | mean = cfg.get("mean", None)
44 | std = cfg.get("std", None)
45 |
46 | return cls(image_size=image_size, mean=mean, std=std)
47 |
48 |
49 | @registry.register_processor("train")
50 | class ImageTrainProcessor(BaseProcessor):
51 | def __init__(
52 | self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0
53 | ):
54 | super().__init__(mean=mean, std=std)
55 |
56 | self.transform = transforms.Compose(
57 | [
58 | transforms.RandomResizedCrop(
59 | image_size,
60 | scale=(min_scale, max_scale),
61 | interpolation=InterpolationMode.BICUBIC,
62 | ),
63 | transforms.RandomHorizontalFlip(),
64 | transforms.ToTensor(),
65 | self.normalize,
66 | ]
67 | )
68 |
69 | def __call__(self, item):
70 | return self.transform(item)
71 |
72 | @classmethod
73 | def from_config(cls, cfg=None):
74 | if cfg is None:
75 | cfg = OmegaConf.create()
76 |
77 | image_size = cfg.get("image_size", 224)
78 |
79 | mean = cfg.get("mean", None)
80 | std = cfg.get("std", None)
81 |
82 | min_scale = cfg.get("min_scale", 0.5)
83 | max_scale = cfg.get("max_scale", 1.0)
84 |
85 | return cls(
86 | image_size=image_size,
87 | mean=mean,
88 | std=std,
89 | min_scale=min_scale,
90 | max_scale=max_scale,
91 | )
--------------------------------------------------------------------------------
/ram/__init__.py:
--------------------------------------------------------------------------------
1 | from .inference import inference_tag2text, inference_ram, inference_ram_openset
2 | from .transform import get_transform
3 |
--------------------------------------------------------------------------------
/ram/configs/finetune.yaml:
--------------------------------------------------------------------------------
1 | train_file: [
2 | 'datasets/train/coco_train_rmcocodev_ram.json',
3 | ]
4 | image_path_root: ""
5 |
6 | # size of vit model; base or large
7 | vit: 'swin_l'
8 | vit_grad_ckpt: False
9 | vit_ckpt_layer: 0
10 |
11 | image_size: 384
12 | batch_size: 26
13 |
14 | # optimizer
15 | weight_decay: 0.05
16 | init_lr: 5e-06
17 | min_lr: 0
18 | max_epoch: 2
19 | warmup_steps: 3000
20 |
21 | class_num: 4585
22 |
23 |
--------------------------------------------------------------------------------
/ram/configs/finetune_tag2text.yaml:
--------------------------------------------------------------------------------
1 | train_file: [
2 | 'datasets/train/coco_train_rmcocodev_ram.json',
3 | ]
4 | image_path_root: ""
5 |
6 | # size of vit model; base or large
7 | vit: 'swin_b'
8 | vit_grad_ckpt: False
9 | vit_ckpt_layer: 0
10 |
11 | image_size: 384
12 | batch_size: 36
13 |
14 | # optimizer
15 | weight_decay: 0.05
16 | init_lr: 5e-06
17 | min_lr: 0
18 | max_epoch: 2
19 | warmup_steps: 3000
20 |
21 | class_num: 4585
22 |
23 |
--------------------------------------------------------------------------------
/ram/configs/med_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30524,
19 | "encoder_width": 768,
20 | "add_cross_attention": true
21 | }
--------------------------------------------------------------------------------
/ram/configs/pretrain.yaml:
--------------------------------------------------------------------------------
1 | train_file: [
2 | 'datasets/train/coco_train_rmcocodev_ram.json',
3 | 'datasets/train/vg_ram.json',
4 | 'datasets/train/sbu_ram.json',
5 | 'datasets/train/cc3m_train_ram.json',
6 | 'datasets/train/cc3m_val_ram.json',
7 | 'datasets/train/cc12m_ram.json',
8 | ]
9 | image_path_root: ""
10 |
11 | # size of vit model; base or large
12 | vit: 'swin_l'
13 | vit_grad_ckpt: False
14 | vit_ckpt_layer: 0
15 |
16 | image_size: 224
17 | batch_size: 52
18 |
19 | # optimizer
20 | weight_decay: 0.05
21 | init_lr: 1e-4
22 | min_lr: 5e-7
23 | warmup_lr: 5e-7
24 | lr_decay_rate: 0.9
25 | max_epoch: 5
26 | warmup_steps: 3000
27 |
28 | class_num: 4585
29 |
30 |
--------------------------------------------------------------------------------
/ram/configs/pretrain_tag2text.yaml:
--------------------------------------------------------------------------------
1 | train_file: [
2 | 'datasets/train/coco_train_rmcocodev_ram.json',
3 | 'datasets/train/vg_ram.json',
4 | 'datasets/train/sbu_ram.json',
5 | 'datasets/train/cc3m_train_ram.json',
6 | 'datasets/train/cc3m_val_ram.json',
7 | 'datasets/train/cc12m_ram.json',
8 | ]
9 | image_path_root: ""
10 |
11 | # size of vit model; base or large
12 | vit: 'swin_b'
13 | vit_grad_ckpt: False
14 | vit_ckpt_layer: 0
15 |
16 | image_size: 224
17 | batch_size: 80
18 |
19 | # optimizer
20 | weight_decay: 0.05
21 | init_lr: 1e-4
22 | min_lr: 5e-7
23 | warmup_lr: 5e-7
24 | lr_decay_rate: 0.9
25 | max_epoch: 5
26 | warmup_steps: 3000
27 |
28 | class_num: 4585
29 |
30 |
--------------------------------------------------------------------------------
/ram/configs/q2l_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 4,
15 | "num_hidden_layers": 2,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "encoder_width": 768,
20 | "add_cross_attention": true,
21 | "add_tag_cross_attention": false
22 | }
--------------------------------------------------------------------------------
/ram/configs/swin/config_swinB_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3 | "vision_width": 1024,
4 | "image_res": 224,
5 | "window_size": 7,
6 | "embed_dim": 128,
7 | "depths": [ 2, 2, 18, 2 ],
8 | "num_heads": [ 4, 8, 16, 32 ]
9 | }
--------------------------------------------------------------------------------
/ram/configs/swin/config_swinB_384.json:
--------------------------------------------------------------------------------
1 | {
2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3 | "vision_width": 1024,
4 | "image_res": 384,
5 | "window_size": 12,
6 | "embed_dim": 128,
7 | "depths": [ 2, 2, 18, 2 ],
8 | "num_heads": [ 4, 8, 16, 32 ]
9 | }
--------------------------------------------------------------------------------
/ram/configs/swin/config_swinL_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "ckpt": "pretrain_model/swin_large_patch4_window7_224_22k.pth",
3 | "vision_width": 1536,
4 | "image_res": 224,
5 | "window_size": 7,
6 | "embed_dim": 192,
7 | "depths": [ 2, 2, 18, 2 ],
8 | "num_heads": [ 6, 12, 24, 48 ]
9 | }
10 |
--------------------------------------------------------------------------------
/ram/configs/swin/config_swinL_384.json:
--------------------------------------------------------------------------------
1 | {
2 | "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth",
3 | "vision_width": 1536,
4 | "image_res": 384,
5 | "window_size": 12,
6 | "embed_dim": 192,
7 | "depths": [ 2, 2, 18, 2 ],
8 | "num_heads": [ 6, 12, 24, 48 ]
9 | }
--------------------------------------------------------------------------------
/ram/data/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torchvision import transforms
4 | from torchvision.transforms.functional import InterpolationMode
5 |
6 | from .dataset import pretrain_dataset, finetune_dataset
7 | from .randaugment import RandomAugment
8 |
9 | def create_dataset(dataset, config, min_scale=0.5):
10 |
11 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
12 |
13 | transform_train = transforms.Compose([
14 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
15 | transforms.RandomHorizontalFlip(),
16 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
17 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
18 | transforms.ToTensor(),
19 | normalize,
20 | ])
21 |
22 | transform_inputsize_224 = transforms.Compose([
23 | transforms.RandomResizedCrop(224,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
24 | transforms.RandomHorizontalFlip(),
25 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
26 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
27 | transforms.ToTensor(),
28 | normalize,
29 | ])
30 |
31 | if dataset=='pretrain':
32 | dataset = pretrain_dataset(config['train_file'], transform_train, class_num=config['class_num'], root=config['image_path_root'])
33 | return dataset
34 |
35 | elif dataset=='finetune':
36 | dataset = finetune_dataset(config['train_file'], transform_train, transform_inputsize_224, class_num=config['class_num'], root=config['image_path_root'])
37 | return dataset
38 |
39 |
40 |
41 |
42 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
43 | samplers = []
44 | for dataset,shuffle in zip(datasets,shuffles):
45 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
46 | samplers.append(sampler)
47 | return samplers
48 |
49 |
50 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
51 | loaders = []
52 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
53 | if is_train:
54 | shuffle = (sampler is None)
55 | drop_last = True
56 | else:
57 | shuffle = False
58 | drop_last = False
59 | loader = DataLoader(
60 | dataset,
61 | batch_size=bs,
62 | num_workers=n_worker,
63 | pin_memory=True,
64 | sampler=sampler,
65 | shuffle=shuffle,
66 | collate_fn=collate_fn,
67 | drop_last=drop_last,
68 | )
69 | loaders.append(loader)
70 | return loaders
71 |
72 |
--------------------------------------------------------------------------------
/ram/data/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | from torch.utils.data import Dataset
6 |
7 | from PIL import Image
8 | from PIL import ImageFile
9 | ImageFile.LOAD_TRUNCATED_IMAGES = True
10 | Image.MAX_IMAGE_PIXELS = None
11 |
12 | from .utils import pre_caption
13 | import os,glob
14 |
15 | import torch
16 | import numpy as np
17 |
18 | class pretrain_dataset(Dataset):
19 | def __init__(self, ann_file, transform, class_num = 4585, root = ''):
20 |
21 | self.ann = []
22 | for f in ann_file:
23 | print('loading '+f)
24 | ann = json.load(open(f,'r'))
25 | self.ann += ann
26 |
27 | self.transform = transform
28 | self.class_num = class_num
29 | self.root = root
30 |
31 |
32 | def __len__(self):
33 | return len(self.ann)
34 |
35 | def __getitem__(self, index):
36 |
37 | ann = self.ann[index]
38 |
39 | image_path_use = os.path.join(self.root, ann['image_path'])
40 | image = Image.open(image_path_use).convert('RGB')
41 | image = self.transform(image)
42 |
43 | # required for tag2text support
44 | if ann.get('union_label_id') is not None:
45 | num = ann['union_label_id']
46 | image_tag = np.zeros([self.class_num])
47 | image_tag[num] = 1
48 | image_tag = torch.tensor(image_tag, dtype = torch.long)
49 | else:
50 | image_tag = None
51 |
52 | caption_index = np.random.randint(0, len(ann['caption']))
53 |
54 | caption = pre_caption(ann['caption'][caption_index],30)
55 |
56 | num = ann['parse_label_id'][caption_index]
57 | parse_tag = np.zeros([self.class_num])
58 | parse_tag[num] = 1
59 | parse_tag = torch.tensor(parse_tag, dtype = torch.long)
60 |
61 | return image, caption, image_tag, parse_tag
62 |
63 |
64 | class finetune_dataset(Dataset):
65 | def __init__(self, ann_file, transform, transform_224, class_num = 4585, root = ''):
66 |
67 | self.ann = []
68 | for f in ann_file:
69 | print('loading '+f)
70 | ann = json.load(open(f,'r'))
71 | self.ann += ann
72 |
73 | self.transform = transform
74 | self.transform_224 = transform_224
75 | self.class_num = class_num
76 | self.root = root
77 |
78 |
79 | def __len__(self):
80 | return len(self.ann)
81 |
82 | def __getitem__(self, index):
83 |
84 | ann = self.ann[index]
85 |
86 | image_path_use = os.path.join(self.root, ann['image_path'])
87 | image = Image.open(image_path_use).convert('RGB')
88 | image = self.transform(image)
89 |
90 | image_224 = Image.open(image_path_use).convert('RGB')
91 | image_224 = self.transform_224(image_224)
92 |
93 | # required for tag2text support
94 | if ann.get('union_label_id') is not None:
95 | num = ann['union_label_id']
96 | image_tag = np.zeros([self.class_num])
97 | image_tag[num] = 1
98 | image_tag = torch.tensor(image_tag, dtype = torch.long)
99 | else:
100 | image_tag = None
101 |
102 | caption_index = np.random.randint(0, len(ann['caption']))
103 |
104 | caption = pre_caption(ann['caption'][caption_index],30)
105 |
106 | num = ann['parse_label_id'][caption_index]
107 | parse_tag = np.zeros([self.class_num])
108 | parse_tag[num] = 1
109 | parse_tag = torch.tensor(parse_tag, dtype = torch.long)
110 |
111 | return image, image_224, caption, image_tag, parse_tag
112 |
113 |
--------------------------------------------------------------------------------
/ram/data/randaugment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | ## aug functions
6 | def identity_func(img):
7 | return img
8 |
9 |
10 | def autocontrast_func(img, cutoff=0):
11 | '''
12 | same output as PIL.ImageOps.autocontrast
13 | '''
14 | n_bins = 256
15 |
16 | def tune_channel(ch):
17 | n = ch.size
18 | cut = cutoff * n // 100
19 | if cut == 0:
20 | high, low = ch.max(), ch.min()
21 | else:
22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23 | low = np.argwhere(np.cumsum(hist) > cut)
24 | low = 0 if low.shape[0] == 0 else low[0]
25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27 | if high <= low:
28 | table = np.arange(n_bins)
29 | else:
30 | scale = (n_bins - 1) / (high - low)
31 | offset = -low * scale
32 | table = np.arange(n_bins) * scale + offset
33 | table[table < 0] = 0
34 | table[table > n_bins - 1] = n_bins - 1
35 | table = table.clip(0, 255).astype(np.uint8)
36 | return table[ch]
37 |
38 | channels = [tune_channel(ch) for ch in cv2.split(img)]
39 | out = cv2.merge(channels)
40 | return out
41 |
42 |
43 | def equalize_func(img):
44 | '''
45 | same output as PIL.ImageOps.equalize
46 | PIL's implementation is different from cv2.equalize
47 | '''
48 | n_bins = 256
49 |
50 | def tune_channel(ch):
51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52 | non_zero_hist = hist[hist != 0].reshape(-1)
53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54 | if step == 0: return ch
55 | n = np.empty_like(hist)
56 | n[0] = step // 2
57 | n[1:] = hist[:-1]
58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59 | return table[ch]
60 |
61 | channels = [tune_channel(ch) for ch in cv2.split(img)]
62 | out = cv2.merge(channels)
63 | return out
64 |
65 |
66 | def rotate_func(img, degree, fill=(0, 0, 0)):
67 | '''
68 | like PIL, rotate by degree, not radians
69 | '''
70 | H, W = img.shape[0], img.shape[1]
71 | center = W / 2, H / 2
72 | M = cv2.getRotationMatrix2D(center, degree, 1)
73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74 | return out
75 |
76 |
77 | def solarize_func(img, thresh=128):
78 | '''
79 | same output as PIL.ImageOps.posterize
80 | '''
81 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
82 | table = table.clip(0, 255).astype(np.uint8)
83 | out = table[img]
84 | return out
85 |
86 |
87 | def color_func(img, factor):
88 | '''
89 | same output as PIL.ImageEnhance.Color
90 | '''
91 | ## implementation according to PIL definition, quite slow
92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93 | # out = blend(degenerate, img, factor)
94 | # M = (
95 | # np.eye(3) * factor
96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97 | # )[np.newaxis, np.newaxis, :]
98 | M = (
99 | np.float32([
100 | [0.886, -0.114, -0.114],
101 | [-0.587, 0.413, -0.587],
102 | [-0.299, -0.299, 0.701]]) * factor
103 | + np.float32([[0.114], [0.587], [0.299]])
104 | )
105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106 | return out
107 |
108 |
109 | def contrast_func(img, factor):
110 | """
111 | same output as PIL.ImageEnhance.Contrast
112 | """
113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114 | table = np.array([(
115 | el - mean) * factor + mean
116 | for el in range(256)
117 | ]).clip(0, 255).astype(np.uint8)
118 | out = table[img]
119 | return out
120 |
121 |
122 | def brightness_func(img, factor):
123 | '''
124 | same output as PIL.ImageEnhance.Contrast
125 | '''
126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127 | out = table[img]
128 | return out
129 |
130 |
131 | def sharpness_func(img, factor):
132 | '''
133 | The differences the this result and PIL are all on the 4 boundaries, the center
134 | areas are same
135 | '''
136 | kernel = np.ones((3, 3), dtype=np.float32)
137 | kernel[1][1] = 5
138 | kernel /= 13
139 | degenerate = cv2.filter2D(img, -1, kernel)
140 | if factor == 0.0:
141 | out = degenerate
142 | elif factor == 1.0:
143 | out = img
144 | else:
145 | out = img.astype(np.float32)
146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148 | out = out.astype(np.uint8)
149 | return out
150 |
151 |
152 | def shear_x_func(img, factor, fill=(0, 0, 0)):
153 | H, W = img.shape[0], img.shape[1]
154 | M = np.float32([[1, factor, 0], [0, 1, 0]])
155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156 | return out
157 |
158 |
159 | def translate_x_func(img, offset, fill=(0, 0, 0)):
160 | '''
161 | same output as PIL.Image.transform
162 | '''
163 | H, W = img.shape[0], img.shape[1]
164 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166 | return out
167 |
168 |
169 | def translate_y_func(img, offset, fill=(0, 0, 0)):
170 | '''
171 | same output as PIL.Image.transform
172 | '''
173 | H, W = img.shape[0], img.shape[1]
174 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176 | return out
177 |
178 |
179 | def posterize_func(img, bits):
180 | '''
181 | same output as PIL.ImageOps.posterize
182 | '''
183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184 | return out
185 |
186 |
187 | def shear_y_func(img, factor, fill=(0, 0, 0)):
188 | H, W = img.shape[0], img.shape[1]
189 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191 | return out
192 |
193 |
194 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
195 | replace = np.array(replace, dtype=np.uint8)
196 | H, W = img.shape[0], img.shape[1]
197 | rh, rw = np.random.random(2)
198 | pad_size = pad_size // 2
199 | ch, cw = int(rh * H), int(rw * W)
200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202 | out = img.copy()
203 | out[x1:x2, y1:y2, :] = replace
204 | return out
205 |
206 |
207 | ### level to args
208 | def enhance_level_to_args(MAX_LEVEL):
209 | def level_to_args(level):
210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211 | return level_to_args
212 |
213 |
214 | def shear_level_to_args(MAX_LEVEL, replace_value):
215 | def level_to_args(level):
216 | level = (level / MAX_LEVEL) * 0.3
217 | if np.random.random() > 0.5: level = -level
218 | return (level, replace_value)
219 |
220 | return level_to_args
221 |
222 |
223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224 | def level_to_args(level):
225 | level = (level / MAX_LEVEL) * float(translate_const)
226 | if np.random.random() > 0.5: level = -level
227 | return (level, replace_value)
228 |
229 | return level_to_args
230 |
231 |
232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233 | def level_to_args(level):
234 | level = int((level / MAX_LEVEL) * cutout_const)
235 | return (level, replace_value)
236 |
237 | return level_to_args
238 |
239 |
240 | def solarize_level_to_args(MAX_LEVEL):
241 | def level_to_args(level):
242 | level = int((level / MAX_LEVEL) * 256)
243 | return (level, )
244 | return level_to_args
245 |
246 |
247 | def none_level_to_args(level):
248 | return ()
249 |
250 |
251 | def posterize_level_to_args(MAX_LEVEL):
252 | def level_to_args(level):
253 | level = int((level / MAX_LEVEL) * 4)
254 | return (level, )
255 | return level_to_args
256 |
257 |
258 | def rotate_level_to_args(MAX_LEVEL, replace_value):
259 | def level_to_args(level):
260 | level = (level / MAX_LEVEL) * 30
261 | if np.random.random() < 0.5:
262 | level = -level
263 | return (level, replace_value)
264 |
265 | return level_to_args
266 |
267 |
268 | func_dict = {
269 | 'Identity': identity_func,
270 | 'AutoContrast': autocontrast_func,
271 | 'Equalize': equalize_func,
272 | 'Rotate': rotate_func,
273 | 'Solarize': solarize_func,
274 | 'Color': color_func,
275 | 'Contrast': contrast_func,
276 | 'Brightness': brightness_func,
277 | 'Sharpness': sharpness_func,
278 | 'ShearX': shear_x_func,
279 | 'TranslateX': translate_x_func,
280 | 'TranslateY': translate_y_func,
281 | 'Posterize': posterize_func,
282 | 'ShearY': shear_y_func,
283 | }
284 |
285 | translate_const = 10
286 | MAX_LEVEL = 10
287 | replace_value = (128, 128, 128)
288 | arg_dict = {
289 | 'Identity': none_level_to_args,
290 | 'AutoContrast': none_level_to_args,
291 | 'Equalize': none_level_to_args,
292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293 | 'Solarize': solarize_level_to_args(MAX_LEVEL),
294 | 'Color': enhance_level_to_args(MAX_LEVEL),
295 | 'Contrast': enhance_level_to_args(MAX_LEVEL),
296 | 'Brightness': enhance_level_to_args(MAX_LEVEL),
297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299 | 'TranslateX': translate_level_to_args(
300 | translate_const, MAX_LEVEL, replace_value
301 | ),
302 | 'TranslateY': translate_level_to_args(
303 | translate_const, MAX_LEVEL, replace_value
304 | ),
305 | 'Posterize': posterize_level_to_args(MAX_LEVEL),
306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307 | }
308 |
309 |
310 | class RandomAugment(object):
311 |
312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
313 | self.N = N
314 | self.M = M
315 | self.isPIL = isPIL
316 | if augs:
317 | self.augs = augs
318 | else:
319 | self.augs = list(arg_dict.keys())
320 |
321 | def get_random_ops(self):
322 | sampled_ops = np.random.choice(self.augs, self.N)
323 | return [(op, 0.5, self.M) for op in sampled_ops]
324 |
325 | def __call__(self, img):
326 | if self.isPIL:
327 | img = np.array(img)
328 | ops = self.get_random_ops()
329 | for name, prob, level in ops:
330 | if np.random.random() > prob:
331 | continue
332 | args = arg_dict[name](level)
333 | img = func_dict[name](img, *args)
334 | return img
335 |
336 |
337 | if __name__ == '__main__':
338 | a = RandomAugment()
339 | img = np.random.randn(32, 32, 3)
340 | a(img)
--------------------------------------------------------------------------------
/ram/data/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import os
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 | import utils
9 |
10 | def pre_caption(caption,max_words=50):
11 | caption = re.sub(
12 | r"([.!\"()*#:;~])",
13 | ' ',
14 | caption.lower(),
15 | )
16 | caption = re.sub(
17 | r"\s{2,}",
18 | ' ',
19 | caption,
20 | )
21 | caption = caption.rstrip('\n')
22 | caption = caption.strip(' ')
23 |
24 | #truncate caption
25 | caption_words = caption.split(' ')
26 | if len(caption_words)>max_words:
27 | caption = ' '.join(caption_words[:max_words])
28 |
29 | return caption
30 |
31 | def pre_question(question,max_ques_words=50):
32 | question = re.sub(
33 | r"([.!\"()*#:;~])",
34 | '',
35 | question.lower(),
36 | )
37 | question = question.rstrip(' ')
38 |
39 | #truncate question
40 | question_words = question.split(' ')
41 | if len(question_words)>max_ques_words:
42 | question = ' '.join(question_words[:max_ques_words])
43 |
44 | return question
45 |
46 |
47 | def save_result(result, result_dir, filename, remove_duplicate=''):
48 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
49 | final_result_file = os.path.join(result_dir, '%s.json'%filename)
50 |
51 | json.dump(result,open(result_file,'w'))
52 |
53 | dist.barrier()
54 |
55 | if utils.is_main_process():
56 | # combine results from all processes
57 | result = []
58 |
59 | for rank in range(utils.get_world_size()):
60 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
61 | res = json.load(open(result_file,'r'))
62 | result += res
63 |
64 | if remove_duplicate:
65 | result_new = []
66 | id_list = []
67 | for res in result:
68 | if res[remove_duplicate] not in id_list:
69 | id_list.append(res[remove_duplicate])
70 | result_new.append(res)
71 | result = result_new
72 |
73 | json.dump(result,open(final_result_file,'w'))
74 | print('result file saved to %s'%final_result_file)
75 |
76 | return final_result_file
77 |
78 |
79 |
80 | from pycocotools.coco import COCO
81 | from pycocoevalcap.eval import COCOEvalCap
82 | from torchvision.datasets.utils import download_url
83 |
84 | def coco_caption_eval(coco_gt_root, results_file, split):
85 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
86 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
87 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
88 |
89 | download_url(urls[split],coco_gt_root)
90 | annotation_file = os.path.join(coco_gt_root,filenames[split])
91 |
92 | # create coco object and coco_result object
93 | coco = COCO(annotation_file)
94 | coco_result = coco.loadRes(results_file)
95 |
96 | # create coco_eval object by taking coco and coco_result
97 | coco_eval = COCOEvalCap(coco, coco_result)
98 |
99 | # evaluate on a subset of images by setting
100 | # coco_eval.params['image_id'] = coco_result.getImgIds()
101 | # please remove this line when evaluating the full validation set
102 | # coco_eval.params['image_id'] = coco_result.getImgIds()
103 |
104 | # evaluate results
105 | # SPICE will take a few minutes the first time, but speeds up due to caching
106 | coco_eval.evaluate()
107 |
108 | # print output evaluation scores
109 | for metric, score in coco_eval.eval.items():
110 | print(f'{metric}: {score:.3f}')
111 |
112 | return coco_eval
--------------------------------------------------------------------------------
/ram/inference.py:
--------------------------------------------------------------------------------
1 | '''
2 | * The Inference of RAM and Tag2Text Models
3 | * Written by Xinyu Huang
4 | '''
5 | import torch
6 |
7 |
8 | def inference_tag2text(image, model, input_tag="None"):
9 |
10 | with torch.no_grad():
11 | caption, tag_predict = model.generate(image,
12 | tag_input=None,
13 | max_length=50,
14 | return_tag_predict=True)
15 |
16 | if input_tag == '' or input_tag == 'none' or input_tag == 'None':
17 | return tag_predict[0], None, caption[0]
18 |
19 | # If user input specified tags:
20 | else:
21 | input_tag_list = []
22 | input_tag_list.append(input_tag.replace(',', ' | '))
23 |
24 | with torch.no_grad():
25 | caption, input_tag = model.generate(image,
26 | tag_input=input_tag_list,
27 | max_length=50,
28 | return_tag_predict=True)
29 |
30 | return tag_predict[0], input_tag[0], caption[0]
31 |
32 |
33 | def inference_ram(image, model):
34 |
35 | with torch.no_grad():
36 | tags, tags_chinese = model.generate_tag(image)
37 |
38 | return tags[0],tags_chinese[0]
39 |
40 |
41 | def inference_ram_openset(image, model):
42 |
43 | with torch.no_grad():
44 | tags = model.generate_tag_openset(image)
45 |
46 | return tags[0]
47 |
--------------------------------------------------------------------------------
/ram/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ram_plus import ram_plus
2 | from .ram import ram
3 | from .tag2text import tag2text
4 |
--------------------------------------------------------------------------------
/ram/models/ram.py:
--------------------------------------------------------------------------------
1 | '''
2 | * The Recognize Anything Model (RAM)
3 | * Written by Xinyu Huang
4 | '''
5 | import json
6 | import warnings
7 |
8 | import numpy as np
9 | import torch
10 | from torch import nn
11 |
12 | from .bert import BertConfig, BertLMHeadModel, BertModel
13 | from .swin_transformer import SwinTransformer
14 | from .utils import *
15 | import torch.nn.functional as F
16 |
17 | warnings.filterwarnings("ignore")
18 |
19 |
20 | class RAM(nn.Module):
21 | def __init__(self,
22 | med_config=f'{CONFIG_PATH}/configs/med_config.json',
23 | image_size=384,
24 | text_encoder_type='bert-base-uncased',
25 | vit='base',
26 | vit_grad_ckpt=False,
27 | vit_ckpt_layer=0,
28 | prompt='a picture of ',
29 | threshold=0.68,
30 | delete_tag_index=[],
31 | tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt',
32 | tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt',
33 | stage='eval'):
34 | r""" The Recognize Anything Model (RAM) inference module.
35 | RAM is a strong image tagging model, which can recognize any common category with high accuracy.
36 | Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/
37 |
38 | Args:
39 | med_config (str): path for the mixture of encoder-decoder model's configuration file
40 | image_size (int): input image size
41 | vit (str): model size of vision transformer
42 | threshold (int): tagging threshold
43 | delete_tag_index (list): delete some tags that may disturb captioning
44 | """
45 | super().__init__()
46 |
47 | # create image encoder
48 | if vit == 'swin_b':
49 | if image_size == 224:
50 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
51 | elif image_size == 384:
52 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
53 | vision_config = read_json(vision_config_path)
54 | assert image_size == vision_config['image_res']
55 | # assert config['patch_size'] == 32
56 | vision_width = vision_config['vision_width']
57 |
58 | self.visual_encoder = SwinTransformer(
59 | img_size=vision_config['image_res'],
60 | patch_size=4,
61 | in_chans=3,
62 | embed_dim=vision_config['embed_dim'],
63 | depths=vision_config['depths'],
64 | num_heads=vision_config['num_heads'],
65 | window_size=vision_config['window_size'],
66 | mlp_ratio=4.,
67 | qkv_bias=True,
68 | drop_rate=0.0,
69 | drop_path_rate=0.1,
70 | ape=False,
71 | patch_norm=True,
72 | use_checkpoint=False)
73 |
74 | if stage == 'train_from_scratch':
75 | # download from https://github.com/microsoft/Swin-Transformer
76 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model']
77 |
78 | for k in list(state_dict.keys()):
79 | if 'relative_position_bias_table' in k:
80 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2
81 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
82 | elif ('relative_position_index' in k) or ('attn_mask' in k):
83 | del state_dict[k]
84 |
85 | print("### Load Vision Backbone", vit)
86 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False)
87 | print("missing_keys: ", msg.missing_keys)
88 | print("unexpected_keys: ", msg.unexpected_keys)
89 |
90 | elif vit == 'swin_l':
91 | if image_size == 224:
92 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
93 | elif image_size == 384:
94 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
95 | vision_config = read_json(vision_config_path)
96 | assert image_size == vision_config['image_res']
97 | # assert config['patch_size'] == 32
98 | vision_width = vision_config['vision_width']
99 |
100 | self.visual_encoder = SwinTransformer(
101 | img_size=vision_config['image_res'],
102 | patch_size=4,
103 | in_chans=3,
104 | embed_dim=vision_config['embed_dim'],
105 | depths=vision_config['depths'],
106 | num_heads=vision_config['num_heads'],
107 | window_size=vision_config['window_size'],
108 | mlp_ratio=4.,
109 | qkv_bias=True,
110 | drop_rate=0.0,
111 | drop_path_rate=0.1,
112 | ape=False,
113 | patch_norm=True,
114 | use_checkpoint=False)
115 |
116 | if stage == 'train_from_scratch':
117 | # download from https://github.com/microsoft/Swin-Transformer
118 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model']
119 |
120 | for k in list(state_dict.keys()):
121 | if 'relative_position_bias_table' in k:
122 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2
123 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
124 | elif ('relative_position_index' in k) or ('attn_mask' in k):
125 | del state_dict[k]
126 |
127 | print("### Load Vision Backbone", vit)
128 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False)
129 | print("missing_keys: ", msg.missing_keys)
130 | print("unexpected_keys: ", msg.unexpected_keys)
131 |
132 | else:
133 | self.visual_encoder, vision_width = create_vit(
134 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
135 |
136 | # create tokenzier
137 | self.tokenizer = init_tokenizer(text_encoder_type)
138 |
139 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
140 | # create image-tag interaction encoder
141 | encoder_config = BertConfig.from_json_file(med_config)
142 | encoder_config.encoder_width = 512
143 | self.tag_encoder = BertModel(config=encoder_config,
144 | add_pooling_layer=False)
145 |
146 | # create image-tag-text decoder
147 | decoder_config = BertConfig.from_json_file(med_config)
148 | self.text_decoder = BertLMHeadModel(config=decoder_config)
149 |
150 | self.delete_tag_index = delete_tag_index
151 | self.prompt = prompt
152 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
153 |
154 | # load tag list
155 | self.tag_list = self.load_tag_list(tag_list)
156 | self.tag_list_chinese = self.load_tag_list(tag_list_chinese)
157 |
158 | # create image-tag recognition decoder
159 | self.threshold = threshold
160 | self.num_class = len(self.tag_list)
161 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
162 | q2l_config.encoder_width = 512
163 | self.tagging_head = BertModel(config=q2l_config,
164 | add_pooling_layer=False)
165 | self.tagging_head.resize_token_embeddings(len(self.tokenizer))
166 |
167 | if stage == 'train_from_scratch':
168 | self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/frozen_tag_embedding/ram_tag_embedding_class_4585.pth',map_location='cpu').float())
169 | else:
170 | # when eval with pretrained RAM model, directly load from ram_swin_large_14m.pth
171 | self.label_embed = nn.Parameter(torch.zeros(self.num_class, q2l_config.encoder_width))
172 |
173 | if q2l_config.hidden_size != 512:
174 | self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
175 | else:
176 | self.wordvec_proj = nn.Identity()
177 |
178 | self.fc = nn.Linear(q2l_config.hidden_size, 1)
179 |
180 | self.del_selfattention()
181 |
182 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7,
183 | gamma_pos=0,
184 | clip=0.05)
185 |
186 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
187 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
188 | ' ')
189 | self.image_proj = nn.Linear(vision_width, 512)
190 | # self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/textual_label_embedding.pth',map_location='cpu').float())
191 |
192 | # adjust thresholds for some tags
193 | self.class_threshold = torch.ones(self.num_class) * self.threshold
194 | ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt'
195 | with open(ram_class_threshold_path, 'r', encoding='utf-8') as f:
196 | ram_class_threshold = [float(s.strip()) for s in f]
197 | for key,value in enumerate(ram_class_threshold):
198 | self.class_threshold[key] = value
199 |
200 | def load_tag_list(self, tag_list_file):
201 | with open(tag_list_file, 'r', encoding="utf-8") as f:
202 | tag_list = f.read().splitlines()
203 | tag_list = np.array(tag_list)
204 | return tag_list
205 |
206 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
207 | def del_selfattention(self):
208 | del self.tagging_head.embeddings
209 | for layer in self.tagging_head.encoder.layer:
210 | del layer.attention
211 |
212 | def forward(self, image, caption, image_tag, parse_tag, clip_feature):
213 | """
214 | call function as forward
215 |
216 | Args:
217 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384
218 | caption: type: list[string] len: batch_size
219 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0
220 |
221 | Returns:
222 | loss: type: torch.Tensor
223 | """
224 |
225 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
226 |
227 | image_embeds = self.image_proj(self.visual_encoder(image))
228 | image_atts = torch.ones(image_embeds.size()[:-1],
229 | dtype=torch.long).to(image.device)
230 |
231 | ##================= Distillation from CLIP ================##
232 | image_cls_embeds = image_embeds[:, 0, :]
233 | image_spatial_embeds = image_embeds[:, 1:, :]
234 |
235 | loss_dis = F.l1_loss(image_cls_embeds, clip_feature)
236 |
237 | ##================= Image Tagging ================##
238 | bs = image_embeds.shape[0]
239 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
240 |
241 | tagging_embed = self.tagging_head(
242 | encoder_embeds=label_embed,
243 | encoder_hidden_states=image_embeds,
244 | encoder_attention_mask=image_atts,
245 | return_dict=False,
246 | mode='tagging',
247 | )
248 |
249 | logits = self.fc(tagging_embed[0]).squeeze(-1)
250 |
251 | loss_tag = self.tagging_loss_function(logits, image_tag)
252 |
253 | ##================= Image-Tag-Text Generation ================##
254 | tag = parse_tag.cpu().numpy()
255 | tag_input = []
256 | for b in range(bs):
257 | index = np.argwhere(tag[b] == 1)
258 | token = self.tag_list[index].squeeze(axis=1)
259 | tag_input.append(' | '.join(token))
260 |
261 | # tokenizer input tags
262 | tag_input_tokenzier = self.tokenizer(tag_input,
263 | padding='max_length',
264 | truncation=True,
265 | max_length=40,
266 | return_tensors="pt").to(
267 | image.device)
268 | encoder_input_ids = tag_input_tokenzier.input_ids
269 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
270 |
271 | # put input tag into image-tag interaction encoder to interact with image embeddings
272 | output_tagembedding = self.tag_encoder(
273 | encoder_input_ids,
274 | attention_mask=tag_input_tokenzier.attention_mask,
275 | encoder_hidden_states=image_embeds,
276 | encoder_attention_mask=image_atts,
277 | return_dict=True,
278 | )
279 |
280 | text = self.tokenizer(caption,
281 | padding='longest',
282 | truncation=True,
283 | max_length=40,
284 | return_tensors="pt").to(
285 | image.device)
286 |
287 | decoder_input_ids = text.input_ids
288 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id
289 |
290 | decoder_targets = decoder_input_ids.masked_fill(
291 | decoder_input_ids == self.tokenizer.pad_token_id, -100)
292 | decoder_targets[:,:self.prompt_length] = -100
293 |
294 | decoder_output = self.text_decoder(decoder_input_ids,
295 | attention_mask = text.attention_mask,
296 | encoder_hidden_states = output_tagembedding.last_hidden_state,
297 | encoder_attention_mask = None,
298 | labels = decoder_targets,
299 | return_dict = True,
300 | )
301 |
302 | loss_t2t = decoder_output.loss
303 |
304 | return loss_t2t, loss_tag, loss_dis
305 |
306 | def generate_tag(self,
307 | image,
308 | threshold=0.68,
309 | tag_input=None,
310 | ):
311 |
312 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
313 |
314 | image_embeds = self.image_proj(self.visual_encoder(image))
315 | image_atts = torch.ones(image_embeds.size()[:-1],
316 | dtype=torch.long).to(image.device)
317 |
318 | # recognized image tags using image-tag recogntiion decoder
319 | image_cls_embeds = image_embeds[:, 0, :]
320 | image_spatial_embeds = image_embeds[:, 1:, :]
321 |
322 | bs = image_spatial_embeds.shape[0]
323 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
324 | tagging_embed = self.tagging_head(
325 | encoder_embeds=label_embed,
326 | encoder_hidden_states=image_embeds,
327 | encoder_attention_mask=image_atts,
328 | return_dict=False,
329 | mode='tagging',
330 | )
331 |
332 | logits = self.fc(tagging_embed[0]).squeeze(-1)
333 |
334 | targets = torch.where(
335 | torch.sigmoid(logits) > self.class_threshold.to(image.device),
336 | torch.tensor(1.0).to(image.device),
337 | torch.zeros(self.num_class).to(image.device))
338 |
339 | tag = targets.cpu().numpy()
340 | tag[:,self.delete_tag_index] = 0
341 | tag_output = []
342 | tag_output_chinese = []
343 | for b in range(bs):
344 | index = np.argwhere(tag[b] == 1)
345 | token = self.tag_list[index].squeeze(axis=1)
346 | tag_output.append(' | '.join(token))
347 | token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
348 | tag_output_chinese.append(' | '.join(token_chinese))
349 |
350 |
351 | return tag_output, tag_output_chinese
352 |
353 | def generate_tag_openset(self,
354 | image,
355 | threshold=0.68,
356 | tag_input=None,
357 | ):
358 |
359 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
360 |
361 | image_embeds = self.image_proj(self.visual_encoder(image))
362 | image_atts = torch.ones(image_embeds.size()[:-1],
363 | dtype=torch.long).to(image.device)
364 |
365 | # recognized image tags using image-tag recogntiion decoder
366 | image_cls_embeds = image_embeds[:, 0, :]
367 | image_spatial_embeds = image_embeds[:, 1:, :]
368 |
369 | bs = image_spatial_embeds.shape[0]
370 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
371 | tagging_embed = self.tagging_head(
372 | encoder_embeds=label_embed,
373 | encoder_hidden_states=image_embeds,
374 | encoder_attention_mask=image_atts,
375 | return_dict=False,
376 | mode='tagging',
377 | )
378 |
379 | logits = self.fc(tagging_embed[0]).squeeze(-1)
380 |
381 | targets = torch.where(
382 | torch.sigmoid(logits) > self.class_threshold.to(image.device),
383 | torch.tensor(1.0).to(image.device),
384 | torch.zeros(self.num_class).to(image.device))
385 |
386 | tag = targets.cpu().numpy()
387 | tag[:,self.delete_tag_index] = 0
388 | tag_output = []
389 | for b in range(bs):
390 | index = np.argwhere(tag[b] == 1)
391 | token = self.tag_list[index].squeeze(axis=1)
392 | tag_output.append(' | '.join(token))
393 |
394 | return tag_output
395 |
396 |
397 | # load RAM pretrained model parameters
398 | def ram(pretrained='', **kwargs):
399 | model = RAM(**kwargs)
400 | if pretrained:
401 | if kwargs['vit'] == 'swin_b':
402 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
403 | elif kwargs['vit'] == 'swin_l':
404 | model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs)
405 | else:
406 | model, msg = load_checkpoint(model, pretrained)
407 | print('vit:', kwargs['vit'])
408 | # print('msg', msg)
409 | return model
410 |
--------------------------------------------------------------------------------
/ram/models/ram_plus.py:
--------------------------------------------------------------------------------
1 | '''
2 | * The Recognize Anything Plus Model (RAM++)
3 | * Written by Xinyu Huang
4 | '''
5 | import json
6 | import warnings
7 |
8 | import numpy as np
9 | import torch
10 | from torch import nn
11 |
12 | import torch.nn.functional as F
13 | from .bert import BertConfig, BertLMHeadModel, BertModel
14 | from .swin_transformer import SwinTransformer
15 | from .utils import *
16 |
17 | warnings.filterwarnings("ignore")
18 |
19 |
20 |
21 | class RAM_plus(nn.Module):
22 | def __init__(self,
23 | med_config=f'{CONFIG_PATH}/configs/med_config.json',
24 | image_size=384,
25 | text_encoder_type='bert-base-uncased',
26 | vit='base',
27 | vit_grad_ckpt=False,
28 | vit_ckpt_layer=0,
29 | threshold=0.68,
30 | delete_tag_index=[],
31 | tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt',
32 | tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt',
33 | stage='eval'):
34 | r""" The Recognize Anything Plus Model (RAM++) inference module.
35 | RAM++ is a strong image tagging model, which can recognize any category with high accuracy using tag categories.
36 | Described in the paper "Open-Set Image Tagging with Multi-Grained Text Supervision" https://arxiv.org/abs/2310.15200
37 |
38 | Args:
39 | med_config (str): path for the mixture of encoder-decoder model's configuration file
40 | image_size (int): input image size
41 | vit (str): model size of vision transformer
42 | threshold (int): tagging threshold
43 | delete_tag_index (list): delete some tags that may disturb captioning
44 | """
45 | super().__init__()
46 |
47 | # create image encoder
48 | if vit == 'swin_b':
49 | if image_size == 224:
50 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
51 | elif image_size == 384:
52 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
53 | vision_config = read_json(vision_config_path)
54 | assert image_size == vision_config['image_res']
55 | # assert config['patch_size'] == 32
56 | vision_width = vision_config['vision_width']
57 |
58 | self.visual_encoder = SwinTransformer(
59 | img_size=vision_config['image_res'],
60 | patch_size=4,
61 | in_chans=3,
62 | embed_dim=vision_config['embed_dim'],
63 | depths=vision_config['depths'],
64 | num_heads=vision_config['num_heads'],
65 | window_size=vision_config['window_size'],
66 | mlp_ratio=4.,
67 | qkv_bias=True,
68 | drop_rate=0.0,
69 | drop_path_rate=0.1,
70 | ape=False,
71 | patch_norm=True,
72 | use_checkpoint=False)
73 |
74 | if stage == 'train_from_scratch':
75 | # download from https://github.com/microsoft/Swin-Transformer
76 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model']
77 |
78 | for k in list(state_dict.keys()):
79 | if 'relative_position_bias_table' in k:
80 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2
81 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
82 | elif ('relative_position_index' in k) or ('attn_mask' in k):
83 | del state_dict[k]
84 |
85 | print("### Load Vision Backbone", vit)
86 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False)
87 | print("missing_keys: ", msg.missing_keys)
88 | print("unexpected_keys: ", msg.unexpected_keys)
89 |
90 | elif vit == 'swin_l':
91 | if image_size == 224:
92 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
93 | elif image_size == 384:
94 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
95 | vision_config = read_json(vision_config_path)
96 | assert image_size == vision_config['image_res']
97 | # assert config['patch_size'] == 32
98 | vision_width = vision_config['vision_width']
99 |
100 | self.visual_encoder = SwinTransformer(
101 | img_size=vision_config['image_res'],
102 | patch_size=4,
103 | in_chans=3,
104 | embed_dim=vision_config['embed_dim'],
105 | depths=vision_config['depths'],
106 | num_heads=vision_config['num_heads'],
107 | window_size=vision_config['window_size'],
108 | mlp_ratio=4.,
109 | qkv_bias=True,
110 | drop_rate=0.0,
111 | drop_path_rate=0.1,
112 | ape=False,
113 | patch_norm=True,
114 | use_checkpoint=False)
115 |
116 | if stage == 'train_from_scratch':
117 | # download from https://github.com/microsoft/Swin-Transformer
118 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model']
119 |
120 | for k in list(state_dict.keys()):
121 | if 'relative_position_bias_table' in k:
122 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2
123 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
124 | elif ('relative_position_index' in k) or ('attn_mask' in k):
125 | del state_dict[k]
126 |
127 | print("### Load Vision Backbone", vit)
128 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False)
129 | print("missing_keys: ", msg.missing_keys)
130 | print("unexpected_keys: ", msg.unexpected_keys)
131 |
132 | else:
133 | self.visual_encoder, vision_width = create_vit(
134 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
135 |
136 | # create tokenzier
137 | self.tokenizer = init_tokenizer(text_encoder_type)
138 |
139 | self.delete_tag_index = delete_tag_index
140 |
141 | # load tag list
142 | self.tag_list = self.load_tag_list(tag_list)
143 | self.tag_list_chinese = self.load_tag_list(tag_list_chinese)
144 |
145 | # create image-tag recognition decoder
146 | self.threshold = threshold
147 | self.num_class = len(self.tag_list)
148 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
149 | q2l_config.encoder_width = 512
150 | self.tagging_head = BertModel(config=q2l_config,
151 | add_pooling_layer=False)
152 | self.tagging_head.resize_token_embeddings(len(self.tokenizer))
153 |
154 | if stage == 'train_from_scratch':
155 | self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/frozen_tag_embedding/ram_plus_tag_embedding_class_4585_des_51.pth',map_location='cpu').float())
156 | else:
157 | # when eval with pretrained RAM++ model, directly load from ram_plus_swin_large_14m.pth
158 | self.label_embed = nn.Parameter(torch.zeros(self.num_class * 51, q2l_config.encoder_width))
159 |
160 | if q2l_config.hidden_size != 512:
161 | self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
162 | else:
163 | self.wordvec_proj = nn.Identity()
164 |
165 | self.fc = nn.Linear(q2l_config.hidden_size, 1)
166 |
167 | self.del_selfattention()
168 |
169 | self.image_proj = nn.Linear(vision_width, 512)
170 |
171 | # adjust thresholds for some tags
172 | self.class_threshold = torch.ones(self.num_class) * self.threshold
173 | ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt'
174 | with open(ram_class_threshold_path, 'r', encoding='utf-8') as f:
175 | ram_class_threshold = [float(s.strip()) for s in f]
176 | for key,value in enumerate(ram_class_threshold):
177 | self.class_threshold[key] = value
178 |
179 | self.reweight_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
180 |
181 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7,
182 | gamma_pos=0,
183 | clip=0.05)
184 |
185 | self.text_alignment_loss_function = AsymmetricLoss(gamma_neg=4,
186 | gamma_pos=0,
187 | clip=0.05)
188 |
189 | def load_tag_list(self, tag_list_file):
190 | with open(tag_list_file, 'r', encoding="utf-8") as f:
191 | tag_list = f.read().splitlines()
192 | tag_list = np.array(tag_list)
193 | return tag_list
194 |
195 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
196 | def del_selfattention(self):
197 | del self.tagging_head.embeddings
198 | for layer in self.tagging_head.encoder.layer:
199 | del layer.attention
200 |
201 | def forward(self, image, caption, image_tag, clip_feature, batch_text_embed):
202 | """
203 | call function as forward
204 |
205 | Args:
206 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384
207 | caption: type: list[string] len: batch_size
208 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0
209 |
210 | Returns:
211 | loss: type: torch.Tensor
212 | """
213 |
214 | image_embeds = self.image_proj(self.visual_encoder(image))
215 | image_atts = torch.ones(image_embeds.size()[:-1],
216 | dtype=torch.long).to(image.device)
217 |
218 | ##================= Distillation from CLIP ================##
219 | image_cls_embeds = image_embeds[:, 0, :]
220 | image_spatial_embeds = image_embeds[:, 1:, :]
221 |
222 | loss_dis = F.l1_loss(image_cls_embeds, clip_feature)
223 |
224 | ###===========multi tag des reweight==============###
225 | bs = image_embeds.shape[0]
226 |
227 | des_per_class = int(self.label_embed.shape[0] / self.num_class)
228 |
229 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True)
230 | reweight_scale = self.reweight_scale.exp()
231 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t())
232 | logits_per_image = logits_per_image.view(bs, -1,des_per_class)
233 |
234 | weight_normalized = F.softmax(logits_per_image, dim=2)
235 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype)
236 |
237 | for i in range(bs):
238 | reshaped_value = self.label_embed.view(-1, des_per_class, 512)
239 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value
240 | label_embed_reweight[i] = product.sum(dim=1)
241 |
242 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight))
243 |
244 | ##================= Image Tagging ================##
245 |
246 | tagging_embed = self.tagging_head(
247 | encoder_embeds=label_embed,
248 | encoder_hidden_states=image_embeds,
249 | encoder_attention_mask=image_atts,
250 | return_dict=False,
251 | mode='tagging',
252 | )
253 |
254 | logits = self.fc(tagging_embed[0]).squeeze(-1)
255 |
256 | loss_tag = self.tagging_loss_function(logits, image_tag)
257 |
258 | ##================= Image-text Alignment ================##
259 |
260 | batch_text_embed = torch.nn.functional.relu(self.wordvec_proj(batch_text_embed.to(self.label_embed.dtype)))
261 | batch_text_embed = batch_text_embed.unsqueeze(0).repeat(bs, 1, 1)
262 | alignment_embedding = self.tagging_head(
263 | encoder_embeds=batch_text_embed,
264 | encoder_hidden_states=image_embeds,
265 | encoder_attention_mask=image_atts,
266 | return_dict=False,
267 | mode='tagging',
268 | )
269 | alignment_logits = self.fc(alignment_embedding[0]).squeeze(-1)
270 |
271 | with torch.no_grad():
272 | alignment_targets = torch.zeros(alignment_logits.size()).to(image.device)
273 | alignment_targets.fill_diagonal_(1)
274 |
275 | loss_alignment = self.text_alignment_loss_function(alignment_logits,alignment_targets)
276 |
277 | return loss_tag, loss_dis, loss_alignment
278 |
279 |
280 | def generate_tag(self,
281 | image
282 | ):
283 |
284 | image_embeds = self.image_proj(self.visual_encoder(image))
285 | image_atts = torch.ones(image_embeds.size()[:-1],
286 | dtype=torch.long).to(image.device)
287 |
288 | image_cls_embeds = image_embeds[:, 0, :]
289 | image_spatial_embeds = image_embeds[:, 1:, :]
290 |
291 | bs = image_spatial_embeds.shape[0]
292 |
293 | des_per_class = int(self.label_embed.shape[0] / self.num_class)
294 |
295 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True)
296 | reweight_scale = self.reweight_scale.exp()
297 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t())
298 | logits_per_image = logits_per_image.view(bs, -1,des_per_class)
299 |
300 | weight_normalized = F.softmax(logits_per_image, dim=2)
301 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype)
302 |
303 | for i in range(bs):
304 | # 这里对 value_ori 进行 reshape,然后使用 broadcasting
305 | reshaped_value = self.label_embed.view(-1, des_per_class, 512)
306 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value
307 | label_embed_reweight[i] = product.sum(dim=1)
308 |
309 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight))
310 |
311 | # recognized image tags using alignment decoder
312 | tagging_embed = self.tagging_head(
313 | encoder_embeds=label_embed,
314 | encoder_hidden_states=image_embeds,
315 | encoder_attention_mask=image_atts,
316 | return_dict=False,
317 | mode='tagging',
318 | )
319 |
320 | logits = self.fc(tagging_embed[0]).squeeze(-1)
321 |
322 | targets = torch.where(
323 | torch.sigmoid(logits) > self.class_threshold.to(image.device),
324 | torch.tensor(1.0).to(image.device),
325 | torch.zeros(self.num_class).to(image.device))
326 |
327 | tag = targets.cpu().numpy()
328 | tag[:,self.delete_tag_index] = 0
329 | tag_output = []
330 | tag_output_chinese = []
331 | for b in range(bs):
332 | index = np.argwhere(tag[b] == 1)
333 | token = self.tag_list[index].squeeze(axis=1)
334 | tag_output.append(' | '.join(token))
335 | token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
336 | tag_output_chinese.append(' | '.join(token_chinese))
337 |
338 |
339 | return tag_output, tag_output_chinese
340 |
341 | def generate_tag_openset(self,
342 | image,
343 | threshold=0.68,
344 | tag_input=None,
345 | ):
346 |
347 | image_embeds = self.image_proj(self.visual_encoder(image))
348 | image_atts = torch.ones(image_embeds.size()[:-1],
349 | dtype=torch.long).to(image.device)
350 |
351 | image_cls_embeds = image_embeds[:, 0, :]
352 | image_spatial_embeds = image_embeds[:, 1:, :]
353 |
354 | bs = image_spatial_embeds.shape[0]
355 |
356 | des_per_class = int(self.label_embed.shape[0] / self.num_class)
357 |
358 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True)
359 | reweight_scale = self.reweight_scale.exp()
360 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t())
361 | logits_per_image = logits_per_image.view(bs, -1,des_per_class)
362 |
363 | weight_normalized = F.softmax(logits_per_image, dim=2)
364 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype)
365 |
366 | for i in range(bs):
367 | # 这里对 value_ori 进行 reshape,然后使用 broadcasting
368 | reshaped_value = self.label_embed.view(-1, des_per_class, 512)
369 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value
370 | label_embed_reweight[i] = product.sum(dim=1)
371 |
372 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight))
373 |
374 | # recognized image tags using alignment decoder
375 | tagging_embed = self.tagging_head(
376 | encoder_embeds=label_embed,
377 | encoder_hidden_states=image_embeds,
378 | encoder_attention_mask=image_atts,
379 | return_dict=False,
380 | mode='tagging',
381 | )
382 |
383 | logits = self.fc(tagging_embed[0]).squeeze(-1)
384 |
385 | targets = torch.where(
386 | torch.sigmoid(logits) > self.class_threshold.to(image.device),
387 | torch.tensor(1.0).to(image.device),
388 | torch.zeros(self.num_class).to(image.device))
389 |
390 | tag = targets.cpu().numpy()
391 | tag[:,self.delete_tag_index] = 0
392 | tag_output = []
393 | for b in range(bs):
394 | index = np.argwhere(tag[b] == 1)
395 | token = self.tag_list[index].squeeze(axis=1)
396 | tag_output.append(' | '.join(token))
397 |
398 | return tag_output
399 |
400 |
401 | # load RAM++ pretrained model parameters
402 | def ram_plus(pretrained='', **kwargs):
403 | model = RAM_plus(**kwargs)
404 | if pretrained:
405 | if kwargs['vit'] == 'swin_b':
406 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
407 | elif kwargs['vit'] == 'swin_l':
408 | model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs)
409 | else:
410 | model, msg = load_checkpoint(model, pretrained)
411 | print('vit:', kwargs['vit'])
412 | # print('msg', msg)
413 | return model
414 |
--------------------------------------------------------------------------------
/ram/models/tag2text.py:
--------------------------------------------------------------------------------
1 | '''
2 | * The Tag2Text Model
3 | * Written by Xinyu Huang
4 | '''
5 | import numpy as np
6 | import json
7 | import torch
8 | import warnings
9 |
10 | from torch import nn
11 | from .bert import BertConfig, BertModel, BertLMHeadModel
12 | from .swin_transformer import SwinTransformer
13 |
14 | from .utils import *
15 |
16 | warnings.filterwarnings("ignore")
17 |
18 |
19 | class Tag2Text(nn.Module):
20 |
21 | def __init__(self,
22 | med_config=f'{CONFIG_PATH}/configs/med_config.json',
23 | image_size=384,
24 | text_encoder_type='bert-base-uncased',
25 | vit='base',
26 | vit_grad_ckpt=False,
27 | vit_ckpt_layer=0,
28 | prompt='a picture of ',
29 | threshold=0.68,
30 | delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359],
31 | tag_list=f'{CONFIG_PATH}/data/tag2text_ori_tag_list.txt',
32 | stage='eval'):
33 | r""" Tag2Text inference module, both captioning and tagging are included.
34 | Tag2Text is an efficient and controllable vision-language pre-training framework.
35 | Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657
36 |
37 | Args:
38 | med_config (str): path for the mixture of encoder-decoder model's configuration file
39 | image_size (int): input image size
40 | vit (str): model size of vision transformer
41 | threshold (int): tagging threshold
42 | delete_tag_index (list): delete some tags that may disturb captioning
43 | """
44 | super().__init__()
45 |
46 | # create image encoder
47 | if vit == 'swin_b':
48 | if image_size == 224:
49 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
50 | elif image_size == 384:
51 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
52 | vision_config = read_json(vision_config_path)
53 | assert image_size == vision_config['image_res']
54 | # assert config['patch_size'] == 32
55 | vision_width = vision_config['vision_width']
56 |
57 | self.visual_encoder = SwinTransformer(
58 | img_size=vision_config['image_res'],
59 | patch_size=4,
60 | in_chans=3,
61 | embed_dim=vision_config['embed_dim'],
62 | depths=vision_config['depths'],
63 | num_heads=vision_config['num_heads'],
64 | window_size=vision_config['window_size'],
65 | mlp_ratio=4.,
66 | qkv_bias=True,
67 | drop_rate=0.0,
68 | drop_path_rate=0.1,
69 | ape=False,
70 | patch_norm=True,
71 | use_checkpoint=False)
72 |
73 | if stage == 'train_from_scratch':
74 | # download from https://github.com/microsoft/Swin-Transformer
75 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model']
76 |
77 | for k in list(state_dict.keys()):
78 | if 'relative_position_bias_table' in k:
79 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2
80 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
81 | elif ('relative_position_index' in k) or ('attn_mask' in k):
82 | del state_dict[k]
83 |
84 | print("### Load Vision Backbone", vit)
85 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False)
86 | print("missing_keys: ", msg.missing_keys)
87 | print("unexpected_keys: ", msg.unexpected_keys)
88 |
89 | else:
90 | self.visual_encoder, vision_width = create_vit(
91 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
92 |
93 | # create tokenzier
94 | self.tokenizer = init_tokenizer(text_encoder_type)
95 |
96 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
97 | # create image-tag interaction encoder
98 | encoder_config = BertConfig.from_json_file(med_config)
99 | encoder_config.encoder_width = vision_width
100 | self.tag_encoder = BertModel(config=encoder_config,
101 | add_pooling_layer=False)
102 |
103 | # create image-tag-text decoder
104 | decoder_config = BertConfig.from_json_file(med_config)
105 | self.text_decoder = BertLMHeadModel(config=decoder_config)
106 |
107 | # delete some tags that may disturb captioning
108 | # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
109 | self.delete_tag_index = delete_tag_index
110 | self.prompt = prompt
111 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
112 |
113 | # load tag list
114 | self.tag_list = self.load_tag_list(tag_list)
115 |
116 | # create image-tag recognition decoder
117 | self.threshold = threshold
118 | self.num_class = len(self.tag_list)
119 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
120 | q2l_config.encoder_width = vision_width
121 | self.tagging_head = BertModel(config=q2l_config,
122 | add_pooling_layer=False)
123 | self.tagging_head.resize_token_embeddings(len(self.tokenizer))
124 | self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
125 | self.fc = GroupWiseLinear(self.num_class,
126 | q2l_config.hidden_size,
127 | bias=True)
128 | self.del_selfattention()
129 |
130 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7,
131 | gamma_pos=0,
132 | clip=0.05)
133 |
134 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
135 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
136 | ' ')
137 |
138 | # adjust thresholds for some tags
139 | # default threshold: 0.68
140 | # 2701: "person"; 2828: "man"; 1167: "woman";
141 | tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
142 | self.class_threshold = torch.ones(self.num_class) * self.threshold
143 | for key,value in tag_thrshold.items():
144 | self.class_threshold[key] = value
145 |
146 | def load_tag_list(self, tag_list_file):
147 | with open(tag_list_file, 'r') as f:
148 | tag_list = f.read().splitlines()
149 | tag_list = np.array(tag_list)
150 | return tag_list
151 |
152 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
153 | def del_selfattention(self):
154 | del self.tagging_head.embeddings
155 | for layer in self.tagging_head.encoder.layer:
156 | del layer.attention
157 |
158 |
159 | def forward(self, image, caption, tag):
160 | """
161 | call function as forward
162 |
163 | Args:
164 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384
165 | caption: type: list[string] len: batch_size
166 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0
167 |
168 | Returns:
169 | loss: type: torch.Tensor
170 | """
171 |
172 | image_embeds = self.visual_encoder(image)
173 | image_atts = torch.ones(image_embeds.size()[:-1],
174 | dtype=torch.long).to(image.device)
175 |
176 | ##================= Image Tagging ================##
177 | bs = image_embeds.shape[0]
178 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
179 |
180 | tagging_embed = self.tagging_head(
181 | encoder_embeds=label_embed,
182 | encoder_hidden_states=image_embeds,
183 | encoder_attention_mask=image_atts,
184 | return_dict=False,
185 | mode='tagging',
186 | )
187 |
188 | logits = self.fc(tagging_embed[0])
189 |
190 | loss_tag = self.tagging_loss_function(logits, tag)
191 |
192 | ##================= Image-Tag-Text Generation ================##
193 | tag = tag.cpu().numpy()
194 | tag_input = []
195 | for b in range(bs):
196 | index = np.argwhere(tag[b] == 1)
197 | token = self.tag_list[index].squeeze(axis=1)
198 | tag_input.append(' | '.join(token))
199 |
200 | # tokenizer input tags
201 | tag_input_tokenzier = self.tokenizer(tag_input,
202 | padding='max_length',
203 | truncation=True,
204 | max_length=40,
205 | return_tensors="pt").to(
206 | image.device)
207 | encoder_input_ids = tag_input_tokenzier.input_ids
208 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
209 |
210 | # put input tag into image-tag interaction encoder to interact with image embeddings
211 | output_tagembedding = self.tag_encoder(
212 | encoder_input_ids,
213 | attention_mask=tag_input_tokenzier.attention_mask,
214 | encoder_hidden_states=image_embeds,
215 | encoder_attention_mask=image_atts,
216 | return_dict=True,
217 | )
218 |
219 | text = self.tokenizer(caption,
220 | padding='longest',
221 | truncation=True,
222 | max_length=40,
223 | return_tensors="pt").to(
224 | image.device)
225 |
226 | decoder_input_ids = text.input_ids
227 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id
228 |
229 | decoder_targets = decoder_input_ids.masked_fill(
230 | decoder_input_ids == self.tokenizer.pad_token_id, -100)
231 | decoder_targets[:,:self.prompt_length] = -100
232 |
233 | decoder_output = self.text_decoder(decoder_input_ids,
234 | attention_mask = text.attention_mask,
235 | encoder_hidden_states = output_tagembedding.last_hidden_state,
236 | encoder_attention_mask = None,
237 | labels = decoder_targets,
238 | return_dict = True,
239 | )
240 |
241 | loss_t2t = decoder_output.loss
242 |
243 | return loss_t2t, loss_tag
244 |
245 |
246 | def generate(self,
247 | image,
248 | sample=False,
249 | num_beams=3,
250 | max_length=30,
251 | min_length=10,
252 | top_p=0.9,
253 | repetition_penalty=1.0,
254 | tag_input=None,
255 | return_tag_predict=False):
256 |
257 | image_embeds = self.visual_encoder(image)
258 | image_atts = torch.ones(image_embeds.size()[:-1],
259 | dtype=torch.long).to(image.device)
260 |
261 | # if not user specified tags, recognized image tags using image-tag recogntiion decoder
262 | if tag_input == None:
263 |
264 | bs = image_embeds.shape[0]
265 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
266 | tagging_embed = self.tagging_head(
267 | encoder_embeds=label_embed,
268 | encoder_hidden_states=image_embeds,
269 | encoder_attention_mask=image_atts,
270 | return_dict=False,
271 | mode='tagging',
272 | )
273 |
274 | logits = self.fc(tagging_embed[0])
275 |
276 | targets = torch.where(
277 | torch.sigmoid(logits) > self.class_threshold.to(image.device),
278 | torch.tensor(1.0).to(image.device),
279 | torch.zeros(self.num_class).to(image.device))
280 |
281 | tag = targets.cpu().numpy()
282 |
283 | # delete some tags that may disturb captioning
284 | tag[:, self.delete_tag_index] = 0
285 |
286 | tag_input = []
287 | for b in range(bs):
288 | index = np.argwhere(tag[b] == 1)
289 | token = self.tag_list[index].squeeze(axis=1)
290 | tag_input.append(' | '.join(token))
291 |
292 | tag_output = tag_input
293 |
294 | # beam search for text generation(default)
295 | if not sample:
296 | image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
297 | tag_input_temp = []
298 | for tag in tag_input:
299 | for i in range(num_beams):
300 | tag_input_temp.append(tag)
301 | tag_input = tag_input_temp
302 |
303 | image_atts = torch.ones(image_embeds.size()[:-1],
304 | dtype=torch.long).to(image.device)
305 |
306 | # tokenizer input tags
307 | tag_input_tokenzier = self.tokenizer(tag_input,
308 | padding='max_length',
309 | truncation=True,
310 | max_length=40,
311 | return_tensors="pt").to(
312 | image.device)
313 | encoder_input_ids = tag_input_tokenzier.input_ids
314 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
315 |
316 | # put input tag into image-tag interaction encoder to interact with image embeddings
317 | output_tagembedding = self.tag_encoder(
318 | encoder_input_ids,
319 | attention_mask=tag_input_tokenzier.attention_mask,
320 | encoder_hidden_states=image_embeds,
321 | encoder_attention_mask=image_atts,
322 | return_dict=True,
323 | )
324 |
325 | # prompt trick for better captioning, followed BLIP
326 | prompt = [self.prompt] * image.size(0)
327 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
328 | image.device)
329 | input_ids[:, 0] = self.tokenizer.bos_token_id
330 | input_ids = input_ids[:, :-1]
331 |
332 | if sample:
333 | # nucleus sampling
334 | model_kwargs = {
335 | "encoder_hidden_states": output_tagembedding.last_hidden_state,
336 | "encoder_attention_mask": None
337 | }
338 | outputs = self.text_decoder.generate(
339 | input_ids=input_ids,
340 | max_length=max_length,
341 | min_length=min_length,
342 | do_sample=True,
343 | top_p=top_p,
344 | num_return_sequences=1,
345 | eos_token_id=self.tokenizer.sep_token_id,
346 | pad_token_id=self.tokenizer.pad_token_id,
347 | repetition_penalty=1.1,
348 | **model_kwargs)
349 | else:
350 | # beam search (default)
351 | model_kwargs = {
352 | "encoder_hidden_states": output_tagembedding.last_hidden_state,
353 | "encoder_attention_mask": None
354 | }
355 | outputs = self.text_decoder.generate(
356 | input_ids=input_ids,
357 | max_length=max_length,
358 | min_length=min_length,
359 | num_beams=num_beams,
360 | eos_token_id=self.tokenizer.sep_token_id,
361 | pad_token_id=self.tokenizer.pad_token_id,
362 | repetition_penalty=repetition_penalty,
363 | **model_kwargs)
364 |
365 | captions = []
366 | for output in outputs:
367 | caption = self.tokenizer.decode(output, skip_special_tokens=True)
368 | captions.append(caption[len(self.prompt):])
369 | if return_tag_predict == True:
370 | return captions, tag_output
371 | return captions
372 |
373 |
374 | # load Tag2Text pretrained model parameters
375 | def tag2text(pretrained='', **kwargs):
376 | model = Tag2Text(**kwargs)
377 | if pretrained:
378 | if kwargs['vit'] == 'swin_b':
379 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
380 | else:
381 | model, msg = load_checkpoint(model, pretrained)
382 | print('vit:', kwargs['vit'])
383 | # print('msg', msg)
384 | return model
385 |
386 |
--------------------------------------------------------------------------------
/ram/models/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import math
5 |
6 | from torch import nn
7 | from typing import List
8 | from transformers import BertTokenizer
9 | from urllib.parse import urlparse
10 | from timm.models.hub import download_cached_file
11 | from .vit import interpolate_pos_embed
12 | from .swin_transformer import interpolate_relative_pos_embed
13 | from pathlib import Path
14 | CONFIG_PATH=(Path(__file__).resolve().parents[1])
15 |
16 | def read_json(rpath):
17 | with open(rpath, 'r') as f:
18 | return json.load(f)
19 |
20 |
21 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module,
22 | base_model_prefix: str, skip_key: str):
23 | uninitialized_encoder_weights: List[str] = []
24 | if decoder.__class__ != encoder.__class__:
25 | logger.info(
26 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
27 | )
28 |
29 | def tie_encoder_to_decoder_recursively(
30 | decoder_pointer: nn.Module,
31 | encoder_pointer: nn.Module,
32 | module_name: str,
33 | uninitialized_encoder_weights: List[str],
34 | skip_key: str,
35 | depth=0,
36 | ):
37 | assert isinstance(decoder_pointer, nn.Module) and isinstance(
38 | encoder_pointer, nn.Module
39 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
40 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
41 | assert hasattr(encoder_pointer, "weight")
42 | encoder_pointer.weight = decoder_pointer.weight
43 | if hasattr(decoder_pointer, "bias"):
44 | assert hasattr(encoder_pointer, "bias")
45 | encoder_pointer.bias = decoder_pointer.bias
46 | print(module_name + ' is tied')
47 | return
48 |
49 | encoder_modules = encoder_pointer._modules
50 | decoder_modules = decoder_pointer._modules
51 | if len(decoder_modules) > 0:
52 | assert (
53 | len(encoder_modules) > 0
54 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
55 |
56 | all_encoder_weights = set([
57 | module_name + "/" + sub_name
58 | for sub_name in encoder_modules.keys()
59 | ])
60 | encoder_layer_pos = 0
61 | for name, module in decoder_modules.items():
62 | if name.isdigit():
63 | encoder_name = str(int(name) + encoder_layer_pos)
64 | decoder_name = name
65 | if not isinstance(
66 | decoder_modules[decoder_name],
67 | type(encoder_modules[encoder_name])) and len(
68 | encoder_modules) != len(decoder_modules):
69 | # this can happen if the name corresponds to the position in a list module list of layers
70 | # in this case the decoder has added a cross-attention that the encoder does not have
71 | # thus skip this step and subtract one layer pos from encoder
72 | encoder_layer_pos -= 1
73 | continue
74 | elif name not in encoder_modules:
75 | continue
76 | elif depth > 500:
77 | raise ValueError(
78 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
79 | )
80 | else:
81 | decoder_name = encoder_name = name
82 | tie_encoder_to_decoder_recursively(
83 | decoder_modules[decoder_name],
84 | encoder_modules[encoder_name],
85 | module_name + "/" + name,
86 | uninitialized_encoder_weights,
87 | skip_key,
88 | depth=depth + 1,
89 | )
90 | all_encoder_weights.remove(module_name + "/" + encoder_name)
91 |
92 | uninitialized_encoder_weights += list(all_encoder_weights)
93 |
94 | # tie weights recursively
95 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix,
96 | uninitialized_encoder_weights, skip_key)
97 |
98 |
99 | class GroupWiseLinear(nn.Module):
100 | # could be changed to:
101 | # output = torch.einsum('ijk,zjk->ij', x, self.W)
102 | # or output = torch.einsum('ijk,jk->ij', x, self.W[0])
103 | def __init__(self, num_class, hidden_dim, bias=True):
104 | super().__init__()
105 | self.num_class = num_class
106 | self.hidden_dim = hidden_dim
107 | self.bias = bias
108 |
109 | self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
110 | if bias:
111 | self.b = nn.Parameter(torch.Tensor(1, num_class))
112 | self.reset_parameters()
113 |
114 | def reset_parameters(self):
115 | stdv = 1. / math.sqrt(self.W.size(2))
116 | for i in range(self.num_class):
117 | self.W[0][i].data.uniform_(-stdv, stdv)
118 | if self.bias:
119 | for i in range(self.num_class):
120 | self.b[0][i].data.uniform_(-stdv, stdv)
121 |
122 | def forward(self, x):
123 | # x: B,K,d
124 | x = (self.W * x).sum(-1)
125 | if self.bias:
126 | x = x + self.b
127 | return x
128 |
129 |
130 | def init_tokenizer(text_encoder_type='bert-base-uncased'):
131 | tokenizer = BertTokenizer.from_pretrained(text_encoder_type)
132 | tokenizer.add_special_tokens({'bos_token': '[DEC]'})
133 | tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
134 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
135 | return tokenizer
136 |
137 |
138 | def create_vit(vit,
139 | image_size,
140 | use_grad_checkpointing=False,
141 | ckpt_layer=0,
142 | drop_path_rate=0):
143 |
144 | assert vit in ['base', 'large'], "vit parameter must be base or large"
145 | if vit == 'base':
146 | vision_width = 768
147 | visual_encoder = VisionTransformer(
148 | img_size=image_size,
149 | patch_size=16,
150 | embed_dim=vision_width,
151 | depth=12,
152 | num_heads=12,
153 | use_grad_checkpointing=use_grad_checkpointing,
154 | ckpt_layer=ckpt_layer,
155 | drop_path_rate=0 or drop_path_rate)
156 | elif vit == 'large':
157 | vision_width = 1024
158 | visual_encoder = VisionTransformer(
159 | img_size=image_size,
160 | patch_size=16,
161 | embed_dim=vision_width,
162 | depth=24,
163 | num_heads=16,
164 | use_grad_checkpointing=use_grad_checkpointing,
165 | ckpt_layer=ckpt_layer,
166 | drop_path_rate=0.1 or drop_path_rate)
167 | return visual_encoder, vision_width
168 |
169 |
170 | def is_url(url_or_filename):
171 | parsed = urlparse(url_or_filename)
172 | return parsed.scheme in ("http", "https")
173 |
174 |
175 | def load_checkpoint(model, url_or_filename):
176 | if is_url(url_or_filename):
177 | cached_file = download_cached_file(url_or_filename,
178 | check_hash=False,
179 | progress=True)
180 | checkpoint = torch.load(cached_file, map_location='cpu')
181 | elif os.path.isfile(url_or_filename):
182 | checkpoint = torch.load(url_or_filename, map_location='cpu')
183 | else:
184 | raise RuntimeError('checkpoint url or path is invalid')
185 |
186 | state_dict = checkpoint['model']
187 |
188 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(
189 | state_dict['visual_encoder.pos_embed'], model.visual_encoder)
190 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
191 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(
192 | state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m)
193 | for key in model.state_dict().keys():
194 | if key in state_dict.keys():
195 | if state_dict[key].shape != model.state_dict()[key].shape:
196 | del state_dict[key]
197 |
198 | msg = model.load_state_dict(state_dict, strict=False)
199 | print('load checkpoint from %s' % url_or_filename)
200 | return model, msg
201 |
202 |
203 | def load_checkpoint_swinbase(model, url_or_filename, kwargs):
204 | if kwargs['image_size'] == 224:
205 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
206 | elif kwargs['image_size'] == 384:
207 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
208 | window_size = read_json(vision_config_path)['window_size']
209 | print('--------------')
210 | print(url_or_filename)
211 | print('--------------')
212 | if is_url(url_or_filename):
213 | cached_file = download_cached_file(url_or_filename,
214 | check_hash=False,
215 | progress=True)
216 | checkpoint = torch.load(cached_file, map_location='cpu')
217 | elif os.path.isfile(url_or_filename):
218 | checkpoint = torch.load(url_or_filename, map_location='cpu')
219 | else:
220 | raise RuntimeError('checkpoint url or path is invalid')
221 |
222 | state_dict = checkpoint['model']
223 |
224 | for k in list(state_dict.keys()):
225 | if 'relative_position_bias_table' in k:
226 | dst_num_pos = (2 * window_size - 1)**2
227 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k],
228 | dst_num_pos,
229 | param_name=k)
230 | elif ('relative_position_index' in k) or ('attn_mask' in k):
231 | del state_dict[k]
232 | elif "vision_multi" in k:
233 | state_dict[k.replace("vision_multi",
234 | "tagging_head")] = state_dict.pop(k)
235 |
236 | msg = model.load_state_dict(state_dict, strict=False)
237 | print('load checkpoint from %s' % url_or_filename)
238 | return model, msg
239 |
240 |
241 | def load_checkpoint_swinlarge(model, url_or_filename, kwargs):
242 | if kwargs['image_size'] == 224:
243 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
244 | elif kwargs['image_size'] == 384:
245 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
246 | window_size = read_json(vision_config_path)['window_size']
247 | print('--------------')
248 | print(url_or_filename)
249 | print('--------------')
250 | if is_url(url_or_filename):
251 | cached_file = download_cached_file(url_or_filename,
252 | check_hash=False,
253 | progress=True)
254 | checkpoint = torch.load(cached_file, map_location='cpu')
255 | elif os.path.isfile(url_or_filename):
256 | checkpoint = torch.load(url_or_filename, map_location='cpu')
257 | else:
258 | raise RuntimeError('checkpoint url or path is invalid')
259 |
260 | state_dict = checkpoint['model']
261 |
262 | for k in list(state_dict.keys()):
263 | if 'relative_position_bias_table' in k:
264 | dst_num_pos = (2 * window_size - 1)**2
265 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k],
266 | dst_num_pos,
267 | param_name=k)
268 | elif ('relative_position_index' in k) or ('attn_mask' in k):
269 | del state_dict[k]
270 | elif "vision_multi" in k:
271 | state_dict[k.replace("vision_multi",
272 | "tagging_head")] = state_dict.pop(k)
273 |
274 | msg = model.load_state_dict(state_dict, strict=False)
275 | print('load checkpoint from %s' % url_or_filename)
276 | return model, msg
277 |
278 |
279 | # Tagging loss function
280 | # copy from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
281 | class AsymmetricLoss(nn.Module):
282 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
283 | super(AsymmetricLoss, self).__init__()
284 |
285 | self.gamma_neg = gamma_neg
286 | self.gamma_pos = gamma_pos
287 | self.clip = clip
288 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
289 | self.eps = eps
290 |
291 | def forward(self, x, y):
292 | """"
293 | Parameters
294 | ----------
295 | x: input logits
296 | y: targets (multi-label binarized vector)
297 | """
298 |
299 | # Calculating Probabilities
300 | x_sigmoid = torch.sigmoid(x)
301 | xs_pos = x_sigmoid
302 | xs_neg = 1 - x_sigmoid
303 |
304 | # Asymmetric Clipping
305 | if self.clip is not None and self.clip > 0:
306 | xs_neg = (xs_neg + self.clip).clamp(max=1)
307 |
308 | # Basic CE calculation
309 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
310 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
311 | loss = los_pos + los_neg
312 |
313 | # Asymmetric Focusing
314 | if self.gamma_neg > 0 or self.gamma_pos > 0:
315 | if self.disable_torch_grad_focal_loss:
316 | torch.set_grad_enabled(False)
317 | pt0 = xs_pos * y
318 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
319 | pt = pt0 + pt1
320 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
321 | one_sided_w = torch.pow(1 - pt, one_sided_gamma)
322 | if self.disable_torch_grad_focal_loss:
323 | torch.set_grad_enabled(True)
324 | loss *= one_sided_w
325 |
326 | return -loss.sum()
--------------------------------------------------------------------------------
/ram/models/vit.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | * Based on timm code base
8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from functools import partial
15 |
16 | from timm.models.vision_transformer import _cfg, PatchEmbed
17 | from timm.models.registry import register_model
18 | from timm.models.layers import trunc_normal_, DropPath
19 | from timm.models.helpers import named_apply, adapt_input_conv
20 |
21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22 |
23 | class Mlp(nn.Module):
24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25 | """
26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x):
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
43 |
44 | class Attention(nn.Module):
45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46 | super().__init__()
47 | self.num_heads = num_heads
48 | head_dim = dim // num_heads
49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50 | self.scale = qk_scale or head_dim ** -0.5
51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52 | self.attn_drop = nn.Dropout(attn_drop)
53 | self.proj = nn.Linear(dim, dim)
54 | self.proj_drop = nn.Dropout(proj_drop)
55 | self.attn_gradients = None
56 | self.attention_map = None
57 |
58 | def save_attn_gradients(self, attn_gradients):
59 | self.attn_gradients = attn_gradients
60 |
61 | def get_attn_gradients(self):
62 | return self.attn_gradients
63 |
64 | def save_attention_map(self, attention_map):
65 | self.attention_map = attention_map
66 |
67 | def get_attention_map(self):
68 | return self.attention_map
69 |
70 | def forward(self, x, register_hook=False):
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74 |
75 | attn = (q @ k.transpose(-2, -1)) * self.scale
76 | attn = attn.softmax(dim=-1)
77 | attn = self.attn_drop(attn)
78 |
79 | if register_hook:
80 | self.save_attention_map(attn)
81 | attn.register_hook(self.save_attn_gradients)
82 |
83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84 | x = self.proj(x)
85 | x = self.proj_drop(x)
86 | return x
87 |
88 |
89 | class Block(nn.Module):
90 |
91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93 | super().__init__()
94 | self.norm1 = norm_layer(dim)
95 | self.attn = Attention(
96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99 | self.norm2 = norm_layer(dim)
100 | mlp_hidden_dim = int(dim * mlp_ratio)
101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102 |
103 | if use_grad_checkpointing:
104 | self.attn = checkpoint_wrapper(self.attn)
105 | self.mlp = checkpoint_wrapper(self.mlp)
106 |
107 | def forward(self, x, register_hook=False):
108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109 | x = x + self.drop_path(self.mlp(self.norm2(x)))
110 | return x
111 |
112 |
113 | class VisionTransformer(nn.Module):
114 | """ Vision Transformer
115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116 | https://arxiv.org/abs/2010.11929
117 | """
118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121 | use_grad_checkpointing=False, ckpt_layer=0):
122 | """
123 | Args:
124 | img_size (int, tuple): input image size
125 | patch_size (int, tuple): patch size
126 | in_chans (int): number of input channels
127 | num_classes (int): number of classes for classification head
128 | embed_dim (int): embedding dimension
129 | depth (int): depth of transformer
130 | num_heads (int): number of attention heads
131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132 | qkv_bias (bool): enable bias for qkv if True
133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135 | drop_rate (float): dropout rate
136 | attn_drop_rate (float): attention dropout rate
137 | drop_path_rate (float): stochastic depth rate
138 | norm_layer: (nn.Module): normalization layer
139 | """
140 | super().__init__()
141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143 |
144 | self.patch_embed = PatchEmbed(
145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146 |
147 | num_patches = self.patch_embed.num_patches
148 |
149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151 | self.pos_drop = nn.Dropout(p=drop_rate)
152 |
153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154 | self.blocks = nn.ModuleList([
155 | Block(
156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159 | )
160 | for i in range(depth)])
161 | self.norm = norm_layer(embed_dim)
162 |
163 | trunc_normal_(self.pos_embed, std=.02)
164 | trunc_normal_(self.cls_token, std=.02)
165 | self.apply(self._init_weights)
166 |
167 | def _init_weights(self, m):
168 | if isinstance(m, nn.Linear):
169 | trunc_normal_(m.weight, std=.02)
170 | if isinstance(m, nn.Linear) and m.bias is not None:
171 | nn.init.constant_(m.bias, 0)
172 | elif isinstance(m, nn.LayerNorm):
173 | nn.init.constant_(m.bias, 0)
174 | nn.init.constant_(m.weight, 1.0)
175 |
176 | @torch.jit.ignore
177 | def no_weight_decay(self):
178 | return {'pos_embed', 'cls_token'}
179 |
180 | def forward(self, x, register_blk=-1):
181 | B = x.shape[0]
182 | x = self.patch_embed(x)
183 |
184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185 | x = torch.cat((cls_tokens, x), dim=1)
186 |
187 | x = x + self.pos_embed[:,:x.size(1),:]
188 | x = self.pos_drop(x)
189 |
190 | for i,blk in enumerate(self.blocks):
191 | x = blk(x, register_blk==i)
192 | x = self.norm(x)
193 |
194 | return x
195 |
196 | @torch.jit.ignore()
197 | def load_pretrained(self, checkpoint_path, prefix=''):
198 | _load_weights(self, checkpoint_path, prefix)
199 |
200 |
201 | @torch.no_grad()
202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204 | """
205 | import numpy as np
206 |
207 | def _n2p(w, t=True):
208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209 | w = w.flatten()
210 | if t:
211 | if w.ndim == 4:
212 | w = w.transpose([3, 2, 0, 1])
213 | elif w.ndim == 3:
214 | w = w.transpose([2, 0, 1])
215 | elif w.ndim == 2:
216 | w = w.transpose([1, 0])
217 | return torch.from_numpy(w)
218 |
219 | w = np.load(checkpoint_path)
220 | if not prefix and 'opt/target/embedding/kernel' in w:
221 | prefix = 'opt/target/'
222 |
223 | if hasattr(model.patch_embed, 'backbone'):
224 | # hybrid
225 | backbone = model.patch_embed.backbone
226 | stem_only = not hasattr(backbone, 'stem')
227 | stem = backbone if stem_only else backbone.stem
228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231 | if not stem_only:
232 | for i, stage in enumerate(backbone.stages):
233 | for j, block in enumerate(stage.blocks):
234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235 | for r in range(3):
236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239 | if block.downsample is not None:
240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244 | else:
245 | embed_conv_w = adapt_input_conv(
246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247 | model.patch_embed.proj.weight.copy_(embed_conv_w)
248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251 | if pos_embed_w.shape != model.pos_embed.shape:
252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254 | model.pos_embed.copy_(pos_embed_w)
255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263 | for i, block in enumerate(model.blocks.children()):
264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268 | block.attn.qkv.weight.copy_(torch.cat([
269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270 | block.attn.qkv.bias.copy_(torch.cat([
271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274 | for r in range(2):
275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279 |
280 |
281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282 | # interpolate position embedding
283 | embedding_size = pos_embed_checkpoint.shape[-1]
284 | num_patches = visual_encoder.patch_embed.num_patches
285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286 | # height (== width) for the checkpoint position embedding
287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288 | # height (== width) for the new position embedding
289 | new_size = int(num_patches ** 0.5)
290 |
291 | if orig_size!=new_size:
292 | # class_token and dist_token are kept unchanged
293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294 | # only the position tokens are interpolated
295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297 | pos_tokens = torch.nn.functional.interpolate(
298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302 |
303 | return new_pos_embed
304 | else:
305 | return pos_embed_checkpoint
--------------------------------------------------------------------------------
/ram/transform.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Normalize, Compose, Resize, ToTensor
2 |
3 |
4 | def convert_to_rgb(image):
5 | return image.convert("RGB")
6 |
7 | def get_transform(image_size=384):
8 | return Compose([
9 | convert_to_rgb,
10 | Resize((image_size, image_size)),
11 | ToTensor(),
12 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13 | ])
14 |
--------------------------------------------------------------------------------
/ram/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import get_mAP, get_PR
2 | from .openset_utils import build_openset_label_embedding, build_openset_llm_label_embedding
3 |
--------------------------------------------------------------------------------
/ram/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import numpy as np
4 | from numpy import ndarray
5 |
6 |
7 | def get_mAP(
8 | preds: ndarray,
9 | gt_file: str,
10 | taglist: List[str]
11 | ) -> Tuple[float, ndarray]:
12 | assert preds.shape[1] == len(taglist)
13 |
14 | # When mapping categories from test datasets to our system, there might be
15 | # multiple vs one situation due to different semantic definitions of tags.
16 | # So there can be duplicate tags in `taglist`. This special case is taken
17 | # into account.
18 | tag2idxs = {}
19 | for idx, tag in enumerate(taglist):
20 | if tag not in tag2idxs:
21 | tag2idxs[tag] = []
22 | tag2idxs[tag].append(idx)
23 |
24 | # build targets
25 | targets = np.zeros_like(preds)
26 | with open(gt_file, "r") as f:
27 | lines = [line.strip("\n").split(",") for line in f.readlines()]
28 | assert len(lines) == targets.shape[0]
29 | for i, line in enumerate(lines):
30 | for tag in line[1:]:
31 | targets[i, tag2idxs[tag]] = 1.0
32 |
33 | # compute average precision for each class
34 | APs = np.zeros(preds.shape[1])
35 | for k in range(preds.shape[1]):
36 | APs[k] = _average_precision(preds[:, k], targets[:, k])
37 |
38 | return APs.mean(), APs
39 |
40 |
41 | def _average_precision(output: ndarray, target: ndarray) -> float:
42 | epsilon = 1e-8
43 |
44 | # sort examples
45 | indices = output.argsort()[::-1]
46 | # Computes prec@i
47 | total_count_ = np.cumsum(np.ones((len(output), 1)))
48 |
49 | target_ = target[indices]
50 | ind = target_ == 1
51 | pos_count_ = np.cumsum(ind)
52 | total = pos_count_[-1]
53 | pos_count_[np.logical_not(ind)] = 0
54 | pp = pos_count_ / total_count_
55 | precision_at_i_ = np.sum(pp)
56 | precision_at_i = precision_at_i_ / (total + epsilon)
57 |
58 | return precision_at_i
59 |
60 |
61 | def get_PR(
62 | pred_file: str,
63 | gt_file: str,
64 | taglist: List[str]
65 | ) -> Tuple[float, float, ndarray, ndarray]:
66 | # When mapping categories from test datasets to our system, there might be
67 | # multiple vs one situation due to different semantic definitions of tags.
68 | # So there can be duplicate tags in `taglist`. This special case is taken
69 | # into account.
70 | tag2idxs = {}
71 | for idx, tag in enumerate(taglist):
72 | if tag not in tag2idxs:
73 | tag2idxs[tag] = []
74 | tag2idxs[tag].append(idx)
75 |
76 | # build preds
77 | with open(pred_file, "r", encoding="utf-8") as f:
78 | lines = [line.strip().split(",") for line in f.readlines()]
79 | preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
80 | for i, line in enumerate(lines):
81 | for tag in line[1:]:
82 | preds[i, tag2idxs[tag]] = True
83 |
84 | # build targets
85 | with open(gt_file, "r", encoding="utf-8") as f:
86 | lines = [line.strip().split(",") for line in f.readlines()]
87 | targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
88 | for i, line in enumerate(lines):
89 | for tag in line[1:]:
90 | targets[i, tag2idxs[tag]] = True
91 |
92 | assert preds.shape == targets.shape
93 |
94 | # calculate P and R
95 | TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222
96 | FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222
97 | FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222
98 | eps = 1.e-9
99 | Ps = TPs / (TPs + FPs + eps)
100 | Rs = TPs / (TPs + FNs + eps)
101 |
102 | return Ps.mean(), Rs.mean(), Ps, Rs
103 |
--------------------------------------------------------------------------------
/ram/utils/openset_utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import torch
5 | import torch.nn as nn
6 | from clip import clip
7 |
8 |
9 | def article(name):
10 | return "an" if name[0] in "aeiou" else "a"
11 |
12 |
13 | def processed_name(name, rm_dot=False):
14 | # _ for lvis
15 | # / for obj365
16 | res = name.replace("_", " ").replace("/", " or ").lower()
17 | if rm_dot:
18 | res = res.rstrip(".")
19 | return res
20 |
21 |
22 | single_template = ["a photo of a {}."]
23 |
24 | multiple_templates = [
25 | "There is {article} {} in the scene.",
26 | "There is the {} in the scene.",
27 | "a photo of {article} {} in the scene.",
28 | "a photo of the {} in the scene.",
29 | "a photo of one {} in the scene.",
30 | "itap of {article} {}.",
31 | "itap of my {}.", # itap: I took a picture of
32 | "itap of the {}.",
33 | "a photo of {article} {}.",
34 | "a photo of my {}.",
35 | "a photo of the {}.",
36 | "a photo of one {}.",
37 | "a photo of many {}.",
38 | "a good photo of {article} {}.",
39 | "a good photo of the {}.",
40 | "a bad photo of {article} {}.",
41 | "a bad photo of the {}.",
42 | "a photo of a nice {}.",
43 | "a photo of the nice {}.",
44 | "a photo of a cool {}.",
45 | "a photo of the cool {}.",
46 | "a photo of a weird {}.",
47 | "a photo of the weird {}.",
48 | "a photo of a small {}.",
49 | "a photo of the small {}.",
50 | "a photo of a large {}.",
51 | "a photo of the large {}.",
52 | "a photo of a clean {}.",
53 | "a photo of the clean {}.",
54 | "a photo of a dirty {}.",
55 | "a photo of the dirty {}.",
56 | "a bright photo of {article} {}.",
57 | "a bright photo of the {}.",
58 | "a dark photo of {article} {}.",
59 | "a dark photo of the {}.",
60 | "a photo of a hard to see {}.",
61 | "a photo of the hard to see {}.",
62 | "a low resolution photo of {article} {}.",
63 | "a low resolution photo of the {}.",
64 | "a cropped photo of {article} {}.",
65 | "a cropped photo of the {}.",
66 | "a close-up photo of {article} {}.",
67 | "a close-up photo of the {}.",
68 | "a jpeg corrupted photo of {article} {}.",
69 | "a jpeg corrupted photo of the {}.",
70 | "a blurry photo of {article} {}.",
71 | "a blurry photo of the {}.",
72 | "a pixelated photo of {article} {}.",
73 | "a pixelated photo of the {}.",
74 | "a black and white photo of the {}.",
75 | "a black and white photo of {article} {}.",
76 | "a plastic {}.",
77 | "the plastic {}.",
78 | "a toy {}.",
79 | "the toy {}.",
80 | "a plushie {}.",
81 | "the plushie {}.",
82 | "a cartoon {}.",
83 | "the cartoon {}.",
84 | "an embroidered {}.",
85 | "the embroidered {}.",
86 | "a painting of the {}.",
87 | "a painting of a {}.",
88 | ]
89 |
90 |
91 | openimages_rare_unseen = ['Aerial photography',
92 | 'Aircraft engine',
93 | 'Ale',
94 | 'Aloe',
95 | 'Amphibian',
96 | 'Angling',
97 | 'Anole',
98 | 'Antique car',
99 | 'Arcade game',
100 | 'Arthropod',
101 | 'Assault rifle',
102 | 'Athletic shoe',
103 | 'Auto racing',
104 | 'Backlighting',
105 | 'Bagpipes',
106 | 'Ball game',
107 | 'Barbecue chicken',
108 | 'Barechested',
109 | 'Barquentine',
110 | 'Beef tenderloin',
111 | 'Billiard room',
112 | 'Billiards',
113 | 'Bird of prey',
114 | 'Black swan',
115 | 'Black-and-white',
116 | 'Blond',
117 | 'Boating',
118 | 'Bonbon',
119 | 'Bottled water',
120 | 'Bouldering',
121 | 'Bovine',
122 | 'Bratwurst',
123 | 'Breadboard',
124 | 'Briefs',
125 | 'Brisket',
126 | 'Brochette',
127 | 'Calabaza',
128 | 'Camera operator',
129 | 'Canola',
130 | 'Childbirth',
131 | 'Chordophone',
132 | 'Church bell',
133 | 'Classical sculpture',
134 | 'Close-up',
135 | 'Cobblestone',
136 | 'Coca-cola',
137 | 'Combat sport',
138 | 'Comics',
139 | 'Compact car',
140 | 'Computer speaker',
141 | 'Cookies and crackers',
142 | 'Coral reef fish',
143 | 'Corn on the cob',
144 | 'Cosmetics',
145 | 'Crocodilia',
146 | 'Digital camera',
147 | 'Dishware',
148 | 'Divemaster',
149 | 'Dobermann',
150 | 'Dog walking',
151 | 'Domestic rabbit',
152 | 'Domestic short-haired cat',
153 | 'Double-decker bus',
154 | 'Drums',
155 | 'Electric guitar',
156 | 'Electric piano',
157 | 'Electronic instrument',
158 | 'Equestrianism',
159 | 'Equitation',
160 | 'Erinaceidae',
161 | 'Extreme sport',
162 | 'Falafel',
163 | 'Figure skating',
164 | 'Filling station',
165 | 'Fire apparatus',
166 | 'Firearm',
167 | 'Flatbread',
168 | 'Floristry',
169 | 'Forklift truck',
170 | 'Freight transport',
171 | 'Fried food',
172 | 'Fried noodles',
173 | 'Frigate',
174 | 'Frozen yogurt',
175 | 'Frying',
176 | 'Full moon',
177 | 'Galleon',
178 | 'Glacial landform',
179 | 'Gliding',
180 | 'Go-kart',
181 | 'Goats',
182 | 'Grappling',
183 | 'Great white shark',
184 | 'Gumbo',
185 | 'Gun turret',
186 | 'Hair coloring',
187 | 'Halter',
188 | 'Headphones',
189 | 'Heavy cruiser',
190 | 'Herding',
191 | 'High-speed rail',
192 | 'Holding hands',
193 | 'Horse and buggy',
194 | 'Horse racing',
195 | 'Hound',
196 | 'Hunting knife',
197 | 'Hurdling',
198 | 'Inflatable',
199 | 'Jackfruit',
200 | 'Jeans',
201 | 'Jiaozi',
202 | 'Junk food',
203 | 'Khinkali',
204 | 'Kitesurfing',
205 | 'Lawn game',
206 | 'Leaf vegetable',
207 | 'Lechon',
208 | 'Lifebuoy',
209 | 'Locust',
210 | 'Lumpia',
211 | 'Luxury vehicle',
212 | 'Machine tool',
213 | 'Medical imaging',
214 | 'Melee weapon',
215 | 'Microcontroller',
216 | 'Middle ages',
217 | 'Military person',
218 | 'Military vehicle',
219 | 'Milky way',
220 | 'Miniature Poodle',
221 | 'Modern dance',
222 | 'Molluscs',
223 | 'Monoplane',
224 | 'Motorcycling',
225 | 'Musical theatre',
226 | 'Narcissus',
227 | 'Nest box',
228 | 'Newsagent\'s shop',
229 | 'Nile crocodile',
230 | 'Nordic skiing',
231 | 'Nuclear power plant',
232 | 'Orator',
233 | 'Outdoor shoe',
234 | 'Parachuting',
235 | 'Pasta salad',
236 | 'Peafowl',
237 | 'Pelmeni',
238 | 'Perching bird',
239 | 'Performance car',
240 | 'Personal water craft',
241 | 'Pit bull',
242 | 'Plant stem',
243 | 'Pork chop',
244 | 'Portrait photography',
245 | 'Primate',
246 | 'Procyonidae',
247 | 'Prosciutto',
248 | 'Public speaking',
249 | 'Racewalking',
250 | 'Ramen',
251 | 'Rear-view mirror',
252 | 'Residential area',
253 | 'Ribs',
254 | 'Rice ball',
255 | 'Road cycling',
256 | 'Roller skating',
257 | 'Roman temple',
258 | 'Rowing',
259 | 'Rural area',
260 | 'Sailboat racing',
261 | 'Scaled reptile',
262 | 'Scuba diving',
263 | 'Senior citizen',
264 | 'Shallot',
265 | 'Shinto shrine',
266 | 'Shooting range',
267 | 'Siberian husky',
268 | 'Sledding',
269 | 'Soba',
270 | 'Solar energy',
271 | 'Sport climbing',
272 | 'Sport utility vehicle',
273 | 'Steamed rice',
274 | 'Stemware',
275 | 'Sumo',
276 | 'Surfing Equipment',
277 | 'Team sport',
278 | 'Touring car',
279 | 'Toy block',
280 | 'Trampolining',
281 | 'Underwater diving',
282 | 'Vegetarian food',
283 | 'Wallaby',
284 | 'Water polo',
285 | 'Watercolor paint',
286 | 'Whiskers',
287 | 'Wind wave',
288 | 'Woodwind instrument',
289 | 'Yakitori',
290 | 'Zeppelin']
291 |
292 |
293 | def build_openset_label_embedding(categories=None):
294 | if categories is None:
295 | categories = openimages_rare_unseen
296 | print("Creating pretrained CLIP model")
297 | model, _ = clip.load("ViT-B/16")
298 | templates = multiple_templates
299 |
300 | run_on_gpu = torch.cuda.is_available()
301 |
302 | with torch.no_grad():
303 | openset_label_embedding = []
304 | for category in categories:
305 | texts = [
306 | template.format(
307 | processed_name(category, rm_dot=True), article=article(category)
308 | )
309 | for template in templates
310 | ]
311 | texts = [
312 | "This is " + text if text.startswith("a") or text.startswith("the") else text
313 | for text in texts
314 | ]
315 | texts = clip.tokenize(texts) # tokenize
316 | if run_on_gpu:
317 | texts = texts.cuda()
318 | model = model.cuda()
319 | text_embeddings = model.encode_text(texts)
320 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
321 | text_embedding = text_embeddings.mean(dim=0)
322 | text_embedding /= text_embedding.norm()
323 | openset_label_embedding.append(text_embedding)
324 | openset_label_embedding = torch.stack(openset_label_embedding, dim=1)
325 | if run_on_gpu:
326 | openset_label_embedding = openset_label_embedding.cuda()
327 |
328 | openset_label_embedding = openset_label_embedding.t()
329 | return openset_label_embedding, categories
330 |
331 |
332 |
333 | import json
334 | from tqdm import tqdm
335 |
336 | def build_openset_llm_label_embedding(llm_tag_des):
337 | print("Creating pretrained CLIP model")
338 | model, _ = clip.load("ViT-B/16")
339 | llm_tag_des = llm_tag_des
340 | categories = []
341 |
342 | run_on_gpu = torch.cuda.is_available()
343 |
344 | with torch.no_grad():
345 | openset_label_embedding = []
346 | for item in tqdm(llm_tag_des):
347 | category = list(item.keys())[0]
348 | des = list(item.values())[0]
349 |
350 | categories.append(category)
351 |
352 | texts = clip.tokenize(des, truncate=True) # tokenize
353 | if run_on_gpu:
354 | texts = texts.cuda()
355 | model = model.cuda()
356 | text_embeddings = model.encode_text(texts)
357 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
358 | # text_embedding = text_embeddings.mean(dim=0)
359 | # text_embedding /= text_embedding.norm()
360 | # openset_label_embedding.append(text_embedding)
361 | openset_label_embedding.append(text_embeddings)
362 | # openset_label_embedding = torch.stack(openset_label_embedding, dim=1)
363 | openset_label_embedding = torch.cat(openset_label_embedding, dim=0)
364 | if run_on_gpu:
365 | openset_label_embedding = openset_label_embedding.cuda()
366 |
367 | # openset_label_embedding = openset_label_embedding.t()
368 | return openset_label_embedding, categories
369 |
370 |
371 |
372 |
373 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.32.1
2 | einops==0.8.0
3 | gradio==4.38.1
4 | joblib==1.2.0
5 | numpy==1.26.4
6 | omegaconf==2.3.0
7 | packaging==24.1
8 | pandas==1.5.3
9 | peft==0.11.1
10 | Pillow==9.4.0
11 | pycocoevalcap==1.2
12 | pycocotools==2.0.8
13 | pynvml==11.5.0
14 | quantization==0.0.1
15 | Requests==2.32.3
16 | scikit_learn==1.5.1
17 | spacy==3.7.5
18 | timm==0.4.12
19 | torch==2.3.1
20 | torchmetrics==0.11.4
21 | torchvision==0.18.1
22 | tqdm==4.64.1
23 | transformers==4.42.4
24 | webdataset==0.2.86
25 | fairscale==0.4.4
--------------------------------------------------------------------------------