├── .gitignore ├── LICENSE ├── README.md ├── assets ├── JiuTian.pdf ├── LION-6Examples.jpg ├── LION-CapVQA.jpg ├── LION-Examples.jpg ├── LION-Image-level.jpg ├── LION-Introduction.jpg ├── LION-MMBench.jpg ├── LION-Method.jpg ├── LION-POPE.jpg ├── LION-REC.jpg ├── LION-Region-level.jpg ├── LION-Score.jpg ├── LION_logo.png └── model.jpg ├── common └── registry.py ├── configs └── models │ ├── lion_flant5xl.yaml │ └── lion_flant5xxl.yaml ├── images ├── COCO_train2014_000000024935.jpg └── COCO_train2014_000000533220.jpg ├── models ├── Qformer.py ├── __init__.py ├── base_model.py ├── eva_vit.py ├── lion_adapters.py ├── lion_t5.py └── modeling_t5.py ├── playground.ipynb ├── preprocessors └── lion_preprocessors.py ├── ram ├── __init__.py ├── configs │ ├── finetune.yaml │ ├── finetune_tag2text.yaml │ ├── med_config.json │ ├── pretrain.yaml │ ├── pretrain_tag2text.yaml │ ├── q2l_config.json │ └── swin │ │ ├── config_swinB_224.json │ │ ├── config_swinB_384.json │ │ ├── config_swinL_224.json │ │ └── config_swinL_384.json ├── data │ ├── __init__.py │ ├── dataset.py │ ├── ram_tag_list.txt │ ├── ram_tag_list_chinese.txt │ ├── ram_tag_list_threshold.txt │ ├── randaugment.py │ ├── tag2text_ori_tag_list.txt │ ├── tag_list.txt │ └── utils.py ├── inference.py ├── models │ ├── __init__.py │ ├── bert.py │ ├── ram.py │ ├── ram_plus.py │ ├── swin_transformer.py │ ├── tag2text.py │ ├── utils.py │ └── vit.py ├── transform.py └── utils │ ├── __init__.py │ ├── metrics.py │ └── openset_utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Rui Shao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 |

LION: Empowering Multimodal Large Language Model with Dual-Level Visual Knowledge

