├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── base_models ├── __init__.py ├── inspyrenet │ ├── InSPyReNet.py │ ├── __init__.py │ ├── backbones │ │ ├── Res2Net_v1b.py │ │ └── SwinTransformer.py │ ├── modules │ │ ├── attention_module.py │ │ ├── context_module.py │ │ ├── decoder_module.py │ │ └── layers.py │ ├── optim │ │ ├── __init__.py │ │ ├── losses.py │ │ └── scheduler.py │ └── saliency_transforms.py ├── tcl │ ├── __init__.py │ ├── tcl_config_bert.json │ ├── tcl_model_pretrain.py │ ├── tcl_tokenization_bert.py │ ├── tcl_vit.py │ └── tcl_xbert.py └── xvlm │ ├── config_bert.json │ ├── swin_transformer.py │ ├── vit.py │ ├── xbert.py │ └── xvlm.py ├── configs ├── __init__.py ├── base_config.yaml ├── benchmarks │ ├── gqa.yaml │ ├── nextqa.yaml │ ├── okvqa.yaml │ └── refcoco.yaml ├── config_codellama.yaml └── my_config.yaml ├── data └── queries.csv ├── datasets ├── __init__.py ├── gqa.py ├── my_dataset.py ├── nextqa.py ├── okvqa.py └── refcoco.py ├── download_models.sh ├── image_patch.py ├── main_batch.py ├── main_simple.ipynb ├── main_simple_lib.py ├── prompts ├── api.prompt ├── benchmarks │ ├── gqa.prompt │ ├── nextqa.prompt │ ├── okvqa.prompt │ └── refcoco.prompt ├── chatapi.prompt ├── fixed_code │ ├── blip2.prompt │ ├── blip2_video.prompt │ └── glip.prompt └── gpt3 │ ├── gpt3_process_guess.txt │ ├── gpt3_qa.txt │ └── video_question.txt ├── requirements.txt ├── setup.sh ├── setup_env.sh ├── teaser.gif ├── useful_lists ├── possible_options.json └── random_negatives.txt ├── utils.py ├── video_segment.py ├── vision_models.py └── vision_processes.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,jupyternotebooks 3 | 4 | ### JupyterNotebooks ### 5 | # gitignore template for Jupyter Notebooks 6 | # website: http://jupyter.org/ 7 | 8 | .ipynb_checkpoints 9 | */.ipynb_checkpoints/* 10 | 11 | # IPython 12 | profile_default/ 13 | ipython_config.py 14 | 15 | # Remove previous ipynb_checkpoints 16 | # git rm -r .ipynb_checkpoints/ 17 | 18 | ### Python ### 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | cover/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | .pybuilder/ 94 | target/ 95 | 96 | # Jupyter Notebook 97 | 98 | # IPython 99 | 100 | # pyenv 101 | # For a library or package, you might want to ignore these files since the code is 102 | # intended to run in multiple environments; otherwise, check them in: 103 | # .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # poetry 113 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 114 | # This is especially recommended for binary packages to ensure reproducibility, and is more 115 | # commonly ignored for libraries. 116 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 117 | #poetry.lock 118 | 119 | # pdm 120 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 121 | #pdm.lock 122 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 123 | # in version control. 124 | # https://pdm.fming.dev/#use-with-ide 125 | .pdm.toml 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # pytype static type analyzer 165 | .pytype/ 166 | 167 | # Cython debug symbols 168 | cython_debug/ 169 | 170 | # PyCharm 171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 173 | # and can be added to the global gitignore or merged into this file. For a more nuclear 174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 175 | #.idea/ 176 | 177 | ### Python Patch ### 178 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 179 | poetry.toml 180 | 181 | 182 | ### VisualStudioCode ### 183 | .vscode/* 184 | !.vscode/settings.json 185 | !.vscode/tasks.json 186 | !.vscode/launch.json 187 | !.vscode/extensions.json 188 | !.vscode/*.code-snippets 189 | 190 | # Local History for Visual Studio Code 191 | .history/ 192 | 193 | # Built Visual Studio Code Extensions 194 | *.vsix 195 | 196 | ### VisualStudioCode Patch ### 197 | # Ignore all local history of files 198 | .history 199 | .ionide 200 | 201 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks 202 | n 203 | *cache/ 204 | images/ 205 | /chekpoints/ 206 | /outputs/ 207 | /results/ 208 | /api.key 209 | /.vscode/ 210 | /wandb/ 211 | *.pyscratch 212 | *.key 213 | MKT/ 214 | RLIP/ 215 | scratchwork/ 216 | blip2_runner.py 217 | LAVIS/ 218 | pretrained_models/ 219 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "GLIP"] 2 | path = GLIP 3 | url = https://github.com/sachit-menon/GLIP.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViperGPT: Visual Inference via Python Execution for Reasoning 2 | 3 | This is the code for the paper [ViperGPT: Visual Inference via Python Execution for Reasoning](https://viper.cs.columbia.edu) by [Dídac Surís](https://www.didacsuris.com/)\*, [Sachit Menon](https://sachit-menon.github.io/)\* and [Carl Vondrick](https://www.cs.columbia.edu/~vondrick/). 4 | 5 | ![teaser](teaser.gif "Teaser") 6 | 7 | ## Quickstart 8 | Clone recursively: 9 | ```bash 10 | git clone --recurse-submodules https://github.com/cvlab-columbia/viper.git 11 | ``` 12 | 13 | After cloning: 14 | ```bash 15 | cd viper 16 | export PATH=/usr/local/cuda/bin:$PATH 17 | bash setup.sh # This may take a while. Make sure the vipergpt environment is active 18 | cd GLIP 19 | python setup.py clean --all build develop --user 20 | cd .. 21 | echo YOUR_OPENAI_API_KEY_HERE > api.key 22 | ``` 23 | Then you can start exploring with the `main_simple.ipynb` notebook. For running on datasets instead of individual 24 | examples, use `main_batch.py` as discussed later on. 25 | 26 | > :warning: WARNING: ViperGPT runs code generated by a large language model. We do not have direct control over this 27 | > code, so it can be dangerous to run it, especially if modifications to the API are made (the current prompts do not 28 | > have any dangerous functions like interaction with the filesystem, so it is unlikely that any malicious code can be 29 | > generated). We cannot guarantee that the code is safe, so use at your own risk, or run in a sandboxed environment. 30 | > For this reason, the default `execute_code` parameter in the config is `False`. Set it to `True` if you would like the 31 | > generated code to be executed automatically in `main_batch.py`, otherwise you can execute it yourself (as in 32 | > `main_simple.ipynb`). 33 | 34 | 35 | > :information_source: NOTE: OpenAI discontinued support for the Codex API on March 23rd, 2023. This repository implements 36 | > GPT-3.5 Turbo and GPT-4 as alternatives, but we have not tested them extensively; as they are chat models and not completion, their behavior likely differs. 37 | 38 | ## Detailed Installation 39 | The easiest way to get started exploring ViperGPT is through `main_simple.ipynb`. To run it, you will need to do the following: 40 | 1. Clone this repository with its submodules. 41 | 2. Install the dependencies. See the see [Dependencies](#Dependencies). 42 | 3. Download two pretrained models (the rest are downloaded automatically). See [Pretrained models](#Pretrained-models). 43 | 4. Set up the OpenAI key. See [OpenAI key](#OpenAI-key). 44 | 45 | ### Cloning this Repo 46 | 47 | ```bash 48 | git clone --recurse-submodules https://github.com/cvlab-columbia/viper.git 49 | ``` 50 | 51 | ### Dependencies 52 | 53 | First, create a conda environment using `setup_env.sh` and then install our modified version of GLIP. 54 | To do so, just `cd` into the `viper` directory, and run: 55 | 56 | ```bash 57 | export PATH=/usr/local/cuda/bin:$PATH 58 | bash setup_env.sh 59 | conda activate vipergpt 60 | cd GLIP 61 | python setup.py clean --all build develop --user 62 | ``` 63 | 64 | Please make sure to install GLIP as described (i.e., from our provided repo) as we have updated the CUDA kernels to be 65 | compatible with newer versions of PyTorch, which are required for other models. 66 | 67 | ### Pretrained models 68 | 69 | Note that ViperGPT may inherit biases from the pretrained models it uses. These biases may be reflected in the outputs 70 | generated by our model. It is recommended to consider this potential bias when using ViperGPT and interpreting its 71 | outputs. 72 | 73 | This repository implements more models than the ones described in the paper, which can be useful for further research. 74 | Most of the implemented modules automatically download the pretrained models. However, there are four models that 75 | need to be downloaded manually, if they are to be used. They have to be stored in the same directory 76 | `/path/to/pretrained_models`, by default `./pretrained_models/`, which has to be specified in the configuration (see [Configuration](#Configuration)). 77 | 78 | We provide the convenience script `download_models.sh` to perform this download for you; you can set the variable $PRETRAINED_MODEL_PATH match your config's `/path/to/pretrained_models/`. 79 | 80 | #### Pretrained model system requirements 81 | 82 | Many of the models used are very large, and require quite a bit of GPU memory. In particular, GLIP and BLIP2 are especially large. Please use smaller variants of those models if running on hardware that cannot support the larger ones; however, this comes at the expense of performance. 83 | 84 | ### OpenAI key 85 | 86 | To run the OpenAI models, you will need to configure an OpenAI key. This can be done by signing up for an account [e.g. here](https://platform.openai.com/), and then creating a key in [account/api-keys](https://platform.openai.com/account/api-keys). 87 | **Create a file `api.key` and store the key in it.** 88 | 89 | ## Running the code 90 | 91 | Once the previous steps are done, you can run the Jupyter Notebook `main_simple.ipynb`. This notebook contains 92 | the code to try ViperGPT on your own images. The notebook is well documented, and it describes how to use the code. 93 | 94 | ## Dataset 95 | 96 | You can run ViperGPT on a pre-defined set of query-image/video pairs as well. In order to do that, you will have to 97 | create a `queries.csv` file, which contains the queries and the filenames for the corresponding images/videos. The format of the file is 98 | `query,answer,image_name/video_name`. The answer is optional, and only needed for evaluation. See `data` for an example. 99 | 100 | Your dataset directory will contain the `queries.csv` file as well as the images/videos in the `images`/`videos` 101 | directory. Add the path to the dataset directory in the configuration (see [Configuration](#Configuration)). 102 | 103 | ## Configuration 104 | 105 | All the configuration parameters are defined in `configs/base_config.yaml`. In order to run the code, 106 | modify the paths in the parameters `path_pretrained_models` and optionally `dataset.data_path` to point to the correct 107 | directories. 108 | 109 | For every new configuration you need to run, create a new yaml file in the `configs` directory (like `my_config.yaml`), 110 | and modify the parameters you need to change. The parameters in the new file will overwrite 111 | the ones in `base_config.yaml`. Any number of configuration files can be specified, they will be merged in the order 112 | they are specified in the command line. 113 | 114 | The `multiprocessing` parameter refers to *both* the batch (every sample is run by a different worker) and the models 115 | (every model runs in its own process). 116 | 117 | ## Running the code on a dataset, without the Jupyter notebook 118 | 119 | The code can be run using the following command: 120 | 121 | ```bash 122 | CONFIG_NAMES=your_config_name python main_batch.py 123 | ``` 124 | 125 | `CONFIG_NAMES` is an environment variable that specifies the configuration files to use. 126 | 127 | If you want to run the code using multiprocessing, set `multiprocessing: True` in the config file. 128 | 129 | It is especially important to consider the risks of executing arbitrary code when running in a batch; in particular, if you modify the API or any inputs to Codex, be mindful to not include potentially damaging abilities such as file modification/deletion. 130 | 131 | ## Code structure 132 | 133 | The code is prepared to run in a multiprocessing manner, from two points of view. First, it runs the models in parallel, 134 | meaning that each pretrained model runs in its own process. Second, it runs the samples in parallel, meaning that 135 | several workers are created to run the samples for a given batch. There is a producer-consumer queuing mechanism where 136 | the processes controlling the models are the consumers of inputs coming from the workers that run each sample 137 | (producer). Our implementation allows for batching of samples, which means that several workers can send their inputs to 138 | the same model process, which will run them as a batch, and return the output to each worker separately. 139 | 140 | The code has comments and docstrings, but here is a brief overview of the code structure: 141 | - `vision_models.py`: Contains the code for the pretrained models. Each one of them is a subclass of `BaseModel`. 142 | Implementing a new model is easy. Just create a new class that inherits from `BaseModel` and implement the `forward` 143 | method, as well as the `name` method. The latter will be used to call the model. 144 | - `vision_processes.py`: Acts as a bridge between the models and the rest of the code. It contains the code for to start 145 | all the required processes, whether multiprocessing or not. It automatically detects all the new models implemented in 146 | `vision_models.py`. It defines a `forward` method that takes a name as input (as well as arguments), and calls the 147 | appropriate model. 148 | - `main_batch.py` and `main_simple.ipynb`: These are the main files to run the code. The former runs the whole dataset and 149 | is suited for parallel processing of samples, while the latter runs a single image/video and is suited for debugging. 150 | - `image_patch.py` and `video_segment.py`: These are the classes that represent the image patches and video segments. 151 | They contain all the methods that call the `forward` method of `vision_processes.py` and therefore call the models. 152 | - `configs`: Directory containing the configuration files. The configuration files are in YAML format, and read using 153 | OmegaConf. 154 | - `datasets`: Directory containing the code for the datasets. The datasets are subclasses of `torch.utils.data.Dataset`. 155 | - `prompts`: Directory containing the prompts for Codex and GPT-3. The Codex ones define the API specifications. 156 | - `utils.py`, `useful_lists` and `base_models`: Auxiliary files containing useful functions, lists and pretrained model 157 | implementations. 158 | 159 | ## Citation 160 | 161 | If you use this code, please consider citing the paper as: 162 | 163 | ``` 164 | @article{surismenon2023vipergpt, 165 | title={ViperGPT: Visual Inference via Python Execution for Reasoning}, 166 | author={D\'idac Sur\'is and Sachit Menon and Carl Vondrick}, 167 | journal={arXiv preprint arXiv:2303.08128}, 168 | year={2023} 169 | } 170 | ``` -------------------------------------------------------------------------------- /base_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-columbia/viper/09fe3465224766860d8dd4ec48db942f22b05092/base_models/__init__.py -------------------------------------------------------------------------------- /base_models/inspyrenet/InSPyReNet.py: -------------------------------------------------------------------------------- 1 | from base_models.inspyrenet.optim import * 2 | from base_models.inspyrenet.modules.context_module import * 3 | from base_models.inspyrenet.modules.attention_module import * 4 | from base_models.inspyrenet.modules.decoder_module import * 5 | from base_models.inspyrenet.backbones.Res2Net_v1b import res2net50_v1b_26w_4s 6 | from base_models.inspyrenet.backbones.SwinTransformer import SwinB 7 | 8 | class InSPyReNet(nn.Module): 9 | def __init__(self, backbone, in_channels, depth=64, base_size=[384, 384], threshold=512, **kwargs): 10 | super(InSPyReNet, self).__init__() 11 | self.backbone = backbone 12 | self.in_channels = in_channels 13 | self.depth = depth 14 | self.base_size = base_size 15 | self.threshold = threshold 16 | 17 | self.context1 = PAA_e(self.in_channels[0], self.depth, base_size=self.base_size, stage=0) 18 | self.context2 = PAA_e(self.in_channels[1], self.depth, base_size=self.base_size, stage=1) 19 | self.context3 = PAA_e(self.in_channels[2], self.depth, base_size=self.base_size, stage=2) 20 | self.context4 = PAA_e(self.in_channels[3], self.depth, base_size=self.base_size, stage=3) 21 | self.context5 = PAA_e(self.in_channels[4], self.depth, base_size=self.base_size, stage=4) 22 | 23 | self.decoder = PAA_d(self.depth * 3, depth=self.depth, base_size=base_size, stage=2) 24 | 25 | self.attention0 = SICA(self.depth , depth=self.depth, base_size=self.base_size, stage=0, lmap_in=True) 26 | self.attention1 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=1, lmap_in=True) 27 | self.attention2 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=2 ) 28 | 29 | self.sod_loss_fn = lambda x, y: weighted_bce_loss_with_logits(x, y, reduction='mean') + iou_loss_with_logits(x, y, reduction='mean') 30 | self.pc_loss_fn = nn.L1Loss() 31 | 32 | self.ret = lambda x, target: F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False) 33 | self.res = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=False) 34 | self.des = lambda x, size: F.interpolate(x, size=size, mode='nearest') 35 | 36 | self.image_pyramid = ImagePyramid(7, 1) 37 | 38 | self.transition0 = Transition(17) 39 | self.transition1 = Transition(9) 40 | self.transition2 = Transition(5) 41 | 42 | self.forward = self.forward_inference 43 | 44 | def to(self, device): 45 | self.image_pyramid.to(device) 46 | self.transition0.to(device) 47 | self.transition1.to(device) 48 | self.transition2.to(device) 49 | super(InSPyReNet, self).to(device) 50 | return self 51 | 52 | def cuda(self, idx=None): 53 | if idx is None: 54 | idx = torch.cuda.current_device() 55 | 56 | self.to(device="cuda:{}".format(idx)) 57 | return self 58 | 59 | def train(self, mode=True): 60 | super(InSPyReNet, self).train(mode) 61 | self.forward = self.forward_train 62 | return self 63 | 64 | def eval(self): 65 | super(InSPyReNet, self).train(False) 66 | self.forward = self.forward_inference 67 | return self 68 | 69 | def forward_inspyre(self, x): 70 | B, _, H, W = x.shape 71 | 72 | x1, x2, x3, x4, x5 = self.backbone(x) 73 | 74 | x1 = self.context1(x1) #4 75 | x2 = self.context2(x2) #4 76 | x3 = self.context3(x3) #8 77 | x4 = self.context4(x4) #16 78 | x5 = self.context5(x5) #32 79 | 80 | f3, d3 = self.decoder([x3, x4, x5]) #16 81 | 82 | f3 = self.res(f3, (H // 4, W // 4 )) 83 | f2, p2 = self.attention2(torch.cat([x2, f3], dim=1), d3.detach()) 84 | d2 = self.image_pyramid.reconstruct(d3.detach(), p2) #4 85 | 86 | x1 = self.res(x1, (H // 2, W // 2)) 87 | f2 = self.res(f2, (H // 2, W // 2)) 88 | f1, p1 = self.attention1(torch.cat([x1, f2], dim=1), d2.detach(), p2.detach()) #2 89 | d1 = self.image_pyramid.reconstruct(d2.detach(), p1) #2 90 | 91 | f1 = self.res(f1, (H, W)) 92 | _, p0 = self.attention0(f1, d1.detach(), p1.detach()) #2 93 | d0 = self.image_pyramid.reconstruct(d1.detach(), p0) #2 94 | 95 | out = dict() 96 | out['saliency'] = [d3, d2, d1, d0] 97 | out['laplacian'] = [p2, p1, p0] 98 | 99 | return out 100 | 101 | def forward_train(self, sample): 102 | x = sample['image'] 103 | B, _, H, W = x.shape 104 | out = self.forward_inspyre(x) 105 | 106 | d3, d2, d1, d0 = out['saliency'] 107 | p2, p1, p0 = out['laplacian'] 108 | 109 | if type(sample) == dict and 'gt' in sample.keys() and sample['gt'] is not None: 110 | y = sample['gt'] 111 | 112 | y1 = self.image_pyramid.reduce(y) 113 | y2 = self.image_pyramid.reduce(y1) 114 | y3 = self.image_pyramid.reduce(y2) 115 | 116 | loss = self.pc_loss_fn(self.des(d3, (H, W)), self.des(self.image_pyramid.reduce(d2), (H, W)).detach()) * 0.0001 117 | loss += self.pc_loss_fn(self.des(d2, (H, W)), self.des(self.image_pyramid.reduce(d1), (H, W)).detach()) * 0.0001 118 | loss += self.pc_loss_fn(self.des(d1, (H, W)), self.des(self.image_pyramid.reduce(d0), (H, W)).detach()) * 0.0001 119 | 120 | loss += self.sod_loss_fn(self.des(d3, (H, W)), self.des(y3, (H, W))) 121 | loss += self.sod_loss_fn(self.des(d2, (H, W)), self.des(y2, (H, W))) 122 | loss += self.sod_loss_fn(self.des(d1, (H, W)), self.des(y1, (H, W))) 123 | loss += self.sod_loss_fn(self.des(d0, (H, W)), self.des(y, (H, W))) 124 | 125 | else: 126 | loss = 0 127 | 128 | pred = torch.sigmoid(d0) 129 | pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) 130 | 131 | sample['pred'] = pred 132 | sample['loss'] = loss 133 | sample['saliency'] = [d3, d2, d1, d0] 134 | sample['laplacian'] = [p2, p1, p0] 135 | return sample 136 | 137 | def forward_inference(self, sample): 138 | B, _, H, W = sample['image'].shape 139 | 140 | if self.threshold is None: 141 | out = self.forward_inspyre(sample['image']) 142 | d3, d2, d1, d0 = out['saliency'] 143 | p2, p1, p0 = out['laplacian'] 144 | 145 | elif (H <= self.threshold or W <= self.threshold): 146 | if 'image_resized' in sample.keys(): 147 | out = self.forward_inspyre(sample['image_resized']) 148 | else: 149 | out = self.forward_inspyre(sample['image']) 150 | d3, d2, d1, d0 = out['saliency'] 151 | p2, p1, p0 = out['laplacian'] 152 | 153 | else: 154 | # LR Saliency Pyramid 155 | lr_out = self.forward_inspyre(sample['image_resized']) 156 | lr_d3, lr_d2, lr_d1, lr_d0 = lr_out['saliency'] 157 | lr_p2, lr_p1, lr_p0 = lr_out['laplacian'] 158 | 159 | # HR Saliency Pyramid 160 | hr_out = self.forward_inspyre(sample['image']) 161 | hr_d3, hr_d2, hr_d1, hr_d0 = hr_out['saliency'] 162 | hr_p2, hr_p1, hr_p0 = hr_out['laplacian'] 163 | 164 | # Pyramid Blending 165 | d3 = self.ret(lr_d0, hr_d3) 166 | 167 | t2 = self.ret(self.transition2(d3), hr_p2) 168 | p2 = t2 * hr_p2 169 | d2 = self.image_pyramid.reconstruct(d3, p2) 170 | 171 | t1 = self.ret(self.transition1(d2), hr_p1) 172 | p1 = t1 * hr_p1 173 | d1 = self.image_pyramid.reconstruct(d2, p1) 174 | 175 | t0 = self.ret(self.transition0(d1), hr_p0) 176 | p0 = t0 * hr_p0 177 | d0 = self.image_pyramid.reconstruct(d1, p0) 178 | 179 | pred = torch.sigmoid(d0) 180 | pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) 181 | 182 | sample['pred'] = pred 183 | sample['loss'] = 0 184 | sample['saliency'] = [d3, d2, d1, d0] 185 | sample['laplacian'] = [p2, p1, p0] 186 | return sample 187 | 188 | def InSPyReNet_Res2Net50(depth, pretrained, base_size, **kwargs): 189 | return InSPyReNet(res2net50_v1b_26w_4s(pretrained=pretrained), [64, 256, 512, 1024, 2048], depth, base_size, **kwargs) 190 | 191 | def InSPyReNet_SwinB(depth, pretrained, base_size, **kwargs): 192 | return InSPyReNet(SwinB(pretrained=pretrained), [128, 128, 256, 512, 1024], depth, base_size, **kwargs) -------------------------------------------------------------------------------- /base_models/inspyrenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-columbia/viper/09fe3465224766860d8dd4ec48db942f22b05092/base_models/inspyrenet/__init__.py -------------------------------------------------------------------------------- /base_models/inspyrenet/backbones/Res2Net_v1b.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['Res2Net', 'res2net50_v1b', 8 | 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | class Bottle2neck(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 18 | super(Bottle2neck, self).__init__() 19 | 20 | width = int(math.floor(planes * (baseWidth / 64.0))) 21 | self.conv1 = nn.Conv2d(inplanes, width * scale, 22 | kernel_size=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(width * scale) 24 | 25 | if scale == 1: 26 | self.nums = 1 27 | else: 28 | self.nums = scale - 1 29 | if stype == 'stage': 30 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 31 | convs = [] 32 | bns = [] 33 | for i in range(self.nums): 34 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, 35 | dilation=dilation, padding=dilation, bias=False)) 36 | bns.append(nn.BatchNorm2d(width)) 37 | self.convs = nn.ModuleList(convs) 38 | self.bns = nn.ModuleList(bns) 39 | 40 | self.conv3 = nn.Conv2d(width * scale, planes * 41 | self.expansion, kernel_size=1, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 43 | 44 | self.relu = nn.ReLU(inplace=True) 45 | self.downsample = downsample 46 | self.stype = stype 47 | self.scale = scale 48 | self.width = width 49 | 50 | def forward(self, x): 51 | residual = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | spx = torch.split(out, self.width, 1) 58 | for i in range(self.nums): 59 | if i == 0 or self.stype == 'stage': 60 | sp = spx[i] 61 | else: 62 | sp = sp + spx[i] 63 | sp = self.convs[i](sp) 64 | sp = self.relu(self.bns[i](sp)) 65 | if i == 0: 66 | out = sp 67 | else: 68 | out = torch.cat((out, sp), 1) 69 | if self.scale != 1 and self.stype == 'normal': 70 | out = torch.cat((out, spx[self.nums]), 1) 71 | elif self.scale != 1 and self.stype == 'stage': 72 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Res2Net(nn.Module): 87 | 88 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000, output_stride=32): 89 | self.inplanes = 64 90 | super(Res2Net, self).__init__() 91 | self.baseWidth = baseWidth 92 | self.scale = scale 93 | self.output_stride = output_stride 94 | if self.output_stride == 8: 95 | self.grid = [1, 2, 1] 96 | self.stride = [1, 2, 1, 1] 97 | self.dilation = [1, 1, 2, 4] 98 | elif self.output_stride == 16: 99 | self.grid = [1, 2, 4] 100 | self.stride = [1, 2, 2, 1] 101 | self.dilation = [1, 1, 1, 2] 102 | elif self.output_stride == 32: 103 | self.grid = [1, 2, 4] 104 | self.stride = [1, 2, 2, 2] 105 | self.dilation = [1, 1, 2, 4] 106 | self.conv1 = nn.Sequential( 107 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 108 | nn.BatchNorm2d(32), 109 | nn.ReLU(inplace=True), 110 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 111 | nn.BatchNorm2d(32), 112 | nn.ReLU(inplace=True), 113 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 114 | ) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU() 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer( 119 | block, 64, layers[0], stride=self.stride[0], dilation=self.dilation[0]) 120 | self.layer2 = self._make_layer( 121 | block, 128, layers[1], stride=self.stride[1], dilation=self.dilation[1]) 122 | self.layer3 = self._make_layer( 123 | block, 256, layers[2], stride=self.stride[2], dilation=self.dilation[2]) 124 | self.layer4 = self._make_layer( 125 | block, 512, layers[3], stride=self.stride[3], dilation=self.dilation[3], grid=self.grid) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_( 130 | m.weight, mode='fan_out', nonlinearity='relu') 131 | elif isinstance(m, nn.BatchNorm2d): 132 | nn.init.constant_(m.weight, 1) 133 | nn.init.constant_(m.bias, 0) 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, grid=None): 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | nn.AvgPool2d(kernel_size=stride, stride=stride, 140 | ceil_mode=True, count_include_pad=False), 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=1, bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = [] 147 | layers.append(block(self.inplanes, planes, stride, dilation, downsample=downsample, 148 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 149 | self.inplanes = planes * block.expansion 150 | 151 | if grid is not None: 152 | assert len(grid) == blocks 153 | else: 154 | grid = [1] * blocks 155 | 156 | for i in range(1, blocks): 157 | layers.append(block(self.inplanes, planes, dilation=dilation * 158 | grid[i], baseWidth=self.baseWidth, scale=self.scale)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def change_stride(self, output_stride=16): 163 | if output_stride == self.output_stride: 164 | return 165 | else: 166 | self.output_stride = output_stride 167 | if self.output_stride == 8: 168 | self.grid = [1, 2, 1] 169 | self.stride = [1, 2, 1, 1] 170 | self.dilation = [1, 1, 2, 4] 171 | elif self.output_stride == 16: 172 | self.grid = [1, 2, 4] 173 | self.stride = [1, 2, 2, 1] 174 | self.dilation = [1, 1, 1, 2] 175 | elif self.output_stride == 32: 176 | self.grid = [1, 2, 4] 177 | self.stride = [1, 2, 2, 2] 178 | self.dilation = [1, 1, 2, 4] 179 | 180 | for i, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): 181 | for j, block in enumerate(layer): 182 | if block.downsample is not None: 183 | block.downsample[0].kernel_size = ( 184 | self.stride[i], self.stride[i]) 185 | block.downsample[0].stride = ( 186 | self.stride[i], self.stride[i]) 187 | if hasattr(block, 'pool'): 188 | block.pool.stride = ( 189 | self.stride[i], self.stride[i]) 190 | for conv in block.convs: 191 | conv.stride = (self.stride[i], self.stride[i]) 192 | for conv in block.convs: 193 | d = self.dilation[i] if i != 3 else self.dilation[i] * \ 194 | self.grid[j] 195 | conv.dilation = (d, d) 196 | conv.padding = (d, d) 197 | 198 | def forward(self, x): 199 | x = self.conv1(x) 200 | x = self.bn1(x) 201 | x = self.relu(x) 202 | x = self.maxpool(x) 203 | 204 | out = [x] 205 | 206 | x = self.layer1(x) 207 | out.append(x) 208 | x = self.layer2(x) 209 | out.append(x) 210 | x = self.layer3(x) 211 | out.append(x) 212 | x = self.layer4(x) 213 | out.append(x) 214 | 215 | return out 216 | 217 | 218 | def res2net50_v1b(pretrained=False, **kwargs): 219 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 220 | if pretrained: 221 | model.load_state_dict(model_zoo.load_url( 222 | model_urls['res2net50_v1b_26w_4s'])) 223 | 224 | return model 225 | 226 | 227 | def res2net101_v1b(pretrained=False, **kwargs): 228 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], 229 | baseWidth=26, scale=4, **kwargs) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url( 232 | model_urls['res2net101_v1b_26w_4s'])) 233 | 234 | return model 235 | 236 | 237 | def res2net50_v1b_26w_4s(pretrained=True, **kwargs): 238 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 239 | if pretrained is True: 240 | model.load_state_dict(torch.load('data/backbone_ckpt/res2net50_v1b_26w_4s-3cf99910.pth', map_location='cpu')) 241 | 242 | return model 243 | 244 | 245 | def res2net101_v1b_26w_4s(pretrained=True, **kwargs): 246 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], 247 | baseWidth=26, scale=4, **kwargs) 248 | if pretrained is True: 249 | model.load_state_dict(torch.load('data/backbone_ckpt/res2net101_v1b_26w_4s-0812c246.pth', map_location='cpu')) 250 | 251 | return model 252 | 253 | 254 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 255 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], 256 | baseWidth=26, scale=4, **kwargs) 257 | if pretrained: 258 | model.load_state_dict(model_zoo.load_url( 259 | model_urls['res2net152_v1b_26w_4s'])) 260 | 261 | return model 262 | 263 | 264 | if __name__ == '__main__': 265 | images = torch.rand(1, 3, 224, 224).cuda(0) 266 | model = res2net50_v1b_26w_4s(pretrained=True) 267 | model = model.cuda(0) 268 | print(model(images).size()) 269 | -------------------------------------------------------------------------------- /base_models/inspyrenet/modules/attention_module.py: -------------------------------------------------------------------------------- 1 | from operator import xor 2 | 3 | from base_models.inspyrenet.modules.layers import * 4 | # from utils.misc import * 5 | class SICA(nn.Module): 6 | def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None, lmap_in=False): 7 | super(SICA, self).__init__() 8 | self.in_channel = in_channel 9 | self.depth = depth 10 | self.lmap_in = lmap_in 11 | if base_size is not None and stage is not None: 12 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage)) 13 | else: 14 | self.stage_size = None 15 | 16 | self.conv_query = nn.Sequential(Conv2d(in_channel, depth, 3, relu=True), 17 | Conv2d(depth, depth, 3, relu=True)) 18 | self.conv_key = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True), 19 | Conv2d(depth, depth, 1, relu=True)) 20 | self.conv_value = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True), 21 | Conv2d(depth, depth, 1, relu=True)) 22 | 23 | if self.lmap_in is True: 24 | self.ctx = 5 25 | else: 26 | self.ctx = 3 27 | 28 | self.conv_out1 = Conv2d(depth, depth, 3, relu=True) 29 | self.conv_out2 = Conv2d(in_channel + depth, depth, 3, relu=True) 30 | self.conv_out3 = Conv2d(depth, depth, 3, relu=True) 31 | self.conv_out4 = Conv2d(depth, out_channel, 1) 32 | 33 | self.threshold = Parameter(torch.tensor([0.5])) 34 | 35 | if self.lmap_in is True: 36 | self.lthreshold = Parameter(torch.tensor([0.5])) 37 | 38 | def forward(self, x, smap, lmap: Optional[torch.Tensor]=None): 39 | assert not xor(self.lmap_in is True, lmap is not None) 40 | b, c, h, w = x.shape 41 | 42 | # compute class probability 43 | smap = F.interpolate(smap, size=x.shape[-2:], mode='bilinear', align_corners=False) 44 | smap = torch.sigmoid(smap) 45 | p = smap - self.threshold 46 | 47 | fg = torch.clip(p, 0, 1) # foreground 48 | bg = torch.clip(-p, 0, 1) # background 49 | cg = self.threshold - torch.abs(p) # confusion area 50 | 51 | if self.lmap_in is True and lmap is not None: 52 | lmap = F.interpolate(lmap, size=x.shape[-2:], mode='bilinear', align_corners=False) 53 | lmap = torch.sigmoid(lmap) 54 | lp = lmap - self.lthreshold 55 | fp = torch.clip(lp, 0, 1) # foreground 56 | bp = torch.clip(-lp, 0, 1) # background 57 | 58 | prob = [fg, bg, cg, fp, bp] 59 | else: 60 | prob = [fg, bg, cg] 61 | 62 | prob = torch.cat(prob, dim=1) 63 | 64 | # reshape feature & prob 65 | if self.stage_size is not None: 66 | shape = self.stage_size 67 | shape_mul = self.stage_size[0] * self.stage_size[1] 68 | else: 69 | shape = (h, w) 70 | shape_mul = h * w 71 | 72 | f = F.interpolate(x, size=shape, mode='bilinear', align_corners=False).view(b, shape_mul, -1) 73 | prob = F.interpolate(prob, size=shape, mode='bilinear', align_corners=False).view(b, self.ctx, shape_mul) 74 | 75 | # compute context vector 76 | context = torch.bmm(prob, f).permute(0, 2, 1).unsqueeze(3) # b, 3, c 77 | 78 | # k q v compute 79 | query = self.conv_query(x).view(b, self.depth, -1).permute(0, 2, 1) 80 | key = self.conv_key(context).view(b, self.depth, -1) 81 | value = self.conv_value(context).view(b, self.depth, -1).permute(0, 2, 1) 82 | 83 | # compute similarity map 84 | sim = torch.bmm(query, key) # b, hw, c x b, c, 2 85 | sim = (self.depth ** -.5) * sim 86 | sim = F.softmax(sim, dim=-1) 87 | 88 | # compute refined feature 89 | context = torch.bmm(sim, value).permute(0, 2, 1).contiguous().view(b, -1, h, w) 90 | context = self.conv_out1(context) 91 | 92 | x = torch.cat([x, context], dim=1) 93 | x = self.conv_out2(x) 94 | x = self.conv_out3(x) 95 | out = self.conv_out4(x) 96 | 97 | return x, out -------------------------------------------------------------------------------- /base_models/inspyrenet/modules/context_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .layers import * 6 | 7 | class PAA_kernel(nn.Module): 8 | def __init__(self, in_channel, out_channel, receptive_size, stage_size=None): 9 | super(PAA_kernel, self).__init__() 10 | self.conv0 = Conv2d(in_channel, out_channel, 1) 11 | self.conv1 = Conv2d(out_channel, out_channel, kernel_size=(1, receptive_size)) 12 | self.conv2 = Conv2d(out_channel, out_channel, kernel_size=(receptive_size, 1)) 13 | self.conv3 = Conv2d(out_channel, out_channel, 3, dilation=receptive_size) 14 | self.Hattn = SelfAttention(out_channel, 'h', stage_size[0] if stage_size is not None else None) 15 | self.Wattn = SelfAttention(out_channel, 'w', stage_size[1] if stage_size is not None else None) 16 | 17 | def forward(self, x): 18 | x = self.conv0(x) 19 | x = self.conv1(x) 20 | x = self.conv2(x) 21 | 22 | Hx = self.Hattn(x) 23 | Wx = self.Wattn(x) 24 | 25 | x = self.conv3(Hx + Wx) 26 | return x 27 | 28 | class PAA_e(nn.Module): 29 | def __init__(self, in_channel, out_channel, base_size=None, stage=None): 30 | super(PAA_e, self).__init__() 31 | self.relu = nn.ReLU(True) 32 | if base_size is not None and stage is not None: 33 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage)) 34 | else: 35 | self.stage_size = None 36 | 37 | self.branch0 = Conv2d(in_channel, out_channel, 1) 38 | self.branch1 = PAA_kernel(in_channel, out_channel, 3, self.stage_size) 39 | self.branch2 = PAA_kernel(in_channel, out_channel, 5, self.stage_size) 40 | self.branch3 = PAA_kernel(in_channel, out_channel, 7, self.stage_size) 41 | 42 | self.conv_cat = Conv2d(4 * out_channel, out_channel, 3) 43 | self.conv_res = Conv2d(in_channel, out_channel, 1) 44 | 45 | def forward(self, x): 46 | x0 = self.branch0(x) 47 | x1 = self.branch1(x) 48 | x2 = self.branch2(x) 49 | x3 = self.branch3(x) 50 | 51 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 52 | x = self.relu(x_cat + self.conv_res(x)) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /base_models/inspyrenet/modules/decoder_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .layers import * 6 | class PAA_d(nn.Module): 7 | def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None): 8 | super(PAA_d, self).__init__() 9 | self.conv1 = Conv2d(in_channel ,depth, 3) 10 | self.conv2 = Conv2d(depth, depth, 3) 11 | self.conv3 = Conv2d(depth, depth, 3) 12 | self.conv4 = Conv2d(depth, depth, 3) 13 | self.conv5 = Conv2d(depth, out_channel, 3, bn=False) 14 | 15 | self.base_size = base_size 16 | self.stage = stage 17 | 18 | if base_size is not None and stage is not None: 19 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage)) 20 | else: 21 | self.stage_size = [None, None] 22 | 23 | self.Hattn = SelfAttention(depth, 'h', self.stage_size[0]) 24 | self.Wattn = SelfAttention(depth, 'w', self.stage_size[1]) 25 | 26 | self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True) 27 | 28 | def forward(self, fs): #f3 f4 f5 -> f3 f2 f1 29 | fx = fs[0] 30 | for i in range(1, len(fs)): 31 | fs[i] = self.upsample(fs[i], fx.shape[-2:]) 32 | fx = torch.cat(fs[::-1], dim=1) 33 | 34 | fx = self.conv1(fx) 35 | 36 | Hfx = self.Hattn(fx) 37 | Wfx = self.Wattn(fx) 38 | 39 | fx = self.conv2(Hfx + Wfx) 40 | fx = self.conv3(fx) 41 | fx = self.conv4(fx) 42 | out = self.conv5(fx) 43 | 44 | return fx, out -------------------------------------------------------------------------------- /base_models/inspyrenet/modules/layers.py: -------------------------------------------------------------------------------- 1 | from optparse import Option 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from kornia.morphology import dilation, erosion 10 | from torch.nn.parameter import Parameter 11 | from typing import Optional 12 | class ImagePyramid: 13 | def __init__(self, ksize=7, sigma=1, channels=1): 14 | self.ksize = ksize 15 | self.sigma = sigma 16 | self.channels = channels 17 | 18 | k = cv2.getGaussianKernel(ksize, sigma) 19 | k = np.outer(k, k) 20 | k = torch.tensor(k).float() 21 | self.kernel = k.repeat(channels, 1, 1, 1) 22 | 23 | def to(self, device): 24 | self.kernel = self.kernel.to(device) 25 | return self 26 | 27 | def cuda(self, idx=None): 28 | if idx is None: 29 | idx = torch.cuda.current_device() 30 | 31 | self.to(device="cuda:{}".format(idx)) 32 | return self 33 | 34 | def expand(self, x): 35 | z = torch.zeros_like(x) 36 | x = torch.cat([x, z, z, z], dim=1) 37 | x = F.pixel_shuffle(x, 2) 38 | x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect') 39 | x = F.conv2d(x, self.kernel * 4, groups=self.channels) 40 | return x 41 | 42 | def reduce(self, x): 43 | x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect') 44 | x = F.conv2d(x, self.kernel, groups=self.channels) 45 | x = x[:, :, ::2, ::2] 46 | return x 47 | 48 | def deconstruct(self, x): 49 | reduced_x = self.reduce(x) 50 | expanded_reduced_x = self.expand(reduced_x) 51 | 52 | if x.shape != expanded_reduced_x.shape: 53 | expanded_reduced_x = F.interpolate(expanded_reduced_x, x.shape[-2:]) 54 | 55 | laplacian_x = x - expanded_reduced_x 56 | return reduced_x, laplacian_x 57 | 58 | def reconstruct(self, x, laplacian_x): 59 | expanded_x = self.expand(x) 60 | if laplacian_x.shape != expanded_x: 61 | laplacian_x = F.interpolate(laplacian_x, expanded_x.shape[-2:], mode='bilinear', align_corners=True) 62 | return expanded_x + laplacian_x 63 | 64 | class Transition: 65 | def __init__(self, k=3): 66 | self.kernel = torch.tensor(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))).float() 67 | 68 | def to(self, device): 69 | self.kernel = self.kernel.to(device) 70 | return self 71 | 72 | def cuda(self, idx=None): 73 | if idx is None: 74 | idx = torch.cuda.current_device() 75 | 76 | self.to(device="cuda:{}".format(idx)) 77 | return self 78 | 79 | def __call__(self, x): 80 | x = torch.sigmoid(x) 81 | dx = dilation(x, self.kernel) 82 | ex = erosion(x, self.kernel) 83 | 84 | return ((dx - ex) > .5).float() 85 | 86 | class Conv2d(nn.Module): 87 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding='same', bias=False, bn=True, relu=False): 88 | super(Conv2d, self).__init__() 89 | if '__iter__' not in dir(kernel_size): 90 | kernel_size = (kernel_size, kernel_size) 91 | if '__iter__' not in dir(stride): 92 | stride = (stride, stride) 93 | if '__iter__' not in dir(dilation): 94 | dilation = (dilation, dilation) 95 | 96 | if padding == 'same': 97 | width_pad_size = kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1) 98 | height_pad_size = kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1) 99 | elif padding == 'valid': 100 | width_pad_size = 0 101 | height_pad_size = 0 102 | else: 103 | if '__iter__' in dir(padding): 104 | width_pad_size = padding[0] * 2 105 | height_pad_size = padding[1] * 2 106 | else: 107 | width_pad_size = padding * 2 108 | height_pad_size = padding * 2 109 | 110 | width_pad_size = width_pad_size // 2 + (width_pad_size % 2 - 1) 111 | height_pad_size = height_pad_size // 2 + (height_pad_size % 2 - 1) 112 | pad_size = (width_pad_size, height_pad_size) 113 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_size, dilation, groups, bias=bias) 114 | self.reset_parameters() 115 | 116 | if bn is True: 117 | self.bn = nn.BatchNorm2d(out_channels) 118 | else: 119 | self.bn = None 120 | 121 | if relu is True: 122 | self.relu = nn.ReLU(inplace=True) 123 | else: 124 | self.relu = None 125 | 126 | def forward(self, x): 127 | x = self.conv(x) 128 | if self.bn is not None: 129 | x = self.bn(x) 130 | if self.relu is not None: 131 | x = self.relu(x) 132 | return x 133 | 134 | def reset_parameters(self): 135 | nn.init.kaiming_normal_(self.conv.weight) 136 | 137 | 138 | class SelfAttention(nn.Module): 139 | def __init__(self, in_channels, mode='hw', stage_size=None): 140 | super(SelfAttention, self).__init__() 141 | 142 | self.mode = mode 143 | 144 | self.query_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1)) 145 | self.key_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1)) 146 | self.value_conv = Conv2d(in_channels, in_channels, kernel_size=(1, 1)) 147 | 148 | self.gamma = Parameter(torch.zeros(1)) 149 | self.softmax = nn.Softmax(dim=-1) 150 | 151 | self.stage_size = stage_size 152 | 153 | def forward(self, x): 154 | batch_size, channel, height, width = x.size() 155 | 156 | axis = 1 157 | if 'h' in self.mode: 158 | axis *= height 159 | if 'w' in self.mode: 160 | axis *= width 161 | 162 | view = (batch_size, -1, axis) 163 | 164 | projected_query = self.query_conv(x).view(*view).permute(0, 2, 1) 165 | projected_key = self.key_conv(x).view(*view) 166 | 167 | attention_map = torch.bmm(projected_query, projected_key) 168 | attention = self.softmax(attention_map) 169 | projected_value = self.value_conv(x).view(*view) 170 | 171 | out = torch.bmm(projected_value, attention.permute(0, 2, 1)) 172 | out = out.view(batch_size, channel, height, width) 173 | 174 | out = self.gamma * out + x 175 | return out 176 | -------------------------------------------------------------------------------- /base_models/inspyrenet/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * 2 | from .scheduler import * -------------------------------------------------------------------------------- /base_models/inspyrenet/optim/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def bce_loss(pred, mask, reduction='none'): 5 | bce = F.binary_cross_entropy(pred, mask, reduction=reduction) 6 | return bce 7 | 8 | def weighted_bce_loss(pred, mask, reduction='none'): 9 | weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 10 | weight = weight.flatten() 11 | 12 | bce = weight * bce_loss(pred, mask, reduction='none').flatten() 13 | 14 | if reduction == 'mean': 15 | bce = bce.mean() 16 | 17 | return bce 18 | 19 | def iou_loss(pred, mask, reduction='none'): 20 | inter = pred * mask 21 | union = pred + mask 22 | iou = 1 - (inter + 1) / (union - inter + 1) 23 | 24 | if reduction == 'mean': 25 | iou = iou.mean() 26 | 27 | return iou 28 | 29 | def bce_loss_with_logits(pred, mask, reduction='none'): 30 | return bce_loss(torch.sigmoid(pred), mask, reduction=reduction) 31 | 32 | def weighted_bce_loss_with_logits(pred, mask, reduction='none'): 33 | return weighted_bce_loss(torch.sigmoid(pred), mask, reduction=reduction) 34 | 35 | def iou_loss_with_logits(pred, mask, reduction='none'): 36 | return iou_loss(torch.sigmoid(pred), mask, reduction=reduction) -------------------------------------------------------------------------------- /base_models/inspyrenet/optim/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyLr(_LRScheduler): 5 | def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1): 6 | self.gamma = gamma 7 | self.max_iteration = max_iteration 8 | self.minimum_lr = minimum_lr 9 | self.warmup_iteration = warmup_iteration 10 | 11 | self.last_epoch = None 12 | self.base_lrs = [] 13 | 14 | super(PolyLr, self).__init__(optimizer, last_epoch) 15 | 16 | def poly_lr(self, base_lr, step): 17 | return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr 18 | 19 | def warmup_lr(self, base_lr, alpha): 20 | return base_lr * (1 / 10.0 * (1 - alpha) + alpha) 21 | 22 | def get_lr(self): 23 | if self.last_epoch < self.warmup_iteration: 24 | alpha = self.last_epoch / self.warmup_iteration 25 | lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in 26 | self.base_lrs] 27 | else: 28 | lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] 29 | 30 | return lrs -------------------------------------------------------------------------------- /base_models/inspyrenet/saliency_transforms.py: -------------------------------------------------------------------------------- 1 | from email.mime import base 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import sys 6 | import torch 7 | import torchvision.transforms as transforms 8 | import torch.nn.functional as F 9 | from PIL import Image, ImageOps, ImageFilter, ImageEnhance 10 | from typing import Optional 11 | 12 | filepath = os.path.split(__file__)[0] 13 | repopath = os.path.split(filepath)[0] 14 | sys.path.append(repopath) 15 | 16 | # from utils.misc import * 17 | 18 | 19 | class static_resize: 20 | # Resize for training 21 | # size: h x w 22 | def __init__(self, size=[384, 384], base_size=None): 23 | self.size = size[::-1] 24 | self.base_size = base_size[::-1] if base_size is not None else None 25 | 26 | def __call__(self, sample): 27 | sample['image'] = sample['image'].resize(self.size, Image.BILINEAR) 28 | if 'gt' in sample.keys(): 29 | sample['gt'] = sample['gt'].resize(self.size, Image.NEAREST) 30 | 31 | if self.base_size is not None: 32 | sample['image_resized'] = sample['image'].resize(self.size, Image.BILINEAR) 33 | if 'gt' in sample.keys(): 34 | sample['gt_resized'] = sample['gt'].resize(self.size, Image.NEAREST) 35 | 36 | return sample 37 | 38 | 39 | class dynamic_resize: 40 | # base_size: h x w 41 | def __init__(self, L=1280, base_size=[384, 384]): 42 | self.L = L 43 | self.base_size = base_size[::-1] 44 | 45 | def __call__(self, sample): 46 | size = list(sample['image'].size) 47 | if (size[0] >= size[1]) and size[1] > self.L: 48 | size[0] = size[0] / (size[1] / self.L) 49 | size[1] = self.L 50 | elif (size[1] > size[0]) and size[0] > self.L: 51 | size[1] = size[1] / (size[0] / self.L) 52 | size[0] = self.L 53 | size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32) 54 | 55 | if 'image' in sample.keys(): 56 | sample['image_resized'] = sample['image'].resize(self.base_size, Image.BILINEAR) 57 | sample['image'] = sample['image'].resize(size, Image.BILINEAR) 58 | 59 | if 'gt' in sample.keys(): 60 | sample['gt_resized'] = sample['gt'].resize(self.base_size, Image.NEAREST) 61 | sample['gt'] = sample['gt'].resize(size, Image.NEAREST) 62 | 63 | return sample 64 | 65 | 66 | class random_scale_crop: 67 | def __init__(self, range=[0.75, 1.25]): 68 | self.range = range 69 | 70 | def __call__(self, sample): 71 | scale = np.random.random() * (self.range[1] - self.range[0]) + self.range[0] 72 | if np.random.random() < 0.5: 73 | for key in sample.keys(): 74 | if key in ['image', 'gt']: 75 | base_size = sample[key].size 76 | 77 | scale_size = tuple((np.array(base_size) * scale).round().astype(int)) 78 | sample[key] = sample[key].resize(scale_size) 79 | 80 | lf = (sample[key].size[0] - base_size[0]) // 2 81 | up = (sample[key].size[1] - base_size[1]) // 2 82 | rg = (sample[key].size[0] + base_size[0]) // 2 83 | lw = (sample[key].size[1] + base_size[1]) // 2 84 | 85 | border = -min(0, min(lf, up)) 86 | sample[key] = ImageOps.expand(sample[key], border=border) 87 | sample[key] = sample[key].crop((lf + border, up + border, rg + border, lw + border)) 88 | return sample 89 | 90 | 91 | class random_flip: 92 | def __init__(self, lr=True, ud=True): 93 | self.lr = lr 94 | self.ud = ud 95 | 96 | def __call__(self, sample): 97 | lr = np.random.random() < 0.5 and self.lr is True 98 | ud = np.random.random() < 0.5 and self.ud is True 99 | 100 | for key in sample.keys(): 101 | if key in ['image', 'gt']: 102 | sample[key] = np.array(sample[key]) 103 | if lr: 104 | sample[key] = np.fliplr(sample[key]) 105 | if ud: 106 | sample[key] = np.flipud(sample[key]) 107 | sample[key] = Image.fromarray(sample[key]) 108 | 109 | return sample 110 | 111 | 112 | class random_rotate: 113 | def __init__(self, range=[0, 360], interval=1): 114 | self.range = range 115 | self.interval = interval 116 | 117 | def __call__(self, sample): 118 | rot = (np.random.randint(*self.range) // self.interval) * self.interval 119 | rot = rot + 360 if rot < 0 else rot 120 | 121 | if np.random.random() < 0.5: 122 | for key in sample.keys(): 123 | if key in ['image', 'gt']: 124 | base_size = sample[key].size 125 | sample[key] = sample[key].rotate(rot, expand=True, fillcolor=255 if key == 'depth' else None) 126 | 127 | sample[key] = sample[key].crop(((sample[key].size[0] - base_size[0]) // 2, 128 | (sample[key].size[1] - base_size[1]) // 2, 129 | (sample[key].size[0] + base_size[0]) // 2, 130 | (sample[key].size[1] + base_size[1]) // 2)) 131 | 132 | return sample 133 | 134 | 135 | class random_image_enhance: 136 | def __init__(self, methods=['contrast', 'brightness', 'sharpness']): 137 | self.enhance_method = [] 138 | if 'contrast' in methods: 139 | self.enhance_method.append(ImageEnhance.Contrast) 140 | if 'brightness' in methods: 141 | self.enhance_method.append(ImageEnhance.Brightness) 142 | if 'sharpness' in methods: 143 | self.enhance_method.append(ImageEnhance.Sharpness) 144 | 145 | def __call__(self, sample): 146 | if 'image' in sample.keys(): 147 | np.random.shuffle(self.enhance_method) 148 | 149 | for method in self.enhance_method: 150 | if np.random.random() > 0.5: 151 | enhancer = method(sample['image']) 152 | factor = float(1 + np.random.random() / 10) 153 | sample['image'] = enhancer.enhance(factor) 154 | 155 | return sample 156 | 157 | 158 | class tonumpy: 159 | def __init__(self): 160 | pass 161 | 162 | def __call__(self, sample): 163 | for key in sample.keys(): 164 | if key in ['image', 'image_resized', 'gt', 'gt_resized']: 165 | sample[key] = np.array(sample[key], dtype=np.float32) 166 | 167 | return sample 168 | 169 | 170 | class normalize: 171 | def __init__(self, mean: Optional[list] = None, std: Optional[list] = None, div=255): 172 | self.mean = mean if mean is not None else 0.0 173 | self.std = std if std is not None else 1.0 174 | self.div = div 175 | 176 | def __call__(self, sample): 177 | if 'image' in sample.keys(): 178 | sample['image'] /= self.div 179 | sample['image'] -= self.mean 180 | sample['image'] /= self.std 181 | 182 | if 'image_resized' in sample.keys(): 183 | sample['image_resized'] /= self.div 184 | sample['image_resized'] -= self.mean 185 | sample['image_resized'] /= self.std 186 | 187 | if 'gt' in sample.keys(): 188 | sample['gt'] /= self.div 189 | 190 | if 'gt_resized' in sample.keys(): 191 | sample['gt_resized'] /= self.div 192 | 193 | return sample 194 | 195 | 196 | class totensor: 197 | def __init__(self): 198 | pass 199 | 200 | def __call__(self, sample): 201 | if 'image' in sample.keys(): 202 | sample['image'] = sample['image'].transpose((2, 0, 1)) 203 | sample['image'] = torch.from_numpy(sample['image']).float() 204 | 205 | if 'image_resized' in sample.keys(): 206 | sample['image_resized'] = sample['image_resized'].transpose((2, 0, 1)) 207 | sample['image_resized'] = torch.from_numpy(sample['image_resized']).float() 208 | 209 | if 'gt' in sample.keys(): 210 | sample['gt'] = torch.from_numpy(sample['gt']) 211 | sample['gt'] = sample['gt'].unsqueeze(dim=0) 212 | 213 | if 'gt_resized' in sample.keys(): 214 | sample['gt_resized'] = torch.from_numpy(sample['gt_resized']) 215 | sample['gt_resized'] = sample['gt_resized'].unsqueeze(dim=0) 216 | 217 | return sample 218 | 219 | 220 | def get_transform(tfs): 221 | comp = [] 222 | for key, value in zip(tfs.keys(), tfs.values()): 223 | if value is not None: 224 | tf = eval(key)(**value) 225 | else: 226 | tf = eval(key)() 227 | comp.append(tf) 228 | return transforms.Compose(comp) 229 | -------------------------------------------------------------------------------- /base_models/tcl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-columbia/viper/09fe3465224766860d8dd4ec48db942f22b05092/base_models/tcl/__init__.py -------------------------------------------------------------------------------- /base_models/tcl/tcl_config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 6, 20 | "encoder_width": 768 21 | } -------------------------------------------------------------------------------- /base_models/tcl/tcl_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used for TCL, from https://github.com/uta-smile/TCL/ 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from functools import partial 9 | 10 | from timm.models.vision_transformer import _cfg, PatchEmbed 11 | from timm.models.registry import register_model 12 | from timm.models.layers import trunc_normal_, DropPath 13 | 14 | 15 | class Mlp(nn.Module): 16 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 17 | """ 18 | 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 43 | self.scale = qk_scale or head_dim ** -0.5 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | self.attn_gradients = None 49 | self.attention_map = None 50 | 51 | def save_attn_gradients(self, attn_gradients): 52 | self.attn_gradients = attn_gradients 53 | 54 | def get_attn_gradients(self): 55 | return self.attn_gradients 56 | 57 | def save_attention_map(self, attention_map): 58 | self.attention_map = attention_map 59 | 60 | def get_attention_map(self): 61 | return self.attention_map 62 | 63 | def forward(self, x, register_hook=False): 64 | B, N, C = x.shape 65 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 66 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 67 | 68 | attn = (q @ k.transpose(-2, -1)) * self.scale 69 | attn = attn.softmax(dim=-1) 70 | attn = self.attn_drop(attn) 71 | 72 | if register_hook: 73 | self.save_attention_map(attn) 74 | attn.register_hook(self.save_attn_gradients) 75 | 76 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 77 | x = self.proj(x) 78 | x = self.proj_drop(x) 79 | return x 80 | 81 | 82 | class Block(nn.Module): 83 | 84 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 85 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 86 | super().__init__() 87 | self.norm1 = norm_layer(dim) 88 | self.attn = Attention( 89 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 90 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 91 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 92 | self.norm2 = norm_layer(dim) 93 | mlp_hidden_dim = int(dim * mlp_ratio) 94 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 95 | 96 | def forward(self, x, register_hook=False): 97 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 98 | x = x + self.drop_path(self.mlp(self.norm2(x))) 99 | return x 100 | 101 | 102 | class VisionTransformer(nn.Module): 103 | """ Vision Transformer 104 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 105 | https://arxiv.org/abs/2010.11929 106 | """ 107 | 108 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 109 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 110 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None): 111 | """ 112 | Args: 113 | img_size (int, tuple): input image size 114 | patch_size (int, tuple): patch size 115 | in_chans (int): number of input channels 116 | num_classes (int): number of classes for classification head 117 | embed_dim (int): embedding dimension 118 | depth (int): depth of transformer 119 | num_heads (int): number of attention heads 120 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 121 | qkv_bias (bool): enable bias for qkv if True 122 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 123 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 124 | drop_rate (float): dropout rate 125 | attn_drop_rate (float): attention dropout rate 126 | drop_path_rate (float): stochastic depth rate 127 | norm_layer: (nn.Module): normalization layer 128 | """ 129 | super().__init__() 130 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 131 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 132 | 133 | self.patch_embed = PatchEmbed( 134 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 135 | num_patches = self.patch_embed.num_patches 136 | 137 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 138 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 139 | self.pos_drop = nn.Dropout(p=drop_rate) 140 | 141 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 142 | self.blocks = nn.ModuleList([ 143 | Block( 144 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 145 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 146 | for i in range(depth)]) 147 | self.norm = norm_layer(embed_dim) 148 | 149 | trunc_normal_(self.pos_embed, std=.02) 150 | trunc_normal_(self.cls_token, std=.02) 151 | self.apply(self._init_weights) 152 | 153 | def _init_weights(self, m): 154 | if isinstance(m, nn.Linear): 155 | trunc_normal_(m.weight, std=.02) 156 | if isinstance(m, nn.Linear) and m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | elif isinstance(m, nn.LayerNorm): 159 | nn.init.constant_(m.bias, 0) 160 | nn.init.constant_(m.weight, 1.0) 161 | 162 | @torch.jit.ignore 163 | def no_weight_decay(self): 164 | return {'pos_embed', 'cls_token'} 165 | 166 | def forward(self, x, register_blk=-1): 167 | B = x.shape[0] 168 | x = self.patch_embed(x) 169 | 170 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 171 | x = torch.cat((cls_tokens, x), dim=1) 172 | 173 | x = x + self.pos_embed[:, :x.size(1), :] 174 | x = self.pos_drop(x) 175 | 176 | for i, blk in enumerate(self.blocks): 177 | x = blk(x, register_blk == i) 178 | x = self.norm(x) 179 | 180 | return x 181 | 182 | 183 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 184 | # interpolate position embedding 185 | embedding_size = pos_embed_checkpoint.shape[-1] 186 | num_patches = visual_encoder.patch_embed.num_patches 187 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 188 | # height (== width) for the checkpoint position embedding 189 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 190 | # height (== width) for the new position embedding 191 | new_size = int(num_patches ** 0.5) 192 | 193 | if orig_size != new_size: 194 | # class_token and dist_token are kept unchanged 195 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 196 | # only the position tokens are interpolated 197 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 198 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 199 | pos_tokens = torch.nn.functional.interpolate( 200 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 201 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 202 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 203 | # print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2)) 204 | 205 | return new_pos_embed 206 | else: 207 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /base_models/xvlm/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 6, 20 | "encoder_width": 1024 21 | } 22 | -------------------------------------------------------------------------------- /base_models/xvlm/vit.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from functools import partial 7 | 8 | from timm.models.vision_transformer import _cfg, PatchEmbed 9 | from timm.models.registry import register_model 10 | from timm.models.layers import trunc_normal_, DropPath 11 | 12 | 13 | class Mlp(nn.Module): 14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 15 | """ 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 40 | self.scale = qk_scale or head_dim ** -0.5 41 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | self.attn_gradients = None 46 | self.attention_map = None 47 | 48 | def save_attn_gradients(self, attn_gradients): 49 | self.attn_gradients = attn_gradients 50 | 51 | def get_attn_gradients(self): 52 | return self.attn_gradients 53 | 54 | def save_attention_map(self, attention_map): 55 | self.attention_map = attention_map 56 | 57 | def get_attention_map(self): 58 | return self.attention_map 59 | 60 | def forward(self, x, register_hook=False, image_atts=None): 61 | B, N, C = x.shape 62 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 63 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 64 | 65 | attn = (q @ k.transpose(-2, -1)) * self.scale 66 | 67 | if image_atts is not None: 68 | attn += image_atts 69 | 70 | attn = attn.softmax(dim=-1) 71 | attn = self.attn_drop(attn) 72 | 73 | if register_hook: 74 | self.save_attention_map(attn) 75 | attn.register_hook(self.save_attn_gradients) 76 | 77 | # attn: (bs, num_heads, num_patches, num_patches) 78 | # v: (bs, num_heads, num_patches, d) 79 | # attn @ v: (bs, num_heads, num_patches, d) 80 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 81 | x = self.proj(x) 82 | x = self.proj_drop(x) 83 | return x 84 | 85 | 86 | class Block(nn.Module): 87 | 88 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 89 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 90 | super().__init__() 91 | self.norm1 = norm_layer(dim) 92 | self.attn = Attention( 93 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 94 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 95 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 96 | self.norm2 = norm_layer(dim) 97 | mlp_hidden_dim = int(dim * mlp_ratio) 98 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 99 | 100 | def forward(self, x, register_hook=False, image_atts=None): 101 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook, image_atts=image_atts)) 102 | x = x + self.drop_path(self.mlp(self.norm2(x))) 103 | return x 104 | 105 | 106 | class VisionTransformer(nn.Module): 107 | """ Vision Transformer 108 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 109 | https://arxiv.org/abs/2010.11929 110 | """ 111 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 112 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 113 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, local_attn_depth=0): 114 | """ 115 | Args: 116 | img_size (int, tuple): input image size 117 | patch_size (int, tuple): patch size 118 | in_chans (int): number of input channels 119 | num_classes (int): number of classes for classification head 120 | embed_dim (int): embedding dimension 121 | depth (int): depth of transformer 122 | num_heads (int): number of attention heads 123 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 124 | qkv_bias (bool): enable bias for qkv if True 125 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 126 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 127 | drop_rate (float): dropout rate 128 | attn_drop_rate (float): attention dropout rate 129 | drop_path_rate (float): stochastic depth rate 130 | norm_layer: (nn.Module): normalization layer 131 | """ 132 | super().__init__() 133 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 134 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 135 | 136 | self.patch_embed = PatchEmbed( 137 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 138 | 139 | self.num_patch_embed = self.patch_embed.num_patches 140 | 141 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 142 | 143 | self.num_pos_embed = self.num_patch_embed + 1 144 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_pos_embed, embed_dim)) 145 | 146 | self.pos_drop = nn.Dropout(p=drop_rate) 147 | 148 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 149 | self.blocks = nn.ModuleList([ 150 | Block( 151 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 152 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 153 | for i in range(depth)]) 154 | 155 | self.depth = depth 156 | self.local_attn_depth = local_attn_depth # do local attn from index=(depth - local_attn_depth) 157 | 158 | self.norm = norm_layer(embed_dim) 159 | 160 | trunc_normal_(self.pos_embed, std=.02) 161 | trunc_normal_(self.cls_token, std=.02) 162 | self.apply(self._init_weights) 163 | 164 | def _init_weights(self, m): 165 | if isinstance(m, nn.Linear): 166 | trunc_normal_(m.weight, std=.02) 167 | if isinstance(m, nn.Linear) and m.bias is not None: 168 | nn.init.constant_(m.bias, 0) 169 | elif isinstance(m, nn.LayerNorm): 170 | nn.init.constant_(m.bias, 0) 171 | nn.init.constant_(m.weight, 1.0) 172 | 173 | @torch.jit.ignore 174 | def no_weight_decay(self): 175 | return {'pos_embed', 'cls_token'} 176 | 177 | def forward(self, x, register_blk=-1, idx_to_group_img=None, image_atts=None): 178 | 179 | B = x.shape[0] 180 | x = self.patch_embed(x) 181 | 182 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 183 | x = torch.cat((cls_tokens, x), dim=1) 184 | 185 | x = x + self.pos_embed[:,:x.size(1),:] 186 | x = self.pos_drop(x) 187 | 188 | do_gather = True if idx_to_group_img is not None else False 189 | 190 | if do_gather and (image_atts is not None): 191 | full_atts = torch.ones(x.shape[:2], dtype=x.dtype).to(x.device) 192 | image_atts_blk = torch.cat([image_atts, full_atts], dim=0) 193 | 194 | image_atts_blk = image_atts_blk.unsqueeze(1).unsqueeze(2) 195 | image_atts_blk = (1.0 - image_atts_blk) * -10000.0 196 | else: 197 | image_atts_blk = None 198 | 199 | for i, blk in enumerate(self.blocks): 200 | if (self.local_attn_depth > 0) and (i >= self.depth-self.local_attn_depth): 201 | if do_gather: 202 | do_gather = False 203 | 204 | x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2])) 205 | x = torch.cat([x_bs, x], dim=0) 206 | 207 | x = blk(x, register_blk == i, image_atts=image_atts_blk) 208 | 209 | else: 210 | x = blk(x, register_blk==i, image_atts=None) 211 | 212 | x = self.norm(x) 213 | 214 | if idx_to_group_img is not None: 215 | bs = len(idx_to_group_img) 216 | x_bs, x_fullatts = torch.split(x, [bs, x.size(0)-bs]) 217 | return x_bs, x_fullatts 218 | 219 | return x 220 | 221 | 222 | def interpolate_pos_embed(pos_embed_checkpoint, num_patches, num_extra_tokens=1): 223 | # num_patches = visual_encoder.num_patch_embed 224 | # num_extra_tokens = visual_encoder.num_pos_embed - visual_encoder.num_patch_embed 225 | 226 | # interpolate position embedding 227 | embedding_size = pos_embed_checkpoint.shape[-1] 228 | # height (== width) for the checkpoint position embedding 229 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 230 | # height (== width) for the new position embedding 231 | new_size = int(num_patches ** 0.5) 232 | 233 | if orig_size != new_size: 234 | # class_token and dist_token are kept unchanged 235 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 236 | # only the position tokens are interpolated 237 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 238 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 239 | pos_tokens = torch.nn.functional.interpolate( 240 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 241 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 242 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 243 | # print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2)) 244 | 245 | return new_pos_embed 246 | else: 247 | return pos_embed_checkpoint 248 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | 4 | # The default 5 | config_names = os.getenv('CONFIG_NAMES', None) 6 | if config_names is None: 7 | config_names = 'my_config' # Modify this if you want to use another default config 8 | 9 | configs = [OmegaConf.load('configs/base_config.yaml')] 10 | 11 | if config_names is not None: 12 | for config_name in config_names.split(','): 13 | configs.append(OmegaConf.load(f'configs/{config_name.strip()}.yaml')) 14 | 15 | # unsafe_merge makes the individual configs unusable, but it is faster 16 | config = OmegaConf.unsafe_merge(*configs) 17 | 18 | -------------------------------------------------------------------------------- /configs/base_config.yaml: -------------------------------------------------------------------------------- 1 | multiprocessing: False # Run the models and samples in parallel 2 | path_pretrained_models: './pretrained_models' # Path to the pretrained models 3 | execute_code: False # Execute the code after generating it. Only applies to main_batch 4 | 5 | dataset: # Dataset configuration 6 | dataset_name: 'MyDataset' # Dataset name 7 | data_path: 'data' # Dataset path 8 | split: '' # Dataset split. If '', it assumes there is only one split 9 | max_samples: # Maximum number of samples to load 10 | batch_size: 20 # Batch size 11 | start_sample: 0 # Start sample index. Only used if max_samples is not None 12 | 13 | load_models: # Which pretrained models to load 14 | maskrcnn: False 15 | clip: False 16 | glip: True 17 | owlvit: False 18 | tcl: False 19 | gpt3_qa: True 20 | gpt3_general: True 21 | depth: True 22 | blip: True 23 | saliency: False 24 | xvlm: True 25 | codex: True 26 | codellama: False 27 | 28 | detect_thresholds: # Thresholds for the models that perform detection 29 | glip: 0.5 30 | maskrcnn: 0.8 31 | owlvit: 0.1 32 | ratio_box_area_to_image_area: 0.0 # Any detected patch under this size will not be returned 33 | crop_larger_margin: True # Increase size of crop by 10% to include more context 34 | 35 | verify_property: # Parameters for verify_property 36 | model: xvlm # Model to use for verify_property 37 | thresh_clip: 0.6 38 | thresh_tcl: 0.25 39 | thresh_xvlm: 0.6 40 | 41 | best_match_model: xvlm # Which model to use for best_[image, text]_match 42 | 43 | gpt3: # GPT-3 configuration 44 | n_votes: 1 # Number of tries to use for GPT-3. Use with temperature > 0 45 | qa_prompt: ./prompts/gpt3/gpt3_qa.txt 46 | guess_prompt: ./prompts/gpt3/gpt3_process_guess.txt 47 | temperature: 0. # Temperature for GPT-3. Almost deterministic if 0 48 | model: text-davinci-003 # See openai.Model.list() for available models 49 | 50 | codex: 51 | temperature: 0. # Temperature for Codex. (Almost) deterministic if 0 52 | best_of: 1 # Number of tries to choose from. Use when temperature > 0 53 | max_tokens: 512 # Maximum number of tokens to generate for Codex 54 | prompt: ./prompts/chatapi.prompt # Codex prompt file, which defines the API. (doesn't support video for now due to token limits) 55 | model: gpt-3.5-turbo # Codex model to use. [code-davinci-002, gpt-3.5-turbo, gpt-4]. See openai.Model.list() for available models 56 | 57 | # Saving and loading parameters 58 | save: True # Save the results to a file 59 | save_new_results: True # If False, overwrite the results file 60 | results_dir: ./results/ # Directory to save the results 61 | use_cache: True # Use cache for the models that support it (now, GPT-3) 62 | clear_cache: False # Clear stored cache 63 | use_cached_codex: False # Use previously-computed Codex results 64 | cached_codex_path: '' # Path to the csv results file from which to load Codex results 65 | log_every: 20 # Log accuracy every n batches 66 | wandb: False # Use Weights and Biases 67 | 68 | blip_half_precision: True # Use 8bit (Faster but slightly less accurate) for BLIP if True 69 | blip_v2_model_type: blip2-flan-t5-xxl # Which model to use for BLIP-2 70 | 71 | use_fixed_code: False # Use a fixed code for all samples (do not generate with Codex) 72 | fixed_code_file: ./prompts/fixed_code/blip2.prompt # Path to the fixed code file -------------------------------------------------------------------------------- /configs/benchmarks/gqa.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | data_path: '/path/to/datasets/GQA' 3 | dataset_name: GQA 4 | split: testdev 5 | testing: False 6 | max_samples: 1000 7 | batch_size: 20 8 | start_sample: 0 9 | 10 | prompt : ./prompts/benchmarks/gqa.prompt 11 | results_dir : ./results/gqa/ 12 | 13 | load_models: 14 | maskrcnn: True 15 | clip: False 16 | glip: True 17 | owlvit: False 18 | tcl: False 19 | gpt3_list: False 20 | gpt3_qa: False 21 | gpt3_guess: False 22 | depth: False 23 | blip: True 24 | saliency: False 25 | xvlm: True 26 | 27 | fixed_code_file: ./prompts/fixed_code/blip2.prompt -------------------------------------------------------------------------------- /configs/benchmarks/nextqa.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | data_path: /path/to/datasets/NExTQA/ 3 | dataset_name: NExTQA 4 | version: multiplechoice 5 | fps: 1 6 | max_num_frames: 30 7 | split: val 8 | batch_size: 10 9 | max_samples: 100 10 | 11 | codex: 12 | prompt: ./prompts/benchmarks/nextqa.prompt 13 | select_answer_prompt: ./prompts/gpt3/video_question.txt 14 | fixed_code_file: ./prompts/fixed_code/blip2_video.prompt 15 | 16 | results_dir: ./results/nextqa/ 17 | 18 | gpt3: 19 | model: chatgpt 20 | -------------------------------------------------------------------------------- /configs/benchmarks/okvqa.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | data_path: /path/to/datasets/OKVQA 3 | dataset_name: OKVQA 4 | split: test 5 | testing: False 6 | max_samples: 7 | batch_size: 20 8 | start_sample: 0 9 | 10 | prompt: ./prompts/benchmarks/okvqa.prompt 11 | 12 | results_dir: ./results/okvqa/ 13 | 14 | load_models: 15 | maskrcnn: False 16 | clip: False 17 | glip: False 18 | owlvit: False 19 | tcl: False 20 | gpt3_list: True 21 | gpt3_qa: True 22 | gpt3_guess: True 23 | depth: False 24 | blip: True 25 | saliency: False 26 | xvlm: False 27 | 28 | fixed_code_file: ./prompts/fixed_code/blip2.prompt -------------------------------------------------------------------------------- /configs/benchmarks/refcoco.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | data_path: /path/to/datasets/refcoco 3 | dataset_name: RefCOCO 4 | split_by: unc # [google, unc] 5 | split: testA 6 | version: refcoco 7 | batch_size: 20 8 | 9 | codex: 10 | prompt: ./prompts/benchmarks/refcoco.prompt 11 | fixed_code_file: ./prompts/fixed_code/glip.prompt 12 | 13 | results_dir: ./results/refcoco/ 14 | 15 | load_models: 16 | gpt3_general: False 17 | 18 | ratio_box_area_to_image_area: 0.03 19 | crop_larger_margin: False 20 | 21 | 22 | -------------------------------------------------------------------------------- /configs/config_codellama.yaml: -------------------------------------------------------------------------------- 1 | codex: 2 | model: codellama 3 | codellama_model_name: /path/to/code_llama_models/CodeLlama-34b-Python-hf 4 | load_models: 5 | codex: False 6 | codellama: True -------------------------------------------------------------------------------- /configs/my_config.yaml: -------------------------------------------------------------------------------- 1 | # For example: 2 | multiprocessing: False 3 | path_pretrained_models: './pretrained_models' 4 | dataset: 5 | data_path: 'data' 6 | blip_v2_model_type: blip2-flan-t5-xxl # Change to blip2-flan-t5-xl for smaller GPUs 7 | blip_half_precision: True 8 | # Add more changes here, following the same format as base_config.yaml 9 | -------------------------------------------------------------------------------- /data/queries.csv: -------------------------------------------------------------------------------- 1 | index,sample_id,possible_answers,query_type,query,answer,image_name 2 | 0,0,purple,,What color do you get if you combine the colors of the viper and the flower?,purple,viper_flower.png 3 | 0,0,,,Tell me about the competition between the two skyscrapers in the image.,,skyscrapers.png 4 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data loaders 3 | Adapted in part from https://github.com/phiyodr/vqaloader/blob/master/vqaloader/loaders.py 4 | """ 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | 10 | # ----------------------------- General for all datasets ----------------------------- # 11 | def get_dataset(config_dataset): 12 | dataset_name = config_dataset.dataset_name 13 | 14 | if dataset_name == 'RefCOCO': 15 | from datasets.refcoco import RefCOCODataset 16 | dataset = RefCOCODataset(**config_dataset, 17 | image_transforms=transforms.Compose([transforms.ToTensor()])) 18 | elif dataset_name == 'GQA': 19 | from datasets.gqa import GQADataset 20 | dataset = GQADataset(**config_dataset, 21 | balanced=True, 22 | image_transforms=transforms.Compose([transforms.ToTensor()])) 23 | elif dataset_name == 'OKVQA': 24 | from datasets.okvqa import OKVQADataset 25 | dataset = OKVQADataset(**config_dataset, 26 | image_transforms=transforms.Compose([transforms.ToTensor()])) 27 | elif dataset_name == 'NExTQA': 28 | from datasets.nextqa import NExTQADataset 29 | dataset = NExTQADataset(**config_dataset) 30 | elif dataset_name == 'MyDataset': 31 | from datasets.my_dataset import MyDataset 32 | dataset = MyDataset(**config_dataset) 33 | else: 34 | raise ValueError(f"Unknown dataset {dataset_name}") 35 | return dataset 36 | 37 | 38 | def all_answers_from_dict(dct): 39 | return [x["answer"] for x in dct] 40 | 41 | 42 | def general_postprocessing(prediction): 43 | try: 44 | if type(prediction).__name__ == 'ImagePatch': 45 | prediction = prediction.classify_object() 46 | 47 | if isinstance(prediction, list): 48 | prediction = prediction[0] if len(prediction) > 0 else "no" 49 | 50 | if isinstance(prediction, torch.Tensor): 51 | prediction = prediction.item() 52 | if prediction is None: 53 | prediction = "no" 54 | if isinstance(prediction, bool): 55 | if prediction: 56 | prediction = "yes" 57 | else: 58 | prediction = "no" 59 | elif isinstance(prediction, int): 60 | prediction = str(prediction) 61 | print("No answer is a number, so this will be wrong") 62 | except: 63 | prediction = str(prediction) 64 | 65 | prediction = str(prediction) 66 | 67 | prediction = prediction.replace('\n', ' ') 68 | prediction = prediction.replace('\t', ' ') 69 | prediction = prediction.strip() 70 | prediction = prediction.lower() 71 | 72 | if prediction == 'true': 73 | prediction = 'yes' 74 | elif prediction == 'false': 75 | prediction = 'no' 76 | return prediction 77 | 78 | 79 | def accuracy(prediction, ground_truth, *args): 80 | """ 81 | Args: 82 | prediction (list): List of predicted answers. 83 | ground_truth (list): List of ground truth answers. 84 | Returns: 85 | score (float): Score of the prediction. 86 | """ 87 | if len(prediction) == 0: # if no prediction, return 0 88 | return 0 89 | assert len(prediction) == len(ground_truth) 90 | pred_gt_filtered = [(pred, gt) for pred, gt in zip(prediction, ground_truth) if gt != ''] 91 | score = 0 92 | for p, g in pred_gt_filtered: 93 | if general_postprocessing(p) == g: 94 | score += 1 95 | return score / len(pred_gt_filtered) 96 | -------------------------------------------------------------------------------- /datasets/gqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import re 4 | 5 | from PIL import Image 6 | import pandas as pd 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | 10 | from datasets import general_postprocessing 11 | 12 | 13 | class GQADataset(Dataset): 14 | BALANCED_TYPE = { 15 | True: "balanced", 16 | False: "all" 17 | } 18 | 19 | def __init__(self, split, balanced=True, data_path="", 20 | image_transforms=None, question_transforms=None, tokenize=None, 21 | verbose=False, testing=False, max_samples=None, first_n=None, return_pil=True): 22 | """ 23 | Args: 24 | split (str): Data split. One of ["challenge", "submission", "test", "testdev", "train", "val"] 25 | balanced (bool): You balanced version or full version. 26 | image_transforms: 27 | question_transforms: 28 | tokenize (fct): 29 | verbose (bool): Print some infos. Default=True 30 | testing (bool): Set to true for data splits without targets. Default=False. 31 | first_n (int): Only use the first n samples. Default=None. Only valid if loading from hdf. 32 | """ 33 | start_time = time.time() 34 | self.split = split 35 | self.testing = testing 36 | assert split in ["challenge", "submission", "test", "testdev", "train", "val"] 37 | self.balanced = balanced 38 | self.balanced_type = self.BALANCED_TYPE[balanced] 39 | self.data_path = data_path 40 | self.image_transforms = image_transforms 41 | self.question_transforms = question_transforms 42 | self.tokenize = tokenize 43 | self.input_type = 'image' 44 | self.return_pil = return_pil 45 | 46 | if not balanced and split == "train": 47 | raise NotImplementedError 48 | else: 49 | # check path to cached df exists 50 | if self.split == 'train' and self.balanced_type == 'balanced' and os.path.exists( 51 | os.path.join(data_path, f"questions/{self.split}_{self.balanced_type}_questions.h5")): 52 | if verbose: 53 | print(f"Loading GQA Dataset from {data_path}", flush=True) 54 | self.df = pd.read_hdf( 55 | os.path.join(data_path, f"questions/{self.split}_{self.balanced_type}_questions.h5"), "table", stop=first_n) 56 | else: 57 | self.file_name = f"questions/{self.split}_{self.balanced_type}_questions.json" 58 | path = os.path.expanduser(os.path.join(data_path, self.file_name)) 59 | if verbose: 60 | print(f"Loading GQA Dataset from {path}", flush=True) 61 | self.df = pd.read_json(path, orient="index") 62 | 63 | if max_samples is not None: 64 | self.df = self.df.sample(n=max_samples) 65 | 66 | self.n_samples = self.df.shape[0] 67 | if verbose: 68 | print( 69 | f"Loading GQA Dataset done in {time.time() - start_time:.1f} seconds. Loaded {self.n_samples} samples.") 70 | 71 | # For evaluation 72 | self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", 73 | "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", 74 | "didnt": "didn't","doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", 75 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", 76 | "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", 77 | "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", 78 | "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", 79 | "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", 80 | "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 81 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", 82 | "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", 83 | "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", 84 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", 85 | "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", 86 | "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", 87 | "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 88 | "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", 89 | "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", 90 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", 91 | "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", 92 | "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", 93 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", 94 | "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", 95 | "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", 96 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", 97 | "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", 98 | "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", 99 | "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", 100 | "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 101 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", 102 | "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", 103 | "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", 104 | "youre": "you're", "youve": "you've"} 105 | self.manualMap = {'none': '0', 106 | 'zero': '0', 107 | 'one': '1', 108 | 'two': '2', 109 | 'three': '3', 110 | 'four': '4', 111 | 'five': '5', 112 | 'six': '6', 113 | 'seven': '7', 114 | 'eight': '8', 115 | 'nine': '9', 116 | 'ten': '10' 117 | } 118 | self.articles = ['a', 119 | 'an', 120 | 'the' 121 | ] 122 | 123 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 124 | self.commaStrip = re.compile("(\d)(\,)(\d)") 125 | self.punct = [';', r"/", '[', ']', '"', '{', '}', 126 | '(', ')', '=', '+', '\\', '_', '-', 127 | '>', '<', '@', '`', ',', '?', '!'] 128 | 129 | self.max_words = 50 130 | 131 | def processPunctuation(self, inText): 132 | outText = inText 133 | for p in self.punct: 134 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): 135 | outText = outText.replace(p, '') 136 | else: 137 | outText = outText.replace(p, ' ') 138 | outText = self.periodStrip.sub("", outText, re.UNICODE) 139 | return outText 140 | 141 | def processDigitArticle(self, inText): 142 | outText = [] 143 | tempText = inText.lower().split() 144 | for word in tempText: 145 | word = self.manualMap.setdefault(word, word) 146 | if word not in self.articles: 147 | outText.append(word) 148 | else: 149 | pass 150 | for wordId, word in enumerate(outText): 151 | if word in self.contractions: 152 | outText[wordId] = self.contractions[word] 153 | outText = ' '.join(outText) 154 | return outText 155 | 156 | def get_img_path(self, index): 157 | if "imageId" in self.df.columns: 158 | image_id = self.df.iloc[index]["imageId"] 159 | else: 160 | image_id = self.df.iloc[index]["image_id"] 161 | return os.path.expanduser(os.path.join(self.data_path, "../images", f"{image_id}.jpg")) 162 | 163 | def get_index_from_sample_id(self, sample_id): 164 | return np.where(self.df.index == sample_id)[0][0].item() 165 | 166 | def __getitem__(self, index): 167 | # image input 168 | sample_id = self.df.iloc[index].name 169 | if "imageId" in self.df.columns: 170 | image_id = self.df.iloc[index]["imageId"] 171 | else: 172 | image_id = self.df.iloc[index]["image_id"] 173 | question = self.df.iloc[index]["question"] 174 | 175 | question_type = -1 176 | answer = -1 177 | if not self.testing: 178 | answer = self.df.iloc[index]["answer"] 179 | question_type = self.df.iloc[index]["groups"]["global"] 180 | if question_type is None: 181 | question_type = -1 # can't have None for DataLoader 182 | 183 | # Load and transform image 184 | image_path = os.path.expanduser(os.path.join(self.data_path, "images", f"{image_id}.jpg")) 185 | with open(image_path, "rb") as f: 186 | pil_img = Image.open(f).convert("RGB") 187 | if self.image_transforms: 188 | img = self.image_transforms(pil_img) 189 | else: 190 | img = pil_img 191 | 192 | # Load, transform and tokenize question 193 | if self.question_transforms: 194 | question = self.question_transforms(question) 195 | if self.tokenize: 196 | question = self.tokenize(question) 197 | 198 | # Return 199 | if self.testing: 200 | if (sample_id is None) or (img is None) or (question is None): 201 | raise Exception(f"Error in GQA Dataset: sample_id={sample_id}, img={img}, question={question}") 202 | out_dict = {"sample_id": sample_id, "img": img, "question": question, 'index': index} 203 | if self.return_pil: 204 | out_dict["pil_img"] = pil_img 205 | return out_dict 206 | else: 207 | out_dict = {"sample_id": sample_id, "answer": answer, "img": img, "question": question, 'pil_img': pil_img, 208 | "question_type": question_type, 'index': index, 'possible_answers': [], 209 | 'info_to_prompt': question} 210 | return out_dict 211 | 212 | def post_process(self, prediction, stem=True): 213 | """ 214 | Code from https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py, 215 | as indicated here https://okvqa.allenai.org/leaderboard.html 216 | :return: 217 | """ 218 | prediction = general_postprocessing(prediction) 219 | 220 | prediction = prediction.replace('\n', ' ') 221 | prediction = prediction.replace('\t', ' ') 222 | prediction = prediction.strip() 223 | prediction = self.processPunctuation(prediction) 224 | prediction = self.processDigitArticle(prediction) 225 | return prediction 226 | 227 | def accuracy(self, prediction, ground_truth, *args): 228 | """ 229 | Args: 230 | prediction (list): List of predicted answers. 231 | ground_truth (list): List of ground truth answers. 232 | Returns: 233 | score (float): Score of the prediction. 234 | """ 235 | if len(prediction) == 0: # if no prediction, return 0 236 | return 0 237 | assert len(prediction) == len(ground_truth) 238 | score = 0 239 | for p, g in zip(prediction, ground_truth): 240 | if self.post_process(p) == g: 241 | score += 1 242 | return score / len(prediction) 243 | 244 | # we can call len(dataset) to return the size 245 | def __len__(self): 246 | return self.n_samples 247 | -------------------------------------------------------------------------------- /datasets/my_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import decord 4 | from decord import cpu, gpu 5 | import numpy as np 6 | from PIL import Image 7 | import pandas as pd 8 | from torch.utils.data import Dataset 9 | 10 | from datasets import accuracy as general_accuracy 11 | 12 | 13 | class MyDataset(Dataset): 14 | def __init__(self, split, data_path="", input_type='image', image_transforms=None, fps=30, max_num_frames=30, 15 | max_samples=None, start_sample=0, **kwargs): 16 | """ 17 | Args: 18 | split (str): Data split. 19 | data_path (str): Path to the data folder 20 | input_type (str): Type of input. One of ["image", "video"] 21 | image_transforms (callable, optional): Optional transform to be applied on an image. Only used if input_type 22 | is "image". 23 | fps (int): Frames per second. Only used if input_type is "video". 24 | max_num_frames (int): Maximum number of frames to use. Only used if input_type is "video". 25 | max_samples (int, optional): Maximum number of samples to load. If None, load all samples. 26 | start_sample (int, optional): Index of the first sample to load. If None, start from the beginning. 27 | """ 28 | 29 | self.split = split 30 | self.data_path = Path(data_path) 31 | self.input_type = input_type 32 | self.image_transforms = image_transforms 33 | self.fps = fps 34 | self.max_num_frames = max_num_frames 35 | 36 | # Load questions, answers, and image ids 37 | with open(self.data_path / self.split / 'queries.csv', 'r') as f: 38 | # The csv has the rows [query, answer, image_name or video_name] 39 | self.df = pd.read_csv(f, index_col=None, keep_default_na=False) 40 | 41 | if max_samples is not None: 42 | self.df = self.df.iloc[start_sample:start_sample + max_samples] 43 | 44 | self.n_samples = len(self.df) 45 | 46 | def get_sample_path(self, index): 47 | sample_name = self.df.iloc[index][f"{self.input_type}_name"] 48 | sample_path = self.data_path / f"{self.input_type}s" / sample_name 49 | return sample_path 50 | 51 | def get_image(self, image_path): 52 | with open(image_path, "rb") as f: 53 | pil_image = Image.open(f).convert("RGB") 54 | if self.image_transforms: 55 | image = self.image_transforms(pil_image)[:3] 56 | else: 57 | image = pil_image 58 | return image 59 | 60 | def get_video(self, video_path): 61 | # If fixed width and height are required, VideoReader takes width and height as arguments. 62 | video_reader = decord.VideoReader(str(video_path), num_threads=1, ctx=cpu(0)) 63 | decord.bridge.set_bridge('torch') 64 | vlen = len(video_reader) 65 | original_fps = video_reader.get_avg_fps() 66 | num_frames = int(vlen * self.fps / original_fps) 67 | num_frames = min(self.max_num_frames, num_frames) 68 | frame_idxs = np.linspace(0, vlen, num_frames, endpoint=False).astype(np.int) 69 | video = video_reader.get_batch(frame_idxs).byte() 70 | video = video.permute(0, 3, 1, 2) 71 | return video 72 | 73 | def __getitem__(self, index): 74 | 75 | out_dict = self.df.iloc[index].to_dict() 76 | 77 | sample_path = self.get_sample_path(index) 78 | 79 | # Load and transform image 80 | image = self.get_image(sample_path) if self.input_type == "image" else self.get_video(sample_path) 81 | 82 | out_dict["image"] = image 83 | out_dict["index"] = index 84 | 85 | if 'extra_context' not in out_dict: 86 | out_dict['extra_context'] = '' 87 | 88 | return out_dict 89 | 90 | def __len__(self): 91 | return self.n_samples 92 | 93 | @classmethod 94 | def accuracy(cls, *args, **kwargs): 95 | return general_accuracy(*args, **kwargs) 96 | -------------------------------------------------------------------------------- /datasets/nextqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pandas as pd 5 | from torch.utils.data import Dataset 6 | import decord 7 | from decord import cpu, gpu 8 | import numpy as np 9 | import spacy 10 | 11 | from nltk.tokenize import word_tokenize 12 | from nltk.corpus import wordnet 13 | import numpy as np 14 | 15 | from pywsd.utils import lemmatize_sentence 16 | from collections import Counter 17 | 18 | 19 | def load_file(file_name): 20 | annos = None 21 | if os.path.splitext(file_name)[-1] == '.csv': 22 | return pd.read_csv(file_name) 23 | with open(file_name, 'r') as fp: 24 | if os.path.splitext(file_name)[1]== '.txt': 25 | annos = fp.readlines() 26 | annos = [line.rstrip() for line in annos] 27 | if os.path.splitext(file_name)[1] == '.json': 28 | annos = json.load(fp) 29 | 30 | return annos 31 | 32 | 33 | def save_file(obj, filename): 34 | """ 35 | save obj to filename 36 | :param obj: 37 | :param filename: 38 | :return: 39 | """ 40 | filepath = os.path.dirname(filename) 41 | if filepath != '' and not os.path.exists(filepath): 42 | os.makedirs(filepath) 43 | else: 44 | with open(filename, 'w') as fp: 45 | json.dump(obj, fp, indent=4) 46 | 47 | 48 | class NExTQADataset(Dataset): 49 | def __init__(self, split, data_path="", tokenize=None, max_samples=None, version='openended', fps=30, 50 | max_num_frames=30, start_sample=0, **kwargs): 51 | 52 | assert version in ['openended', 'multiplechoice'] 53 | directory = 'nextqa' if version == 'multiplechoice' else 'nextoe' 54 | 55 | self.split = split 56 | self.data_path = data_path 57 | self.tokenize = tokenize 58 | self.version = version 59 | self.fps = fps 60 | self.input_type = 'video' 61 | self.max_num_frames = max_num_frames 62 | 63 | sample_list_path = os.path.join(self.data_path, directory, f'{split}.csv') 64 | self.sample_list = load_file(sample_list_path) 65 | 66 | if max_samples is not None: 67 | # self.sample_list = self.sample_list.sample(n=max_samples) 68 | self.sample_list = self.sample_list[start_sample:start_sample+max_samples] 69 | 70 | self.sample_ids = self.sample_list.index 71 | self.sample_id_to_index = {sample_id: idx for idx, sample_id in enumerate(self.sample_ids)} 72 | 73 | self.video_to_dir = {} 74 | for directory in os.listdir(os.path.join(self.data_path, 'videos')): 75 | for video in os.listdir(os.path.join(self.data_path, 'videos', directory)): 76 | self.video_to_dir[video.split('.')[0]] = directory 77 | 78 | def get_sample_path(self, index): 79 | sample_id = self.sample_ids[index] 80 | cur_sample = self.sample_list.loc[sample_id] 81 | video_name = str(cur_sample['video']) 82 | video_path = os.path.join(self.data_path, 'videos', self.video_to_dir[video_name], video_name + '.mp4') 83 | return video_path 84 | 85 | def get_video(self, video_path): 86 | # If fixed width and height are required, VideoReader takes width and height as arguments. 87 | video_reader = decord.VideoReader(video_path, num_threads=1, ctx=cpu(0)) 88 | decord.bridge.set_bridge('torch') 89 | vlen = len(video_reader) 90 | original_fps = video_reader.get_avg_fps() 91 | num_frames = int(vlen * self.fps / original_fps) 92 | num_frames = min(self.max_num_frames, num_frames) 93 | frame_idxs = np.linspace(0, vlen, num_frames, endpoint=False).astype(np.int) 94 | video = video_reader.get_batch(frame_idxs).byte() 95 | video = video.permute(0, 3, 1, 2) 96 | return video 97 | 98 | def __getitem__(self, idx): 99 | sample_id = self.sample_ids[idx] 100 | cur_sample = self.sample_list.loc[sample_id] 101 | 102 | question = str(cur_sample['question']) 103 | if self.tokenize: 104 | question = self.tokenize(question) 105 | 106 | video_name = str(cur_sample['video']) 107 | video_path = os.path.join(self.data_path, 'videos', self.video_to_dir[video_name], video_name + '.mp4') 108 | video = self.get_video(video_path) 109 | 110 | if self.version == 'openended': 111 | answer = str(cur_sample['answer']) 112 | if self.tokenize: 113 | answer = self.tokenize(answer) 114 | possible_answers = '' 115 | else: # multiple choice 116 | answer_idx = int(cur_sample['answer']) 117 | possible_answers = [str(cur_sample[f'a{i}']) for i in range(5)] 118 | answer = possible_answers[answer_idx] 119 | 120 | query_type = str(cur_sample['type']) 121 | 122 | out_dict = {"sample_id": sample_id, "answer": answer, "image": video, "query": question, 'pil_img': -1, 123 | "query_type": query_type, 'index': idx, 'possible_answers': possible_answers, 124 | 'extra_context': possible_answers} 125 | 126 | return out_dict 127 | 128 | def __len__(self): 129 | return self.sample_list.shape[0] 130 | 131 | def get_index_from_sample_id(self, sample_id): 132 | return self.sample_id_to_index[sample_id] 133 | 134 | def get_img_path(self, index): 135 | sample_id = self.sample_ids[index] 136 | cur_sample = self.sample_list.loc[sample_id] 137 | video_name = str(cur_sample['video']) 138 | video_path = os.path.join(self.data_path, 'videos', self.video_to_dir[video_name], video_name + '.mp4') 139 | return video_path 140 | 141 | def accuracy(self, prediction, ground_truth, possible_answers, query_type): 142 | """ 143 | Args: 144 | prediction (list): List of predicted answers. 145 | ground_truth (list): List of ground truth answers. 146 | possible_answers (list): List of possible answers. 147 | query_type (list): List of query types 148 | Returns: 149 | score (float): Score of the prediction. 150 | """ 151 | 152 | assert len(prediction) == len(ground_truth) 153 | score = 0 154 | 155 | if self.version == 'openended': 156 | for p, g, qt in zip(prediction, ground_truth, query_type): 157 | if isinstance(p, list) or isinstance(p, tuple): 158 | p = p[0] # p[1] is the info dict 159 | if p is None: 160 | print('None case') 161 | p = 'object' # To select some word 162 | if qt == 'DC' or qt == 'DB': 163 | s = 1 if remove_stop(p) == remove_stop(g) else 0 164 | else: 165 | s = get_wups(remove_stop(p), remove_stop(g), 0) 166 | score += 100 * s 167 | else: 168 | nlp = spacy.load('en_core_web_lg') 169 | for p, g, a in zip(prediction, ground_truth, possible_answers): 170 | if isinstance(p, list) or isinstance(p, tuple): 171 | if len(p) == 2: 172 | p = p[0] # p[1] is the info dict 173 | else: # Multiple predictions 174 | all_answers = [] 175 | for pp in p: 176 | if pp not in a: 177 | pred_tokens = nlp(pp) 178 | a.sort(key=lambda x: pred_tokens.similarity(nlp(x)), reverse=True) 179 | pp = a[0] 180 | all_answers.append(pp) 181 | # Majority vote 182 | c = Counter(all_answers).most_common(1)[0] 183 | if c[1] == 1: 184 | # If no majority, select the middle one 185 | p = all_answers[1] 186 | else: 187 | p = c[0] 188 | if p not in a: 189 | if p is None: 190 | print('None case') # Should not happen 191 | else: 192 | pred_tokens = nlp(p) 193 | a.sort(key=lambda x: pred_tokens.similarity(nlp(x)), reverse=True) 194 | p = a[0] 195 | if p == g: 196 | score += 1 197 | return score / len(prediction) 198 | 199 | 200 | # Below is code from https://github.com/doc-doc/NExT-OE/blob/main/eval_oe.py 201 | 202 | stopwords = "i, me, my, myself, we, our, ours, ourselves, you, you're, you've, you'll, you'd, your, yours, yourself, " \ 203 | "yourselves, he, him, his, himself, she, she's, her, hers, herself, it, it's, its, itself, they, them, " \ 204 | "their, theirs, themselves, what, which, who, whom, this, that, that'll, these, those, am, is, are, was, " \ 205 | "were, be, been, being, have, has, had, having, do, does, did, doing, a, an, the, and, but, if, or, " \ 206 | "because, as, until, while, to, from, of, at, for, with, about, into, through, during, again, further, " \ 207 | "then, here, there, when, where, why, how, all, any, each, most, other, some, such, only, own, so, than, " \ 208 | "too, very, s, t, can, will, just, don, don't, should, should've, now, d, ll, m, o, re, ve, y, ain, " \ 209 | "aren, aren't, couldn, couldn't, didn, didn't, doesn, doesn't, hadn, hadn't, hasn, hasn't, haven, " \ 210 | "haven't, isn, isn't, ma, mightn, mightn't, mustn, mustn't, needn, needn't, shan, shan't, shouldn, " \ 211 | "shouldn't, wasn, wasn't, weren, weren't, won, won't, wouldn, wouldn't" 212 | 213 | 214 | def remove_stop(sentence): 215 | 216 | words = lemmatize_sentence(sentence) 217 | words = [w for w in words if not w in stopwords] 218 | return ' '.join(words) 219 | 220 | 221 | def wup(word1, word2, alpha): 222 | """ 223 | calculate the wup similarity 224 | :param word1: 225 | :param word2: 226 | :param alpha: 227 | :return: 228 | """ 229 | # print(word1, word2) 230 | if word1 == word2: 231 | return 1.0 232 | 233 | w1 = wordnet.synsets(word1) 234 | w1_len = len(w1) 235 | if w1_len == 0: return 0.0 236 | w2 = wordnet.synsets(word2) 237 | w2_len = len(w2) 238 | if w2_len == 0: return 0.0 239 | 240 | #match the first 241 | word_sim = w1[0].wup_similarity(w2[0]) 242 | if word_sim is None: 243 | word_sim = 0.0 244 | 245 | if word_sim < alpha: 246 | word_sim = 0.1*word_sim 247 | return word_sim 248 | 249 | 250 | def wups(words1, words2, alpha): 251 | """ 252 | :param pred: 253 | :param truth: 254 | :param alpha: 255 | :return: 256 | """ 257 | sim = 1.0 258 | flag = False 259 | for w1 in words1: 260 | max_sim = 0 261 | for w2 in words2: 262 | word_sim = wup(w1, w2, alpha) 263 | if word_sim > max_sim: 264 | max_sim = word_sim 265 | if max_sim == 0: continue 266 | sim *= max_sim 267 | flag = True 268 | if not flag: 269 | sim = 0.0 270 | return sim 271 | 272 | 273 | def get_wups(pred, truth, alpha): 274 | """ 275 | calculate the wups score 276 | :param pred: 277 | :param truth: 278 | :return: 279 | """ 280 | pred = word_tokenize(pred) 281 | truth = word_tokenize(truth) 282 | item1 = wups(pred, truth, alpha) 283 | item2 = wups(truth, pred, alpha) 284 | value = min(item1, item2) 285 | return value -------------------------------------------------------------------------------- /datasets/refcoco.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import json 7 | import pickle 8 | import itertools 9 | import torch 10 | from torchvision.ops import box_iou 11 | 12 | 13 | class RefCOCODataset(Dataset): 14 | """ 15 | Used code from https://github.com/lichengunc/refer/blob/master/refer.py 16 | """ 17 | def __init__(self, split, data_path="", image_transforms=None, question_transforms=None, tokenize=None, 18 | max_samples=None, version='refcoco', split_by='unc', **kwargs): 19 | 20 | self.split = split 21 | self.data_path = data_path 22 | self.max_samples = max_samples 23 | self.image_transforms = image_transforms 24 | self.question_transforms = question_transforms 25 | self.tokenize = tokenize 26 | self.input_type = 'image' 27 | 28 | assert version in ['refcoco', 'refcoco+', 'refcocog'] 29 | 30 | # load refs from data/dataset/refs(dataset).json 31 | ref_file = os.path.join(data_path, version, 'refs(' + split_by + ').p') 32 | with open(ref_file, 'rb') as f: 33 | self.refs = pickle.load(f) 34 | 35 | # load annotations from data/dataset/instances.json 36 | 37 | instances_file = os.path.join(data_path, version, 'instances.json') 38 | with open(instances_file, 'r') as f: 39 | instances = json.load(f) 40 | self.images = instances['images'] 41 | self.annotations = instances['annotations'] 42 | self.categories = instances['categories'] 43 | 44 | self.create_index() 45 | 46 | ref_ids = self.get_ref_ids(split=split) 47 | self.samples = [] 48 | for ref_id in ref_ids: 49 | ref = self.Refs[ref_id] 50 | for i in range(len(ref['sent_ids'])): 51 | self.samples.append((ref_id, i)) 52 | 53 | np.random.seed(4) 54 | np.random.shuffle(self.samples) 55 | 56 | if max_samples is not None: 57 | self.samples = self.samples[:max_samples] 58 | 59 | def create_index(self): 60 | # create sets of mapping 61 | # 1) Refs: {ref_id: ref} 62 | # 2) Anns: {ann_id: ann} 63 | # 3) Imgs: {image_id: image} 64 | # 4) Cats: {category_id: category_name} 65 | # 5) Sents: {sent_id: sent} 66 | # 6) imgToRefs: {image_id: refs} 67 | # 7) imgToAnns: {image_id: anns} 68 | # 8) refToAnn: {ref_id: ann} 69 | # 9) annToRef: {ann_id: ref} 70 | # 10) catToRefs: {category_id: refs} 71 | # 11) sentToRef: {sent_id: ref} 72 | # 12) sentToTokens: {sent_id: tokens} 73 | 74 | # fetch info from instances 75 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} 76 | for img in self.images: 77 | Imgs[img['id']] = img 78 | for ann in self.annotations: 79 | Anns[ann['id']] = ann 80 | height = Imgs[ann['image_id']]['height'] 81 | ann['bbox'] = [ann['bbox'][0], height-(ann['bbox'][1]+ann['bbox'][3]), ann['bbox'][2]+ann['bbox'][0], 82 | height-ann['bbox'][1]] 83 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann] 84 | for cat in self.categories: 85 | Cats[cat['id']] = cat['name'] 86 | 87 | # fetch info from refs 88 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} 89 | Sents, sentToRef, sentToTokens = {}, {}, {} 90 | for ref in self.refs: 91 | # ids 92 | ref_id = ref['ref_id'] 93 | ann_id = ref['ann_id'] 94 | category_id = ref['category_id'] 95 | image_id = ref['image_id'] 96 | 97 | # add mapping related to ref 98 | Refs[ref_id] = ref 99 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] 100 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] 101 | refToAnn[ref_id] = Anns[ann_id] 102 | annToRef[ann_id] = ref 103 | 104 | # add mapping of sent 105 | for sent in ref['sentences']: 106 | Sents[sent['sent_id']] = sent 107 | sentToRef[sent['sent_id']] = ref 108 | sentToTokens[sent['sent_id']] = sent['tokens'] 109 | 110 | # create class members 111 | self.Refs = Refs 112 | self.Anns = Anns 113 | self.Imgs = Imgs 114 | self.Cats = Cats 115 | self.Sents = Sents 116 | self.imgToRefs = imgToRefs 117 | self.imgToAnns = imgToAnns 118 | self.refToAnn = refToAnn 119 | self.annToRef = annToRef 120 | self.catToRefs = catToRefs 121 | self.sentToRef = sentToRef 122 | self.sentToTokens = sentToTokens 123 | 124 | def get_ref_ids(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): 125 | image_ids = image_ids if type(image_ids) == list else [image_ids] 126 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 127 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 128 | 129 | if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: 130 | refs = self.data['refs'] 131 | else: 132 | if not len(image_ids) == 0: 133 | refs = [self.imgToRefs[image_id] for image_id in image_ids] 134 | else: 135 | refs = self.refs 136 | if not len(cat_ids) == 0: 137 | refs = [ref for ref in refs if ref['category_id'] in cat_ids] 138 | if not len(ref_ids) == 0: 139 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids] 140 | if not len(split) == 0: 141 | if split in ['testA', 'testB', 'testC']: 142 | refs = [ref for ref in refs if 143 | split[-1] in ref['split']] # we also consider testAB, testBC, ... 144 | elif split in ['testAB', 'testBC', 'testAC']: 145 | refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess... 146 | elif split == 'test': 147 | refs = [ref for ref in refs if 'test' in ref['split']] 148 | elif split == 'train' or split == 'val': 149 | refs = [ref for ref in refs if ref['split'] == split] 150 | else: 151 | raise KeyError(f'No split {split}') 152 | ref_ids = [ref['ref_id'] for ref in refs] 153 | return ref_ids 154 | 155 | def get_ann_ids(self, image_ids=[], cat_ids=[], ref_ids=[]): 156 | image_ids = image_ids if type(image_ids) == list else [image_ids] 157 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 158 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 159 | 160 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: 161 | ann_ids = [ann['id'] for ann in self.annotations] 162 | else: 163 | if not len(image_ids) == 0: 164 | lists = [self.imgToAnns[image_id] for image_id in image_ids if 165 | image_id in self.imgToAnns] # list of [anns] 166 | anns = list(itertools.chain.from_iterable(lists)) 167 | else: 168 | anns = self.annotations 169 | if not len(cat_ids) == 0: 170 | anns = [ann for ann in anns if ann['category_id'] in cat_ids] 171 | ann_ids = [ann['id'] for ann in anns] 172 | if not len(ref_ids) == 0: 173 | ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) 174 | return ann_ids 175 | 176 | def get_img_ids(self, ref_ids=[]): 177 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 178 | 179 | if not len(ref_ids) == 0: 180 | image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) 181 | else: 182 | image_ids = self.Imgs.keys() 183 | return image_ids 184 | 185 | def load_refs(self, ref_ids=[]): 186 | if type(ref_ids) == list: 187 | return [self.Refs[ref_id] for ref_id in ref_ids] 188 | elif type(ref_ids) == int: 189 | return [self.Refs[ref_ids]] 190 | 191 | def get_index_from_sample_id(self, sample_id): 192 | return sample_id 193 | 194 | def get_sample_path(self, index=None, ref=None): 195 | if ref is None: 196 | assert index is not None 197 | ref_id, i = self.samples[index] 198 | ref = self.load_refs(ref_id)[0] 199 | 200 | file_name = '_'.join(ref['file_name'].split('_')[:-1]) + '.' + ref['file_name'].split('.')[-1] 201 | coco_split = file_name.split('_')[1] 202 | 203 | img_path = os.path.join(self.data_path, 'mscoco', coco_split, file_name) 204 | return img_path 205 | 206 | def __getitem__(self, index): 207 | ref_id, i = self.samples[index] 208 | ref = self.load_refs(ref_id)[0] 209 | 210 | img_path = self.get_sample_path(ref=ref) 211 | 212 | with open(img_path, "rb") as f: 213 | pil_img = Image.open(f).convert("RGB") 214 | if self.image_transforms: 215 | img = self.image_transforms(pil_img) 216 | else: 217 | img = pil_img 218 | 219 | # There are different texts associated to every image 220 | text = ref['sentences'][i]['sent'] 221 | 222 | answer = self.refToAnn[ref_id]['bbox'] 223 | 224 | return {'query': text, 'image': img, 'sample_id': index, 'answer': answer, 'index': index, 225 | 'possible_answers': [], 'info_to_prompt': text, "query_type": -1, 'extra_context': ''} 226 | 227 | def __len__(self): 228 | return len(self.samples) 229 | 230 | @classmethod 231 | def accuracy(cls, prediction, ground_truth, *args): 232 | """ 233 | Compute IoU score 234 | Args: 235 | prediction (list): List of predicted answers. 236 | ground_truth (list): List of ground truth answers. 237 | Returns: 238 | score (float): Score of the prediction. It is an IoU score 239 | """ 240 | assert len(prediction) == len(ground_truth) 241 | num_samples = 0 242 | iou = 0 243 | acc = 0 244 | for p, g in zip(prediction, ground_truth): 245 | try: 246 | if p is None: 247 | # Average bounding box 248 | p = torch.tensor([50.9, 39.1, 493.5, 356.5])[None] # Mean IoU is 22.64% 249 | else: 250 | if type(p) == list: 251 | p = torch.tensor(p)[None] 252 | elif type(p) == str: 253 | p = torch.tensor([float(x) for x in p.split('(')[1].split(')')[0].split(',')])[None] 254 | else: 255 | p = torch.tensor([p.left, p.lower, p.right, p.upper])[None] 256 | if type(g) == str: 257 | g = [float(x) for x in g.split('[')[1].split(']')[0].split(',')] 258 | g = torch.tensor([g[0], g[1], g[2], g[3]])[None] 259 | iou_ = box_iou(p, g).item() # Expects (x1, y1, x2, y2) format. So (left, lower, right, upper) 260 | iou += iou_ 261 | if iou_ > 0.7: 262 | acc += 1 263 | except Exception as e: 264 | pass # If the prediction is not a box, we consider iou = 0 265 | num_samples += 1 266 | return iou / max(num_samples, 1), acc / max(num_samples, 1) 267 | 268 | -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # change this to your preferred download location 4 | PRETRAINED_MODELS_PATH=./pretrained_models 5 | 6 | # GLIP model 7 | mkdir -p $PRETRAINED_MODELS_PATH/GLIP/checkpoints 8 | mkdir -p $PRETRAINED_MODELS_PATH/GLIP/configs 9 | wget -nc -P $PRETRAINED_MODELS_PATH/GLIP/checkpoints https://huggingface.co/GLIPModel/GLIP/resolve/main/glip_large_model.pth 10 | wget -nc -P $PRETRAINED_MODELS_PATH/GLIP/configs https://raw.githubusercontent.com/microsoft/GLIP/main/configs/pretrain/glip_Swin_L.yaml 11 | 12 | # X-VLM model 13 | mkdir -p $PRETRAINED_MODELS_PATH/xvlm 14 | gdown "https://drive.google.com/u/0/uc?id=1bv6_pZOsXW53EhlwU0ZgSk03uzFI61pN" -O $PRETRAINED_MODELS_PATH/xvlm/retrieval_mscoco_checkpoint_9.pth 15 | 16 | # TCL model 17 | mkdir -p $PRETRAINED_MODELS_PATH/TCL 18 | gdown "https://drive.google.com/uc?id=1Cb1azBdcdbm0pRMFs-tupKxILTCXlB4O" -O $PRETRAINED_MODELS_PATH/TCL/TCL_4M.pth 19 | 20 | # InSPyReNet model 21 | mkdir -p $PRETRAINED_MODELS_PATH/saliency_inspyrenet_plus_ultra 22 | gdown "https://drive.google.com/uc?id=13oBl5MTVcWER3YU4fSxW3ATlVfueFQPY" -O $PRETRAINED_MODELS_PATH/saliency_inspyrenet_plus_ultra/latest.pth 23 | -------------------------------------------------------------------------------- /main_batch.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | import pathlib 5 | from functools import partial 6 | import warnings 7 | import traceback 8 | 9 | 10 | import pandas as pd 11 | import torch.multiprocessing as mp 12 | from joblib import Memory 13 | from num2words import num2words 14 | import numpy as np 15 | from omegaconf import OmegaConf 16 | from rich.console import Console 17 | from torch.utils.data import DataLoader 18 | from tqdm import tqdm 19 | 20 | from configs import config 21 | from utils import seed_everything 22 | import datasets 23 | 24 | # See https://github.com/pytorch/pytorch/issues/11201, https://github.com/pytorch/pytorch/issues/973 25 | # Not for dataloader, but for multiprocessing batches 26 | mp.set_sharing_strategy('file_system') 27 | queue_results = None 28 | 29 | cache = Memory('cache/' if config.use_cache else None, verbose=0) 30 | runs_dict = {} 31 | seed_everything() 32 | console = Console(highlight=False) 33 | 34 | 35 | def my_collate(batch): 36 | # Avoid stacking images (different size). Return everything as a list 37 | to_return = {k: [d[k] for d in batch] for k in batch[0].keys()} 38 | return to_return 39 | 40 | 41 | def run_program(parameters, queues_in_, input_type_, retrying=False): 42 | from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno 43 | from video_segment import VideoSegment 44 | 45 | global queue_results 46 | 47 | code, sample_id, image, possible_answers, query = parameters 48 | 49 | code_header = f'def execute_command_{sample_id}(' \ 50 | f'{input_type_}, possible_answers, query, ' \ 51 | f'ImagePatch, VideoSegment, ' \ 52 | 'llm_query, bool_to_yesno, distance, best_image_match):\n' \ 53 | f' # Answer is:' 54 | code = code_header + code 55 | 56 | try: 57 | exec(compile(code, 'Codex', 'exec'), globals()) 58 | except Exception as e: 59 | print(f'Sample {sample_id} failed at compilation time with error: {e}') 60 | try: 61 | with open(config.fixed_code_file, 'r') as f: 62 | fixed_code = f.read() 63 | code = code_header + fixed_code 64 | exec(compile(code, 'Codex', 'exec'), globals()) 65 | except Exception as e2: 66 | print(f'Not even the fixed code worked. Sample {sample_id} failed at compilation time with error: {e2}') 67 | return None, code 68 | 69 | queues = [queues_in_, queue_results] 70 | 71 | image_patch_partial = partial(ImagePatch, queues=queues) 72 | video_segment_partial = partial(VideoSegment, queues=queues) 73 | llm_query_partial = partial(llm_query, queues=queues) 74 | 75 | try: 76 | result = globals()[f'execute_command_{sample_id}']( 77 | # Inputs to the function 78 | image, possible_answers, query, 79 | # Classes to be used 80 | image_patch_partial, video_segment_partial, 81 | # Functions to be used 82 | llm_query_partial, bool_to_yesno, distance, best_image_match) 83 | except Exception as e: 84 | # print full traceback 85 | traceback.print_exc() 86 | if retrying: 87 | return None, code 88 | print(f'Sample {sample_id} failed with error: {e}. Next you will see an "expected an indented block" error. ') 89 | # Retry again with fixed code 90 | new_code = "[" # This code will break upon execution, and it will be caught by the except clause 91 | result = run_program((new_code, sample_id, image, possible_answers, query), queues_in_, input_type_, 92 | retrying=True)[0] 93 | 94 | # The function run_{sample_id} is defined globally (exec doesn't work locally). A cleaner alternative would be to 95 | # save it in a global dict (replace globals() for dict_name in exec), but then it doesn't detect the imported 96 | # libraries for some reason. Because defining it globally is not ideal, we just delete it after running it. 97 | if f'execute_command_{sample_id}' in globals(): 98 | del globals()[f'execute_command_{sample_id}'] # If it failed to compile the code, it won't be defined 99 | return result, code 100 | 101 | 102 | def worker_init(queue_results_): 103 | global queue_results 104 | index_queue = mp.current_process()._identity[0] % len(queue_results_) 105 | queue_results = queue_results_[index_queue] 106 | 107 | 108 | def main(): 109 | mp.set_start_method('spawn') 110 | 111 | from vision_processes import queues_in, finish_all_consumers, forward, manager 112 | from datasets import get_dataset 113 | 114 | batch_size = config.dataset.batch_size 115 | num_processes = min(batch_size, 50) 116 | 117 | if config.multiprocessing: 118 | queue_results_main = manager.Queue() 119 | queues_results = [manager.Queue() for _ in range(batch_size)] 120 | else: 121 | queue_results_main = None 122 | queues_results = [None for _ in range(batch_size)] 123 | 124 | model_name_codex = 'codellama' if config.codex.model == 'codellama' else 'codex' 125 | codex = partial(forward, model_name=model_name_codex, queues=[queues_in, queue_results_main]) 126 | 127 | if config.clear_cache: 128 | cache.clear() 129 | 130 | if config.wandb: 131 | import wandb 132 | wandb.init(project="viper", config=OmegaConf.to_container(config)) 133 | # log the prompt file 134 | wandb.save(config.codex.prompt) 135 | 136 | dataset = get_dataset(config.dataset) 137 | 138 | with open(config.codex.prompt) as f: 139 | base_prompt = f.read().strip() 140 | 141 | codes_all = None 142 | if config.use_cached_codex: 143 | results = pd.read_csv(config.cached_codex_path) 144 | codes_all = [r.split('# Answer is:')[1] for r in results['code']] 145 | # python -c "from joblib import Memory; cache = Memory('cache/', verbose=0); cache.clear()" 146 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, 147 | collate_fn=my_collate) 148 | input_type = dataset.input_type 149 | 150 | all_results = [] 151 | all_answers = [] 152 | all_codes = [] 153 | all_ids = [] 154 | all_queries = [] 155 | all_img_paths = [] 156 | all_possible_answers = [] 157 | all_query_types = [] 158 | 159 | with mp.Pool(processes=num_processes, initializer=worker_init, initargs=(queues_results,)) \ 160 | if config.multiprocessing else open(os.devnull, "w") as pool: 161 | try: 162 | n_batches = len(dataloader) 163 | 164 | for i, batch in tqdm(enumerate(dataloader), total=n_batches): 165 | 166 | # Combine all queries and get Codex predictions for them 167 | # TODO compute Codex for next batch as current batch is being processed 168 | 169 | if not config.use_cached_codex: 170 | codes = codex(prompt=batch['query'], base_prompt=base_prompt, input_type=input_type, 171 | extra_context=batch['extra_context']) 172 | 173 | else: 174 | codes = codes_all[i * batch_size:(i + 1) * batch_size] # If cache 175 | 176 | # Run the code 177 | if config.execute_code: 178 | if not config.multiprocessing: 179 | # Otherwise, we would create a new model for every process 180 | results = [] 181 | for c, sample_id, img, possible_answers, query in \ 182 | zip(codes, batch['sample_id'], batch['image'], batch['possible_answers'], batch['query']): 183 | result = run_program([c, sample_id, img, possible_answers, query], queues_in, input_type) 184 | results.append(result) 185 | else: 186 | results = list(pool.imap(partial( 187 | run_program, queues_in_=queues_in, input_type_=input_type), 188 | zip(codes, batch['sample_id'], batch['image'], batch['possible_answers'], batch['query']))) 189 | else: 190 | results = [(None, c) for c in codes] 191 | warnings.warn("Not executing code! This is only generating the code. We set the flag " 192 | "'execute_code' to False by default, because executing code generated by a language " 193 | "model can be dangerous. Set the flag 'execute_code' to True if you want to execute " 194 | "it.") 195 | 196 | all_results += [r[0] for r in results] 197 | all_codes += [r[1] for r in results] 198 | all_ids += batch['sample_id'] 199 | all_answers += batch['answer'] 200 | all_possible_answers += batch['possible_answers'] 201 | all_query_types += batch['query_type'] 202 | all_queries += batch['query'] 203 | all_img_paths += [dataset.get_sample_path(idx) for idx in batch['index']] 204 | if i % config.log_every == 0: 205 | try: 206 | accuracy = dataset.accuracy(all_results, all_answers, all_possible_answers, all_query_types) 207 | console.print(f'Accuracy at Batch {i}/{n_batches}: {accuracy}') 208 | except Exception as e: 209 | console.print(f'Error computing accuracy: {e}') 210 | 211 | except Exception as e: 212 | # print full stack trace 213 | traceback.print_exc() 214 | console.print(f'Exception: {e}') 215 | console.print("Completing logging and exiting...") 216 | 217 | try: 218 | accuracy = dataset.accuracy(all_results, all_answers, all_possible_answers, all_query_types) 219 | console.print(f'Final accuracy: {accuracy}') 220 | except Exception as e: 221 | print(f'Error computing accuracy: {e}') 222 | 223 | if config.save: 224 | results_dir = pathlib.Path(config['results_dir']) 225 | results_dir = results_dir / config.dataset.split 226 | results_dir.mkdir(parents=True, exist_ok=True) 227 | if not config.save_new_results: 228 | filename = 'results.csv' 229 | else: 230 | existing_files = list(results_dir.glob('results_*.csv')) 231 | if len(existing_files) == 0: 232 | filename = 'results_0.csv' 233 | else: 234 | filename = 'results_' + str(max([int(ef.stem.split('_')[-1]) for ef in existing_files if 235 | str.isnumeric(ef.stem.split('_')[-1])]) + 1) + '.csv' 236 | print('Saving results to', filename) 237 | df = pd.DataFrame([all_results, all_answers, all_codes, all_ids, all_queries, all_img_paths, 238 | all_possible_answers]).T 239 | df.columns = ['result', 'answer', 'code', 'id', 'query', 'img_path', 'possible_answers'] 240 | # make the result column a string 241 | df['result'] = df['result'].apply(str) 242 | df.to_csv(results_dir / filename, header=True, index=False, encoding='utf-8') 243 | # torch.save([all_results, all_answers, all_codes, all_ids, all_queries, all_img_paths], results_dir/filename) 244 | 245 | if config.wandb: 246 | wandb.log({'accuracy': accuracy}) 247 | wandb.log({'results': wandb.Table(dataframe=df, allow_mixed_types=True)}) 248 | 249 | finish_all_consumers() 250 | 251 | 252 | if __name__ == '__main__': 253 | main() 254 | -------------------------------------------------------------------------------- /main_simple_lib.py: -------------------------------------------------------------------------------- 1 | # General imports and variables, as well as config 2 | import ast 3 | import math 4 | import sys 5 | import time 6 | 7 | import requests 8 | import torch.multiprocessing as mp 9 | from joblib import Memory 10 | from rich.console import Console 11 | from rich.live import Live 12 | from rich.padding import Padding 13 | from rich.pretty import pprint 14 | from rich.prompt import Prompt 15 | from rich.syntax import Syntax 16 | from rich import print 17 | from rich.markup import escape as rich_escape 18 | 19 | from IPython.display import update_display, clear_output, display 20 | from PIL import Image 21 | import matplotlib.pyplot as plt 22 | 23 | from configs import config 24 | from utils import show_single_image 25 | 26 | from IPython.display import update_display, clear_output 27 | from IPython.core.display import HTML 28 | 29 | cache = Memory('cache/' if config.use_cache else None, verbose=0) 30 | 31 | mp.set_start_method('spawn', force=True) 32 | from vision_processes import forward, finish_all_consumers # This import loads all the models. May take a while 33 | from image_patch import * 34 | from video_segment import * 35 | from datasets.dataset import MyDataset 36 | 37 | console = Console(highlight=False, force_terminal=False) 38 | 39 | time_wait_between_lines = 0.5 40 | 41 | 42 | def inject_saver(code, show_intermediate_steps, syntax=None, time_wait_between_lines=None, console=None): 43 | injected_function_name = 'show_all' 44 | if injected_function_name in code: 45 | return code 46 | code = code.split("\n") 47 | newcode = [] 48 | for n, codeline in enumerate(code): 49 | codeline, indent = split_codeline_and_indent_level(codeline) 50 | 51 | if codeline.startswith('#') or codeline == '': # this will cause issues if you have lots of comment lines 52 | continue 53 | if '#' in codeline: 54 | codeline = codeline.split('#')[0] 55 | 56 | thing_to_show, code_type = get_thing_to_show_codetype(codeline) 57 | 58 | if code_type in ('assign', 'append', 'if', 'return', 'for', 'sort', 'add'): 59 | if '\'' in codeline: 60 | codeline.replace('\'', '\\\'') 61 | 62 | if show_intermediate_steps: 63 | escape_thing = lambda x: x.replace("'", "\\'") 64 | injection_string_format = \ 65 | lambda \ 66 | thing: f"{indent}{injected_function_name}(lineno={n},value=({thing}),valuename='{escape_thing(thing)}'," \ 67 | f"fig=my_fig,console_in=console,time_wait_between_lines=time_wait_between_lines); " \ 68 | f"CodexAtLine({n},syntax=syntax,time_wait_between_lines=time_wait_between_lines)" 69 | else: 70 | injection_string_format = lambda thing: f"{indent}CodexAtLine({n},syntax=syntax," \ 71 | f"time_wait_between_lines=time_wait_between_lines)" 72 | 73 | extension_list = [] 74 | if isinstance(thing_to_show, list): 75 | injection_string_list = [injection_string_format(f"{thing}") for thing in thing_to_show] 76 | extension_list.extend(injection_string_list) 77 | elif code_type == 'for': 78 | injection_string = injection_string_format(f"{thing_to_show}") 79 | injection_string = " " * 4 + injection_string 80 | extension_list.append(injection_string) 81 | else: 82 | extension_list.append(injection_string_format(f"{thing_to_show}")) 83 | 84 | if code_type in ('if', 'return'): 85 | extension_list = extension_list + [f"{indent}{codeline}"] 86 | else: 87 | extension_list = [f"{indent}{codeline}"] + extension_list 88 | 89 | newcode.extend(extension_list) 90 | 91 | elif code_type == 'elif_else': 92 | newcode.append(f"{indent}{codeline}") 93 | else: 94 | newcode.append(f"{indent}{codeline}") 95 | return "\n".join(newcode) 96 | 97 | 98 | def get_thing_to_show_codetype(codeline): 99 | # can output either a list of things to show, or a single thing to show 100 | things_to_show = [] 101 | if codeline.startswith("if"): 102 | condition, rest = codeline[3:].split(":", 1) 103 | codeline = f"if {condition}:{rest}" 104 | code_type = "if" 105 | 106 | operators = ['==', '!=', '>=', '<=', '>', '<'] 107 | things_to_show = [] 108 | for op in operators: 109 | if op in condition: 110 | things_to_show = [x.strip() for x in condition.split(op)] 111 | # print(things_to_show) 112 | break 113 | # things_to_show.append(thing_to_show) 114 | thing_to_show = things_to_show + [condition.strip()] 115 | 116 | elif codeline.startswith("for"): 117 | code_type = 'for' 118 | thing_to_show = codeline.split("for ")[1].split(" in ")[0] 119 | 120 | elif codeline.startswith("return"): 121 | thing_to_show = codeline.split("return ")[1] 122 | code_type = 'return' 123 | 124 | elif ' = ' in codeline: 125 | code_type = 'assign' 126 | thing_to_show = codeline.split(' = ')[0] 127 | elif ' += ' in codeline: 128 | code_type = 'assign' 129 | thing_to_show = codeline.split(' += ')[0] 130 | elif ' -= ' in codeline: 131 | code_type = 'assign' 132 | thing_to_show = codeline.split(' -= ')[0] 133 | elif ' *= ' in codeline: 134 | code_type = 'assign' 135 | thing_to_show = codeline.split(' *= ')[0] 136 | elif ' /= ' in codeline: 137 | code_type = 'assign' 138 | thing_to_show = codeline.split(' /= ')[0] 139 | 140 | elif '.append(' in codeline: 141 | code_type = 'append' 142 | thing_to_show = codeline.split('.append(')[0] + '[-1]' 143 | elif '.add(' in codeline: 144 | code_type = 'add' 145 | thing_to_show = codeline.split('.add(')[0] 146 | 147 | elif '.sort(' in codeline: 148 | code_type = 'sort' 149 | thing_to_show = codeline.split('.sort(')[0] 150 | 151 | elif codeline.startswith("elif") or codeline.startswith("else"): 152 | thing_to_show = None 153 | code_type = 'elif_else' 154 | else: 155 | thing_to_show = None 156 | code_type = 'other' 157 | 158 | if isinstance(thing_to_show, list): 159 | thing_to_show = [thing if not (thing.strip().startswith("'") and thing.strip().endswith("'")) 160 | else thing.replace("'", '"') for thing in thing_to_show if thing is not None] 161 | elif isinstance(thing_to_show, str): 162 | thing_to_show = thing_to_show if not (thing_to_show.strip().startswith("'") and 163 | thing_to_show.strip().endswith("'")) else thing_to_show.replace("'", '"') 164 | return thing_to_show, code_type 165 | 166 | 167 | def split_codeline_and_indent_level(codeline): 168 | origlen = len(codeline) 169 | codeline = codeline.lstrip() 170 | indent = origlen - len(codeline) 171 | indent = " " * indent 172 | return codeline, indent 173 | 174 | 175 | def show_one_image(image, ax): 176 | if isinstance(image, torch.Tensor): 177 | image = image.detach().cpu() 178 | if image.dtype == torch.float32: 179 | image = image.clamp(0, 1) 180 | image = image.squeeze(0).permute(1, 2, 0) 181 | ax.imshow(image) 182 | 183 | 184 | def CodexAtLine(lineno, syntax, time_wait_between_lines=1.): 185 | syntax._stylized_ranges = [] 186 | syntax.stylize_range('on red', (lineno + 1, 0), (lineno + 1, 80)) 187 | time.sleep(time_wait_between_lines) 188 | 189 | 190 | def show_all(lineno, value, valuename, fig=None, usefig=True, disp=True, console_in=None, time_wait_between_lines=None, 191 | lastlineno=[-1]): 192 | time.sleep(0.1) # to avoid race condition! 193 | 194 | if console_in is None: 195 | console_in = console 196 | 197 | thing_to_show = value 198 | 199 | if lineno is not None and lineno != lastlineno[0]: 200 | console_in.rule(f"[bold]Line {lineno}[/bold]", style="chartreuse2") 201 | lastlineno[0] = lineno # ugly hack 202 | 203 | if usefig: 204 | plt.clf() 205 | ax = fig.add_axes([0, 0, 1, 1]) 206 | ax.set_xticks([]) 207 | ax.set_yticks([]) 208 | if isinstance(thing_to_show, Image.Image): 209 | if valuename: 210 | console_in.print(f'{rich_escape(valuename)} = ') 211 | show_one_image(thing_to_show, ax) 212 | elif str(type(thing_to_show)) == "": 213 | if valuename: 214 | console_in.print(f'{rich_escape(valuename)} = ') 215 | show_one_image(thing_to_show.cropped_image, ax) 216 | elif isinstance(thing_to_show, list) or isinstance(thing_to_show, tuple): 217 | if len(thing_to_show) > 0: 218 | for i, thing in enumerate(thing_to_show): 219 | disp_ = disp or i < len(thing_to_show) - 1 220 | show_all(None, thing, f"{rich_escape(valuename)}[{i}]", fig=fig, disp=disp_, usefig=usefig) 221 | return 222 | else: 223 | console_in.print(f"{rich_escape(valuename)} is empty") 224 | elif isinstance(thing_to_show, dict): 225 | if len(thing_to_show) > 0: 226 | for i, (thing_k, thing_v) in enumerate(thing_to_show.items()): 227 | disp_ = disp or i < len(thing_to_show) - 1 228 | show_all(None, thing_v, f"{rich_escape(valuename)}['{thing_k}']", fig=fig, disp=disp_, usefig=usefig) 229 | return 230 | else: 231 | console_in.print(f"{rich_escape(valuename)} is empty") 232 | else: 233 | console_in.print(f"{rich_escape(valuename)} = {thing_to_show}") 234 | if time_wait_between_lines is not None: 235 | time.sleep(time_wait_between_lines / 2) 236 | return 237 | 238 | # display small 239 | if usefig: 240 | fig.set_size_inches(2, 2) 241 | if disp: 242 | display(fig) 243 | 244 | 245 | def load_image(path): 246 | if path.startswith("http://") or path.startswith("https://"): 247 | image = Image.open(requests.get(path, stream=True).raw).convert('RGB') 248 | image = transforms.ToTensor()(image) 249 | else: 250 | image = Image.open(path) 251 | image = transforms.ToTensor()(image) 252 | return image 253 | 254 | 255 | def get_code(query): 256 | model_name_codex = 'codellama' if config.codex.model == 'codellama' else 'codex' 257 | code = forward(model_name_codex, prompt=query, input_type="image") 258 | if config.codex.model not in ('gpt-3.5-turbo', 'gpt-4'): 259 | code = f'def execute_command(image, my_fig, time_wait_between_lines, syntax):' + code # chat models give execute_command due to system behaviour 260 | code_for_syntax = code.replace("(image, my_fig, time_wait_between_lines, syntax)", "(image)") 261 | syntax_1 = Syntax(code_for_syntax, "python", theme="monokai", line_numbers=True, start_line=0) 262 | console.print(syntax_1) 263 | code = ast.unparse(ast.parse(code)) 264 | code_for_syntax_2 = code.replace("(image, my_fig, time_wait_between_lines, syntax)", "(image)") 265 | syntax_2 = Syntax(code_for_syntax_2, "python", theme="monokai", line_numbers=True, start_line=0) 266 | return code, syntax_2 267 | 268 | 269 | def execute_code(code, im, show_intermediate_steps=True): 270 | code, syntax = code 271 | code_line = inject_saver(code, show_intermediate_steps, syntax, time_wait_between_lines, console) 272 | 273 | display(HTML("")) 274 | 275 | with Live(Padding(syntax, 1), refresh_per_second=10, console=console, auto_refresh=True) as live: 276 | my_fig = plt.figure(figsize=(4, 4)) 277 | try: 278 | exec(compile(code_line, 'Codex', 'exec'), globals()) 279 | result = execute_command(im, my_fig, time_wait_between_lines, syntax) # The code is created in the exec() 280 | except Exception as e: 281 | print(f"Encountered error {e} when trying to run with visualizations. Trying from scratch.") 282 | exec(compile(code, 'Codex', 'exec'), globals()) 283 | result = execute_command(im, my_fig, time_wait_between_lines, syntax) # The code is created in the exec() 284 | 285 | plt.close(my_fig) 286 | 287 | def is_not_fig(x): 288 | if x is None: 289 | return True 290 | elif isinstance(x, str): 291 | return True 292 | elif isinstance(x, float): 293 | return True 294 | elif isinstance(x, int): 295 | return True 296 | elif isinstance(x, list) or isinstance(x, tuple): 297 | return all([is_not_fig(xx) for xx in x]) 298 | elif isinstance(x, dict): 299 | return all([is_not_fig(xx) for xx in x.values()]) 300 | return False 301 | 302 | f = None 303 | usefig = False 304 | if not is_not_fig(result): 305 | f = plt.figure(figsize=(4, 4)) 306 | usefig = True 307 | 308 | console.rule(f"[bold]Final Result[/bold]", style="chartreuse2") 309 | show_all(None, result, 'Result', fig=f, usefig=usefig, disp=False, console_in=console, time_wait_between_lines=0) 310 | 311 | 312 | def show_single_image(im): 313 | im = Image.fromarray((im.detach().cpu().numpy().transpose(1, 2, 0) * 255).astype("uint8")) 314 | im.copy() 315 | im.thumbnail((400, 400)) 316 | display(im) 317 | -------------------------------------------------------------------------------- /prompts/benchmarks/gqa.prompt: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from vision_functions import find_in_image, simple_qa, verify_property, best_text_match 3 | 4 | def bool_to_yesno(bool_answer: bool)->str: 5 | return "yes" if bool_answer else "no" 6 | 7 | class ImagePatch: 8 | """A Python class containing a crop of an image centered around a particular object, as well as relevant information. 9 | Attributes 10 | ---------- 11 | cropped_image : array_like 12 | An array-like of the cropped image taken from the original image. 13 | left : int 14 | An int describing the position of the left border of the crop's bounding box in the original image. 15 | lower : int 16 | An int describing the position of the bottom border of the crop's bounding box in the original image. 17 | right : int 18 | An int describing the position of the right border of the crop's bounding box in the original image. 19 | upper : int 20 | An int describing the position of the top border of the crop's bounding box in the original image. 21 | 22 | Methods 23 | ------- 24 | find(object_name: str)->List[ImagePatch] 25 | Returns a list of new ImagePatch objects containing crops of the image centered around any objects found in the image matching the object_name. 26 | simple_query(question: str=None)->str 27 | Returns the answer to a basic question asked about the image. If no question is provided, returns the answer to "What is this?". 28 | exists(object_name: str)->bool 29 | Returns True if the object specified by object_name is found in the image, and False otherwise. 30 | verify_property(property: str)->bool 31 | Returns True if the property is met, and False otherwise. 32 | best_text_match(string1: str, string2: str)->str 33 | Returns the string that best matches the image. 34 | crop(left: int, lower: int, right: int, upper: int)->ImagePatch 35 | Returns a new ImagePatch object containing a crop of the image at the given coordinates. 36 | """ 37 | 38 | def __init__(self, image, left: int=None, lower: int=None, right: int=None, upper: int=None): 39 | """Initializes an ImagePatch object by cropping the image at the given coordinates and stores the coordinates as attributes. 40 | If no coordinates are provided, the image is left unmodified, and the coordinates are set to the dimensions of the image. 41 | Parameters 42 | ------- 43 | image : array_like 44 | An array-like of the original image. 45 | left : int 46 | An int describing the position of the left border of the crop's bounding box in the original image. 47 | lower : int 48 | An int describing the position of the bottom border of the crop's bounding box in the original image. 49 | right : int 50 | An int describing the position of the right border of the crop's bounding box in the original image. 51 | upper : int 52 | An int describing the position of the top border of the crop's bounding box in the original image. 53 | 54 | """ 55 | if left is None and right is None and upper is None and lower is None: 56 | self.cropped_image = image 57 | self.left = 0 58 | self.lower = 0 59 | self.right = image.shape[2] # width 60 | self.upper = image.shape[1] # height 61 | else: 62 | self.cropped_image = image[:, lower:upper, left:right] 63 | self.left = left 64 | self.upper = upper 65 | self.right = right 66 | self.lower = lower 67 | 68 | self.width = self.cropped_image.shape[2] 69 | self.height = self.cropped_image.shape[1] 70 | 71 | self.horizontal_center = (self.left + self.right) / 2 72 | self.vertical_center = (self.lower + self.upper) / 2 73 | 74 | def find(self, object_name: str)->List["ImagePatch"]: 75 | """Returns a new ImagePatch object containing the crop of the image centered around the object specified by object_name. 76 | Parameters 77 | ------- 78 | object_name : str 79 | A string describing the name of the object to be found in the image. 80 | """ 81 | return find_in_image(self.cropped_image, object_name) 82 | 83 | def simple_query(self, question: str=None)->str: 84 | """Returns the answer to a basic question asked about the image. If no question is provided, returns the answer to "What is this?". 85 | Parameters 86 | ------- 87 | question : str 88 | A string describing the question to be asked. 89 | 90 | Examples 91 | ------- 92 | 93 | >>> # Which kind of animal is not eating? 94 | >>> def execute_command(image)->str: 95 | >>> image_patch = ImagePatch(image) 96 | >>> animal_patches = image_patch.find("animal") 97 | >>> for animal_patch in animal_patches: 98 | >>> if not animal_patch.verify_property("animal", "eating"): 99 | >>> return animal_patch.simple_query("What kind of animal is eating?") # crop would include eating so keep it in the query 100 | >>> # If no animal is not eating, query the image directly 101 | >>> return image_patch.simple_query("Which kind of animal is not eating?") 102 | 103 | >>> # What is in front of the horse? 104 | >>> # contains a relation (around, next to, on, near, on top of, in front of, behind, etc), so ask directly 105 | >>> return image_patch.simple_query("What is in front of the horse?") 106 | >>> 107 | """ 108 | return simple_qa(self.cropped_image, question) 109 | 110 | def exists(self, object_name: str)->bool: 111 | """Returns True if the object specified by object_name is found in the image, and False otherwise. 112 | Parameters 113 | ------- 114 | object_name : str 115 | A string describing the name of the object to be found in the image. 116 | 117 | Examples 118 | ------- 119 | >>> # Are there both cakes and gummy bears in the photo? 120 | >>> def execute_command(image)->str: 121 | >>> image_patch = ImagePatch(image) 122 | >>> is_cake = image_patch.exists("cake") 123 | >>> is_gummy_bear = image_patch.exists("gummy bear") 124 | >>> return bool_to_yesno(is_cake and is_gummy_bear) 125 | """ 126 | return len(self.find(object_name)) > 0 127 | 128 | def verify_property(self, object_name: str, property: str)->bool: 129 | """Returns True if the object possesses the property, and False otherwise. 130 | Differs from 'exists' in that it presupposes the existence of the object specified by object_name, instead checking whether the object possesses the property. 131 | Parameters 132 | ------- 133 | object_name : str 134 | A string describing the name of the object to be found in the image. 135 | property : str 136 | A string describing the property to be checked. 137 | 138 | Examples 139 | ------- 140 | >>> # Do the letters have blue color? 141 | >>> def execute_command(image)->str: 142 | >>> image_patch = ImagePatch(image) 143 | >>> letters_patches = image_patch.find("letters") 144 | >>> # Question assumes only one letter patch 145 | >>> if len(letters_patches) == 0: 146 | >>> # If no letters are found, query the image directly 147 | >>> return image_patch.simple_query("Do the letters have blue color?") 148 | >>> return bool_to_yesno(letters_patches[0].verify_property("letters", "blue")) 149 | """ 150 | return verify_property(self.cropped_image, object_name, property) 151 | 152 | def best_text_match(self, option_list: List[str]) -> str: 153 | """Returns the string that best matches the image. 154 | Parameters 155 | ------- 156 | option_list : str 157 | A list with the names of the different options 158 | prefix : str 159 | A string with the prefixes to append to the options 160 | 161 | Examples 162 | ------- 163 | >>> # Is the cap gold or white? 164 | >>> def execute_command(image)->str: 165 | >>> image_patch = ImagePatch(image) 166 | >>> cap_patches = image_patch.find("cap") 167 | >>> # Question assumes one cap patch 168 | >>> if len(cap_patches) == 0: 169 | >>> # If no cap is found, query the image directly 170 | >>> return image_patch.simple_query("Is the cap gold or white?") 171 | >>> return cap_patches[0].best_text_match(["gold", "white"]) 172 | """ 173 | return best_text_match(self.cropped_image, option_list) 174 | 175 | def crop(self, left: int, lower: int, right: int, upper: int)->"ImagePatch": 176 | """Returns a new ImagePatch cropped from the current ImagePatch. 177 | Parameters 178 | ------- 179 | left : int 180 | The leftmost pixel of the cropped image. 181 | lower : int 182 | The lowest pixel of the cropped image. 183 | right : int 184 | The rightmost pixel of the cropped image. 185 | upper : int 186 | The uppermost pixel of the cropped image. 187 | ------- 188 | """ 189 | return ImagePatch(self.cropped_image, left, lower, right, upper) 190 | 191 | # Examples of using ImagePatch 192 | # Is there a backpack to the right of the man? 193 | def execute_command(image)->str: 194 | image_patch = ImagePatch(image) 195 | man_patches = image_patch.find("man") 196 | # Question assumes one man patch 197 | if len(man_patches) == 0: 198 | # If no man is found, query the image directly 199 | return image_patch.simple_query("Is there a backpack to the right of the man?") 200 | man_patch = man_patches[0] 201 | backpack_patches = image_patch.find("backpack") 202 | # Question assumes one backpack patch 203 | if len(backpack_patches) == 0: 204 | return "no" 205 | for backpack_patch in backpack_patches: 206 | if backpack_patch.horizontal_center > man_patch.horizontal_center: 207 | return "yes" 208 | return "no" 209 | 210 | # In which part is the bread, the bottom or the top? 211 | def execute_command(image)->str: 212 | image_patch = ImagePatch(image) 213 | bread_patches = image_patch.find("bread") 214 | # Question assumes only one bread patch 215 | if len(bread_patches) == 0: 216 | # If no bread is found, query the image directly 217 | return image_patch.simple_query("In which part is the bread, the bottom or the top?") 218 | if bread_patches[0].vertical_center < image_patch.vertical_center: 219 | return "bottom" 220 | else: 221 | return "top" 222 | 223 | # What type of weather do you see in the photograph? 224 | def execute_command(image)->str: 225 | image_patch = ImagePatch(image) 226 | return image_patch.simple_query("What type of weather do you see in the photograph?") 227 | 228 | # Who is the man staring at? 229 | def execute_command(image)->str: 230 | # asks for the predicate of a relational verb (staring at), so ask directly 231 | image_patch = ImagePatch(image) 232 | return image_patch.simple_query("Who is the man staring at?") 233 | 234 | # What toy is wearing a shirt? 235 | def execute_command(image)->str: 236 | # not a relational verb so go step by step 237 | image_patch = ImagePatch(image) 238 | toy_patches = image_patch.find("toy") 239 | # Question assumes only one toy patch 240 | if len(toy_patches) == 0: 241 | # If no toy is found, query the image directly 242 | return image_patch.simple_query("What toy is wearing a shirt?") 243 | for toy_patch in toy_patches: 244 | is_wearing_shirt = (toy_patch.simple_query("Is the toy wearing a shirt?") == "yes") 245 | if is_wearing_shirt: 246 | return toy_patch.simple_query("What toy is wearing a shirt?") # crop would include the shirt so keep it in the query 247 | # If no toy is wearing a shirt, pick the first toy 248 | return toy_patches[0].simple_query("What toy is wearing a shirt?") 249 | 250 | # What is behind the pole? 251 | def execute_command(image)->str: 252 | image_patch = ImagePatch(image) 253 | # contains a relation (around, next to, on, near, on top of, in front of, behind, etc), so ask directly 254 | return image_patch.simple_query("What is behind the pole?") 255 | 256 | # Are there bagels or lemons? 257 | def execute_command(image)->str: 258 | image_patch = ImagePatch(image) 259 | is_bagel = image_patch.exists("bagel") 260 | is_lemon = image_patch.exists("lemon") 261 | return bool_to_yesno(is_bagel or is_lemon) 262 | 263 | # Is that blanket to the right of a pillow? 264 | def execute_command(image)->str: 265 | image_patch = ImagePatch(image) 266 | blanket_patches = image_patch.find("blanket") 267 | # Question assumes only one blanket patch 268 | if len(blanket_patches) == 0: 269 | # If no blanket is found, query the image directly 270 | return image_patch.simple_query("Is that blanket to the right of a pillow?") 271 | for blanket_patch in blanket_patches: 272 | pillow_patches = image_patch.find("pillow") 273 | for pillow_patch in pillow_patches: 274 | if pillow_patch.horizontal_center > blanket_patch.horizontal_center: 275 | return "yes" 276 | return "no" 277 | 278 | # INSERT_PROMPT_HERE 279 | def execute_command(image)->str: -------------------------------------------------------------------------------- /prompts/benchmarks/okvqa.prompt: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from vision_functions import obtain_query_response_from_image 3 | from nlp_functions import llm_query, process_guesses 4 | 5 | def llm_query(question: str)->str: 6 | '''Answers a text question using GPT-3. The input question is always a formatted string with a variable in it. 7 | 8 | Parameters 9 | ---------- 10 | question: str 11 | the text question to ask. Must not contain any reference to 'the image' or 'the photo', etc. 12 | ''' 13 | return llm_query(question) 14 | 15 | def process_guesses(question: str, guesses: List[str])->str: 16 | '''Processes a list of guesses for an answer to a question and returns the best answer.''' 17 | return process_guesses(question, guesses) 18 | 19 | class ImagePatch: 20 | def __init__(self, image, left: int=None, lower: int=None, right: int=None, upper: int=None): 21 | if left is None and right is None and upper is None and lower is None: 22 | self.cropped_image = image 23 | self.left = 0 24 | self.lower = 0 25 | self.right = image.shape[2] 26 | self.upper = image.shape[1] 27 | else: 28 | self.cropped_image = image[:, lower:upper, left:right] 29 | self.left = left 30 | self.lower = lower 31 | self.right = right 32 | self.upper = upper 33 | 34 | self.width = self.cropped_image.shape[2] 35 | self.height = self.cropped_image.shape[1] 36 | 37 | def simple_query(self, query: str): 38 | """Answer basic queries about the image patch. 39 | Parameters 40 | ---------- 41 | query: str 42 | the simple query about the image patch in the form of a question 43 | 44 | Returns 45 | ------- 46 | str 47 | a guess for the answer to the question 48 | """ 49 | answer = obtain_query_response_from_image(self.cropped_image, query) 50 | return answer 51 | 52 | # Examples of using the ImagePatch class 53 | 54 | # What kind of flowers are these? 55 | def execute_command(image)->str: 56 | # The question is direct perception, so we can just ask the image 57 | # There is no additional information needed. 58 | image = ImagePatch(image) 59 | direct_guess = image.simple_query("What kind of flowers are these?") 60 | return direct_guess 61 | 62 | # What do these people on the bikes normally write and give out? 63 | def execute_command(image)->str: 64 | # The question is not direct perception, so we need to ask the image for more information 65 | # What information do we need? We need to know who these people on the bikes are. 66 | image = ImagePatch(image) 67 | guesses = [] 68 | people_on_bikes = image.simple_query("Who are these people on the bikes?") 69 | external_knowledge_query = "What do {} normally write and give out?".format(people_on_bikes) 70 | step_by_step_guess = llm_query(external_knowledge_query) 71 | guesses.append("these people on the bikes are {}".format(people_on_bikes) + ", so " + step_by_step_guess) 72 | direct_guess = image.simple_query("What do these people on the bikes normally write and give out?") 73 | guesses.append(direct_guess) 74 | return process_guesses("What do these people on the bikes normally write and give out?", guesses) 75 | 76 | # What are these children doing? 77 | def execute_command(image)->str: 78 | # The question is direct perception, so we can just ask the image 79 | # Salient information: what are these children doing? 80 | image = ImagePatch(image) 81 | direct_guess = image.simple_query("What are these children watching?") 82 | return direct_guess 83 | 84 | # Is this the mountains or desert? 85 | def execute_command(image)->str: 86 | # The question is not direct perception, so we need to ask the image for more information 87 | # Salient information: is this the mountains? is this the desert? 88 | # Question contains "or", so the answer must be "mountains" or "desert" 89 | image = ImagePatch(image) 90 | guesses = [] 91 | mountains = image.simple_query("Is this the mountains?") 92 | desert = image.simple_query("Is this the desert?") 93 | if mountains == "yes": 94 | guesses.append("this is the mountains") 95 | if desert == "yes": 96 | guesses.append("this is the desert") 97 | direct_guess = image.simple_query("Is this the mountains or desert?") 98 | guesses.append(direct_guess) 99 | return process_guesses("Is this the mountains or desert?", guesses) 100 | 101 | # Who is famous for allegedly doing this in a lightning storm? 102 | def execute_command(image)->str: 103 | # The question is not direct perception, so we need to ask the image for more information 104 | # Salient information: what is being done? 105 | image = ImagePatch(image) 106 | guesses = [] 107 | action = image.simple_query("What is being done?") 108 | external_knowledge_query = "Who is famous for allegedly {} in a lightning storm?".format(action) 109 | step_by_step_guess = llm_query(external_knowledge_query) 110 | guesses.append("what is being done is {}".format(action) + ", so " + step_by_step_guess) 111 | direct_guess = image.simple_query("Who is famous for allegedly doing this in a lightning storm?") 112 | guesses.append(direct_guess) 113 | return process_guesses("Who is famous for allegedly doing this in a lightning storm?", guesses) 114 | 115 | # Should you bake or roast this item? 116 | def execute_command(image)->str: 117 | # The question is not direct perception, so we need to ask the image for more information 118 | # Salient information: what is this item? 119 | # Question contains "or", so the answer must be "bake" or "roast" 120 | image = ImagePatch(image) 121 | guesses = [] 122 | item = image.simple_query("What is this item?") 123 | external_knowledge_query = "Should you bake or roast {}?".format(item) 124 | step_by_step_guess = llm_query(external_knowledge_query) 125 | guesses.append("what is this item is {}".format(item) + ", so " + step_by_step_guess) 126 | direct_guess = image.simple_query("Should you bake or roast this item?") 127 | guesses.append(direct_guess) 128 | return process_guesses("Should you bake or roast this item?", guesses) 129 | 130 | # Where can I get a laptop like the one found here? 131 | def execute_command(image)->str: 132 | # The question is not direct perception, so we need to ask the image for more information 133 | # Salient information: what kind of laptop is this? 134 | image = ImagePatch(image) 135 | guesses = [] 136 | laptop_type = image.simple_query("What kind of laptop is this?") 137 | external_knowledge_query = "Where can I get a {} like the one found here?".format(laptop_type) 138 | step_by_step_guess = llm_query(external_knowledge_query) 139 | guesses.append("what kind of laptop is this is {}".format(laptop_type) + ", so " + step_by_step_guess) 140 | direct_guess = image.simple_query("Where can I get a laptop like the one found here?") 141 | guesses.append(direct_guess) 142 | return process_guesses("Where can I get a laptop like the one found here?", guesses) 143 | 144 | # INSERT_PROMPT_HERE 145 | def execute_command(image)->str: -------------------------------------------------------------------------------- /prompts/fixed_code/blip2.prompt: -------------------------------------------------------------------------------- 1 | 2 | image_patch = ImagePatch(image) 3 | return image_patch.simple_query(question) -------------------------------------------------------------------------------- /prompts/fixed_code/blip2_video.prompt: -------------------------------------------------------------------------------- 1 | 2 | video_segment = VideoSegment(video) 3 | frame_patch = video_segment.frame_from_index(video_segment.num_frames // 2) 4 | query = query + '? The options to answer the previous question are: [' + ', '.join(possible_answers) + ']' 5 | 6 | return frame_patch.simple_query(query) -------------------------------------------------------------------------------- /prompts/fixed_code/glip.prompt: -------------------------------------------------------------------------------- 1 | 2 | image_patch = ImagePatch(image) 3 | bbox = image_patch.forward('glip', image_patch.cropped_image, query)[0] 4 | return image_patch.crop(*bbox) -------------------------------------------------------------------------------- /prompts/gpt3/gpt3_process_guess.txt: -------------------------------------------------------------------------------- 1 | Please answer the following questions using the given guesses. 2 | If a unique answer cannot be determined, choose only one of the possible answers. 3 | Aim to reply in ONE word (at MOST 2). 4 | 5 | Question: What kind of flowers are these? 6 | Guess 1: these flowers are purple, so lavender, lilac, iris, and hyacinth 7 | Guess 2: purple flowers 8 | Answer: lilac 9 | 10 | Question: What do these people on the bikes normally write and give out? 11 | Guess 1: the people on bikes are police, so Tickets 12 | Guess 2: tickets 13 | Answer: tickets 14 | 15 | Question: What kind of cold meet is this? 16 | Guess 1: what kind of meat is this is beef, so roast beef 17 | Guess 2: beef 18 | Answer: beef 19 | 20 | Question: Can you guess the place shown in this picture? 21 | Guess 1: the place is tourist attraction, so the Eiffel Tower in Paris, France 22 | Guess 2: big ben 23 | Answer: big ben 24 | 25 | Question: When was this type of vehicle with two equal sized wheels invented? 26 | Guess 1: the vehicle is a bicycle, so 19th century 27 | Guess 2: 1819 28 | Answer: 1800s 29 | 30 | Question: What is the flavor of the pink topping on this dessert? 31 | Guess 1: the topping is whipped cream, so strawberry, vanilla, chocolate, and raspberry 32 | Guess 2: strawberry 33 | Answer: strawberry 34 | 35 | Question: How are these festive lights held in place? 36 | Guess 1: these festive lights are christmas lights, so with hooks clips 37 | Guess 2: string 38 | Answer: string 39 | 40 | Question: Who is famous for allegedly doing this in a lightning storm? 41 | Guess 1: what is being done is flying a kite, so Benjamin Franklin 42 | Guess 2: Charles Manson 43 | Answer: Benjamin Franklin 44 | 45 | Question: What is the object atop the skier's head used for? 46 | Guess 1: the object atop the skier's head is helmet, so protection from head injuries 47 | Guess 2: sunglasses 48 | Answer: protection 49 | 50 | Question: What rank is the man on the right? 51 | Guess 1: who is the man on the right is sailor, so seaman 52 | Guess 2: captain 53 | Answer: captain 54 | 55 | Question: Chemically what kind of water is in the picture? 56 | Guess 1: the water in the picture is waves, so salt water 57 | Guess 2: salt water 58 | Answer: salt 59 | 60 | Question: Is the material tweed or canvas? 61 | Guess 1: the material is fabric, so fabric 62 | Guess 2: canvas 63 | Answer: canvas 64 | 65 | Question: Which type of meat are in the photo? 66 | Guess 1: the meat in the photo is sausage, so pork 67 | Guess 2: hot dogs 68 | Answer: hotdogs 69 | 70 | Question: What sort of predator might there be in an area like this? 71 | Guess 1: this area is mountains, so predators like wolves fox 72 | Guess 2: shark 73 | Answer: shark 74 | 75 | Question: Can you name a sport this person could be a part of? 76 | Guess 1: this person is a racer, so racing such as auto 77 | Guess 2: motorcycle racing 78 | Answer: racing 79 | 80 | Question: Who makes the yellow top worn in this photograph? 81 | Guess 1: the top is red, so brand is unknown 82 | Guess 2: Burton 83 | Answer: Burton 84 | 85 | Question: Is the athlete right or left handed? 86 | Guess 1: what is the athlete doing is playing baseball, so unclear 87 | Guess 2: right handed 88 | Answer: right handed 89 | 90 | Question: Is this food high or low on fat? 91 | Guess 1: what kind of food is this is sandwich, so depends on ingredients 92 | Guess 2: high 93 | Answer: high 94 | 95 | Question: What wood are those cabinets made of? 96 | Guess 1: what kind of cabinets are these is kitchen cabinets, so typically wood such as oak 97 | Guess 2: maple 98 | 99 | Question: Which objects shown are typically associated with small children? 100 | Guess 1: what objects are shown are stuffed animals, so toys 101 | Guess 2: teddy bears 102 | Answer: teddy bears 103 | 104 | Question: What small appliance is that stuffed animal inside? 105 | Guess 1: the stuffed animal is a teddy bear, so vacuum cleaner 106 | Guess 2: microwave 107 | Answer: microwave 108 | 109 | Question: What is this made with? 110 | Guess 1: what is this is muffin, so flour sugar eggs 111 | Guess 2: oats 112 | Answer: flour 113 | 114 | Question: What is the position name of the player squatting down? 115 | Guess 1: who is squatting down is the batter, so hitter 116 | Guess 2: catcher 117 | 118 | Question: {} 119 | Guess 1: {} 120 | Guess 2: {} 121 | Answer (remember, only 1-2 words): -------------------------------------------------------------------------------- /prompts/gpt3/gpt3_qa.txt: -------------------------------------------------------------------------------- 1 | I am a highly intelligent question answering bot. My goal is to answer the way a typical person might, even if this is not entirely accurate, in very brief answers. 2 | 3 | Q: How does rain help the plants? 4 | A: water them 5 | 6 | Q: Who was president of the United States in 1955? 7 | A: Eisenhower 8 | 9 | Q: What type of material is a bathroom made of? 10 | A: tile 11 | 12 | Q: Where were the 1992 Olympics held? 13 | A: Barcelona, Spain 14 | 15 | Q: What devices can be controlled by a universal remote? 16 | A: tvs 17 | 18 | Q: {} Be very concise, no ranges, no doubt. 19 | A: -------------------------------------------------------------------------------- /prompts/gpt3/video_question.txt: -------------------------------------------------------------------------------- 1 | We want to answer a question about a video. We have information about the video. We also have a question, and a list of options. We want to return the option that is most likely to be the correct answer to the question. 2 | 3 | Example: 4 | - Question: how did the boy in stripped open the book to see its contents 5 | - Caption of middle frame: two children playing on a couch in a living room 6 | - Is there a girl in the scene: True 7 | - Possible answers: ['asked the girl for help', 'observe the book', 'stare at the book', 'flip the pages', 'with a bookmark'] 8 | Take a close look at the question and information provided and select one of the possible answers 9 | - Selected answer: flip the pages 10 | 11 | Example: 12 | - Question: what does the man in checkered do after walking onto the stage with microphone stands at the start 13 | - Caption of frame after walking onto the stage: a man and woman in a kimono is shown 14 | - Possible answers: ['take away the stand', 'set up the stand', 'takes out some paper', 'bow to people', 'hands him a bottle'] 15 | Take a close look at the question and information provided and select one of the possible answers 16 | - Selected answer: set up the stand 17 | 18 | Example: 19 | - Question: what is the relation between the children 20 | - Caption of frame of interest: a group of children wearing hats is shown 21 | - Description of children: a group of children sitting in a room with hats on 22 | - Location: school 23 | - Possible answers: ['twins', 'siblings', 'band members', 'classmate', 'friends'] 24 | Take a close look at the question and information provided and select one of the possible answers 25 | - Selected answer: classmate 26 | 27 | Example: 28 | - Question: what does the man in white at the side do as the girl slide down 29 | - Caption of frame when the girl slides down: a child playing on a slide in a backyard 30 | - Action of man: swinging his son 31 | - Objects in the image: found: [], not found: ['rake', 'cup'] 32 | - Possible answers: ['push the rake again', 'look and walk around', 'push boy on swing', 'drink from cup', 'smiles'] 33 | Take a close look at the question and information provided and select one of the possible answers 34 | - Selected answer: push boy on swing 35 | 36 | Example 37 | - Question: why does the person in white put the weight onto the rack after carrying it 38 | - Caption of frame after carrying the weight: a man lifting a barbell in a gym 39 | - Description of person in white: a man doing a barbell squat 40 | - Possible answers: ['to support the man', 'be elegant', 'to demostrate', 'part of performance', 'take a break'] 41 | Take a close look at the question and information provided and select one of the possible answers 42 | - Selected answer: take a break 43 | 44 | Example 45 | - Question: why was the cup positioned under the tap 46 | - Caption of frame: a woman pouring a beer at a bar 47 | - Possible answers: ['get the drink', 'catch spills', 'make cat comfortable with water', 'soapy', 'facilitate sucking for water'] 48 | Take a close look at the question and information provided and select one of the possible answers 49 | - Selected answer: get the drink 50 | 51 | Example: 52 | - Question: why is the woman wearing a raincoat 53 | - Caption of frame: scene indoors with a lot of people 54 | - Possible answers: ['fashion modeling', 'protect from rain', 'cook', 'hug a friend', 'climatology'] 55 | Reason about the question and information provided and select one of the possible answers. 56 | - Selected answer: fashion modeling 57 | 58 | Example: 59 | - Question: why are there many lines on the snow 60 | - Caption of frame: a person skiing down a snowy hill 61 | - Objects in the image: found: ['ski', 'person', 'tree'], not found: ['vehicle', 'rope'] 62 | - Possible answers: ['due to ropes', 'snowmobile routes', 'due to vehicle', 'many people skied', 'tree roots'] 63 | Take a close look at the question and information provided and select one of the possible answers 64 | - Selected answer: many people skied 65 | 66 | Example: 67 | - Question: where did the man in grey put his right hand when he started sliding down the slope 68 | - Caption of frame when starting sliding: a person riding a bike down a hill 69 | - Right hand location: in his pocket 70 | - Possible answers: ['on his face', 'on her cheeks', 'back pocket', 'on table', 'railings'] 71 | Take a close look at the question and information provided and select one of the possible answers 72 | - Selected answer: back pocket 73 | 74 | Example: 75 | - Question: {question} 76 | {info} 77 | - Possible answers: {options} 78 | Take a close look at the question and information provided and select one of the possible answers. 79 | - Selected answer: -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | backoff==2.2.1 3 | bitsandbytes==0.38.1 4 | cityscapesscripts==2.2.1 5 | git+https://github.com/openai/CLIP.git 6 | decord==0.6.0 7 | dill==0.3.6 8 | einops==0.6.0 9 | ftfy==6.1.1 10 | h5py==3.8.0 11 | inflect==6.0.2 12 | ipython==8.11.0 13 | ipykernel==6.22.0 14 | jupyter==1.0.0 15 | joblib==1.2.0 16 | kornia==0.6.9 17 | matplotlib==3.6.2 18 | nltk==3.8.1 19 | num2words==0.5.12 20 | numpy==1.23.5 21 | omegaconf==2.3.0 22 | git+https://github.com/openai/openai-python.git 23 | opencv_python_headless==4.5.5.64 24 | pandas==1.5.2 25 | Pillow==9.4.0 26 | prettytable==3.6.0 27 | pycocotools==2.0.6 28 | python_dateutil==2.8.2 29 | PyYAML==6.0 30 | qd==0.8.9 31 | regex==2022.10.31 32 | requests==2.28.1 33 | rich==13.3.2 34 | scipy==1.9.3 35 | setuptools==65.6.3 36 | tensorboardX==2.6 37 | tensorflow==2.11.1 38 | timm==0.6.12 39 | torch==1.13.1 40 | torchvision==0.14.1 41 | tqdm==4.64.1 42 | git+https://github.com/huggingface/transformers.git 43 | wandb==0.13.9 44 | word2number==1.1 45 | yacs==0.1.8 46 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # create environment 2 | bash setup_env.sh 3 | # download models 4 | bash download_models.sh -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | conda create -n vipergpt python=3.10 2 | conda activate vipergpt 3 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia 4 | pip install -r requirements.txt -------------------------------------------------------------------------------- /teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-columbia/viper/09fe3465224766860d8dd4ec48db942f22b05092/teaser.gif -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import pandas as pd 6 | import pathlib 7 | import random 8 | import sys 9 | import time 10 | import torch 11 | from PIL import Image 12 | from torchvision import transforms 13 | from torchvision.utils import draw_bounding_boxes as tv_draw_bounding_boxes 14 | from torchvision.utils import make_grid 15 | from typing import Union 16 | 17 | clip_stats = (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 18 | 19 | 20 | def is_interactive() -> bool: 21 | try: 22 | from IPython import get_ipython 23 | if get_ipython() is not None: 24 | return True 25 | else: 26 | return False 27 | except NameError: 28 | return False # Probably standard Python interpreter 29 | 30 | 31 | def denormalize(images, means=(0.485, 0.456, 0.406), stds=(0.229, 0.224, 0.225)): 32 | means = torch.tensor(means).reshape(1, 3, 1, 1) 33 | stds = torch.tensor(stds).reshape(1, 3, 1, 1) 34 | return images * stds + means 35 | 36 | 37 | def show_batch(batch, stats=clip_stats): 38 | fig, ax = plt.subplots(figsize=(12, 12)) 39 | ax.set_xticks([]) 40 | ax.set_yticks([]) 41 | denorm_images = denormalize(batch, *stats) 42 | ax.imshow(make_grid(denorm_images[:64], nrow=8).permute(1, 2, 0).clamp(0, 1)) 43 | 44 | 45 | def show_batch_from_dl(dl): 46 | for images, labels in dl: 47 | show_batch(images) 48 | print(labels[:64]) 49 | break 50 | 51 | 52 | def show_single_image(image, denormalize_stats=None, bgr_image=False, save_path=None, size='small', bbox_info=None): 53 | if not is_interactive(): 54 | import matplotlib 55 | matplotlib.use("module://imgcat") 56 | if size == 'size_img': 57 | figsize = (image.shape[2] / 100, image.shape[1] / 100) # The default dpi of plt.savefig is 100 58 | elif size == 'small': 59 | figsize = (4, 4) 60 | else: 61 | figsize = (12, 12) 62 | 63 | fig = plt.figure(figsize=figsize) 64 | ax = fig.add_axes([0, 0, 1, 1]) 65 | ax.set_xticks([]) 66 | ax.set_yticks([]) 67 | 68 | if bbox_info is not None: 69 | image = draw_bounding_boxes(image, bbox_info['bboxes'], labels=bbox_info['labels'], colors=bbox_info['colors'], 70 | width=5) 71 | 72 | if isinstance(image, torch.Tensor): 73 | image = image.detach().cpu() 74 | if denormalize_stats is not None: 75 | image = denormalize(image.unsqueeze(0), *denormalize_stats) 76 | if image.dtype == torch.float32: 77 | image = image.clamp(0, 1) 78 | ax.imshow(image.squeeze(0).permute(1, 2, 0)) 79 | else: 80 | if bgr_image: 81 | image = image[..., ::-1] 82 | ax.imshow(image) 83 | 84 | if save_path is None: 85 | plt.show() 86 | # save image if save_path is provided 87 | if save_path is not None: 88 | # make path if it does not exist 89 | if not os.path.exists(os.path.dirname(save_path)): 90 | os.makedirs(os.path.dirname(save_path)) 91 | plt.savefig(save_path) 92 | 93 | 94 | def draw_bounding_boxes( 95 | image: Union[torch.Tensor, Image.Image], 96 | bboxes: Union[list, torch.Tensor], 97 | width: int = 5, 98 | **kwargs 99 | ): 100 | """ 101 | Wrapper around torchvision.utils.draw_bounding_boxes 102 | bboxes: [xmin, ymin, xmax, ymax] 103 | :return: 104 | """ 105 | if isinstance(image, Image.Image): 106 | if type(image) == Image.Image: 107 | image = transforms.ToTensor()(image) 108 | if isinstance(bboxes, list): 109 | bboxes = torch.tensor(bboxes) 110 | 111 | image = (image * 255).to(torch.uint8).cpu() 112 | height = image.shape[1] 113 | bboxes = torch.stack([bboxes[:, 0], height - bboxes[:, 3], bboxes[:, 2], height - bboxes[:, 1]], dim=1) 114 | return tv_draw_bounding_boxes(image, bboxes, width=width, **kwargs) 115 | 116 | 117 | def seed_everything(seed=0): 118 | random.seed(seed) 119 | np.random.seed(seed) 120 | torch.manual_seed(seed) 121 | torch.cuda.manual_seed_all(seed) 122 | 123 | 124 | def get_index_from_sample_id(sample_id, dataset): 125 | df = dataset.df 126 | return np.arange(df.shape[0])[df.index == sample_id] 127 | 128 | 129 | def save_json(data: dict, path: Union[str, pathlib.Path]): 130 | if isinstance(path, str): 131 | path = pathlib.Path(path) 132 | if not path.parent.exists(): 133 | path.parent.mkdir(parents=True) 134 | if path.suffix != '.json': 135 | path = path.with_suffix('.json') 136 | with open(path, 'w') as f: 137 | json.dump(data, f, indent=4, sort_keys=True) 138 | 139 | 140 | def load_json(path: Union[str, pathlib.Path]): 141 | if isinstance(path, str): 142 | path = pathlib.Path(path) 143 | if path.suffix != '.json': 144 | path = path.with_suffix('.json') 145 | with open(path, 'r') as f: 146 | data = json.load(f) 147 | return data 148 | 149 | 150 | def make_print_safe(string: str) -> str: 151 | return string.replace(r'[', r'\[') 152 | 153 | 154 | def sprint(string: str): 155 | print(make_print_safe(string)) 156 | 157 | 158 | def print_full_df(df): 159 | with pd.option_context('display.max_rows', None, 'display.max_columns', None): # more options can be specified also 160 | if is_interactive(): 161 | display(df) 162 | else: 163 | print(df) 164 | 165 | 166 | def code_to_paste(code): 167 | print('\n'.join([c[4:] for c in code.split('\n')[1:]]).replace('image', 'ip').replace('return ', '')) 168 | 169 | 170 | class HiddenPrints: 171 | hide_prints = False 172 | 173 | def __init__(self, model_name=None, console=None, use_newline=True): 174 | self.model_name = model_name 175 | self.console = console 176 | self.use_newline = use_newline 177 | self.tqdm_aux = None 178 | 179 | def __enter__(self): 180 | if self.hide_prints: 181 | import tqdm # We need to do an extra step to hide tqdm outputs. Does not work in Jupyter Notebooks. 182 | 183 | def nop(it, *a, **k): 184 | return it 185 | 186 | self.tqdm_aux = tqdm.tqdm 187 | tqdm.tqdm = nop 188 | 189 | if self.model_name is not None: 190 | self.console.print(f'Loading {self.model_name}...') 191 | self._original_stdout = sys.stdout 192 | self._original_stderr = sys.stderr 193 | sys.stdout = open(os.devnull, 'w') 194 | # May not be what we always want, but some annoying warnings end up to stderr 195 | sys.stderr = open(os.devnull, 'w') 196 | 197 | def __exit__(self, exc_type, exc_val, exc_tb): 198 | if self.hide_prints: 199 | sys.stdout.close() 200 | sys.stdout = self._original_stdout 201 | sys.stdout = self._original_stderr 202 | if self.model_name is not None: 203 | self.console.print(f'{self.model_name} loaded ') 204 | import tqdm 205 | tqdm.tqdm = self.tqdm_aux 206 | -------------------------------------------------------------------------------- /video_segment.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from typing import Union, Iterator 5 | 6 | from configs import config 7 | from image_patch import ImagePatch 8 | from vision_processes import forward 9 | 10 | 11 | class VideoSegment: 12 | """A Python class containing a set of frames represented as ImagePatch objects, as well as relevant information. 13 | Attributes 14 | ---------- 15 | video : torch.Tensor 16 | A tensor of the original video. 17 | start : int 18 | An int describing the starting frame in this video segment with respect to the original video. 19 | end : int 20 | An int describing the ending frame in this video segment with respect to the original video. 21 | num_frames->int 22 | An int containing the number of frames in the video segment. 23 | 24 | Methods 25 | ------- 26 | frame_iterator->Iterator[ImagePatch] 27 | trim(start, end)->VideoSegment 28 | Returns a new VideoSegment containing a trimmed version of the original video at the [start, end] segment. 29 | """ 30 | 31 | def __init__(self, video: torch.Tensor, start: int = None, end: int = None, parent_start=0, queues=None): 32 | """Initializes a VideoSegment object by trimming the video at the given [start, end] times and stores the 33 | start and end times as attributes. If no times are provided, the video is left unmodified, and the times are 34 | set to the beginning and end of the video. 35 | 36 | Parameters 37 | ------- 38 | video : torch.Tensor 39 | A tensor of the original video. 40 | start : int 41 | An int describing the starting frame in this video segment with respect to the original video. 42 | end : int 43 | An int describing the ending frame in this video segment with respect to the original video. 44 | """ 45 | 46 | if start is None and end is None: 47 | self.trimmed_video = video 48 | self.start = 0 49 | self.end = video.shape[0] # duration 50 | else: 51 | self.trimmed_video = video[start:end] 52 | if start is None: 53 | start = 0 54 | if end is None: 55 | end = video.shape[0] 56 | self.start = start + parent_start 57 | self.end = end + parent_start 58 | 59 | self.num_frames = self.trimmed_video.shape[0] 60 | 61 | self.cache = {} 62 | self.queues = (None, None) if queues is None else queues 63 | 64 | if self.trimmed_video.shape[0] == 0: 65 | raise Exception("VideoSegment has duration=0") 66 | 67 | def forward(self, model_name, *args, **kwargs): 68 | return forward(model_name, *args, queues=self.queues, **kwargs) 69 | 70 | def frame_from_index(self, index) -> ImagePatch: 71 | """Returns the frame at position 'index', as an ImagePatch object.""" 72 | if index < self.num_frames: 73 | image = self.trimmed_video[index] 74 | else: 75 | image = self.trimmed_video[-1] 76 | return ImagePatch(image, queues=self.queues) 77 | 78 | def trim(self, start: Union[int, None] = None, end: Union[int, None] = None) -> VideoSegment: 79 | """Returns a new VideoSegment containing a trimmed version of the original video at the [start, end] 80 | segment. 81 | 82 | Parameters 83 | ---------- 84 | start : Union[int, None] 85 | An int describing the starting frame in this video segment with respect to the original video. 86 | end : Union[int, None] 87 | An int describing the ending frame in this video segment with respect to the original video. 88 | 89 | Returns 90 | ------- 91 | VideoSegment 92 | a new VideoSegment containing a trimmed version of the original video at the [start, end] 93 | """ 94 | if start is not None: 95 | start = max(start, 0) 96 | if end is not None: 97 | end = min(end, self.num_frames) 98 | 99 | return VideoSegment(self.trimmed_video, start, end, self.start, queues=self.queues) 100 | 101 | def select_answer(self, info: dict, question: str, options=None) -> str: 102 | def format_dict(x): 103 | if isinstance(x, dict): 104 | x = ''.join([f'\n\t- {k}: {format_dict(v)}' for k, v in x.items()]) 105 | return x 106 | with open(config.select_answer_prompt, 'r') as f: 107 | prompt = f.read() 108 | info_formatting = '\n'.join([f"- {k}: {format_dict(v)}" for k, v in info.items()]) 109 | prompt = prompt.format(info=info_formatting, question=question, options=options) 110 | answer = self.forward('gpt3_general', prompt) 111 | answer = answer.strip() 112 | return answer 113 | 114 | def frame_iterator(self) -> Iterator[ImagePatch]: 115 | """Returns an iterator over the frames in the video segment.""" 116 | for i in range(self.num_frames): 117 | yield ImagePatch(self.trimmed_video[i], queues=self.queues) 118 | 119 | def __repr__(self): 120 | return "VideoSegment({}, {})".format(self.start, self.end) 121 | -------------------------------------------------------------------------------- /vision_processes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script that contains the backend code. No need to look at this to implement new functionality 3 | Functions that run separate processes. These processes run on GPUs, and are queried by processes running only CPUs 4 | """ 5 | 6 | import dill 7 | import inspect 8 | import queue 9 | import torch 10 | import torch.multiprocessing as mp 11 | from rich.console import Console 12 | from time import time 13 | from typing import Callable, Union 14 | 15 | from configs import config 16 | 17 | console = Console(highlight=False) 18 | 19 | if mp.current_process().name == 'MainProcess': 20 | # No need to initialize the models inside each process 21 | import vision_models 22 | # Create a list of all the defined models 23 | list_models = [m[1] for m in inspect.getmembers(vision_models, inspect.isclass) 24 | if issubclass(m[1], vision_models.BaseModel) and m[1] != vision_models.BaseModel] 25 | # Sort by attribute "load_order" 26 | list_models.sort(key=lambda x: x.load_order) 27 | if config.multiprocessing: 28 | manager = mp.Manager() 29 | else: 30 | manager = None 31 | else: 32 | list_models = None 33 | manager = None 34 | 35 | 36 | def make_fn(model_class, process_name, counter): 37 | """ 38 | model_class.name and process_name will be the same unless the same model is used in multiple processes, for 39 | different tasks 40 | """ 41 | # We initialize each one on a separate GPU, to make sure there are no out of memory errors 42 | num_gpus = torch.cuda.device_count() 43 | gpu_number = counter % num_gpus 44 | 45 | model_instance = model_class(gpu_number=gpu_number) 46 | 47 | def _function(*args, **kwargs): 48 | if process_name != model_class.name: 49 | kwargs['process_name'] = process_name 50 | 51 | if model_class.to_batch and not config.multiprocessing: 52 | # Batchify the input. Model expects a batch. And later un-batchify the output. 53 | args = [[arg] for arg in args] 54 | kwargs = {k: [v] for k, v in kwargs.items()} 55 | 56 | # The defaults that are not in args or kwargs, also need to listify 57 | full_arg_spec = inspect.getfullargspec(model_instance.forward) 58 | if full_arg_spec.defaults is None: 59 | default_dict = {} 60 | else: 61 | default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults)) 62 | non_given_args = full_arg_spec.args[1:][len(args):] 63 | non_given_args = set(non_given_args) - set(kwargs.keys()) 64 | for arg_name in non_given_args: 65 | kwargs[arg_name] = [default_dict[arg_name]] 66 | 67 | try: 68 | out = model_instance.forward(*args, **kwargs) 69 | if model_class.to_batch and not config.multiprocessing: 70 | out = out[0] 71 | except Exception as e: 72 | print(f'Error in {process_name} model:', e) 73 | out = None 74 | return out 75 | 76 | return _function 77 | 78 | 79 | if config.multiprocessing: 80 | 81 | def make_fn_process(model_class, process_name, counter): 82 | 83 | if model_class.to_batch: 84 | seconds_collect_data = model_class.seconds_collect_data # Window of seconds to group inputs 85 | max_batch_size = model_class.max_batch_size 86 | 87 | def _function(queue_in): 88 | 89 | fn = make_fn(model_class, process_name, counter) 90 | 91 | to_end = False 92 | while True: 93 | start_time = time() 94 | time_left = seconds_collect_data 95 | batch_inputs = [] 96 | batch_queues = [] 97 | while time_left > 0 and len(batch_inputs) < max_batch_size: 98 | try: 99 | received = queue_in.get(timeout=time_left) 100 | if received is None: 101 | to_end = True 102 | break 103 | else: 104 | batch_inputs.append(received[0]) 105 | batch_queues.append(received[1]) 106 | except queue.Empty: # Time-out expired 107 | break # Break inner loop (or do nothing, would break anyway because time_left < 0) 108 | time_left = seconds_collect_data - (time() - start_time) 109 | if len(batch_inputs) > 0: 110 | batch_kwargs = collate(batch_inputs, model_class.forward) 111 | outs = fn(**batch_kwargs) 112 | try: 113 | for out, qu in zip(outs, batch_queues): 114 | qu.put(out) 115 | except Exception as e: 116 | # No message, because we are just carrying the error from before 117 | for qu in batch_queues: 118 | qu.put(None) 119 | if to_end: 120 | print(f'{process_name} model exiting') 121 | break 122 | 123 | else: 124 | def _function(queue_in): 125 | fn = make_fn(model_class, process_name, counter) 126 | while True: 127 | received = queue_in.get() 128 | if received is None: 129 | print(f'{process_name} exiting') 130 | return 131 | (args, kwargs), queue_out = received 132 | out = fn(*args, **kwargs) 133 | queue_out.put(out) 134 | 135 | return _function 136 | 137 | 138 | if mp.current_process().name == 'MainProcess': 139 | queues_in: Union[dict[str, mp.Queue], None] = dict() 140 | consumers: dict[str, Union[mp.Process, Callable]] = dict() 141 | 142 | counter_ = 0 143 | for model_class_ in list_models: 144 | for process_name_ in model_class_.list_processes(): 145 | if process_name_ in config.load_models and config.load_models[process_name_]: 146 | queue_in_ = manager.Queue() # For transfer of data from producer to consumer 147 | queues_in[process_name_] = queue_in_ 148 | 149 | fn_process = make_fn_process(model_class_, process_name_, counter_) 150 | # Otherwise, it is not possible to pickle the _function (not defined at top level) 151 | aux = mp.reducer.dump 152 | mp.reducer.dump = dill.dump 153 | consumer = mp.Process(target=fn_process, kwargs={'queue_in': queue_in_}) 154 | consumer.start() 155 | mp.reducer.dump = aux 156 | consumers[process_name_] = consumer 157 | 158 | counter_ += 1 159 | 160 | else: 161 | queues_in = None 162 | 163 | 164 | def finish_all_consumers(): 165 | # Wait for consumers to finish 166 | for q_in in queues_in.values(): 167 | q_in.put(None) 168 | for cons in consumers.values(): 169 | cons.join() 170 | 171 | else: 172 | 173 | consumers = dict() 174 | 175 | counter_ = 0 176 | for model_class_ in list_models: 177 | for process_name_ in model_class_.list_processes(): 178 | if process_name_ in config.load_models and config.load_models[process_name_]: 179 | consumers[process_name_] = make_fn(model_class_, process_name_, counter_) 180 | counter_ += 1 181 | 182 | queues_in = None 183 | 184 | def finish_all_consumers(): 185 | pass 186 | 187 | 188 | def forward(model_name, *args, queues=None, **kwargs): 189 | """ 190 | Sends data to consumer (calls their "forward" method), and returns the result 191 | """ 192 | error_msg = f'No model named {model_name}. ' \ 193 | 'The available models are: {}. Make sure to activate it in the configs files' 194 | if not config.multiprocessing: 195 | try: 196 | out = consumers[model_name](*args, **kwargs) 197 | except KeyError as e: 198 | raise KeyError(error_msg.format(list(consumers.keys()))) from e 199 | else: 200 | if queues is None: 201 | consumer_queues_in, queue_results = None, None 202 | else: 203 | consumer_queues_in, queue_results = queues 204 | try: 205 | if consumer_queues_in is not None: 206 | consumer_queue_in = consumer_queues_in[model_name] 207 | else: 208 | consumer_queue_in = queues_in[model_name] 209 | except KeyError as e: 210 | options = list(consumer_queues_in.keys()) if consumer_queues_in is not None else list(queues_in.keys()) 211 | raise KeyError(error_msg.format(options)) from e 212 | if queue_results is None: 213 | # print('No queue exists to get results. Creating a new one, but this is inefficient. ' 214 | # 'Consider providing an existing queue for the process') 215 | queue_results = manager.Queue() # To get outputs 216 | consumer_queue_in.put([(args, kwargs), queue_results]) 217 | out = queue_results.get() # Wait for result 218 | return out 219 | 220 | 221 | def collate(batch_inputs, fn): 222 | """ 223 | Combine a list of inputs into a single dictionary. The dictionary contains all the parameters of the 224 | function to be called. If the parameter is not defined in some samples, the default value is used. The 225 | value of the parameters is always a list. 226 | """ 227 | # Separate into args and kwargs 228 | args_input, kwarg_input = list(zip(*batch_inputs)) 229 | full_arg_spec = inspect.getfullargspec(fn) 230 | if full_arg_spec.defaults is None: 231 | default_dict = {} 232 | else: 233 | default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults)) 234 | if 'process_name' in default_dict: # process_name is a special parameter filled in later 235 | del default_dict['process_name'] 236 | 237 | args_list = full_arg_spec.args[1:] # Remove self 238 | 239 | # process_name is a special parameter filled in later 240 | if 'process_name' in args_list: 241 | assert args_list[-1] == 'process_name', 'process_name must be the last argument' 242 | args_list.remove('process_name') 243 | 244 | kwargs_output = {k: [] for k in args_list} 245 | for i, (args, kwargs) in enumerate(zip(args_input, kwarg_input)): 246 | if len(args) + len(kwargs) > len(args_list): 247 | raise Exception( 248 | f'You provided more arguments than the function {fn.__name__} accepts, or some kwargs/args ' 249 | f'overlap. The arguments are: {args_list}') 250 | for j, arg_name in enumerate(args_list): 251 | if len(args) > j: 252 | kwargs_output[arg_name].append(args[j]) 253 | elif arg_name in kwargs: 254 | kwargs_output[arg_name].append(kwargs[arg_name]) 255 | else: 256 | assert arg_name in default_dict, f'You did not provide a value for the argument {arg_name}.' 257 | kwargs_output[arg_name].append(default_dict[arg_name]) 258 | 259 | return kwargs_output 260 | --------------------------------------------------------------------------------