5 |
6 |
7 | Gongwei Chen, 8 | Leyang Shen, 9 | Rui Shao*, 10 | Xiang Deng, 11 | Liqiang Nie* 12 |
13 | 14 | School of Computer Science and Technology, Harbin Institute of Technology, Shenzhen
15 | *Corresponding author 16 | 17 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2024 18 | 19 | [[Paper]](https://arxiv.org/abs/2311.11860) [[Project Page]](https://rshaojimmy.github.io/Projects/JiuTian-LION) [[Video(YouTube)]](https://www.youtube.com/watch?v=YzJ5MZFS5RA) [[Video(bilibili)]](https://www.bilibili.com/video/BV1kH4y1y7UR/) 20 | 21 | :fire: Details will be released. Stay tuned :beers: :+1: 22 | 23 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fwww.slywiki.cn&count_bg=%2379C83D&title_bg=%23555555&icon=github.svg&icon_color=%23E7E7E7&title=Visitors&edge_flat=false)](https://hits.seeyoufarm.com) 24 | 25 |
26 |
27 | 28 | 29 | 30 |
31 | 32 | ## If you find this work useful for your research, please kindly cite our paper and star our repo. 33 | 34 | ## Updates 35 | - [07/2024] Code and checkpoints are released. 36 | - [02/2024] LION has been accepted by CVPR 2024. 37 | - [11/2023] [Arxiv paper](https://arxiv.org/abs/2311.11860) released. 38 | - [11/2023] [Project page](https://rshaojimmy.github.io/Projects/JiuTian-LION) released. 39 | 40 | ## Introduction 41 | 42 | This is the github repository of *LION : Empowering Multimodal Large Language Model with Dual-Level Visual Knowledge*. In this work, we enhance MLLMs by integrating fine-grained spatial-aware visual knowledge and high-level semantic visual evidence, boosting capabilities and alleviating hallucinations. 43 | 44 | The framework of the proposed LION model: 45 | 46 |
47 | 48 |
49 | 50 | ## Installation 51 | 52 | ### Download 53 | ```bash 54 | git clone https://github.com/JiuTian-VL/JiuTian-LION.git 55 | cd JiuTian-LION 56 | ``` 57 | 58 | ### Environment 59 | 60 | ```bash 61 | conda create -n LION python=3.12 62 | conda activate LION 63 | conda install pip 64 | pip install -r requirements.txt 65 | ``` 66 | 67 | ## Checkpoints 68 | 69 | | Version | Checkpoint | 70 | | --- | --- | 71 | | LION-FlanT5-XL| [daybreaksly/LION-FlanT5-XL](https://huggingface.co/daybreaksly/LION-FlanT5-XL) | 72 | | LION-FlanT5-XXL| [daybreaksly/LION-FlanT5-XXL](https://huggingface.co/daybreaksly/LION-FlanT5-XXL) | 73 | 74 | ## Usage 75 | 76 | ### Prepare models 77 | 78 | 1. Download the pre-trained vit model [eva_vit_g](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth). 79 | 2. Download the pre-trained RAM model [ram_swin_large_14m](https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text/blob/main/ram_swin_large_14m.pth). 80 | 3. Download the pre-trained FlanT5 model [FlanT5-XL](https://huggingface.co/google/flan-t5-xl). 81 | 4. Download the pre-trained BERT model [bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) 82 | 5. Fill in the paths to these models into the corresponding locations in the config file `configs\models\lion_flant5xl.yaml` 83 | 84 | ### Inference 85 | 86 | We provide inference examples for **Image-Level** and **Region-Level** tasks in `playground.ipynb`. 87 | 88 | ## Evaluation results 89 | 90 | For image-level tasks, we focus on image captioning and Visual Question Answering (VQA). For region-level tasks, we evaluate LION on three REC datasets including RefCOCO, RefCOCO+ and RefCOCOg. The results, detailed in Table 1~2, highlight LION's superior performance compared to baseline models. 91 | 92 | ![Score](assets/LION-Score.jpg) 93 | 94 | ![Image-level](assets/LION-Image-level.jpg) 95 | ![Region-level](assets/LION-Region-level.jpg) 96 | 97 | We further evaluate LION on a object hallucination benchmark([POPE](https://github.com/AoiDragon/POPE)) and the most popular MLLM benchmark ([MMBench](https://mmbench.opencompass.org.cn/home)). The results in Table 1~2 show that LION has strong performances across various skills and also demonstrates a strong resistance to hallucinations, particularly in popular and adversarial settings in POPE. 98 | 99 | ![MMBench](assets/LION-MMBench.jpg) 100 | ![POPE](assets/LION-POPE.jpg) 101 | 102 | ## Qualitative Comparison 103 | 104 | ![Qualitative Comparison](assets/LION-Examples.jpg) 105 | ![Qualitative Comparison](assets/LION-CapVQA.jpg) 106 | ![Qualitative Comparison](assets/LION-REC.jpg) 107 | 108 | ## More Examples 109 | ![Qualitative Comparison](assets/LION-6Examples.jpg) 110 | 111 | ## Citation 112 | 113 | If you find this work useful for your research, please kindly cite our paper: 114 | ``` 115 | @inproceedings{chen2024lion, 116 | title={LION: Empowering Multimodal Large Language Model with Dual-Level Visual Knowledge}, 117 | author={Chen, Gongwei and Shen, Leyang and Shao, Rui and Deng, Xiang and Nie, Liqiang}, 118 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 119 | year={2024} 120 | } 121 | ``` 122 | -------------------------------------------------------------------------------- /assets/JiuTian.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/JiuTian.pdf -------------------------------------------------------------------------------- /assets/LION-6Examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-6Examples.jpg -------------------------------------------------------------------------------- /assets/LION-CapVQA.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-CapVQA.jpg -------------------------------------------------------------------------------- /assets/LION-Examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Examples.jpg -------------------------------------------------------------------------------- /assets/LION-Image-level.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Image-level.jpg -------------------------------------------------------------------------------- /assets/LION-Introduction.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Introduction.jpg -------------------------------------------------------------------------------- /assets/LION-MMBench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-MMBench.jpg -------------------------------------------------------------------------------- /assets/LION-Method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Method.jpg -------------------------------------------------------------------------------- /assets/LION-POPE.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-POPE.jpg -------------------------------------------------------------------------------- /assets/LION-REC.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-REC.jpg -------------------------------------------------------------------------------- /assets/LION-Region-level.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Region-level.jpg -------------------------------------------------------------------------------- /assets/LION-Score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION-Score.jpg -------------------------------------------------------------------------------- /assets/LION_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/LION_logo.png -------------------------------------------------------------------------------- /assets/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/assets/model.jpg -------------------------------------------------------------------------------- /common/registry.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Registry: 4 | root = os.path.expanduser("~") 5 | mapping = { 6 | "builder_name_mapping": {}, 7 | "processor_name_mapping": {}, 8 | "model_name_mapping": {}, 9 | "evaluator_name_mapping": {} 10 | } 11 | 12 | @classmethod 13 | def register_builder(cls, name): 14 | r"""Register a dataset builder to registry with key 'name' 15 | 16 | Args: 17 | name: Key with which the dataset builder will be registered. 18 | """ 19 | 20 | def wrap(builder_func): 21 | if name in cls.mapping["builder_name_mapping"]: 22 | raise KeyError( 23 | "Name '{}' already registered for {}.".format( 24 | name, cls.mapping["builder_name_mapping"][name] 25 | ) 26 | ) 27 | cls.mapping["builder_name_mapping"][name] = builder_func 28 | return builder_func 29 | 30 | return wrap 31 | 32 | @classmethod 33 | def register_evaluator(cls, name): 34 | r"""Register a task evaluator to registry with key 'name' 35 | 36 | Args: 37 | name: Key with which the task evaluator will be registered. 38 | """ 39 | 40 | def wrap(eval_func): 41 | if name in cls.mapping["evaluator_name_mapping"]: 42 | raise KeyError( 43 | "Name '{}' already registered for {}.".format( 44 | name, cls.mapping["evaluator_name_mapping"][name] 45 | ) 46 | ) 47 | cls.mapping["evaluator_name_mapping"][name] = eval_func 48 | return eval_func 49 | 50 | return wrap 51 | 52 | @classmethod 53 | def register_model(cls, name): 54 | r"""Register a model to registry with key 'name' 55 | 56 | Args: 57 | name: Key with which the model will be registered. 58 | """ 59 | 60 | def wrap(model_cls): 61 | from models import BaseModel 62 | 63 | assert issubclass( 64 | model_cls, BaseModel 65 | ), "All models must inherit BaseModel class" 66 | if name in cls.mapping["model_name_mapping"]: 67 | raise KeyError( 68 | "Name '{}' already registered for {}.".format( 69 | name, cls.mapping["model_name_mapping"][name] 70 | ) 71 | ) 72 | cls.mapping["model_name_mapping"][name] = model_cls 73 | return model_cls 74 | 75 | return wrap 76 | 77 | @classmethod 78 | def register_processor(cls, name): 79 | r"""Register a processor to registry with key 'name' 80 | 81 | Args: 82 | name: Key with which the processor will be registered. 83 | """ 84 | 85 | def wrap(processor_cls): 86 | if name in cls.mapping["processor_name_mapping"]: 87 | raise KeyError( 88 | "Name '{}' already registered for {}.".format( 89 | name, cls.mapping["processor_name_mapping"][name] 90 | ) 91 | ) 92 | cls.mapping["processor_name_mapping"][name] = processor_cls 93 | return processor_cls 94 | 95 | return wrap 96 | 97 | @classmethod 98 | def get_builder_func(cls, name): 99 | return cls.mapping["builder_name_mapping"].get(name, None) 100 | 101 | @classmethod 102 | def get_evaluator_func(cls, name): 103 | return cls.mapping["evaluator_name_mapping"].get(name, None) 104 | 105 | @classmethod 106 | def get_model_class(cls, name): 107 | return cls.mapping["model_name_mapping"].get(name, None) 108 | 109 | @classmethod 110 | def get_processor_class(cls, name): 111 | return cls.mapping["processor_name_mapping"].get(name, None) 112 | 113 | @classmethod 114 | def list_models(cls): 115 | return sorted(cls.mapping["model_name_mapping"].keys()) 116 | 117 | @classmethod 118 | def list_processors(cls): 119 | return sorted(cls.mapping["processor_name_mapping"].keys()) 120 | 121 | @classmethod 122 | def list_datasets(cls): 123 | return sorted(cls.mapping["builder_name_mapping"].keys()) 124 | 125 | @classmethod 126 | def list_evaluators(cls): 127 | return sorted(cls.mapping["evaluator_name_mapping"].keys()) 128 | 129 | registry = Registry() -------------------------------------------------------------------------------- /configs/models/lion_flant5xl.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | bert_model: "/path/to/bert-base-uncased/" 3 | vit_model: "/path/to/eva_vit_g.pth" 4 | llm_model: "/path/to/flan-t5-xl/" 5 | ram_model: "/path/to/ram_swin_large_14m.pth" 6 | visual_input: "ALL" 7 | enable_semantic_tags: True 8 | 9 | load_pretrained: True 10 | pretrained: /path/to/LION-FlanT5-XL.pth 11 | -------------------------------------------------------------------------------- /configs/models/lion_flant5xxl.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | bert_model: "/path/to/bert-base-uncased/" 3 | vit_model: "/path/to/eva_vit_g.pth" 4 | llm_model: "/path/to/flan-t5-xxl/" 5 | ram_model: "/path/to/ram_swin_large_14m.pth" 6 | visual_input: "ALL" 7 | enable_semantic_tags: True 8 | 9 | load_pretrained: True 10 | pretrained: /path/to/LION-FlanT5-XXL.pth 11 | -------------------------------------------------------------------------------- /images/COCO_train2014_000000024935.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/images/COCO_train2014_000000024935.jpg -------------------------------------------------------------------------------- /images/COCO_train2014_000000533220.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiuTian-VL/JiuTian-LION/75784372152798a9f7e80bac0ca569549c63632a/images/COCO_train2014_000000533220.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from common.registry import registry 2 | from models.base_model import BaseModel 3 | from models.lion_t5 import LIONT5InstructAdapter 4 | 5 | __all__ = [ 6 | "LIONT5InstructAdapter" 7 | ] 8 | 9 | def load_model(name, model_type, is_eval=False, device="cpu"): 10 | """ 11 | Load supported models. 12 | 13 | Args: 14 | name (str): name of the model. 15 | model_type (str): type of the model. 16 | is_eval (bool): whether the model is in eval mode. Default: False. 17 | device (str): device to use. Default: "cpu". 18 | 19 | Returns: 20 | model (torch.nn.Module): model. 21 | """ 22 | 23 | model = registry.get_model_class(name).from_pretrained(model_type=model_type) 24 | 25 | if is_eval: 26 | model.eval() 27 | 28 | if device == "cpu": 29 | model = model.float() 30 | 31 | return model.to(device) -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import logging 5 | from omegaconf import OmegaConf 6 | 7 | class BaseModel(nn.Module): 8 | """Base class for models.""" 9 | 10 | def __init__(self): 11 | super().__init__() 12 | 13 | @property 14 | def device(self): 15 | return list(self.parameters())[0].device 16 | 17 | def load_checkpoint_from_config(self, cfg, **kwargs): 18 | """ 19 | Load checkpoint as specified in the config file. 20 | 21 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. 22 | When loading the pretrained model, each task-specific architecture may define their 23 | own load_from_pretrained() method. 24 | """ 25 | load_pretrained = cfg.get("load_pretrained", True) 26 | if load_pretrained: 27 | # load pre-trained weights 28 | pretrain_path = cfg.get("pretrained", None) 29 | assert pretrain_path, "pretrained path is not specified in the config file" 30 | print("loading pretrain: ", pretrain_path) 31 | msg = self.load_checkpoint(filename=pretrain_path, **kwargs) 32 | print("unexpected_keys: ", msg.unexpected_keys) 33 | 34 | def load_checkpoint(self, filename): 35 | """ 36 | Load from a finetuned checkpoint. 37 | 38 | This should expect no mismatch in the model keys and the checkpoint keys. 39 | """ 40 | 41 | if os.path.isfile(filename): 42 | checkpoint = torch.load(filename, map_location="cpu") 43 | else: 44 | raise RuntimeError("checkpoint url or path is invalid") 45 | 46 | if "model" in checkpoint.keys(): 47 | state_dict = checkpoint["model"] 48 | else: 49 | state_dict = checkpoint 50 | 51 | msg = self.load_state_dict(state_dict, strict=False) 52 | 53 | # logging.info("Missing keys {}".format(msg.missing_keys)) 54 | logging.info(f"Unexpected keys: {msg.unexpected_keys}") 55 | logging.info(f"load checkpoint from: {filename}") 56 | 57 | return msg 58 | 59 | @classmethod 60 | def default_config_path(cls, model_type): 61 | assert ( 62 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT 63 | ), "Unknown model type {}".format(model_type) 64 | return cls.PRETRAINED_MODEL_CONFIG_DICT[model_type] 65 | 66 | def counting_training_parameters(self): 67 | total = 0. 68 | trainable_names = [] 69 | all = 0. 70 | for name, param in self.named_parameters(): 71 | if param.requires_grad: 72 | total += param.nelement() 73 | trainable_names.append(name) 74 | all += param.nelement() 75 | logging.info(trainable_names) 76 | logging.info(' + Number of trainable params: %.2fM' % (total / 1e6)) 77 | logging.info('Number of all params: %.2fM' % (all / 1e6)) 78 | return total 79 | 80 | @classmethod 81 | def from_config(cls, cfg): 82 | raise NotImplementedError() 83 | 84 | @classmethod 85 | def from_pretrained(cls, model_type): 86 | """ 87 | Build a pretrained model from default configuration file, specified by model_type. 88 | 89 | Args: 90 | - model_type (str): model type, specifying architecture and checkpoints. 91 | 92 | Returns: 93 | - model (nn.Module): pretrained or finetuned model, depending on the configuration. 94 | """ 95 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model 96 | model = cls.from_config(model_cfg) 97 | 98 | return model 99 | 100 | 101 | def get_optimizer_params(self, weight_decay, lr_scale=1): 102 | p_wd, p_non_wd = [], [] 103 | for n, p in self.named_parameters(): 104 | if not p.requires_grad: 105 | continue # frozen weights 106 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: 107 | p_non_wd.append(p) 108 | else: 109 | p_wd.append(p) 110 | optim_params = [ 111 | {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale}, 112 | {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale}, 113 | ] 114 | return optim_params 115 | -------------------------------------------------------------------------------- /models/eva_vit.py: -------------------------------------------------------------------------------- 1 | # Based on EVA, BEIT, timm and DeiT code bases 2 | # https://github.com/baaivision/EVA 3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/facebookresearch/deit/ 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | import math 9 | from functools import partial 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.utils.checkpoint as checkpoint 15 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 16 | from timm.models.registry import register_model 17 | 18 | 19 | def _cfg(url='', **kwargs): 20 | return { 21 | 'url': url, 22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 23 | 'crop_pct': .9, 'interpolation': 'bicubic', 24 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 25 | **kwargs 26 | } 27 | 28 | 29 | class DropPath(nn.Module): 30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 31 | """ 32 | def __init__(self, drop_prob=None): 33 | super(DropPath, self).__init__() 34 | self.drop_prob = drop_prob 35 | 36 | def forward(self, x): 37 | return drop_path(x, self.drop_prob, self.training) 38 | 39 | def extra_repr(self) -> str: 40 | return 'p={}'.format(self.drop_prob) 41 | 42 | 43 | class Mlp(nn.Module): 44 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 45 | super().__init__() 46 | out_features = out_features or in_features 47 | hidden_features = hidden_features or in_features 48 | self.fc1 = nn.Linear(in_features, hidden_features) 49 | self.act = act_layer() 50 | self.fc2 = nn.Linear(hidden_features, out_features) 51 | self.drop = nn.Dropout(drop) 52 | 53 | def forward(self, x): 54 | x = self.fc1(x) 55 | x = self.act(x) 56 | # x = self.drop(x) 57 | # commit this for the orignal BERT implement 58 | x = self.fc2(x) 59 | x = self.drop(x) 60 | return x 61 | 62 | 63 | class Attention(nn.Module): 64 | def __init__( 65 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 66 | proj_drop=0., window_size=None, attn_head_dim=None): 67 | super().__init__() 68 | self.num_heads = num_heads 69 | head_dim = dim // num_heads 70 | if attn_head_dim is not None: 71 | head_dim = attn_head_dim 72 | all_head_dim = head_dim * self.num_heads 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 76 | if qkv_bias: 77 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 78 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 79 | else: 80 | self.q_bias = None 81 | self.v_bias = None 82 | 83 | if window_size: 84 | self.window_size = window_size 85 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 86 | self.relative_position_bias_table = nn.Parameter( 87 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 88 | # cls to token & token 2 cls & cls to cls 89 | 90 | # get pair-wise relative position index for each token inside the window 91 | coords_h = torch.arange(window_size[0]) 92 | coords_w = torch.arange(window_size[1]) 93 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 94 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 95 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 96 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 97 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 98 | relative_coords[:, :, 1] += window_size[1] - 1 99 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 100 | relative_position_index = \ 101 | torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) 102 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 103 | relative_position_index[0, 0:] = self.num_relative_distance - 3 104 | relative_position_index[0:, 0] = self.num_relative_distance - 2 105 | relative_position_index[0, 0] = self.num_relative_distance - 1 106 | 107 | self.register_buffer("relative_position_index", relative_position_index) 108 | else: 109 | self.window_size = None 110 | self.relative_position_bias_table = None 111 | self.relative_position_index = None 112 | 113 | self.attn_drop = nn.Dropout(attn_drop) 114 | self.proj = nn.Linear(all_head_dim, dim) 115 | self.proj_drop = nn.Dropout(proj_drop) 116 | 117 | def forward(self, x, rel_pos_bias=None): 118 | B, N, C = x.shape 119 | qkv_bias = None 120 | if self.q_bias is not None: 121 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 122 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 123 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 124 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 125 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 126 | 127 | q = q * self.scale 128 | attn = (q @ k.transpose(-2, -1)) 129 | 130 | if self.relative_position_bias_table is not None: 131 | relative_position_bias = \ 132 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 133 | self.window_size[0] * self.window_size[1] + 1, 134 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 135 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 136 | attn = attn + relative_position_bias.unsqueeze(0) 137 | 138 | if rel_pos_bias is not None: 139 | attn = attn + rel_pos_bias 140 | 141 | attn = attn.softmax(dim=-1) 142 | attn = self.attn_drop(attn) 143 | 144 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 145 | x = self.proj(x) 146 | x = self.proj_drop(x) 147 | return x 148 | 149 | 150 | class Block(nn.Module): 151 | 152 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 153 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 154 | window_size=None, attn_head_dim=None): 155 | super().__init__() 156 | self.norm1 = norm_layer(dim) 157 | self.attn = Attention( 158 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 159 | attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) 160 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 161 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 162 | self.norm2 = norm_layer(dim) 163 | mlp_hidden_dim = int(dim * mlp_ratio) 164 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 165 | 166 | if init_values is not None and init_values > 0: 167 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 168 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 169 | else: 170 | self.gamma_1, self.gamma_2 = None, None 171 | 172 | def forward(self, x, rel_pos_bias=None): 173 | if self.gamma_1 is None: 174 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 175 | x = x + self.drop_path(self.mlp(self.norm2(x))) 176 | else: 177 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 178 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 179 | return x 180 | 181 | 182 | class PatchEmbed(nn.Module): 183 | """ Image to Patch Embedding 184 | """ 185 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 186 | super().__init__() 187 | img_size = to_2tuple(img_size) 188 | patch_size = to_2tuple(patch_size) 189 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 190 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 191 | self.img_size = img_size 192 | self.patch_size = patch_size 193 | self.num_patches = num_patches 194 | 195 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 196 | 197 | def forward(self, x, **kwargs): 198 | B, C, H, W = x.shape 199 | # FIXME look at relaxing size constraints 200 | assert H == self.img_size[0] and W == self.img_size[1], \ 201 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 202 | x = self.proj(x).flatten(2).transpose(1, 2) 203 | return x 204 | 205 | 206 | class RelativePositionBias(nn.Module): 207 | 208 | def __init__(self, window_size, num_heads): 209 | super().__init__() 210 | self.window_size = window_size 211 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 212 | self.relative_position_bias_table = nn.Parameter( 213 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 214 | # cls to token & token 2 cls & cls to cls 215 | 216 | # get pair-wise relative position index for each token inside the window 217 | coords_h = torch.arange(window_size[0]) 218 | coords_w = torch.arange(window_size[1]) 219 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 220 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 221 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 222 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 223 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 224 | relative_coords[:, :, 1] += window_size[1] - 1 225 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 226 | relative_position_index = \ 227 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) 228 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 229 | relative_position_index[0, 0:] = self.num_relative_distance - 3 230 | relative_position_index[0:, 0] = self.num_relative_distance - 2 231 | relative_position_index[0, 0] = self.num_relative_distance - 1 232 | 233 | self.register_buffer("relative_position_index", relative_position_index) 234 | 235 | # trunc_normal_(self.relative_position_bias_table, std=.02) 236 | 237 | def forward(self): 238 | relative_position_bias = \ 239 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 240 | self.window_size[0] * self.window_size[1] + 1, 241 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 242 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 243 | 244 | 245 | class VisionTransformer(nn.Module): 246 | """ Vision Transformer with support for patch or hybrid CNN input stage 247 | """ 248 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 249 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 250 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, 251 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, 252 | use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): 253 | super().__init__() 254 | self.image_size = img_size 255 | self.num_classes = num_classes 256 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 257 | 258 | self.patch_embed = PatchEmbed( 259 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 260 | num_patches = self.patch_embed.num_patches 261 | 262 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 263 | if use_abs_pos_emb: 264 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 265 | else: 266 | self.pos_embed = None 267 | self.pos_drop = nn.Dropout(p=drop_rate) 268 | 269 | if use_shared_rel_pos_bias: 270 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) 271 | else: 272 | self.rel_pos_bias = None 273 | self.use_checkpoint = use_checkpoint 274 | 275 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 276 | self.use_rel_pos_bias = use_rel_pos_bias 277 | self.blocks = nn.ModuleList([ 278 | Block( 279 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 280 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 281 | init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) 282 | for i in range(depth)]) 283 | # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 284 | # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 285 | # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 286 | 287 | if self.pos_embed is not None: 288 | trunc_normal_(self.pos_embed, std=.02) 289 | trunc_normal_(self.cls_token, std=.02) 290 | # trunc_normal_(self.mask_token, std=.02) 291 | # if isinstance(self.head, nn.Linear): 292 | # trunc_normal_(self.head.weight, std=.02) 293 | self.apply(self._init_weights) 294 | self.fix_init_weight() 295 | # if isinstance(self.head, nn.Linear): 296 | # self.head.weight.data.mul_(init_scale) 297 | # self.head.bias.data.mul_(init_scale) 298 | 299 | def fix_init_weight(self): 300 | def rescale(param, layer_id): 301 | param.div_(math.sqrt(2.0 * layer_id)) 302 | 303 | for layer_id, layer in enumerate(self.blocks): 304 | rescale(layer.attn.proj.weight.data, layer_id + 1) 305 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 306 | 307 | def _init_weights(self, m): 308 | if isinstance(m, nn.Linear): 309 | trunc_normal_(m.weight, std=.02) 310 | if isinstance(m, nn.Linear) and m.bias is not None: 311 | nn.init.constant_(m.bias, 0) 312 | elif isinstance(m, nn.LayerNorm): 313 | nn.init.constant_(m.bias, 0) 314 | nn.init.constant_(m.weight, 1.0) 315 | 316 | def get_classifier(self): 317 | return self.head 318 | 319 | def reset_classifier(self, num_classes, global_pool=''): 320 | self.num_classes = num_classes 321 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 322 | 323 | def forward_features(self, x, return_intermediate=False): 324 | x = self.patch_embed(x) 325 | batch_size, seq_len, _ = x.size() 326 | 327 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 328 | x = torch.cat((cls_tokens, x), dim=1) 329 | if self.pos_embed is not None: 330 | x = x + self.pos_embed 331 | x = self.pos_drop(x) 332 | 333 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 334 | intermediate = [] 335 | for idx,blk in enumerate(self.blocks): 336 | if self.use_checkpoint: 337 | x = checkpoint.checkpoint(blk, x, rel_pos_bias) 338 | else: 339 | x = blk(x, rel_pos_bias) 340 | if return_intermediate: 341 | intermediate.append(x) 342 | if return_intermediate: 343 | return x, intermediate 344 | return x 345 | # x = self.norm(x) 346 | 347 | # if self.fc_norm is not None: 348 | # t = x[:, 1:, :] 349 | # return self.fc_norm(t.mean(1)) 350 | # else: 351 | # return x[:, 0] 352 | 353 | def forward(self, x, return_intermediate=False): 354 | x = self.forward_features(x, return_intermediate) 355 | # x = self.head(x) 356 | return x 357 | 358 | def get_intermediate_layers(self, x): 359 | x = self.patch_embed(x) 360 | batch_size, seq_len, _ = x.size() 361 | 362 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 363 | x = torch.cat((cls_tokens, x), dim=1) 364 | if self.pos_embed is not None: 365 | x = x + self.pos_embed 366 | x = self.pos_drop(x) 367 | 368 | features = [] 369 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 370 | for blk in self.blocks: 371 | x = blk(x, rel_pos_bias) 372 | features.append(x) 373 | 374 | return features 375 | 376 | def get_num_layer(self, var_name=""): 377 | if var_name in ("cls_token", "mask_token", "pos_embed"): 378 | return 0 379 | elif var_name.startswith("patch_embed"): 380 | return 0 381 | elif var_name.startswith("rel_pos_bias"): 382 | return len(self.blocks) - 1 383 | elif var_name.startswith("blocks"): 384 | layer_id = int(var_name.split('.')[1]) 385 | return layer_id + 1 386 | else: 387 | return len(self.blocks) 388 | 389 | 390 | def interpolate_pos_embed(model, checkpoint_model): 391 | if 'pos_embed' in checkpoint_model: 392 | pos_embed_checkpoint = checkpoint_model['pos_embed'].float() 393 | embedding_size = pos_embed_checkpoint.shape[-1] 394 | num_patches = model.patch_embed.num_patches 395 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 396 | # height (== width) for the checkpoint position embedding 397 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 398 | # height (== width) for the new position embedding 399 | new_size = int(num_patches ** 0.5) 400 | # class_token and dist_token are kept unchanged 401 | if orig_size != new_size: 402 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 403 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 404 | # only the position tokens are interpolated 405 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 406 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 407 | pos_tokens = torch.nn.functional.interpolate( 408 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 409 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 410 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 411 | checkpoint_model['pos_embed'] = new_pos_embed 412 | 413 | 414 | def convert_weights_to_fp16(model: nn.Module): 415 | """Convert applicable model parameters to fp16""" 416 | 417 | def _convert_weights_to_fp16(l): 418 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 419 | l.weight.data = l.weight.data.half() 420 | if l.bias is not None: 421 | l.bias.data = l.bias.data.half() 422 | 423 | # if isinstance(l, (nn.MultiheadAttention, Attention)): 424 | # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 425 | # tensor = getattr(l, attr) 426 | # if tensor is not None: 427 | # tensor.data = tensor.data.half() 428 | 429 | model.apply(_convert_weights_to_fp16) 430 | 431 | 432 | def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16",path="/home/ubuntu/Models/eva_vit_g.pth",depth=39): 433 | print("ViT Depth:", depth) 434 | model = VisionTransformer( 435 | img_size=img_size, 436 | patch_size=14, 437 | use_mean_pooling=False, 438 | embed_dim=1408, 439 | depth=depth, 440 | num_heads=1408//88, 441 | mlp_ratio=4.3637, 442 | qkv_bias=True, 443 | drop_path_rate=drop_path_rate, 444 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 445 | use_checkpoint=use_checkpoint, 446 | ) 447 | state_dict = torch.load(path, map_location="cpu") 448 | interpolate_pos_embed(model,state_dict) 449 | 450 | incompatible_keys = model.load_state_dict(state_dict, strict=False) 451 | 452 | if precision == "fp16": 453 | convert_weights_to_fp16(model) 454 | return model -------------------------------------------------------------------------------- /models/lion_adapters.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.modeling_t5 import T5LayerFF 7 | 8 | import torch 9 | import torch.nn as nn 10 | from transformers import BertLayer, BertConfig 11 | from transformers.models.bert.modeling_bert import BertOutput, BertSelfOutput 12 | 13 | class FusionAdapter(nn.Module): 14 | def __init__( 15 | self, 16 | num_blocks: int = 2, 17 | dim: int = 1408, 18 | num_heads: int = 16, 19 | ): 20 | super().__init__() 21 | config = BertConfig( 22 | hidden_size=dim, 23 | num_attention_heads=num_heads 24 | ) 25 | config.add_cross_attention = True 26 | config.is_decoder = True 27 | self.config = config 28 | self.blocks = nn.ModuleList([BertLayer(config) for _ in range(num_blocks)]) 29 | self.apply(self._init_weights) 30 | 31 | def forward(self, hidden_states, encoder_hidden_states): 32 | if isinstance(encoder_hidden_states, list): 33 | assert len(encoder_hidden_states) == len(self.blocks) 34 | for idx,block in enumerate(self.blocks): 35 | hidden_states = block( 36 | hidden_states, 37 | encoder_hidden_states=encoder_hidden_states[idx], 38 | )[0] 39 | else: 40 | for idx,block in enumerate(self.blocks): 41 | hidden_states = block( 42 | hidden_states, 43 | encoder_hidden_states=encoder_hidden_states, 44 | )[0] 45 | return hidden_states 46 | 47 | def _init_weights(self, module): 48 | """ Initialize the weights """ 49 | if isinstance(module, (nn.Linear, nn.Embedding)): 50 | # Slightly different from the TF version which uses truncated_normal for initialization 51 | # cf https://github.com/pytorch/pytorch/pull/5617 52 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 53 | elif isinstance(module, nn.LayerNorm): 54 | module.bias.data.zero_() 55 | module.weight.data.fill_(1.0) 56 | if isinstance(module, nn.Linear) and module.bias is not None: 57 | module.bias.data.zero_() 58 | if isinstance(module, BertSelfOutput) or isinstance(module, BertOutput): 59 | module.dense.weight.data.zero_() 60 | module.dense.bias.data.zero_() 61 | 62 | class Adapter(nn.Module): 63 | def __init__(self, 64 | d_model=None, 65 | bottleneck=64, 66 | dropout=0.0, 67 | init_option="lora", 68 | adapter_scalar="learnable_scalar", 69 | adapter_layernorm_option="none"): 70 | super().__init__() 71 | self.n_embd = d_model 72 | self.down_size = bottleneck 73 | 74 | #_before 75 | self.adapter_layernorm_option = adapter_layernorm_option 76 | 77 | self.adapter_layer_norm_before = None 78 | if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": 79 | self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd) 80 | 81 | if adapter_scalar == "learnable_scalar": 82 | self.scale = nn.Parameter(torch.ones(1)) 83 | else: 84 | self.scale = float(adapter_scalar) 85 | 86 | self.down_proj = nn.Linear(self.n_embd, self.down_size) 87 | self.non_linear_func = nn.ReLU() 88 | self.up_proj = nn.Linear(self.down_size, self.n_embd) 89 | 90 | self.dropout = dropout 91 | if init_option == "bert": 92 | raise NotImplementedError 93 | elif init_option == "lora": 94 | with torch.no_grad(): 95 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 96 | nn.init.zeros_(self.up_proj.weight) 97 | nn.init.zeros_(self.down_proj.bias) 98 | nn.init.zeros_(self.up_proj.bias) 99 | 100 | def forward(self, x): 101 | if self.adapter_layernorm_option == 'in': 102 | x = self.adapter_layer_norm_before(x) 103 | 104 | down = self.down_proj(x) 105 | down = self.non_linear_func(down) 106 | down = nn.functional.dropout(down, p=self.dropout, training=self.training) 107 | up = self.up_proj(down) 108 | 109 | up = up * self.scale 110 | 111 | if self.adapter_layernorm_option == 'out': 112 | up = self.adapter_layer_norm_before(up) 113 | 114 | return up 115 | 116 | class AdapterRouter(nn.Module): 117 | def __init__(self, 118 | d_model, 119 | bottleneck=64, 120 | dropout=0.0, 121 | init_option="lora", 122 | adapter_scalar="learnable_scalar", 123 | adapter_layernorm_option="none", 124 | num_adapters = 2, 125 | ): 126 | super().__init__() 127 | self.adapters = nn.ModuleList([]) 128 | for _ in range(num_adapters): 129 | self.adapters.append(Adapter( 130 | d_model=d_model, 131 | bottleneck=bottleneck, 132 | dropout=dropout, 133 | init_option=init_option, 134 | adapter_scalar=adapter_scalar, 135 | adapter_layernorm_option=adapter_layernorm_option 136 | )) 137 | if num_adapters > 1: 138 | self.router_ratio1 = nn.Parameter( 139 | torch.tensor([[1],[0]],dtype=torch.float32).repeat(1,d_model) 140 | ) 141 | self.router_ratio2 = nn.Parameter( 142 | torch.tensor([[0],[1]],dtype=torch.float32).repeat(1,d_model) 143 | ) 144 | self.num_adapters = num_adapters 145 | self.router_idx = None 146 | 147 | def forward(self, x): 148 | assert self.router_idx in [0,1] 149 | output1 = self.adapters[0](x) 150 | if self.num_adapters == 1: 151 | return output1 152 | 153 | output2 = self.adapters[1](x) 154 | ratio1 = self.router_ratio1[self.router_idx] 155 | ratio2 = self.router_ratio2[self.router_idx] 156 | return output1 * ratio1 + output2 * ratio2 157 | 158 | def forward_ffn_t5(self, hidden_states): 159 | adapt_hidden_states = self.adapter(hidden_states) 160 | forwarded_states = self.layer_norm(hidden_states) 161 | forwarded_states = self.DenseReluDense(forwarded_states) 162 | hidden_states = hidden_states + self.dropout(forwarded_states) 163 | return hidden_states + adapt_hidden_states 164 | 165 | def set_adapter_t5(model: nn.Module, d_model: int, n: int, bottleneck: int = 64): 166 | for c in model.children(): 167 | if isinstance(c, T5LayerFF): 168 | c.adapter = AdapterRouter(d_model=d_model, bottleneck=bottleneck, num_adapters=n) 169 | bound_method = forward_ffn_t5.__get__(c, c.__class__) 170 | setattr(c, 'forward', bound_method) 171 | elif len(list(c.children())) > 0: 172 | set_adapter_t5(c, d_model, n, bottleneck) 173 | 174 | def set_router_idx(model: nn.Module, idx: int): 175 | for c in model.children(): 176 | if isinstance(c, AdapterRouter): 177 | c.router_idx = idx 178 | elif len(list(c.children())) > 0: 179 | set_router_idx(c, idx) -------------------------------------------------------------------------------- /models/lion_t5.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import string 4 | from typing import Literal, Union, List 5 | from PIL.Image import Image 6 | 7 | import torch 8 | import torch.nn as nn 9 | from icecream import ic 10 | from torch.cuda.amp import autocast as autocast 11 | from transformers import BertTokenizer, T5TokenizerFast 12 | from transformers.modeling_outputs import BaseModelOutput 13 | 14 | from common.registry import registry 15 | from models.base_model import BaseModel 16 | from models.eva_vit import create_eva_vit_g 17 | from models.lion_adapters import FusionAdapter, set_adapter_t5, set_router_idx 18 | from models.modeling_t5 import T5Config, T5ForConditionalGeneration 19 | from models.Qformer import BertConfig, BertLMHeadModel 20 | from ram import get_transform 21 | from ram.models import ram 22 | 23 | 24 | class LayerNorm(nn.LayerNorm): 25 | """Subclass torch's LayerNorm to handle fp16.""" 26 | 27 | def forward(self, x: torch.Tensor): 28 | orig_type = x.dtype 29 | ret = super().forward(x.type(torch.float32)) 30 | return ret.type(orig_type) 31 | 32 | def disabled_train(self, mode=True): 33 | """Overwrite model.train with this function to make sure train/eval mode 34 | does not change anymore.""" 35 | return self 36 | 37 | @registry.register_model("lion_t5") 38 | class LIONT5InstructAdapter(BaseModel): 39 | """ 40 | LION T5 model. 41 | Supported model types: 42 | - flant5xl 43 | - flant5xxl 44 | Usage: 45 | >>> from models import load_model 46 | >>> model = load_model("lion_t5", "flant5xl") 47 | """ 48 | 49 | PRETRAINED_MODEL_CONFIG_DICT = { 50 | "flant5xl": "configs/models/lion_flant5xl.yaml", 51 | "flant5xxl": "configs/models/lion_flant5xxl.yaml", 52 | } 53 | 54 | def __init__( 55 | self, 56 | bert_model, 57 | vit_model, 58 | llm_model, 59 | ram_model, 60 | max_txt_len=128, 61 | max_output_txt_len=128, 62 | visual_input: Literal["ALL", "QFORMER", "AGGREGATOR"] = "ALL", 63 | enable_semantic_tags=True, 64 | boost_lr_scale=1, 65 | 66 | img_size=224, 67 | drop_path_rate=0, 68 | use_grad_checkpoint=False, 69 | vit_precision="fp16", 70 | freeze_vit=True, 71 | num_query_token=32, 72 | ): 73 | super().__init__() 74 | assert bert_model is not None, "The path for bert model is not provided." 75 | assert vit_model is not None, "The path for vit model is not provided." 76 | assert llm_model is not None, "The path for llm model is not provided." 77 | assert visual_input in ["ALL", "QFORMER", "AGGREGATOR"], f"Invalid visual input type: {visual_input}." 78 | self.bert_model = bert_model 79 | self.visual_input = visual_input 80 | self.enable_semantic_tags = enable_semantic_tags 81 | self.max_txt_len = max_txt_len 82 | self.max_output_txt_len = max_output_txt_len 83 | self.boost_lr_scale = boost_lr_scale 84 | self.ram_path = ram_model 85 | self.ram_model = None 86 | logging.info(f"visual_input: {visual_input}") 87 | 88 | print("Loading VIT") 89 | self.visual_encoder = self._init_vision_encoder( 90 | vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision 91 | ) 92 | if freeze_vit: 93 | for name, param in self.visual_encoder.named_parameters(): 94 | param.requires_grad = False 95 | self.visual_encoder = self.visual_encoder.eval() 96 | self.visual_encoder.train = disabled_train 97 | logging.info("freeze vision encoder") 98 | print("Loading VIT Done") 99 | 100 | self._init_llm(llm_model) 101 | 102 | if self.visual_input != "AGGREGATOR": 103 | print("Loading QFormer") 104 | self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model, truncation_side="left") 105 | self.bert_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 106 | self.Qformer, self.query_tokens = self._init_Qformer( 107 | bert_model, num_query_token, self.visual_encoder.num_features 108 | ) 109 | self.Qformer.resize_token_embeddings(len(self.bert_tokenizer)) 110 | self.Qformer.cls = None 111 | self.t5_proj = nn.Linear( 112 | self.Qformer.config.hidden_size, self.t5_model.config.hidden_size 113 | ) 114 | self.ln_vision = LayerNorm(self.visual_encoder.num_features) 115 | print("Loading QFormer Done") 116 | 117 | if self.visual_input != "QFORMER": 118 | print("Loading Vision Aggregator") 119 | self.ln_adapter = LayerNorm(self.visual_encoder.num_features) 120 | self.adapter_proj = nn.Sequential( 121 | nn.Linear(self.visual_encoder.num_features, self.visual_encoder.num_features * 4), 122 | nn.GELU(), 123 | nn.Linear(self.visual_encoder.num_features * 4, self.t5_model.config.hidden_size), 124 | ) 125 | self.fusion_adapter = FusionAdapter(num_blocks=2,dim=self.visual_encoder.num_features) 126 | print("Loading Vision Aggregator Done") 127 | 128 | if self.enable_semantic_tags: 129 | tag_sp_token = "" 130 | self.tag_softPrompt_id = self.t5_tokenizer.convert_tokens_to_ids(tag_sp_token) 131 | self.tag_prompt = "According to , you are allowed to use or partially use the following tags: [{}]. " 132 | self.soft_prompt_hint = nn.Parameter(torch.zeros(self.t5_model.config.hidden_size)) 133 | self.soft_prompt_hint.data.normal_(mean=0.0, std=self.t5_model.config.initializer_factor) 134 | logging.info(f"boost_lr_scale:{boost_lr_scale}") 135 | 136 | def maybe_autocast(self, dtype=torch.bfloat16): 137 | # if on cpu, don't use autocast 138 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 139 | enable_autocast = self.device != torch.device("cpu") 140 | 141 | if enable_autocast: 142 | return torch.cuda.amp.autocast(dtype=dtype) 143 | else: 144 | return contextlib.nullcontext() 145 | 146 | def _init_vision_encoder( 147 | self, model_path, img_size, drop_path_rate, use_grad_checkpoint, precision 148 | ): 149 | print("Using normal vit") 150 | visual_encoder = create_eva_vit_g( 151 | img_size, drop_path_rate, use_grad_checkpoint, precision, model_path 152 | ) 153 | return visual_encoder 154 | 155 | def _init_Qformer(self, bert_model, num_query_token, vision_width, cross_attention_freq=2): 156 | encoder_config = BertConfig.from_pretrained(bert_model) 157 | encoder_config.encoder_width = vision_width 158 | encoder_config.add_cross_attention = True 159 | encoder_config.cross_attention_freq = cross_attention_freq 160 | encoder_config.query_length = num_query_token 161 | Qformer = BertLMHeadModel(config=encoder_config) 162 | query_tokens = nn.Parameter( 163 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 164 | ) 165 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 166 | return Qformer, query_tokens 167 | 168 | def _init_llm(self, llm_model): 169 | print("Loading LLM") 170 | self.t5_tokenizer = T5TokenizerFast.from_pretrained(llm_model, truncation_side='left') 171 | self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(llm_model, truncation_side='right') 172 | 173 | llm_config = T5Config.from_pretrained(llm_model) 174 | llm_config.dense_act_fn = "gelu" 175 | self.t5_model = T5ForConditionalGeneration.from_pretrained( 176 | llm_model, config=llm_config, torch_dtype=torch.bfloat16, 177 | ) 178 | set_adapter_t5(self.t5_model, self.t5_model.config.d_model, n=2 if self.visual_input=="ALL" else 1, bottleneck=64) 179 | 180 | for name, param in self.t5_model.named_parameters(): 181 | if "adapter" in name: 182 | if "router_ratio" in name and self.visual_input != "ALL": 183 | param.requires_grad = False 184 | else: 185 | param.requires_grad = True 186 | else: 187 | param.requires_grad = False 188 | print("Loading LLM Done") 189 | 190 | def _init_ram(self): 191 | if self.ram_model == None: 192 | print("Loading RAM Model For Tag Generation") 193 | self.ram_model = ram(pretrained=self.ram_path, image_size=384, vit="swin_l", text_encoder_type=self.bert_model).cuda() 194 | self.ram_processor = get_transform() 195 | print("Loading RAM Model Done") 196 | 197 | def generate_tags(self, images:Union[List[Image], Image]) -> List[str]: 198 | """ 199 | Generate tags for provided images. 200 | 201 | Args: 202 | images (Image or List[Image]) 203 | Returns: 204 | tags (List[str]) 205 | """ 206 | 207 | self._init_ram() 208 | if isinstance(images, Image): 209 | images = [images] 210 | images = torch.stack([self.ram_processor(img) for img in images]).to(self.device) 211 | tags = self.ram_model.generate_tag(images, threshold=0.85)[0] 212 | return [t.replace(" |",",") for t in tags] 213 | 214 | def _insert_tags(self, samples, prompt): 215 | if self.enable_semantic_tags: 216 | assert self.tag_prompt is not None, "Please provide Tags prompt." 217 | if "tags" not in samples: 218 | samples = self._generate_tags(samples) 219 | prompt = [self.tag_prompt.format(tags) + tin for tags, tin in zip(samples["tags"], prompt)] 220 | return prompt 221 | 222 | def _insert_softTagHint(self, samples, input_tokens, inputs_embeds): 223 | if self.enable_semantic_tags: 224 | bs = inputs_embeds.size(0) 225 | sp_embeds = self.soft_prompt_hint.expand(bs, -1).to(inputs_embeds.dtype) 226 | sp_index = (input_tokens.input_ids == self.tag_softPrompt_id).nonzero(as_tuple=True) 227 | inputs_embeds[sp_index] = sp_embeds 228 | return inputs_embeds 229 | 230 | def get_optimizer_params(self, weight_decay, lr_scale=1): 231 | p_wd, p_non_wd = [], [] 232 | p_boost, p_boost_non_wd = [], [] 233 | boost_name = [] 234 | for n, p in self.named_parameters(): 235 | if not p.requires_grad: 236 | continue # frozen weights 237 | if "fusion_adapter" in n: 238 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: 239 | p_boost_non_wd.append(p) 240 | else: 241 | p_boost.append(p) 242 | boost_name.append(n) 243 | else: 244 | if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: 245 | p_non_wd.append(p) 246 | else: 247 | p_wd.append(p) 248 | optim_params = [ 249 | {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale}, 250 | {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale}, 251 | {"params": p_boost, "weight_decay": weight_decay, "lr_scale": lr_scale*self.boost_lr_scale}, 252 | {"params": p_boost_non_wd, "weight_decay": 0, "lr_scale": lr_scale*self.boost_lr_scale}, 253 | ] 254 | logging.info(f"boost params:{boost_name}") 255 | return optim_params 256 | 257 | def encode_img(self, image, question): 258 | with self.maybe_autocast(dtype=torch.bfloat16): 259 | image_embeds, intermediate = self.visual_encoder(image, return_intermediate=True) 260 | if self.visual_input != "QFORMER": 261 | adapter_embeds = self.ln_adapter(self.fusion_adapter(intermediate[38], [intermediate[28],intermediate[18]])) 262 | adapter_embeds = self.adapter_proj(adapter_embeds) 263 | if self.visual_input != "AGGREGATOR": 264 | image_embeds = self.ln_vision(image_embeds) 265 | if self.visual_input != "AGGREGATOR": 266 | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) 267 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) 268 | 269 | text_Qformer = self.bert_tokenizer( 270 | question, 271 | padding='longest', 272 | truncation=True, 273 | max_length=self.max_txt_len, 274 | return_tensors="pt", 275 | ).to(image.device) 276 | query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) 277 | Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) 278 | 279 | query_output = self.Qformer.bert( 280 | text_Qformer.input_ids, 281 | attention_mask=Qformer_atts, 282 | query_embeds=query_tokens, 283 | encoder_hidden_states=image_embeds, 284 | encoder_attention_mask=image_atts, 285 | return_dict=True, 286 | ) 287 | input_qformer = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:]) 288 | 289 | match self.visual_input: 290 | case "AGGREGATOR": 291 | inputs_t5 = adapter_embeds 292 | case "QFORMER": 293 | inputs_t5 = input_qformer 294 | case "ALL": 295 | inputs_t5 = torch.cat([input_qformer, adapter_embeds],dim=1) 296 | case _: 297 | raise NotImplementedError(f"Visual input type {self.visual_input} is not supported.") 298 | 299 | atts_t5 = torch.ones(inputs_t5.size()[:-1],dtype=torch.long).to(image.device) 300 | 301 | return inputs_t5, atts_t5 302 | 303 | def forward(self, samples): 304 | img_embeds, img_atts = self.encode_img(samples["image"], samples["question"]) 305 | prompt = self._insert_tags(samples, samples["question"]) 306 | with self.maybe_autocast(dtype=torch.bfloat16): 307 | input_tokens = self.t5_tokenizer( 308 | prompt, 309 | padding="longest", 310 | truncation=True, 311 | max_length=self.max_txt_len, 312 | return_tensors="pt", 313 | ).to(self.device) 314 | output_tokens = self.t5_output_tokenizer( 315 | samples["answer"], 316 | padding="longest", 317 | truncation=True, 318 | max_length=self.max_output_txt_len, 319 | return_tensors="pt", 320 | ).to(self.device) 321 | 322 | targets = output_tokens.input_ids.masked_fill( 323 | output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100 324 | ) 325 | 326 | text_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) 327 | text_embeds = self._insert_softTagHint(samples, input_tokens, text_embeds) 328 | text_atts = input_tokens.attention_mask 329 | 330 | input_embeds = torch.cat([img_embeds, text_embeds], dim=1) 331 | encoder_atts = torch.cat([img_atts, text_atts], dim=1) 332 | 333 | set_router_idx(self.t5_model, int(samples["category"][0] != "region_level")) 334 | outputs = self.t5_model( 335 | inputs_embeds=input_embeds, 336 | attention_mask=encoder_atts, 337 | decoder_attention_mask=output_tokens.attention_mask, 338 | return_dict=True, 339 | labels=targets, 340 | ) 341 | loss = outputs.loss 342 | 343 | return {"loss": loss} 344 | 345 | @torch.no_grad() 346 | def generate( 347 | self, 348 | samples, 349 | use_nucleus_sampling=False, 350 | num_beams=5, 351 | max_length=256, 352 | min_length=1, 353 | top_p=0.9, 354 | repetition_penalty=1.5, 355 | length_penalty=1.0, 356 | num_captions=1, 357 | temperature=1, 358 | ): 359 | img_embeds, img_atts = self.encode_img(samples["image"].to(self.device), samples["question"]) 360 | prompt = self._insert_tags(samples, samples["question"]) 361 | input_tokens = self.t5_tokenizer( 362 | prompt, 363 | padding="longest", 364 | return_tensors="pt" 365 | ).to(self.device) 366 | 367 | with self.maybe_autocast(dtype=torch.bfloat16): 368 | text_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) 369 | text_embeds = self._insert_softTagHint(samples, input_tokens, text_embeds) 370 | text_atts = input_tokens.attention_mask 371 | 372 | inputs_embeds = torch.cat([img_embeds, text_embeds], dim=1) 373 | input_atts = torch.cat([img_atts, text_atts], dim=1) 374 | set_router_idx(self.t5_model, int(samples.get("category") != "region_level")) 375 | outputs = self.t5_model.generate( 376 | inputs_embeds=inputs_embeds, 377 | attention_mask=input_atts, 378 | do_sample=use_nucleus_sampling, 379 | top_p=top_p, 380 | temperature=temperature, 381 | num_beams=num_beams, 382 | max_new_tokens=max_length, 383 | min_length=min_length, 384 | repetition_penalty=repetition_penalty, 385 | length_penalty=length_penalty, 386 | num_return_sequences=num_captions, 387 | ) 388 | output_text = self.t5_tokenizer.batch_decode( 389 | outputs, skip_special_tokens=True 390 | ) 391 | 392 | return output_text 393 | 394 | @classmethod 395 | def from_config(cls, cfg): 396 | bert_model = cfg.get("bert_model") 397 | vit_model = cfg.get("vit_model") 398 | llm_model = cfg.get("llm_model") 399 | ram_model = cfg.get("ram_model") 400 | 401 | max_txt_len = cfg.get("max_txt_len", 128) 402 | max_output_txt_len = cfg.get("max_output_txt_len", 128) 403 | visual_input = cfg.get("visual_input", "ALL") 404 | enable_semantic_tags = cfg.get("enable_semantic_tags", True) 405 | boost_lr_scale = cfg.get("boost_lr_scale", 1.0) 406 | 407 | img_size = cfg.get("image_size", 224) 408 | drop_path_rate = cfg.get("drop_path_rate", 0) 409 | use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) 410 | vit_precision = cfg.get("vit_precision", "fp16") 411 | freeze_vit = cfg.get("freeze_vit", True) 412 | num_query_token = cfg.get("num_query_token", 32) 413 | 414 | model = cls( 415 | bert_model=bert_model, 416 | vit_model=vit_model, 417 | llm_model=llm_model, 418 | ram_model=ram_model, 419 | max_txt_len=max_txt_len, 420 | max_output_txt_len=max_output_txt_len, 421 | visual_input=visual_input, 422 | enable_semantic_tags=enable_semantic_tags, 423 | boost_lr_scale=boost_lr_scale, 424 | 425 | img_size=img_size, 426 | drop_path_rate=drop_path_rate, 427 | use_grad_checkpoint=use_grad_checkpoint, 428 | vit_precision=vit_precision, 429 | freeze_vit=freeze_vit, 430 | num_query_token=num_query_token, 431 | ) 432 | 433 | model.load_checkpoint_from_config(cfg) 434 | 435 | return model 436 | -------------------------------------------------------------------------------- /preprocessors/lion_preprocessors.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from torchvision import transforms 3 | from torchvision.transforms.functional import InterpolationMode 4 | 5 | from common.registry import registry 6 | 7 | 8 | class BaseProcessor: 9 | def __init__(self, mean=None, std=None): 10 | if mean is None: 11 | mean = (0.48145466, 0.4578275, 0.40821073) 12 | if std is None: 13 | std = (0.26862954, 0.26130258, 0.27577711) 14 | 15 | self.normalize = transforms.Normalize(mean, std) 16 | 17 | 18 | @registry.register_processor("eval") 19 | class ImageEvalProcessor(BaseProcessor): 20 | def __init__(self, image_size=224, mean=None, std=None): 21 | super().__init__(mean=mean, std=std) 22 | 23 | self.transform = transforms.Compose( 24 | [ 25 | transforms.Resize( 26 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 27 | ), 28 | transforms.ToTensor(), 29 | self.normalize, 30 | ] 31 | ) 32 | 33 | def __call__(self, item): 34 | return self.transform(item) 35 | 36 | @classmethod 37 | def from_config(cls, cfg=None): 38 | if cfg is None: 39 | cfg = OmegaConf.create() 40 | 41 | image_size = cfg.get("image_size", 224) 42 | 43 | mean = cfg.get("mean", None) 44 | std = cfg.get("std", None) 45 | 46 | return cls(image_size=image_size, mean=mean, std=std) 47 | 48 | 49 | @registry.register_processor("train") 50 | class ImageTrainProcessor(BaseProcessor): 51 | def __init__( 52 | self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0 53 | ): 54 | super().__init__(mean=mean, std=std) 55 | 56 | self.transform = transforms.Compose( 57 | [ 58 | transforms.RandomResizedCrop( 59 | image_size, 60 | scale=(min_scale, max_scale), 61 | interpolation=InterpolationMode.BICUBIC, 62 | ), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | self.normalize, 66 | ] 67 | ) 68 | 69 | def __call__(self, item): 70 | return self.transform(item) 71 | 72 | @classmethod 73 | def from_config(cls, cfg=None): 74 | if cfg is None: 75 | cfg = OmegaConf.create() 76 | 77 | image_size = cfg.get("image_size", 224) 78 | 79 | mean = cfg.get("mean", None) 80 | std = cfg.get("std", None) 81 | 82 | min_scale = cfg.get("min_scale", 0.5) 83 | max_scale = cfg.get("max_scale", 1.0) 84 | 85 | return cls( 86 | image_size=image_size, 87 | mean=mean, 88 | std=std, 89 | min_scale=min_scale, 90 | max_scale=max_scale, 91 | ) -------------------------------------------------------------------------------- /ram/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import inference_tag2text, inference_ram, inference_ram_openset 2 | from .transform import get_transform 3 | -------------------------------------------------------------------------------- /ram/configs/finetune.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | ] 4 | image_path_root: "" 5 | 6 | # size of vit model; base or large 7 | vit: 'swin_l' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 384 12 | batch_size: 26 13 | 14 | # optimizer 15 | weight_decay: 0.05 16 | init_lr: 5e-06 17 | min_lr: 0 18 | max_epoch: 2 19 | warmup_steps: 3000 20 | 21 | class_num: 4585 22 | 23 | -------------------------------------------------------------------------------- /ram/configs/finetune_tag2text.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | ] 4 | image_path_root: "" 5 | 6 | # size of vit model; base or large 7 | vit: 'swin_b' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 384 12 | batch_size: 36 13 | 14 | # optimizer 15 | weight_decay: 0.05 16 | init_lr: 5e-06 17 | min_lr: 0 18 | max_epoch: 2 19 | warmup_steps: 3000 20 | 21 | class_num: 4585 22 | 23 | -------------------------------------------------------------------------------- /ram/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } -------------------------------------------------------------------------------- /ram/configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | 'datasets/train/vg_ram.json', 4 | 'datasets/train/sbu_ram.json', 5 | 'datasets/train/cc3m_train_ram.json', 6 | 'datasets/train/cc3m_val_ram.json', 7 | 'datasets/train/cc12m_ram.json', 8 | ] 9 | image_path_root: "" 10 | 11 | # size of vit model; base or large 12 | vit: 'swin_l' 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | 16 | image_size: 224 17 | batch_size: 52 18 | 19 | # optimizer 20 | weight_decay: 0.05 21 | init_lr: 1e-4 22 | min_lr: 5e-7 23 | warmup_lr: 5e-7 24 | lr_decay_rate: 0.9 25 | max_epoch: 5 26 | warmup_steps: 3000 27 | 28 | class_num: 4585 29 | 30 | -------------------------------------------------------------------------------- /ram/configs/pretrain_tag2text.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | 'datasets/train/vg_ram.json', 4 | 'datasets/train/sbu_ram.json', 5 | 'datasets/train/cc3m_train_ram.json', 6 | 'datasets/train/cc3m_val_ram.json', 7 | 'datasets/train/cc12m_ram.json', 8 | ] 9 | image_path_root: "" 10 | 11 | # size of vit model; base or large 12 | vit: 'swin_b' 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | 16 | image_size: 224 17 | batch_size: 80 18 | 19 | # optimizer 20 | weight_decay: 0.05 21 | init_lr: 1e-4 22 | min_lr: 5e-7 23 | warmup_lr: 5e-7 24 | lr_decay_rate: 0.9 25 | max_epoch: 5 26 | warmup_steps: 3000 27 | 28 | class_num: 4585 29 | 30 | -------------------------------------------------------------------------------- /ram/configs/q2l_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 4, 15 | "num_hidden_layers": 2, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true, 21 | "add_tag_cross_attention": false 22 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinB_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 224, 5 | "window_size": 7, 6 | "embed_dim": 128, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 4, 8, 16, 32 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinB_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 128, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 4, 8, 16, 32 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window7_224_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 224, 5 | "window_size": 7, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } 10 | -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } -------------------------------------------------------------------------------- /ram/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms.functional import InterpolationMode 5 | 6 | from .dataset import pretrain_dataset, finetune_dataset 7 | from .randaugment import RandomAugment 8 | 9 | def create_dataset(dataset, config, min_scale=0.5): 10 | 11 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 12 | 13 | transform_train = transforms.Compose([ 14 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 15 | transforms.RandomHorizontalFlip(), 16 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 17 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 18 | transforms.ToTensor(), 19 | normalize, 20 | ]) 21 | 22 | transform_inputsize_224 = transforms.Compose([ 23 | transforms.RandomResizedCrop(224,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 24 | transforms.RandomHorizontalFlip(), 25 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 26 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 27 | transforms.ToTensor(), 28 | normalize, 29 | ]) 30 | 31 | if dataset=='pretrain': 32 | dataset = pretrain_dataset(config['train_file'], transform_train, class_num=config['class_num'], root=config['image_path_root']) 33 | return dataset 34 | 35 | elif dataset=='finetune': 36 | dataset = finetune_dataset(config['train_file'], transform_train, transform_inputsize_224, class_num=config['class_num'], root=config['image_path_root']) 37 | return dataset 38 | 39 | 40 | 41 | 42 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 43 | samplers = [] 44 | for dataset,shuffle in zip(datasets,shuffles): 45 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 46 | samplers.append(sampler) 47 | return samplers 48 | 49 | 50 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 51 | loaders = [] 52 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 53 | if is_train: 54 | shuffle = (sampler is None) 55 | drop_last = True 56 | else: 57 | shuffle = False 58 | drop_last = False 59 | loader = DataLoader( 60 | dataset, 61 | batch_size=bs, 62 | num_workers=n_worker, 63 | pin_memory=True, 64 | sampler=sampler, 65 | shuffle=shuffle, 66 | collate_fn=collate_fn, 67 | drop_last=drop_last, 68 | ) 69 | loaders.append(loader) 70 | return loaders 71 | 72 | -------------------------------------------------------------------------------- /ram/data/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from .utils import pre_caption 13 | import os,glob 14 | 15 | import torch 16 | import numpy as np 17 | 18 | class pretrain_dataset(Dataset): 19 | def __init__(self, ann_file, transform, class_num = 4585, root = ''): 20 | 21 | self.ann = [] 22 | for f in ann_file: 23 | print('loading '+f) 24 | ann = json.load(open(f,'r')) 25 | self.ann += ann 26 | 27 | self.transform = transform 28 | self.class_num = class_num 29 | self.root = root 30 | 31 | 32 | def __len__(self): 33 | return len(self.ann) 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.ann[index] 38 | 39 | image_path_use = os.path.join(self.root, ann['image_path']) 40 | image = Image.open(image_path_use).convert('RGB') 41 | image = self.transform(image) 42 | 43 | # required for tag2text support 44 | if ann.get('union_label_id') is not None: 45 | num = ann['union_label_id'] 46 | image_tag = np.zeros([self.class_num]) 47 | image_tag[num] = 1 48 | image_tag = torch.tensor(image_tag, dtype = torch.long) 49 | else: 50 | image_tag = None 51 | 52 | caption_index = np.random.randint(0, len(ann['caption'])) 53 | 54 | caption = pre_caption(ann['caption'][caption_index],30) 55 | 56 | num = ann['parse_label_id'][caption_index] 57 | parse_tag = np.zeros([self.class_num]) 58 | parse_tag[num] = 1 59 | parse_tag = torch.tensor(parse_tag, dtype = torch.long) 60 | 61 | return image, caption, image_tag, parse_tag 62 | 63 | 64 | class finetune_dataset(Dataset): 65 | def __init__(self, ann_file, transform, transform_224, class_num = 4585, root = ''): 66 | 67 | self.ann = [] 68 | for f in ann_file: 69 | print('loading '+f) 70 | ann = json.load(open(f,'r')) 71 | self.ann += ann 72 | 73 | self.transform = transform 74 | self.transform_224 = transform_224 75 | self.class_num = class_num 76 | self.root = root 77 | 78 | 79 | def __len__(self): 80 | return len(self.ann) 81 | 82 | def __getitem__(self, index): 83 | 84 | ann = self.ann[index] 85 | 86 | image_path_use = os.path.join(self.root, ann['image_path']) 87 | image = Image.open(image_path_use).convert('RGB') 88 | image = self.transform(image) 89 | 90 | image_224 = Image.open(image_path_use).convert('RGB') 91 | image_224 = self.transform_224(image_224) 92 | 93 | # required for tag2text support 94 | if ann.get('union_label_id') is not None: 95 | num = ann['union_label_id'] 96 | image_tag = np.zeros([self.class_num]) 97 | image_tag[num] = 1 98 | image_tag = torch.tensor(image_tag, dtype = torch.long) 99 | else: 100 | image_tag = None 101 | 102 | caption_index = np.random.randint(0, len(ann['caption'])) 103 | 104 | caption = pre_caption(ann['caption'][caption_index],30) 105 | 106 | num = ann['parse_label_id'][caption_index] 107 | parse_tag = np.zeros([self.class_num]) 108 | parse_tag[num] = 1 109 | parse_tag = torch.tensor(parse_tag, dtype = torch.long) 110 | 111 | return image, image_224, caption, image_tag, parse_tag 112 | 113 | -------------------------------------------------------------------------------- /ram/data/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /ram/data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils 9 | 10 | def pre_caption(caption,max_words=50): 11 | caption = re.sub( 12 | r"([.!\"()*#:;~])", 13 | ' ', 14 | caption.lower(), 15 | ) 16 | caption = re.sub( 17 | r"\s{2,}", 18 | ' ', 19 | caption, 20 | ) 21 | caption = caption.rstrip('\n') 22 | caption = caption.strip(' ') 23 | 24 | #truncate caption 25 | caption_words = caption.split(' ') 26 | if len(caption_words)>max_words: 27 | caption = ' '.join(caption_words[:max_words]) 28 | 29 | return caption 30 | 31 | def pre_question(question,max_ques_words=50): 32 | question = re.sub( 33 | r"([.!\"()*#:;~])", 34 | '', 35 | question.lower(), 36 | ) 37 | question = question.rstrip(' ') 38 | 39 | #truncate question 40 | question_words = question.split(' ') 41 | if len(question_words)>max_ques_words: 42 | question = ' '.join(question_words[:max_ques_words]) 43 | 44 | return question 45 | 46 | 47 | def save_result(result, result_dir, filename, remove_duplicate=''): 48 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 49 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 50 | 51 | json.dump(result,open(result_file,'w')) 52 | 53 | dist.barrier() 54 | 55 | if utils.is_main_process(): 56 | # combine results from all processes 57 | result = [] 58 | 59 | for rank in range(utils.get_world_size()): 60 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 61 | res = json.load(open(result_file,'r')) 62 | result += res 63 | 64 | if remove_duplicate: 65 | result_new = [] 66 | id_list = [] 67 | for res in result: 68 | if res[remove_duplicate] not in id_list: 69 | id_list.append(res[remove_duplicate]) 70 | result_new.append(res) 71 | result = result_new 72 | 73 | json.dump(result,open(final_result_file,'w')) 74 | print('result file saved to %s'%final_result_file) 75 | 76 | return final_result_file 77 | 78 | 79 | 80 | from pycocotools.coco import COCO 81 | from pycocoevalcap.eval import COCOEvalCap 82 | from torchvision.datasets.utils import download_url 83 | 84 | def coco_caption_eval(coco_gt_root, results_file, split): 85 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 86 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 87 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 88 | 89 | download_url(urls[split],coco_gt_root) 90 | annotation_file = os.path.join(coco_gt_root,filenames[split]) 91 | 92 | # create coco object and coco_result object 93 | coco = COCO(annotation_file) 94 | coco_result = coco.loadRes(results_file) 95 | 96 | # create coco_eval object by taking coco and coco_result 97 | coco_eval = COCOEvalCap(coco, coco_result) 98 | 99 | # evaluate on a subset of images by setting 100 | # coco_eval.params['image_id'] = coco_result.getImgIds() 101 | # please remove this line when evaluating the full validation set 102 | # coco_eval.params['image_id'] = coco_result.getImgIds() 103 | 104 | # evaluate results 105 | # SPICE will take a few minutes the first time, but speeds up due to caching 106 | coco_eval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in coco_eval.eval.items(): 110 | print(f'{metric}: {score:.3f}') 111 | 112 | return coco_eval -------------------------------------------------------------------------------- /ram/inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Inference of RAM and Tag2Text Models 3 | * Written by Xinyu Huang 4 | ''' 5 | import torch 6 | 7 | 8 | def inference_tag2text(image, model, input_tag="None"): 9 | 10 | with torch.no_grad(): 11 | caption, tag_predict = model.generate(image, 12 | tag_input=None, 13 | max_length=50, 14 | return_tag_predict=True) 15 | 16 | if input_tag == '' or input_tag == 'none' or input_tag == 'None': 17 | return tag_predict[0], None, caption[0] 18 | 19 | # If user input specified tags: 20 | else: 21 | input_tag_list = [] 22 | input_tag_list.append(input_tag.replace(',', ' | ')) 23 | 24 | with torch.no_grad(): 25 | caption, input_tag = model.generate(image, 26 | tag_input=input_tag_list, 27 | max_length=50, 28 | return_tag_predict=True) 29 | 30 | return tag_predict[0], input_tag[0], caption[0] 31 | 32 | 33 | def inference_ram(image, model): 34 | 35 | with torch.no_grad(): 36 | tags, tags_chinese = model.generate_tag(image) 37 | 38 | return tags[0],tags_chinese[0] 39 | 40 | 41 | def inference_ram_openset(image, model): 42 | 43 | with torch.no_grad(): 44 | tags = model.generate_tag_openset(image) 45 | 46 | return tags[0] 47 | -------------------------------------------------------------------------------- /ram/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ram_plus import ram_plus 2 | from .ram import ram 3 | from .tag2text import tag2text 4 | -------------------------------------------------------------------------------- /ram/models/ram.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Model (RAM) 3 | * Written by Xinyu Huang 4 | ''' 5 | import json 6 | import warnings 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from .bert import BertConfig, BertLMHeadModel, BertModel 13 | from .swin_transformer import SwinTransformer 14 | from .utils import * 15 | import torch.nn.functional as F 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | class RAM(nn.Module): 21 | def __init__(self, 22 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 23 | image_size=384, 24 | text_encoder_type='bert-base-uncased', 25 | vit='base', 26 | vit_grad_ckpt=False, 27 | vit_ckpt_layer=0, 28 | prompt='a picture of ', 29 | threshold=0.68, 30 | delete_tag_index=[], 31 | tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt', 32 | tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt', 33 | stage='eval'): 34 | r""" The Recognize Anything Model (RAM) inference module. 35 | RAM is a strong image tagging model, which can recognize any common category with high accuracy. 36 | Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/ 37 | 38 | Args: 39 | med_config (str): path for the mixture of encoder-decoder model's configuration file 40 | image_size (int): input image size 41 | vit (str): model size of vision transformer 42 | threshold (int): tagging threshold 43 | delete_tag_index (list): delete some tags that may disturb captioning 44 | """ 45 | super().__init__() 46 | 47 | # create image encoder 48 | if vit == 'swin_b': 49 | if image_size == 224: 50 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 51 | elif image_size == 384: 52 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 53 | vision_config = read_json(vision_config_path) 54 | assert image_size == vision_config['image_res'] 55 | # assert config['patch_size'] == 32 56 | vision_width = vision_config['vision_width'] 57 | 58 | self.visual_encoder = SwinTransformer( 59 | img_size=vision_config['image_res'], 60 | patch_size=4, 61 | in_chans=3, 62 | embed_dim=vision_config['embed_dim'], 63 | depths=vision_config['depths'], 64 | num_heads=vision_config['num_heads'], 65 | window_size=vision_config['window_size'], 66 | mlp_ratio=4., 67 | qkv_bias=True, 68 | drop_rate=0.0, 69 | drop_path_rate=0.1, 70 | ape=False, 71 | patch_norm=True, 72 | use_checkpoint=False) 73 | 74 | if stage == 'train_from_scratch': 75 | # download from https://github.com/microsoft/Swin-Transformer 76 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 77 | 78 | for k in list(state_dict.keys()): 79 | if 'relative_position_bias_table' in k: 80 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 81 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 82 | elif ('relative_position_index' in k) or ('attn_mask' in k): 83 | del state_dict[k] 84 | 85 | print("### Load Vision Backbone", vit) 86 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 87 | print("missing_keys: ", msg.missing_keys) 88 | print("unexpected_keys: ", msg.unexpected_keys) 89 | 90 | elif vit == 'swin_l': 91 | if image_size == 224: 92 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 93 | elif image_size == 384: 94 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 95 | vision_config = read_json(vision_config_path) 96 | assert image_size == vision_config['image_res'] 97 | # assert config['patch_size'] == 32 98 | vision_width = vision_config['vision_width'] 99 | 100 | self.visual_encoder = SwinTransformer( 101 | img_size=vision_config['image_res'], 102 | patch_size=4, 103 | in_chans=3, 104 | embed_dim=vision_config['embed_dim'], 105 | depths=vision_config['depths'], 106 | num_heads=vision_config['num_heads'], 107 | window_size=vision_config['window_size'], 108 | mlp_ratio=4., 109 | qkv_bias=True, 110 | drop_rate=0.0, 111 | drop_path_rate=0.1, 112 | ape=False, 113 | patch_norm=True, 114 | use_checkpoint=False) 115 | 116 | if stage == 'train_from_scratch': 117 | # download from https://github.com/microsoft/Swin-Transformer 118 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 119 | 120 | for k in list(state_dict.keys()): 121 | if 'relative_position_bias_table' in k: 122 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 123 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 124 | elif ('relative_position_index' in k) or ('attn_mask' in k): 125 | del state_dict[k] 126 | 127 | print("### Load Vision Backbone", vit) 128 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 129 | print("missing_keys: ", msg.missing_keys) 130 | print("unexpected_keys: ", msg.unexpected_keys) 131 | 132 | else: 133 | self.visual_encoder, vision_width = create_vit( 134 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 135 | 136 | # create tokenzier 137 | self.tokenizer = init_tokenizer(text_encoder_type) 138 | 139 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder 140 | # create image-tag interaction encoder 141 | encoder_config = BertConfig.from_json_file(med_config) 142 | encoder_config.encoder_width = 512 143 | self.tag_encoder = BertModel(config=encoder_config, 144 | add_pooling_layer=False) 145 | 146 | # create image-tag-text decoder 147 | decoder_config = BertConfig.from_json_file(med_config) 148 | self.text_decoder = BertLMHeadModel(config=decoder_config) 149 | 150 | self.delete_tag_index = delete_tag_index 151 | self.prompt = prompt 152 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 153 | 154 | # load tag list 155 | self.tag_list = self.load_tag_list(tag_list) 156 | self.tag_list_chinese = self.load_tag_list(tag_list_chinese) 157 | 158 | # create image-tag recognition decoder 159 | self.threshold = threshold 160 | self.num_class = len(self.tag_list) 161 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 162 | q2l_config.encoder_width = 512 163 | self.tagging_head = BertModel(config=q2l_config, 164 | add_pooling_layer=False) 165 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 166 | 167 | if stage == 'train_from_scratch': 168 | self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/frozen_tag_embedding/ram_tag_embedding_class_4585.pth',map_location='cpu').float()) 169 | else: 170 | # when eval with pretrained RAM model, directly load from ram_swin_large_14m.pth 171 | self.label_embed = nn.Parameter(torch.zeros(self.num_class, q2l_config.encoder_width)) 172 | 173 | if q2l_config.hidden_size != 512: 174 | self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) 175 | else: 176 | self.wordvec_proj = nn.Identity() 177 | 178 | self.fc = nn.Linear(q2l_config.hidden_size, 1) 179 | 180 | self.del_selfattention() 181 | 182 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7, 183 | gamma_pos=0, 184 | clip=0.05) 185 | 186 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" 187 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', 188 | ' ') 189 | self.image_proj = nn.Linear(vision_width, 512) 190 | # self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/textual_label_embedding.pth',map_location='cpu').float()) 191 | 192 | # adjust thresholds for some tags 193 | self.class_threshold = torch.ones(self.num_class) * self.threshold 194 | ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt' 195 | with open(ram_class_threshold_path, 'r', encoding='utf-8') as f: 196 | ram_class_threshold = [float(s.strip()) for s in f] 197 | for key,value in enumerate(ram_class_threshold): 198 | self.class_threshold[key] = value 199 | 200 | def load_tag_list(self, tag_list_file): 201 | with open(tag_list_file, 'r', encoding="utf-8") as f: 202 | tag_list = f.read().splitlines() 203 | tag_list = np.array(tag_list) 204 | return tag_list 205 | 206 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 207 | def del_selfattention(self): 208 | del self.tagging_head.embeddings 209 | for layer in self.tagging_head.encoder.layer: 210 | del layer.attention 211 | 212 | def forward(self, image, caption, image_tag, parse_tag, clip_feature): 213 | """ 214 | call function as forward 215 | 216 | Args: 217 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 218 | caption: type: list[string] len: batch_size 219 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 220 | 221 | Returns: 222 | loss: type: torch.Tensor 223 | """ 224 | 225 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 226 | 227 | image_embeds = self.image_proj(self.visual_encoder(image)) 228 | image_atts = torch.ones(image_embeds.size()[:-1], 229 | dtype=torch.long).to(image.device) 230 | 231 | ##================= Distillation from CLIP ================## 232 | image_cls_embeds = image_embeds[:, 0, :] 233 | image_spatial_embeds = image_embeds[:, 1:, :] 234 | 235 | loss_dis = F.l1_loss(image_cls_embeds, clip_feature) 236 | 237 | ##================= Image Tagging ================## 238 | bs = image_embeds.shape[0] 239 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 240 | 241 | tagging_embed = self.tagging_head( 242 | encoder_embeds=label_embed, 243 | encoder_hidden_states=image_embeds, 244 | encoder_attention_mask=image_atts, 245 | return_dict=False, 246 | mode='tagging', 247 | ) 248 | 249 | logits = self.fc(tagging_embed[0]).squeeze(-1) 250 | 251 | loss_tag = self.tagging_loss_function(logits, image_tag) 252 | 253 | ##================= Image-Tag-Text Generation ================## 254 | tag = parse_tag.cpu().numpy() 255 | tag_input = [] 256 | for b in range(bs): 257 | index = np.argwhere(tag[b] == 1) 258 | token = self.tag_list[index].squeeze(axis=1) 259 | tag_input.append(' | '.join(token)) 260 | 261 | # tokenizer input tags 262 | tag_input_tokenzier = self.tokenizer(tag_input, 263 | padding='max_length', 264 | truncation=True, 265 | max_length=40, 266 | return_tensors="pt").to( 267 | image.device) 268 | encoder_input_ids = tag_input_tokenzier.input_ids 269 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 270 | 271 | # put input tag into image-tag interaction encoder to interact with image embeddings 272 | output_tagembedding = self.tag_encoder( 273 | encoder_input_ids, 274 | attention_mask=tag_input_tokenzier.attention_mask, 275 | encoder_hidden_states=image_embeds, 276 | encoder_attention_mask=image_atts, 277 | return_dict=True, 278 | ) 279 | 280 | text = self.tokenizer(caption, 281 | padding='longest', 282 | truncation=True, 283 | max_length=40, 284 | return_tensors="pt").to( 285 | image.device) 286 | 287 | decoder_input_ids = text.input_ids 288 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 289 | 290 | decoder_targets = decoder_input_ids.masked_fill( 291 | decoder_input_ids == self.tokenizer.pad_token_id, -100) 292 | decoder_targets[:,:self.prompt_length] = -100 293 | 294 | decoder_output = self.text_decoder(decoder_input_ids, 295 | attention_mask = text.attention_mask, 296 | encoder_hidden_states = output_tagembedding.last_hidden_state, 297 | encoder_attention_mask = None, 298 | labels = decoder_targets, 299 | return_dict = True, 300 | ) 301 | 302 | loss_t2t = decoder_output.loss 303 | 304 | return loss_t2t, loss_tag, loss_dis 305 | 306 | def generate_tag(self, 307 | image, 308 | threshold=0.68, 309 | tag_input=None, 310 | ): 311 | 312 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 313 | 314 | image_embeds = self.image_proj(self.visual_encoder(image)) 315 | image_atts = torch.ones(image_embeds.size()[:-1], 316 | dtype=torch.long).to(image.device) 317 | 318 | # recognized image tags using image-tag recogntiion decoder 319 | image_cls_embeds = image_embeds[:, 0, :] 320 | image_spatial_embeds = image_embeds[:, 1:, :] 321 | 322 | bs = image_spatial_embeds.shape[0] 323 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 324 | tagging_embed = self.tagging_head( 325 | encoder_embeds=label_embed, 326 | encoder_hidden_states=image_embeds, 327 | encoder_attention_mask=image_atts, 328 | return_dict=False, 329 | mode='tagging', 330 | ) 331 | 332 | logits = self.fc(tagging_embed[0]).squeeze(-1) 333 | 334 | targets = torch.where( 335 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 336 | torch.tensor(1.0).to(image.device), 337 | torch.zeros(self.num_class).to(image.device)) 338 | 339 | tag = targets.cpu().numpy() 340 | tag[:,self.delete_tag_index] = 0 341 | tag_output = [] 342 | tag_output_chinese = [] 343 | for b in range(bs): 344 | index = np.argwhere(tag[b] == 1) 345 | token = self.tag_list[index].squeeze(axis=1) 346 | tag_output.append(' | '.join(token)) 347 | token_chinese = self.tag_list_chinese[index].squeeze(axis=1) 348 | tag_output_chinese.append(' | '.join(token_chinese)) 349 | 350 | 351 | return tag_output, tag_output_chinese 352 | 353 | def generate_tag_openset(self, 354 | image, 355 | threshold=0.68, 356 | tag_input=None, 357 | ): 358 | 359 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 360 | 361 | image_embeds = self.image_proj(self.visual_encoder(image)) 362 | image_atts = torch.ones(image_embeds.size()[:-1], 363 | dtype=torch.long).to(image.device) 364 | 365 | # recognized image tags using image-tag recogntiion decoder 366 | image_cls_embeds = image_embeds[:, 0, :] 367 | image_spatial_embeds = image_embeds[:, 1:, :] 368 | 369 | bs = image_spatial_embeds.shape[0] 370 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 371 | tagging_embed = self.tagging_head( 372 | encoder_embeds=label_embed, 373 | encoder_hidden_states=image_embeds, 374 | encoder_attention_mask=image_atts, 375 | return_dict=False, 376 | mode='tagging', 377 | ) 378 | 379 | logits = self.fc(tagging_embed[0]).squeeze(-1) 380 | 381 | targets = torch.where( 382 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 383 | torch.tensor(1.0).to(image.device), 384 | torch.zeros(self.num_class).to(image.device)) 385 | 386 | tag = targets.cpu().numpy() 387 | tag[:,self.delete_tag_index] = 0 388 | tag_output = [] 389 | for b in range(bs): 390 | index = np.argwhere(tag[b] == 1) 391 | token = self.tag_list[index].squeeze(axis=1) 392 | tag_output.append(' | '.join(token)) 393 | 394 | return tag_output 395 | 396 | 397 | # load RAM pretrained model parameters 398 | def ram(pretrained='', **kwargs): 399 | model = RAM(**kwargs) 400 | if pretrained: 401 | if kwargs['vit'] == 'swin_b': 402 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 403 | elif kwargs['vit'] == 'swin_l': 404 | model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs) 405 | else: 406 | model, msg = load_checkpoint(model, pretrained) 407 | print('vit:', kwargs['vit']) 408 | # print('msg', msg) 409 | return model 410 | -------------------------------------------------------------------------------- /ram/models/ram_plus.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Plus Model (RAM++) 3 | * Written by Xinyu Huang 4 | ''' 5 | import json 6 | import warnings 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | import torch.nn.functional as F 13 | from .bert import BertConfig, BertLMHeadModel, BertModel 14 | from .swin_transformer import SwinTransformer 15 | from .utils import * 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | 21 | class RAM_plus(nn.Module): 22 | def __init__(self, 23 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 24 | image_size=384, 25 | text_encoder_type='bert-base-uncased', 26 | vit='base', 27 | vit_grad_ckpt=False, 28 | vit_ckpt_layer=0, 29 | threshold=0.68, 30 | delete_tag_index=[], 31 | tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt', 32 | tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt', 33 | stage='eval'): 34 | r""" The Recognize Anything Plus Model (RAM++) inference module. 35 | RAM++ is a strong image tagging model, which can recognize any category with high accuracy using tag categories. 36 | Described in the paper "Open-Set Image Tagging with Multi-Grained Text Supervision" https://arxiv.org/abs/2310.15200 37 | 38 | Args: 39 | med_config (str): path for the mixture of encoder-decoder model's configuration file 40 | image_size (int): input image size 41 | vit (str): model size of vision transformer 42 | threshold (int): tagging threshold 43 | delete_tag_index (list): delete some tags that may disturb captioning 44 | """ 45 | super().__init__() 46 | 47 | # create image encoder 48 | if vit == 'swin_b': 49 | if image_size == 224: 50 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 51 | elif image_size == 384: 52 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 53 | vision_config = read_json(vision_config_path) 54 | assert image_size == vision_config['image_res'] 55 | # assert config['patch_size'] == 32 56 | vision_width = vision_config['vision_width'] 57 | 58 | self.visual_encoder = SwinTransformer( 59 | img_size=vision_config['image_res'], 60 | patch_size=4, 61 | in_chans=3, 62 | embed_dim=vision_config['embed_dim'], 63 | depths=vision_config['depths'], 64 | num_heads=vision_config['num_heads'], 65 | window_size=vision_config['window_size'], 66 | mlp_ratio=4., 67 | qkv_bias=True, 68 | drop_rate=0.0, 69 | drop_path_rate=0.1, 70 | ape=False, 71 | patch_norm=True, 72 | use_checkpoint=False) 73 | 74 | if stage == 'train_from_scratch': 75 | # download from https://github.com/microsoft/Swin-Transformer 76 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 77 | 78 | for k in list(state_dict.keys()): 79 | if 'relative_position_bias_table' in k: 80 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 81 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 82 | elif ('relative_position_index' in k) or ('attn_mask' in k): 83 | del state_dict[k] 84 | 85 | print("### Load Vision Backbone", vit) 86 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 87 | print("missing_keys: ", msg.missing_keys) 88 | print("unexpected_keys: ", msg.unexpected_keys) 89 | 90 | elif vit == 'swin_l': 91 | if image_size == 224: 92 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 93 | elif image_size == 384: 94 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 95 | vision_config = read_json(vision_config_path) 96 | assert image_size == vision_config['image_res'] 97 | # assert config['patch_size'] == 32 98 | vision_width = vision_config['vision_width'] 99 | 100 | self.visual_encoder = SwinTransformer( 101 | img_size=vision_config['image_res'], 102 | patch_size=4, 103 | in_chans=3, 104 | embed_dim=vision_config['embed_dim'], 105 | depths=vision_config['depths'], 106 | num_heads=vision_config['num_heads'], 107 | window_size=vision_config['window_size'], 108 | mlp_ratio=4., 109 | qkv_bias=True, 110 | drop_rate=0.0, 111 | drop_path_rate=0.1, 112 | ape=False, 113 | patch_norm=True, 114 | use_checkpoint=False) 115 | 116 | if stage == 'train_from_scratch': 117 | # download from https://github.com/microsoft/Swin-Transformer 118 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 119 | 120 | for k in list(state_dict.keys()): 121 | if 'relative_position_bias_table' in k: 122 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 123 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 124 | elif ('relative_position_index' in k) or ('attn_mask' in k): 125 | del state_dict[k] 126 | 127 | print("### Load Vision Backbone", vit) 128 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 129 | print("missing_keys: ", msg.missing_keys) 130 | print("unexpected_keys: ", msg.unexpected_keys) 131 | 132 | else: 133 | self.visual_encoder, vision_width = create_vit( 134 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 135 | 136 | # create tokenzier 137 | self.tokenizer = init_tokenizer(text_encoder_type) 138 | 139 | self.delete_tag_index = delete_tag_index 140 | 141 | # load tag list 142 | self.tag_list = self.load_tag_list(tag_list) 143 | self.tag_list_chinese = self.load_tag_list(tag_list_chinese) 144 | 145 | # create image-tag recognition decoder 146 | self.threshold = threshold 147 | self.num_class = len(self.tag_list) 148 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 149 | q2l_config.encoder_width = 512 150 | self.tagging_head = BertModel(config=q2l_config, 151 | add_pooling_layer=False) 152 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 153 | 154 | if stage == 'train_from_scratch': 155 | self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/frozen_tag_embedding/ram_plus_tag_embedding_class_4585_des_51.pth',map_location='cpu').float()) 156 | else: 157 | # when eval with pretrained RAM++ model, directly load from ram_plus_swin_large_14m.pth 158 | self.label_embed = nn.Parameter(torch.zeros(self.num_class * 51, q2l_config.encoder_width)) 159 | 160 | if q2l_config.hidden_size != 512: 161 | self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) 162 | else: 163 | self.wordvec_proj = nn.Identity() 164 | 165 | self.fc = nn.Linear(q2l_config.hidden_size, 1) 166 | 167 | self.del_selfattention() 168 | 169 | self.image_proj = nn.Linear(vision_width, 512) 170 | 171 | # adjust thresholds for some tags 172 | self.class_threshold = torch.ones(self.num_class) * self.threshold 173 | ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt' 174 | with open(ram_class_threshold_path, 'r', encoding='utf-8') as f: 175 | ram_class_threshold = [float(s.strip()) for s in f] 176 | for key,value in enumerate(ram_class_threshold): 177 | self.class_threshold[key] = value 178 | 179 | self.reweight_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 180 | 181 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7, 182 | gamma_pos=0, 183 | clip=0.05) 184 | 185 | self.text_alignment_loss_function = AsymmetricLoss(gamma_neg=4, 186 | gamma_pos=0, 187 | clip=0.05) 188 | 189 | def load_tag_list(self, tag_list_file): 190 | with open(tag_list_file, 'r', encoding="utf-8") as f: 191 | tag_list = f.read().splitlines() 192 | tag_list = np.array(tag_list) 193 | return tag_list 194 | 195 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 196 | def del_selfattention(self): 197 | del self.tagging_head.embeddings 198 | for layer in self.tagging_head.encoder.layer: 199 | del layer.attention 200 | 201 | def forward(self, image, caption, image_tag, clip_feature, batch_text_embed): 202 | """ 203 | call function as forward 204 | 205 | Args: 206 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 207 | caption: type: list[string] len: batch_size 208 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 209 | 210 | Returns: 211 | loss: type: torch.Tensor 212 | """ 213 | 214 | image_embeds = self.image_proj(self.visual_encoder(image)) 215 | image_atts = torch.ones(image_embeds.size()[:-1], 216 | dtype=torch.long).to(image.device) 217 | 218 | ##================= Distillation from CLIP ================## 219 | image_cls_embeds = image_embeds[:, 0, :] 220 | image_spatial_embeds = image_embeds[:, 1:, :] 221 | 222 | loss_dis = F.l1_loss(image_cls_embeds, clip_feature) 223 | 224 | ###===========multi tag des reweight==============### 225 | bs = image_embeds.shape[0] 226 | 227 | des_per_class = int(self.label_embed.shape[0] / self.num_class) 228 | 229 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) 230 | reweight_scale = self.reweight_scale.exp() 231 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t()) 232 | logits_per_image = logits_per_image.view(bs, -1,des_per_class) 233 | 234 | weight_normalized = F.softmax(logits_per_image, dim=2) 235 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype) 236 | 237 | for i in range(bs): 238 | reshaped_value = self.label_embed.view(-1, des_per_class, 512) 239 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value 240 | label_embed_reweight[i] = product.sum(dim=1) 241 | 242 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight)) 243 | 244 | ##================= Image Tagging ================## 245 | 246 | tagging_embed = self.tagging_head( 247 | encoder_embeds=label_embed, 248 | encoder_hidden_states=image_embeds, 249 | encoder_attention_mask=image_atts, 250 | return_dict=False, 251 | mode='tagging', 252 | ) 253 | 254 | logits = self.fc(tagging_embed[0]).squeeze(-1) 255 | 256 | loss_tag = self.tagging_loss_function(logits, image_tag) 257 | 258 | ##================= Image-text Alignment ================## 259 | 260 | batch_text_embed = torch.nn.functional.relu(self.wordvec_proj(batch_text_embed.to(self.label_embed.dtype))) 261 | batch_text_embed = batch_text_embed.unsqueeze(0).repeat(bs, 1, 1) 262 | alignment_embedding = self.tagging_head( 263 | encoder_embeds=batch_text_embed, 264 | encoder_hidden_states=image_embeds, 265 | encoder_attention_mask=image_atts, 266 | return_dict=False, 267 | mode='tagging', 268 | ) 269 | alignment_logits = self.fc(alignment_embedding[0]).squeeze(-1) 270 | 271 | with torch.no_grad(): 272 | alignment_targets = torch.zeros(alignment_logits.size()).to(image.device) 273 | alignment_targets.fill_diagonal_(1) 274 | 275 | loss_alignment = self.text_alignment_loss_function(alignment_logits,alignment_targets) 276 | 277 | return loss_tag, loss_dis, loss_alignment 278 | 279 | 280 | def generate_tag(self, 281 | image 282 | ): 283 | 284 | image_embeds = self.image_proj(self.visual_encoder(image)) 285 | image_atts = torch.ones(image_embeds.size()[:-1], 286 | dtype=torch.long).to(image.device) 287 | 288 | image_cls_embeds = image_embeds[:, 0, :] 289 | image_spatial_embeds = image_embeds[:, 1:, :] 290 | 291 | bs = image_spatial_embeds.shape[0] 292 | 293 | des_per_class = int(self.label_embed.shape[0] / self.num_class) 294 | 295 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) 296 | reweight_scale = self.reweight_scale.exp() 297 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t()) 298 | logits_per_image = logits_per_image.view(bs, -1,des_per_class) 299 | 300 | weight_normalized = F.softmax(logits_per_image, dim=2) 301 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype) 302 | 303 | for i in range(bs): 304 | # 这里对 value_ori 进行 reshape,然后使用 broadcasting 305 | reshaped_value = self.label_embed.view(-1, des_per_class, 512) 306 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value 307 | label_embed_reweight[i] = product.sum(dim=1) 308 | 309 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight)) 310 | 311 | # recognized image tags using alignment decoder 312 | tagging_embed = self.tagging_head( 313 | encoder_embeds=label_embed, 314 | encoder_hidden_states=image_embeds, 315 | encoder_attention_mask=image_atts, 316 | return_dict=False, 317 | mode='tagging', 318 | ) 319 | 320 | logits = self.fc(tagging_embed[0]).squeeze(-1) 321 | 322 | targets = torch.where( 323 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 324 | torch.tensor(1.0).to(image.device), 325 | torch.zeros(self.num_class).to(image.device)) 326 | 327 | tag = targets.cpu().numpy() 328 | tag[:,self.delete_tag_index] = 0 329 | tag_output = [] 330 | tag_output_chinese = [] 331 | for b in range(bs): 332 | index = np.argwhere(tag[b] == 1) 333 | token = self.tag_list[index].squeeze(axis=1) 334 | tag_output.append(' | '.join(token)) 335 | token_chinese = self.tag_list_chinese[index].squeeze(axis=1) 336 | tag_output_chinese.append(' | '.join(token_chinese)) 337 | 338 | 339 | return tag_output, tag_output_chinese 340 | 341 | def generate_tag_openset(self, 342 | image, 343 | threshold=0.68, 344 | tag_input=None, 345 | ): 346 | 347 | image_embeds = self.image_proj(self.visual_encoder(image)) 348 | image_atts = torch.ones(image_embeds.size()[:-1], 349 | dtype=torch.long).to(image.device) 350 | 351 | image_cls_embeds = image_embeds[:, 0, :] 352 | image_spatial_embeds = image_embeds[:, 1:, :] 353 | 354 | bs = image_spatial_embeds.shape[0] 355 | 356 | des_per_class = int(self.label_embed.shape[0] / self.num_class) 357 | 358 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) 359 | reweight_scale = self.reweight_scale.exp() 360 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t()) 361 | logits_per_image = logits_per_image.view(bs, -1,des_per_class) 362 | 363 | weight_normalized = F.softmax(logits_per_image, dim=2) 364 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype) 365 | 366 | for i in range(bs): 367 | # 这里对 value_ori 进行 reshape,然后使用 broadcasting 368 | reshaped_value = self.label_embed.view(-1, des_per_class, 512) 369 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value 370 | label_embed_reweight[i] = product.sum(dim=1) 371 | 372 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight)) 373 | 374 | # recognized image tags using alignment decoder 375 | tagging_embed = self.tagging_head( 376 | encoder_embeds=label_embed, 377 | encoder_hidden_states=image_embeds, 378 | encoder_attention_mask=image_atts, 379 | return_dict=False, 380 | mode='tagging', 381 | ) 382 | 383 | logits = self.fc(tagging_embed[0]).squeeze(-1) 384 | 385 | targets = torch.where( 386 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 387 | torch.tensor(1.0).to(image.device), 388 | torch.zeros(self.num_class).to(image.device)) 389 | 390 | tag = targets.cpu().numpy() 391 | tag[:,self.delete_tag_index] = 0 392 | tag_output = [] 393 | for b in range(bs): 394 | index = np.argwhere(tag[b] == 1) 395 | token = self.tag_list[index].squeeze(axis=1) 396 | tag_output.append(' | '.join(token)) 397 | 398 | return tag_output 399 | 400 | 401 | # load RAM++ pretrained model parameters 402 | def ram_plus(pretrained='', **kwargs): 403 | model = RAM_plus(**kwargs) 404 | if pretrained: 405 | if kwargs['vit'] == 'swin_b': 406 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 407 | elif kwargs['vit'] == 'swin_l': 408 | model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs) 409 | else: 410 | model, msg = load_checkpoint(model, pretrained) 411 | print('vit:', kwargs['vit']) 412 | # print('msg', msg) 413 | return model 414 | -------------------------------------------------------------------------------- /ram/models/tag2text.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Tag2Text Model 3 | * Written by Xinyu Huang 4 | ''' 5 | import numpy as np 6 | import json 7 | import torch 8 | import warnings 9 | 10 | from torch import nn 11 | from .bert import BertConfig, BertModel, BertLMHeadModel 12 | from .swin_transformer import SwinTransformer 13 | 14 | from .utils import * 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | class Tag2Text(nn.Module): 20 | 21 | def __init__(self, 22 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 23 | image_size=384, 24 | text_encoder_type='bert-base-uncased', 25 | vit='base', 26 | vit_grad_ckpt=False, 27 | vit_ckpt_layer=0, 28 | prompt='a picture of ', 29 | threshold=0.68, 30 | delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359], 31 | tag_list=f'{CONFIG_PATH}/data/tag2text_ori_tag_list.txt', 32 | stage='eval'): 33 | r""" Tag2Text inference module, both captioning and tagging are included. 34 | Tag2Text is an efficient and controllable vision-language pre-training framework. 35 | Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657 36 | 37 | Args: 38 | med_config (str): path for the mixture of encoder-decoder model's configuration file 39 | image_size (int): input image size 40 | vit (str): model size of vision transformer 41 | threshold (int): tagging threshold 42 | delete_tag_index (list): delete some tags that may disturb captioning 43 | """ 44 | super().__init__() 45 | 46 | # create image encoder 47 | if vit == 'swin_b': 48 | if image_size == 224: 49 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 50 | elif image_size == 384: 51 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 52 | vision_config = read_json(vision_config_path) 53 | assert image_size == vision_config['image_res'] 54 | # assert config['patch_size'] == 32 55 | vision_width = vision_config['vision_width'] 56 | 57 | self.visual_encoder = SwinTransformer( 58 | img_size=vision_config['image_res'], 59 | patch_size=4, 60 | in_chans=3, 61 | embed_dim=vision_config['embed_dim'], 62 | depths=vision_config['depths'], 63 | num_heads=vision_config['num_heads'], 64 | window_size=vision_config['window_size'], 65 | mlp_ratio=4., 66 | qkv_bias=True, 67 | drop_rate=0.0, 68 | drop_path_rate=0.1, 69 | ape=False, 70 | patch_norm=True, 71 | use_checkpoint=False) 72 | 73 | if stage == 'train_from_scratch': 74 | # download from https://github.com/microsoft/Swin-Transformer 75 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 76 | 77 | for k in list(state_dict.keys()): 78 | if 'relative_position_bias_table' in k: 79 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 80 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 81 | elif ('relative_position_index' in k) or ('attn_mask' in k): 82 | del state_dict[k] 83 | 84 | print("### Load Vision Backbone", vit) 85 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 86 | print("missing_keys: ", msg.missing_keys) 87 | print("unexpected_keys: ", msg.unexpected_keys) 88 | 89 | else: 90 | self.visual_encoder, vision_width = create_vit( 91 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 92 | 93 | # create tokenzier 94 | self.tokenizer = init_tokenizer(text_encoder_type) 95 | 96 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder 97 | # create image-tag interaction encoder 98 | encoder_config = BertConfig.from_json_file(med_config) 99 | encoder_config.encoder_width = vision_width 100 | self.tag_encoder = BertModel(config=encoder_config, 101 | add_pooling_layer=False) 102 | 103 | # create image-tag-text decoder 104 | decoder_config = BertConfig.from_json_file(med_config) 105 | self.text_decoder = BertLMHeadModel(config=decoder_config) 106 | 107 | # delete some tags that may disturb captioning 108 | # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" 109 | self.delete_tag_index = delete_tag_index 110 | self.prompt = prompt 111 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 112 | 113 | # load tag list 114 | self.tag_list = self.load_tag_list(tag_list) 115 | 116 | # create image-tag recognition decoder 117 | self.threshold = threshold 118 | self.num_class = len(self.tag_list) 119 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 120 | q2l_config.encoder_width = vision_width 121 | self.tagging_head = BertModel(config=q2l_config, 122 | add_pooling_layer=False) 123 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 124 | self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) 125 | self.fc = GroupWiseLinear(self.num_class, 126 | q2l_config.hidden_size, 127 | bias=True) 128 | self.del_selfattention() 129 | 130 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7, 131 | gamma_pos=0, 132 | clip=0.05) 133 | 134 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" 135 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', 136 | ' ') 137 | 138 | # adjust thresholds for some tags 139 | # default threshold: 0.68 140 | # 2701: "person"; 2828: "man"; 1167: "woman"; 141 | tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7} 142 | self.class_threshold = torch.ones(self.num_class) * self.threshold 143 | for key,value in tag_thrshold.items(): 144 | self.class_threshold[key] = value 145 | 146 | def load_tag_list(self, tag_list_file): 147 | with open(tag_list_file, 'r') as f: 148 | tag_list = f.read().splitlines() 149 | tag_list = np.array(tag_list) 150 | return tag_list 151 | 152 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 153 | def del_selfattention(self): 154 | del self.tagging_head.embeddings 155 | for layer in self.tagging_head.encoder.layer: 156 | del layer.attention 157 | 158 | 159 | def forward(self, image, caption, tag): 160 | """ 161 | call function as forward 162 | 163 | Args: 164 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 165 | caption: type: list[string] len: batch_size 166 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 167 | 168 | Returns: 169 | loss: type: torch.Tensor 170 | """ 171 | 172 | image_embeds = self.visual_encoder(image) 173 | image_atts = torch.ones(image_embeds.size()[:-1], 174 | dtype=torch.long).to(image.device) 175 | 176 | ##================= Image Tagging ================## 177 | bs = image_embeds.shape[0] 178 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 179 | 180 | tagging_embed = self.tagging_head( 181 | encoder_embeds=label_embed, 182 | encoder_hidden_states=image_embeds, 183 | encoder_attention_mask=image_atts, 184 | return_dict=False, 185 | mode='tagging', 186 | ) 187 | 188 | logits = self.fc(tagging_embed[0]) 189 | 190 | loss_tag = self.tagging_loss_function(logits, tag) 191 | 192 | ##================= Image-Tag-Text Generation ================## 193 | tag = tag.cpu().numpy() 194 | tag_input = [] 195 | for b in range(bs): 196 | index = np.argwhere(tag[b] == 1) 197 | token = self.tag_list[index].squeeze(axis=1) 198 | tag_input.append(' | '.join(token)) 199 | 200 | # tokenizer input tags 201 | tag_input_tokenzier = self.tokenizer(tag_input, 202 | padding='max_length', 203 | truncation=True, 204 | max_length=40, 205 | return_tensors="pt").to( 206 | image.device) 207 | encoder_input_ids = tag_input_tokenzier.input_ids 208 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 209 | 210 | # put input tag into image-tag interaction encoder to interact with image embeddings 211 | output_tagembedding = self.tag_encoder( 212 | encoder_input_ids, 213 | attention_mask=tag_input_tokenzier.attention_mask, 214 | encoder_hidden_states=image_embeds, 215 | encoder_attention_mask=image_atts, 216 | return_dict=True, 217 | ) 218 | 219 | text = self.tokenizer(caption, 220 | padding='longest', 221 | truncation=True, 222 | max_length=40, 223 | return_tensors="pt").to( 224 | image.device) 225 | 226 | decoder_input_ids = text.input_ids 227 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 228 | 229 | decoder_targets = decoder_input_ids.masked_fill( 230 | decoder_input_ids == self.tokenizer.pad_token_id, -100) 231 | decoder_targets[:,:self.prompt_length] = -100 232 | 233 | decoder_output = self.text_decoder(decoder_input_ids, 234 | attention_mask = text.attention_mask, 235 | encoder_hidden_states = output_tagembedding.last_hidden_state, 236 | encoder_attention_mask = None, 237 | labels = decoder_targets, 238 | return_dict = True, 239 | ) 240 | 241 | loss_t2t = decoder_output.loss 242 | 243 | return loss_t2t, loss_tag 244 | 245 | 246 | def generate(self, 247 | image, 248 | sample=False, 249 | num_beams=3, 250 | max_length=30, 251 | min_length=10, 252 | top_p=0.9, 253 | repetition_penalty=1.0, 254 | tag_input=None, 255 | return_tag_predict=False): 256 | 257 | image_embeds = self.visual_encoder(image) 258 | image_atts = torch.ones(image_embeds.size()[:-1], 259 | dtype=torch.long).to(image.device) 260 | 261 | # if not user specified tags, recognized image tags using image-tag recogntiion decoder 262 | if tag_input == None: 263 | 264 | bs = image_embeds.shape[0] 265 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 266 | tagging_embed = self.tagging_head( 267 | encoder_embeds=label_embed, 268 | encoder_hidden_states=image_embeds, 269 | encoder_attention_mask=image_atts, 270 | return_dict=False, 271 | mode='tagging', 272 | ) 273 | 274 | logits = self.fc(tagging_embed[0]) 275 | 276 | targets = torch.where( 277 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 278 | torch.tensor(1.0).to(image.device), 279 | torch.zeros(self.num_class).to(image.device)) 280 | 281 | tag = targets.cpu().numpy() 282 | 283 | # delete some tags that may disturb captioning 284 | tag[:, self.delete_tag_index] = 0 285 | 286 | tag_input = [] 287 | for b in range(bs): 288 | index = np.argwhere(tag[b] == 1) 289 | token = self.tag_list[index].squeeze(axis=1) 290 | tag_input.append(' | '.join(token)) 291 | 292 | tag_output = tag_input 293 | 294 | # beam search for text generation(default) 295 | if not sample: 296 | image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) 297 | tag_input_temp = [] 298 | for tag in tag_input: 299 | for i in range(num_beams): 300 | tag_input_temp.append(tag) 301 | tag_input = tag_input_temp 302 | 303 | image_atts = torch.ones(image_embeds.size()[:-1], 304 | dtype=torch.long).to(image.device) 305 | 306 | # tokenizer input tags 307 | tag_input_tokenzier = self.tokenizer(tag_input, 308 | padding='max_length', 309 | truncation=True, 310 | max_length=40, 311 | return_tensors="pt").to( 312 | image.device) 313 | encoder_input_ids = tag_input_tokenzier.input_ids 314 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 315 | 316 | # put input tag into image-tag interaction encoder to interact with image embeddings 317 | output_tagembedding = self.tag_encoder( 318 | encoder_input_ids, 319 | attention_mask=tag_input_tokenzier.attention_mask, 320 | encoder_hidden_states=image_embeds, 321 | encoder_attention_mask=image_atts, 322 | return_dict=True, 323 | ) 324 | 325 | # prompt trick for better captioning, followed BLIP 326 | prompt = [self.prompt] * image.size(0) 327 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 328 | image.device) 329 | input_ids[:, 0] = self.tokenizer.bos_token_id 330 | input_ids = input_ids[:, :-1] 331 | 332 | if sample: 333 | # nucleus sampling 334 | model_kwargs = { 335 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 336 | "encoder_attention_mask": None 337 | } 338 | outputs = self.text_decoder.generate( 339 | input_ids=input_ids, 340 | max_length=max_length, 341 | min_length=min_length, 342 | do_sample=True, 343 | top_p=top_p, 344 | num_return_sequences=1, 345 | eos_token_id=self.tokenizer.sep_token_id, 346 | pad_token_id=self.tokenizer.pad_token_id, 347 | repetition_penalty=1.1, 348 | **model_kwargs) 349 | else: 350 | # beam search (default) 351 | model_kwargs = { 352 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 353 | "encoder_attention_mask": None 354 | } 355 | outputs = self.text_decoder.generate( 356 | input_ids=input_ids, 357 | max_length=max_length, 358 | min_length=min_length, 359 | num_beams=num_beams, 360 | eos_token_id=self.tokenizer.sep_token_id, 361 | pad_token_id=self.tokenizer.pad_token_id, 362 | repetition_penalty=repetition_penalty, 363 | **model_kwargs) 364 | 365 | captions = [] 366 | for output in outputs: 367 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 368 | captions.append(caption[len(self.prompt):]) 369 | if return_tag_predict == True: 370 | return captions, tag_output 371 | return captions 372 | 373 | 374 | # load Tag2Text pretrained model parameters 375 | def tag2text(pretrained='', **kwargs): 376 | model = Tag2Text(**kwargs) 377 | if pretrained: 378 | if kwargs['vit'] == 'swin_b': 379 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 380 | else: 381 | model, msg = load_checkpoint(model, pretrained) 382 | print('vit:', kwargs['vit']) 383 | # print('msg', msg) 384 | return model 385 | 386 | -------------------------------------------------------------------------------- /ram/models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import math 5 | 6 | from torch import nn 7 | from typing import List 8 | from transformers import BertTokenizer 9 | from urllib.parse import urlparse 10 | from timm.models.hub import download_cached_file 11 | from .vit import interpolate_pos_embed 12 | from .swin_transformer import interpolate_relative_pos_embed 13 | from pathlib import Path 14 | CONFIG_PATH=(Path(__file__).resolve().parents[1]) 15 | 16 | def read_json(rpath): 17 | with open(rpath, 'r') as f: 18 | return json.load(f) 19 | 20 | 21 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, 22 | base_model_prefix: str, skip_key: str): 23 | uninitialized_encoder_weights: List[str] = [] 24 | if decoder.__class__ != encoder.__class__: 25 | logger.info( 26 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 27 | ) 28 | 29 | def tie_encoder_to_decoder_recursively( 30 | decoder_pointer: nn.Module, 31 | encoder_pointer: nn.Module, 32 | module_name: str, 33 | uninitialized_encoder_weights: List[str], 34 | skip_key: str, 35 | depth=0, 36 | ): 37 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 38 | encoder_pointer, nn.Module 39 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 40 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 41 | assert hasattr(encoder_pointer, "weight") 42 | encoder_pointer.weight = decoder_pointer.weight 43 | if hasattr(decoder_pointer, "bias"): 44 | assert hasattr(encoder_pointer, "bias") 45 | encoder_pointer.bias = decoder_pointer.bias 46 | print(module_name + ' is tied') 47 | return 48 | 49 | encoder_modules = encoder_pointer._modules 50 | decoder_modules = decoder_pointer._modules 51 | if len(decoder_modules) > 0: 52 | assert ( 53 | len(encoder_modules) > 0 54 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 55 | 56 | all_encoder_weights = set([ 57 | module_name + "/" + sub_name 58 | for sub_name in encoder_modules.keys() 59 | ]) 60 | encoder_layer_pos = 0 61 | for name, module in decoder_modules.items(): 62 | if name.isdigit(): 63 | encoder_name = str(int(name) + encoder_layer_pos) 64 | decoder_name = name 65 | if not isinstance( 66 | decoder_modules[decoder_name], 67 | type(encoder_modules[encoder_name])) and len( 68 | encoder_modules) != len(decoder_modules): 69 | # this can happen if the name corresponds to the position in a list module list of layers 70 | # in this case the decoder has added a cross-attention that the encoder does not have 71 | # thus skip this step and subtract one layer pos from encoder 72 | encoder_layer_pos -= 1 73 | continue 74 | elif name not in encoder_modules: 75 | continue 76 | elif depth > 500: 77 | raise ValueError( 78 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 79 | ) 80 | else: 81 | decoder_name = encoder_name = name 82 | tie_encoder_to_decoder_recursively( 83 | decoder_modules[decoder_name], 84 | encoder_modules[encoder_name], 85 | module_name + "/" + name, 86 | uninitialized_encoder_weights, 87 | skip_key, 88 | depth=depth + 1, 89 | ) 90 | all_encoder_weights.remove(module_name + "/" + encoder_name) 91 | 92 | uninitialized_encoder_weights += list(all_encoder_weights) 93 | 94 | # tie weights recursively 95 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, 96 | uninitialized_encoder_weights, skip_key) 97 | 98 | 99 | class GroupWiseLinear(nn.Module): 100 | # could be changed to: 101 | # output = torch.einsum('ijk,zjk->ij', x, self.W) 102 | # or output = torch.einsum('ijk,jk->ij', x, self.W[0]) 103 | def __init__(self, num_class, hidden_dim, bias=True): 104 | super().__init__() 105 | self.num_class = num_class 106 | self.hidden_dim = hidden_dim 107 | self.bias = bias 108 | 109 | self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) 110 | if bias: 111 | self.b = nn.Parameter(torch.Tensor(1, num_class)) 112 | self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | stdv = 1. / math.sqrt(self.W.size(2)) 116 | for i in range(self.num_class): 117 | self.W[0][i].data.uniform_(-stdv, stdv) 118 | if self.bias: 119 | for i in range(self.num_class): 120 | self.b[0][i].data.uniform_(-stdv, stdv) 121 | 122 | def forward(self, x): 123 | # x: B,K,d 124 | x = (self.W * x).sum(-1) 125 | if self.bias: 126 | x = x + self.b 127 | return x 128 | 129 | 130 | def init_tokenizer(text_encoder_type='bert-base-uncased'): 131 | tokenizer = BertTokenizer.from_pretrained(text_encoder_type) 132 | tokenizer.add_special_tokens({'bos_token': '[DEC]'}) 133 | tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) 134 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 135 | return tokenizer 136 | 137 | 138 | def create_vit(vit, 139 | image_size, 140 | use_grad_checkpointing=False, 141 | ckpt_layer=0, 142 | drop_path_rate=0): 143 | 144 | assert vit in ['base', 'large'], "vit parameter must be base or large" 145 | if vit == 'base': 146 | vision_width = 768 147 | visual_encoder = VisionTransformer( 148 | img_size=image_size, 149 | patch_size=16, 150 | embed_dim=vision_width, 151 | depth=12, 152 | num_heads=12, 153 | use_grad_checkpointing=use_grad_checkpointing, 154 | ckpt_layer=ckpt_layer, 155 | drop_path_rate=0 or drop_path_rate) 156 | elif vit == 'large': 157 | vision_width = 1024 158 | visual_encoder = VisionTransformer( 159 | img_size=image_size, 160 | patch_size=16, 161 | embed_dim=vision_width, 162 | depth=24, 163 | num_heads=16, 164 | use_grad_checkpointing=use_grad_checkpointing, 165 | ckpt_layer=ckpt_layer, 166 | drop_path_rate=0.1 or drop_path_rate) 167 | return visual_encoder, vision_width 168 | 169 | 170 | def is_url(url_or_filename): 171 | parsed = urlparse(url_or_filename) 172 | return parsed.scheme in ("http", "https") 173 | 174 | 175 | def load_checkpoint(model, url_or_filename): 176 | if is_url(url_or_filename): 177 | cached_file = download_cached_file(url_or_filename, 178 | check_hash=False, 179 | progress=True) 180 | checkpoint = torch.load(cached_file, map_location='cpu') 181 | elif os.path.isfile(url_or_filename): 182 | checkpoint = torch.load(url_or_filename, map_location='cpu') 183 | else: 184 | raise RuntimeError('checkpoint url or path is invalid') 185 | 186 | state_dict = checkpoint['model'] 187 | 188 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed( 189 | state_dict['visual_encoder.pos_embed'], model.visual_encoder) 190 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 191 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed( 192 | state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m) 193 | for key in model.state_dict().keys(): 194 | if key in state_dict.keys(): 195 | if state_dict[key].shape != model.state_dict()[key].shape: 196 | del state_dict[key] 197 | 198 | msg = model.load_state_dict(state_dict, strict=False) 199 | print('load checkpoint from %s' % url_or_filename) 200 | return model, msg 201 | 202 | 203 | def load_checkpoint_swinbase(model, url_or_filename, kwargs): 204 | if kwargs['image_size'] == 224: 205 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 206 | elif kwargs['image_size'] == 384: 207 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 208 | window_size = read_json(vision_config_path)['window_size'] 209 | print('--------------') 210 | print(url_or_filename) 211 | print('--------------') 212 | if is_url(url_or_filename): 213 | cached_file = download_cached_file(url_or_filename, 214 | check_hash=False, 215 | progress=True) 216 | checkpoint = torch.load(cached_file, map_location='cpu') 217 | elif os.path.isfile(url_or_filename): 218 | checkpoint = torch.load(url_or_filename, map_location='cpu') 219 | else: 220 | raise RuntimeError('checkpoint url or path is invalid') 221 | 222 | state_dict = checkpoint['model'] 223 | 224 | for k in list(state_dict.keys()): 225 | if 'relative_position_bias_table' in k: 226 | dst_num_pos = (2 * window_size - 1)**2 227 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], 228 | dst_num_pos, 229 | param_name=k) 230 | elif ('relative_position_index' in k) or ('attn_mask' in k): 231 | del state_dict[k] 232 | elif "vision_multi" in k: 233 | state_dict[k.replace("vision_multi", 234 | "tagging_head")] = state_dict.pop(k) 235 | 236 | msg = model.load_state_dict(state_dict, strict=False) 237 | print('load checkpoint from %s' % url_or_filename) 238 | return model, msg 239 | 240 | 241 | def load_checkpoint_swinlarge(model, url_or_filename, kwargs): 242 | if kwargs['image_size'] == 224: 243 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 244 | elif kwargs['image_size'] == 384: 245 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 246 | window_size = read_json(vision_config_path)['window_size'] 247 | print('--------------') 248 | print(url_or_filename) 249 | print('--------------') 250 | if is_url(url_or_filename): 251 | cached_file = download_cached_file(url_or_filename, 252 | check_hash=False, 253 | progress=True) 254 | checkpoint = torch.load(cached_file, map_location='cpu') 255 | elif os.path.isfile(url_or_filename): 256 | checkpoint = torch.load(url_or_filename, map_location='cpu') 257 | else: 258 | raise RuntimeError('checkpoint url or path is invalid') 259 | 260 | state_dict = checkpoint['model'] 261 | 262 | for k in list(state_dict.keys()): 263 | if 'relative_position_bias_table' in k: 264 | dst_num_pos = (2 * window_size - 1)**2 265 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], 266 | dst_num_pos, 267 | param_name=k) 268 | elif ('relative_position_index' in k) or ('attn_mask' in k): 269 | del state_dict[k] 270 | elif "vision_multi" in k: 271 | state_dict[k.replace("vision_multi", 272 | "tagging_head")] = state_dict.pop(k) 273 | 274 | msg = model.load_state_dict(state_dict, strict=False) 275 | print('load checkpoint from %s' % url_or_filename) 276 | return model, msg 277 | 278 | 279 | # Tagging loss function 280 | # copy from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py 281 | class AsymmetricLoss(nn.Module): 282 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): 283 | super(AsymmetricLoss, self).__init__() 284 | 285 | self.gamma_neg = gamma_neg 286 | self.gamma_pos = gamma_pos 287 | self.clip = clip 288 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 289 | self.eps = eps 290 | 291 | def forward(self, x, y): 292 | """" 293 | Parameters 294 | ---------- 295 | x: input logits 296 | y: targets (multi-label binarized vector) 297 | """ 298 | 299 | # Calculating Probabilities 300 | x_sigmoid = torch.sigmoid(x) 301 | xs_pos = x_sigmoid 302 | xs_neg = 1 - x_sigmoid 303 | 304 | # Asymmetric Clipping 305 | if self.clip is not None and self.clip > 0: 306 | xs_neg = (xs_neg + self.clip).clamp(max=1) 307 | 308 | # Basic CE calculation 309 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 310 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 311 | loss = los_pos + los_neg 312 | 313 | # Asymmetric Focusing 314 | if self.gamma_neg > 0 or self.gamma_pos > 0: 315 | if self.disable_torch_grad_focal_loss: 316 | torch.set_grad_enabled(False) 317 | pt0 = xs_pos * y 318 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 319 | pt = pt0 + pt1 320 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 321 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 322 | if self.disable_torch_grad_focal_loss: 323 | torch.set_grad_enabled(True) 324 | loss *= one_sided_w 325 | 326 | return -loss.sum() -------------------------------------------------------------------------------- /ram/models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /ram/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, Resize, ToTensor 2 | 3 | 4 | def convert_to_rgb(image): 5 | return image.convert("RGB") 6 | 7 | def get_transform(image_size=384): 8 | return Compose([ 9 | convert_to_rgb, 10 | Resize((image_size, image_size)), 11 | ToTensor(), 12 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 13 | ]) 14 | -------------------------------------------------------------------------------- /ram/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import get_mAP, get_PR 2 | from .openset_utils import build_openset_label_embedding, build_openset_llm_label_embedding 3 | -------------------------------------------------------------------------------- /ram/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from numpy import ndarray 5 | 6 | 7 | def get_mAP( 8 | preds: ndarray, 9 | gt_file: str, 10 | taglist: List[str] 11 | ) -> Tuple[float, ndarray]: 12 | assert preds.shape[1] == len(taglist) 13 | 14 | # When mapping categories from test datasets to our system, there might be 15 | # multiple vs one situation due to different semantic definitions of tags. 16 | # So there can be duplicate tags in `taglist`. This special case is taken 17 | # into account. 18 | tag2idxs = {} 19 | for idx, tag in enumerate(taglist): 20 | if tag not in tag2idxs: 21 | tag2idxs[tag] = [] 22 | tag2idxs[tag].append(idx) 23 | 24 | # build targets 25 | targets = np.zeros_like(preds) 26 | with open(gt_file, "r") as f: 27 | lines = [line.strip("\n").split(",") for line in f.readlines()] 28 | assert len(lines) == targets.shape[0] 29 | for i, line in enumerate(lines): 30 | for tag in line[1:]: 31 | targets[i, tag2idxs[tag]] = 1.0 32 | 33 | # compute average precision for each class 34 | APs = np.zeros(preds.shape[1]) 35 | for k in range(preds.shape[1]): 36 | APs[k] = _average_precision(preds[:, k], targets[:, k]) 37 | 38 | return APs.mean(), APs 39 | 40 | 41 | def _average_precision(output: ndarray, target: ndarray) -> float: 42 | epsilon = 1e-8 43 | 44 | # sort examples 45 | indices = output.argsort()[::-1] 46 | # Computes prec@i 47 | total_count_ = np.cumsum(np.ones((len(output), 1))) 48 | 49 | target_ = target[indices] 50 | ind = target_ == 1 51 | pos_count_ = np.cumsum(ind) 52 | total = pos_count_[-1] 53 | pos_count_[np.logical_not(ind)] = 0 54 | pp = pos_count_ / total_count_ 55 | precision_at_i_ = np.sum(pp) 56 | precision_at_i = precision_at_i_ / (total + epsilon) 57 | 58 | return precision_at_i 59 | 60 | 61 | def get_PR( 62 | pred_file: str, 63 | gt_file: str, 64 | taglist: List[str] 65 | ) -> Tuple[float, float, ndarray, ndarray]: 66 | # When mapping categories from test datasets to our system, there might be 67 | # multiple vs one situation due to different semantic definitions of tags. 68 | # So there can be duplicate tags in `taglist`. This special case is taken 69 | # into account. 70 | tag2idxs = {} 71 | for idx, tag in enumerate(taglist): 72 | if tag not in tag2idxs: 73 | tag2idxs[tag] = [] 74 | tag2idxs[tag].append(idx) 75 | 76 | # build preds 77 | with open(pred_file, "r", encoding="utf-8") as f: 78 | lines = [line.strip().split(",") for line in f.readlines()] 79 | preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 80 | for i, line in enumerate(lines): 81 | for tag in line[1:]: 82 | preds[i, tag2idxs[tag]] = True 83 | 84 | # build targets 85 | with open(gt_file, "r", encoding="utf-8") as f: 86 | lines = [line.strip().split(",") for line in f.readlines()] 87 | targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 88 | for i, line in enumerate(lines): 89 | for tag in line[1:]: 90 | targets[i, tag2idxs[tag]] = True 91 | 92 | assert preds.shape == targets.shape 93 | 94 | # calculate P and R 95 | TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222 96 | FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222 97 | FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222 98 | eps = 1.e-9 99 | Ps = TPs / (TPs + FPs + eps) 100 | Rs = TPs / (TPs + FNs + eps) 101 | 102 | return Ps.mean(), Rs.mean(), Ps, Rs 103 | -------------------------------------------------------------------------------- /ram/utils/openset_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | from clip import clip 7 | 8 | 9 | def article(name): 10 | return "an" if name[0] in "aeiou" else "a" 11 | 12 | 13 | def processed_name(name, rm_dot=False): 14 | # _ for lvis 15 | # / for obj365 16 | res = name.replace("_", " ").replace("/", " or ").lower() 17 | if rm_dot: 18 | res = res.rstrip(".") 19 | return res 20 | 21 | 22 | single_template = ["a photo of a {}."] 23 | 24 | multiple_templates = [ 25 | "There is {article} {} in the scene.", 26 | "There is the {} in the scene.", 27 | "a photo of {article} {} in the scene.", 28 | "a photo of the {} in the scene.", 29 | "a photo of one {} in the scene.", 30 | "itap of {article} {}.", 31 | "itap of my {}.", # itap: I took a picture of 32 | "itap of the {}.", 33 | "a photo of {article} {}.", 34 | "a photo of my {}.", 35 | "a photo of the {}.", 36 | "a photo of one {}.", 37 | "a photo of many {}.", 38 | "a good photo of {article} {}.", 39 | "a good photo of the {}.", 40 | "a bad photo of {article} {}.", 41 | "a bad photo of the {}.", 42 | "a photo of a nice {}.", 43 | "a photo of the nice {}.", 44 | "a photo of a cool {}.", 45 | "a photo of the cool {}.", 46 | "a photo of a weird {}.", 47 | "a photo of the weird {}.", 48 | "a photo of a small {}.", 49 | "a photo of the small {}.", 50 | "a photo of a large {}.", 51 | "a photo of the large {}.", 52 | "a photo of a clean {}.", 53 | "a photo of the clean {}.", 54 | "a photo of a dirty {}.", 55 | "a photo of the dirty {}.", 56 | "a bright photo of {article} {}.", 57 | "a bright photo of the {}.", 58 | "a dark photo of {article} {}.", 59 | "a dark photo of the {}.", 60 | "a photo of a hard to see {}.", 61 | "a photo of the hard to see {}.", 62 | "a low resolution photo of {article} {}.", 63 | "a low resolution photo of the {}.", 64 | "a cropped photo of {article} {}.", 65 | "a cropped photo of the {}.", 66 | "a close-up photo of {article} {}.", 67 | "a close-up photo of the {}.", 68 | "a jpeg corrupted photo of {article} {}.", 69 | "a jpeg corrupted photo of the {}.", 70 | "a blurry photo of {article} {}.", 71 | "a blurry photo of the {}.", 72 | "a pixelated photo of {article} {}.", 73 | "a pixelated photo of the {}.", 74 | "a black and white photo of the {}.", 75 | "a black and white photo of {article} {}.", 76 | "a plastic {}.", 77 | "the plastic {}.", 78 | "a toy {}.", 79 | "the toy {}.", 80 | "a plushie {}.", 81 | "the plushie {}.", 82 | "a cartoon {}.", 83 | "the cartoon {}.", 84 | "an embroidered {}.", 85 | "the embroidered {}.", 86 | "a painting of the {}.", 87 | "a painting of a {}.", 88 | ] 89 | 90 | 91 | openimages_rare_unseen = ['Aerial photography', 92 | 'Aircraft engine', 93 | 'Ale', 94 | 'Aloe', 95 | 'Amphibian', 96 | 'Angling', 97 | 'Anole', 98 | 'Antique car', 99 | 'Arcade game', 100 | 'Arthropod', 101 | 'Assault rifle', 102 | 'Athletic shoe', 103 | 'Auto racing', 104 | 'Backlighting', 105 | 'Bagpipes', 106 | 'Ball game', 107 | 'Barbecue chicken', 108 | 'Barechested', 109 | 'Barquentine', 110 | 'Beef tenderloin', 111 | 'Billiard room', 112 | 'Billiards', 113 | 'Bird of prey', 114 | 'Black swan', 115 | 'Black-and-white', 116 | 'Blond', 117 | 'Boating', 118 | 'Bonbon', 119 | 'Bottled water', 120 | 'Bouldering', 121 | 'Bovine', 122 | 'Bratwurst', 123 | 'Breadboard', 124 | 'Briefs', 125 | 'Brisket', 126 | 'Brochette', 127 | 'Calabaza', 128 | 'Camera operator', 129 | 'Canola', 130 | 'Childbirth', 131 | 'Chordophone', 132 | 'Church bell', 133 | 'Classical sculpture', 134 | 'Close-up', 135 | 'Cobblestone', 136 | 'Coca-cola', 137 | 'Combat sport', 138 | 'Comics', 139 | 'Compact car', 140 | 'Computer speaker', 141 | 'Cookies and crackers', 142 | 'Coral reef fish', 143 | 'Corn on the cob', 144 | 'Cosmetics', 145 | 'Crocodilia', 146 | 'Digital camera', 147 | 'Dishware', 148 | 'Divemaster', 149 | 'Dobermann', 150 | 'Dog walking', 151 | 'Domestic rabbit', 152 | 'Domestic short-haired cat', 153 | 'Double-decker bus', 154 | 'Drums', 155 | 'Electric guitar', 156 | 'Electric piano', 157 | 'Electronic instrument', 158 | 'Equestrianism', 159 | 'Equitation', 160 | 'Erinaceidae', 161 | 'Extreme sport', 162 | 'Falafel', 163 | 'Figure skating', 164 | 'Filling station', 165 | 'Fire apparatus', 166 | 'Firearm', 167 | 'Flatbread', 168 | 'Floristry', 169 | 'Forklift truck', 170 | 'Freight transport', 171 | 'Fried food', 172 | 'Fried noodles', 173 | 'Frigate', 174 | 'Frozen yogurt', 175 | 'Frying', 176 | 'Full moon', 177 | 'Galleon', 178 | 'Glacial landform', 179 | 'Gliding', 180 | 'Go-kart', 181 | 'Goats', 182 | 'Grappling', 183 | 'Great white shark', 184 | 'Gumbo', 185 | 'Gun turret', 186 | 'Hair coloring', 187 | 'Halter', 188 | 'Headphones', 189 | 'Heavy cruiser', 190 | 'Herding', 191 | 'High-speed rail', 192 | 'Holding hands', 193 | 'Horse and buggy', 194 | 'Horse racing', 195 | 'Hound', 196 | 'Hunting knife', 197 | 'Hurdling', 198 | 'Inflatable', 199 | 'Jackfruit', 200 | 'Jeans', 201 | 'Jiaozi', 202 | 'Junk food', 203 | 'Khinkali', 204 | 'Kitesurfing', 205 | 'Lawn game', 206 | 'Leaf vegetable', 207 | 'Lechon', 208 | 'Lifebuoy', 209 | 'Locust', 210 | 'Lumpia', 211 | 'Luxury vehicle', 212 | 'Machine tool', 213 | 'Medical imaging', 214 | 'Melee weapon', 215 | 'Microcontroller', 216 | 'Middle ages', 217 | 'Military person', 218 | 'Military vehicle', 219 | 'Milky way', 220 | 'Miniature Poodle', 221 | 'Modern dance', 222 | 'Molluscs', 223 | 'Monoplane', 224 | 'Motorcycling', 225 | 'Musical theatre', 226 | 'Narcissus', 227 | 'Nest box', 228 | 'Newsagent\'s shop', 229 | 'Nile crocodile', 230 | 'Nordic skiing', 231 | 'Nuclear power plant', 232 | 'Orator', 233 | 'Outdoor shoe', 234 | 'Parachuting', 235 | 'Pasta salad', 236 | 'Peafowl', 237 | 'Pelmeni', 238 | 'Perching bird', 239 | 'Performance car', 240 | 'Personal water craft', 241 | 'Pit bull', 242 | 'Plant stem', 243 | 'Pork chop', 244 | 'Portrait photography', 245 | 'Primate', 246 | 'Procyonidae', 247 | 'Prosciutto', 248 | 'Public speaking', 249 | 'Racewalking', 250 | 'Ramen', 251 | 'Rear-view mirror', 252 | 'Residential area', 253 | 'Ribs', 254 | 'Rice ball', 255 | 'Road cycling', 256 | 'Roller skating', 257 | 'Roman temple', 258 | 'Rowing', 259 | 'Rural area', 260 | 'Sailboat racing', 261 | 'Scaled reptile', 262 | 'Scuba diving', 263 | 'Senior citizen', 264 | 'Shallot', 265 | 'Shinto shrine', 266 | 'Shooting range', 267 | 'Siberian husky', 268 | 'Sledding', 269 | 'Soba', 270 | 'Solar energy', 271 | 'Sport climbing', 272 | 'Sport utility vehicle', 273 | 'Steamed rice', 274 | 'Stemware', 275 | 'Sumo', 276 | 'Surfing Equipment', 277 | 'Team sport', 278 | 'Touring car', 279 | 'Toy block', 280 | 'Trampolining', 281 | 'Underwater diving', 282 | 'Vegetarian food', 283 | 'Wallaby', 284 | 'Water polo', 285 | 'Watercolor paint', 286 | 'Whiskers', 287 | 'Wind wave', 288 | 'Woodwind instrument', 289 | 'Yakitori', 290 | 'Zeppelin'] 291 | 292 | 293 | def build_openset_label_embedding(categories=None): 294 | if categories is None: 295 | categories = openimages_rare_unseen 296 | print("Creating pretrained CLIP model") 297 | model, _ = clip.load("ViT-B/16") 298 | templates = multiple_templates 299 | 300 | run_on_gpu = torch.cuda.is_available() 301 | 302 | with torch.no_grad(): 303 | openset_label_embedding = [] 304 | for category in categories: 305 | texts = [ 306 | template.format( 307 | processed_name(category, rm_dot=True), article=article(category) 308 | ) 309 | for template in templates 310 | ] 311 | texts = [ 312 | "This is " + text if text.startswith("a") or text.startswith("the") else text 313 | for text in texts 314 | ] 315 | texts = clip.tokenize(texts) # tokenize 316 | if run_on_gpu: 317 | texts = texts.cuda() 318 | model = model.cuda() 319 | text_embeddings = model.encode_text(texts) 320 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 321 | text_embedding = text_embeddings.mean(dim=0) 322 | text_embedding /= text_embedding.norm() 323 | openset_label_embedding.append(text_embedding) 324 | openset_label_embedding = torch.stack(openset_label_embedding, dim=1) 325 | if run_on_gpu: 326 | openset_label_embedding = openset_label_embedding.cuda() 327 | 328 | openset_label_embedding = openset_label_embedding.t() 329 | return openset_label_embedding, categories 330 | 331 | 332 | 333 | import json 334 | from tqdm import tqdm 335 | 336 | def build_openset_llm_label_embedding(llm_tag_des): 337 | print("Creating pretrained CLIP model") 338 | model, _ = clip.load("ViT-B/16") 339 | llm_tag_des = llm_tag_des 340 | categories = [] 341 | 342 | run_on_gpu = torch.cuda.is_available() 343 | 344 | with torch.no_grad(): 345 | openset_label_embedding = [] 346 | for item in tqdm(llm_tag_des): 347 | category = list(item.keys())[0] 348 | des = list(item.values())[0] 349 | 350 | categories.append(category) 351 | 352 | texts = clip.tokenize(des, truncate=True) # tokenize 353 | if run_on_gpu: 354 | texts = texts.cuda() 355 | model = model.cuda() 356 | text_embeddings = model.encode_text(texts) 357 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 358 | # text_embedding = text_embeddings.mean(dim=0) 359 | # text_embedding /= text_embedding.norm() 360 | # openset_label_embedding.append(text_embedding) 361 | openset_label_embedding.append(text_embeddings) 362 | # openset_label_embedding = torch.stack(openset_label_embedding, dim=1) 363 | openset_label_embedding = torch.cat(openset_label_embedding, dim=0) 364 | if run_on_gpu: 365 | openset_label_embedding = openset_label_embedding.cuda() 366 | 367 | # openset_label_embedding = openset_label_embedding.t() 368 | return openset_label_embedding, categories 369 | 370 | 371 | 372 | 373 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.32.1 2 | einops==0.8.0 3 | gradio==4.38.1 4 | joblib==1.2.0 5 | numpy==1.26.4 6 | omegaconf==2.3.0 7 | packaging==24.1 8 | pandas==1.5.3 9 | peft==0.11.1 10 | Pillow==9.4.0 11 | pycocoevalcap==1.2 12 | pycocotools==2.0.8 13 | pynvml==11.5.0 14 | quantization==0.0.1 15 | Requests==2.32.3 16 | scikit_learn==1.5.1 17 | spacy==3.7.5 18 | timm==0.4.12 19 | torch==2.3.1 20 | torchmetrics==0.11.4 21 | torchvision==0.18.1 22 | tqdm==4.64.1 23 | transformers==4.42.4 24 | webdataset==0.2.86 25 | fairscale==0.4.4 --------------------------------------------------------------------------------