├── .gitattributes ├── .gitignore ├── LICENSE ├── Quickstart.ipynb ├── README.md ├── Tables.ipynb ├── Visual_Feature_Engineering.ipynb ├── __init__.py ├── clip_masking_lvis_image_ids.yml ├── datasets ├── coco_wrapper.py ├── pascal_classes.json ├── pascal_zeroshot.py ├── pfe_dataset.py ├── phrasecut.py └── utils.py ├── environment.yml ├── evaluation_utils.py ├── example_image.jpg ├── experiments ├── ablation.yaml ├── coco.yaml ├── pascal_0shot.yaml ├── pascal_1shot.yaml └── phrasecut.yaml ├── general_utils.py ├── metrics.py ├── models ├── clipseg.py └── vitseg.py ├── overview.png ├── sample_rd64.png ├── sample_rd64_refined.png ├── score.py ├── setup.py ├── supplementary.pdf └── training.py /.gitattributes: -------------------------------------------------------------------------------- 1 | weights/rd16-uni.pth filter=lfs diff=lfs merge=lfs -text 2 | weights/rd64-uni.pth filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | This license does not apply to the model weights. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import requests\n", 11 | "\n", 12 | "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n", 13 | "! unzip -d weights -j weights.zip\n", 14 | "from models.clipseg import CLIPDensePredT\n", 15 | "from PIL import Image\n", 16 | "from torchvision import transforms\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "\n", 19 | "# load model\n", 20 | "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n", 21 | "model.eval();\n", 22 | "\n", 23 | "# non-strict, because we only stored decoder weights (not CLIP weights)\n", 24 | "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "Load and normalize `example_image.jpg`. You can also load through an URL." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# load and normalize image\n", 41 | "input_image = Image.open('example_image.jpg')\n", 42 | "\n", 43 | "# or load from URL...\n", 44 | "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n", 45 | "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n", 46 | "\n", 47 | "transform = transforms.Compose([\n", 48 | " transforms.ToTensor(),\n", 49 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 50 | " transforms.Resize((352, 352)),\n", 51 | "])\n", 52 | "img = transform(input_image).unsqueeze(0)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "Predict and visualize (this might take a few seconds if running without GPU support)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n", 69 | "\n", 70 | "# predict\n", 71 | "with torch.no_grad():\n", 72 | " preds = model(img.repeat(4,1,1,1), prompts)[0]\n", 73 | "\n", 74 | "# visualize prediction\n", 75 | "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n", 76 | "[a.axis('off') for a in ax.flatten()]\n", 77 | "ax[0].imshow(input_image)\n", 78 | "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n", 79 | "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];" 80 | ] 81 | } 82 | ], 83 | "metadata": { 84 | "interpreter": { 85 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" 86 | }, 87 | "kernelspec": { 88 | "display_name": "Python 3", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 3 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython3", 102 | "version": "3.8.10" 103 | } 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 4 107 | } 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Segmentation Using Text and Image Prompts 2 | This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003). 3 | 4 | **November 2022:** CLIPSeg has been integrated into the [HuggingFace Transformers library](https://huggingface.co/docs/transformers/main/en/model_doc/clipseg). Thank you, [NielsRogge](https://github.com/NielsRogge)! 5 | **September 2022:** We released new weights for fine-grained predictions (see below for details). 6 | **March 2022:** The Paper has been accepted to CVPR 2022! 7 | 8 | 9 | drawing 10 | 11 | The systems allows to create segmentation models without training based on: 12 | - An arbitrary text query 13 | - Or an image with a mask highlighting stuff or an object. 14 | 15 | ### Quick Start 16 | 17 | In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension. 18 | It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb) 19 | (please note that the VM does not use a GPU, thus inference takes a few seconds). 20 | 21 | 22 | ### Dependencies 23 | This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`). 24 | Additional dependencies are hidden for double blind review. 25 | 26 | 27 | ### Datasets 28 | 29 | * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset 30 | * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation 31 | * `PascalZeroShot`: Wrapper class for PascalZeroShot 32 | * `COCOWrapper`: Wrapper class for COCO. 33 | 34 | ### Models 35 | 36 | * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder. 37 | * `ViTDensePredT`: CLIPSeg model with transformer-based decoder. 38 | 39 | ### Third Party Dependencies 40 | For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder. 41 | ```bash 42 | git clone https://github.com/cvlab-yonsei/JoEm 43 | git clone https://github.com/Jia-Research-Lab/PFENet.git 44 | git clone https://github.com/ChenyunWu/PhraseCutDataset.git 45 | git clone https://github.com/juhongm999/hsnet.git 46 | ``` 47 | 48 | ### Weights 49 | 50 | The MIT license does not apply to these weights. 51 | 52 | We provide three model weights, for D=64 (2x, ~4MB each) and D=16 (~1MB). 53 | ``` 54 | wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip 55 | unzip -d weights -j weights.zip 56 | ``` 57 | 58 | #### New Fine-grained Weights 59 | We introduced a more complex module for transforming tokens into predictions that allow for more refined predictions (in contrast to the square-like predictions of other weights). Corresponding weights are available in the weight download above called `rd64-uni-refined.pth`. 60 | They can be loaded by: 61 | ```python 62 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True) 63 | model.load_state_dict(torch.load('weights/rd64-uni-refined.pth'), strict=False) 64 | ``` 65 | 66 | See below for a direct comparison of the new fine-grained weights (top) and the old weights (below). 67 | drawing 68 | drawing 69 | 70 | 71 | 72 | ### Training and Evaluation 73 | 74 | To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`. 75 | 76 | For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`. 77 | 78 | 79 | ### Usage of PFENet Wrappers 80 | 81 | In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder. 82 | `git clone https://github.com/Jia-Research-Lab/PFENet.git ` 83 | 84 | 85 | ### License 86 | 87 | The source code files in this repository (excluding model weights) are released under MIT license. 88 | 89 | ### Citation 90 | ``` 91 | @InProceedings{lueddecke22_cvpr, 92 | author = {L\"uddecke, Timo and Ecker, Alexander}, 93 | title = {Image Segmentation Using Text and Image Prompts}, 94 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 95 | month = {June}, 96 | year = {2022}, 97 | pages = {7086-7096} 98 | } 99 | 100 | ``` 101 | -------------------------------------------------------------------------------- /Tables.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import clip\n", 13 | "from evaluation_utils import norm, denorm\n", 14 | "from general_utils import *\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# PhraseCut" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n", 49 | "tab1 = pc[['name'] + cols]\n", 50 | "for k in cols:\n", 51 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 52 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", 53 | "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n", 54 | "print(tab1.to_latex(header=False, index=False))" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "For 0.1 threshold" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n", 71 | "tab1 = pc[['name'] + cols]\n", 72 | "for k in cols:\n", 73 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 74 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", 75 | "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n", 76 | "print(tab1.to_latex(header=False, index=False))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "# One-shot" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### Pascal" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 118 | "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n", 119 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 120 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 121 | "\n", 122 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 123 | "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n", 124 | "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 125 | "\n", 126 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", 127 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", 128 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "#### Pascal Zero-shot (in one-shot setting)\n", 136 | "\n", 137 | "Using the same setting as one-shot (hence different from the other zero-shot benchmark)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 147 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", 148 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 149 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 150 | "\n", 151 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 152 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", 153 | "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 154 | "\n", 155 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", 156 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", 157 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "# without fixed thresholds...\n", 167 | "\n", 168 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 169 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", 170 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 171 | "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 172 | "\n", 173 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 174 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", 175 | "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "### COCO" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n", 201 | "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n", 202 | "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n", 203 | "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n", 204 | "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 205 | "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n", 206 | "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "# Zero-shot" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "\n", 232 | "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n", 233 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n", 234 | "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n", 235 | "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "# Ablation" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n", 261 | "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n", 262 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 263 | "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "# Generalization" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "generalization = experiment('experiments/generalize.yaml').dataframe()" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "print(\n", 316 | " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n", 317 | " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n", 318 | " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n", 319 | " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n", 320 | ")" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "interpreter": { 326 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" 327 | }, 328 | "kernelspec": { 329 | "display_name": "env2", 330 | "language": "python", 331 | "name": "env2" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.8.8" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 4 348 | } 349 | -------------------------------------------------------------------------------- /Visual_Feature_Engineering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Systematic" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "\n", 19 | "import clip\n", 20 | "from evaluation_utils import norm, denorm\n", 21 | "from general_utils import *\n", 22 | "from datasets.lvis_oneshot3 import LVIS_OneShot3\n", 23 | "\n", 24 | "clip_device = 'cuda'\n", 25 | "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n", 26 | "clip_model.eval();\n", 27 | "\n", 28 | "from models.clipseg import CLIPDensePredTMasked\n", 29 | "\n", 30 | "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n", 31 | "clip_mask_model.eval();" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n", 41 | " text_class_labels=True, image_size=352, min_area=0.1,\n", 42 | " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "plot_data(lvis)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "from collections import defaultdict\n", 61 | "import json\n", 62 | "\n", 63 | "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n", 64 | "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n", 65 | "\n", 66 | "objects_per_image = defaultdict(lambda : set())\n", 67 | "for ann in lvis_raw['annotations']:\n", 68 | " objects_per_image[ann['image_id']].add(ann['category_id'])\n", 69 | " \n", 70 | "for ann in lvis_val_raw['annotations']:\n", 71 | " objects_per_image[ann['image_id']].add(ann['category_id']) \n", 72 | " \n", 73 | "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n", 74 | "\n", 75 | "del lvis_raw, lvis_val_raw" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "#bs = 32\n", 85 | "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "from general_utils import get_batch\n", 95 | "from functools import partial\n", 96 | "from evaluation_utils import img_preprocess\n", 97 | "import torch\n", 98 | "\n", 99 | "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n", 100 | "\n", 101 | " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n", 102 | "\n", 103 | " all_prompts = []\n", 104 | " \n", 105 | " with torch.no_grad():\n", 106 | " valid_sims = []\n", 107 | " torch.manual_seed(571)\n", 108 | " \n", 109 | " if type(batches_or_dataset) == list:\n", 110 | " loader = batches_or_dataset # already loaded\n", 111 | " max_iter = float('inf')\n", 112 | " else:\n", 113 | " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n", 114 | " max_iter = 50\n", 115 | " \n", 116 | " global batch\n", 117 | " for i_batch, (batch, batch_y) in enumerate(loader):\n", 118 | " \n", 119 | " if i_batch >= max_iter: break\n", 120 | " \n", 121 | " processed_batch = process(batch)\n", 122 | " if type(processed_batch) == dict:\n", 123 | " \n", 124 | " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n", 125 | " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n", 126 | " else:\n", 127 | " processed_batch = process(batch).to(clip_device)\n", 128 | " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n", 129 | " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n", 130 | " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n", 131 | " \n", 132 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 133 | " bs = len(batch[0])\n", 134 | " for j in range(bs):\n", 135 | " \n", 136 | " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n", 137 | " support_image = basename(lvis.samples[c][sid])\n", 138 | " \n", 139 | " img_objs = [o for o in objects_per_image[int(support_image)]]\n", 140 | " img_objs = [o.replace('_', ' ') for o in img_objs]\n", 141 | " \n", 142 | " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n", 143 | " if o != batch_y[2][j]]\n", 144 | " \n", 145 | " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n", 146 | " all_prompts += [prompts]\n", 147 | " \n", 148 | " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n", 149 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n", 150 | "\n", 151 | " global logits\n", 152 | " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n", 153 | "\n", 154 | " global sim\n", 155 | " sim = torch.softmax(logits, dim=-1)\n", 156 | " \n", 157 | " valid_sims += [sim]\n", 158 | " \n", 159 | " #valid_sims = torch.stack(valid_sims)\n", 160 | " return valid_sims, all_prompts\n", 161 | " \n", 162 | "\n", 163 | "def new_img_preprocess(x):\n", 164 | " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n", 165 | " \n", 166 | "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n", 167 | "get_similarities(lvis, lambda x: x[1]);" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "preprocessing_functions = [\n", 177 | "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n", 178 | "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n", 179 | "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n", 180 | "# ['colorize object red', partial(img_preprocess, colorize=True)],\n", 181 | "# ['add red outline', partial(img_preprocess, outline=True)],\n", 182 | " \n", 183 | "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n", 184 | "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n", 185 | "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n", 186 | "# ['BG blur', partial(img_preprocess, blur=3)],\n", 187 | "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", 188 | " \n", 189 | "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n", 190 | "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n", 191 | " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n", 192 | " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", 193 | "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n", 194 | "]\n", 195 | "\n", 196 | "preprocessing_functions = preprocessing_functions\n", 197 | "\n", 198 | "base, base_p = get_similarities(lvis, lambda x: x[1])\n", 199 | "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "for j in range(1):\n", 218 | " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "from pandas import DataFrame\n", 228 | "tab = dict()\n", 229 | "for j, (name, _) in enumerate(preprocessing_functions):\n", 230 | " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n", 231 | " \n", 232 | " \n", 233 | "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) " 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "# Visual" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "from evaluation_utils import denorm, norm" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "def load_sample(filename, filename2):\n", 259 | " from os.path import join\n", 260 | " bp = expanduser('~/cloud/resources/sample_images')\n", 261 | " tf = transforms.Compose([\n", 262 | " transforms.ToTensor(),\n", 263 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 264 | " transforms.Resize(224),\n", 265 | " transforms.CenterCrop(224)\n", 266 | " ])\n", 267 | " tf2 = transforms.Compose([\n", 268 | " transforms.ToTensor(),\n", 269 | " transforms.Resize(224),\n", 270 | " transforms.CenterCrop(224)\n", 271 | " ])\n", 272 | " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n", 273 | " inp1[1] = inp1[1].unsqueeze(0)\n", 274 | " inp1[2] = inp1[2][:1] \n", 275 | " return inp1\n", 276 | "\n", 277 | "def all_preprocessing(inp1):\n", 278 | " return [\n", 279 | " img_preprocess(inp1),\n", 280 | " img_preprocess(inp1, colorize=True),\n", 281 | " img_preprocess(inp1, outline=True), \n", 282 | " img_preprocess(inp1, blur=3),\n", 283 | " img_preprocess(inp1, bg_fac=0.1),\n", 284 | " #img_preprocess(inp1, bg_fac=0.5),\n", 285 | " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n", 286 | " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n", 287 | " ]\n", 288 | "\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "from torchvision import transforms\n", 298 | "from PIL import Image\n", 299 | "from matplotlib import pyplot as plt\n", 300 | "from evaluation_utils import img_preprocess\n", 301 | "import clip\n", 302 | "\n", 303 | "images_queries = [\n", 304 | " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n", 305 | " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n", 306 | "]\n", 307 | "\n", 308 | "\n", 309 | "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n", 310 | "\n", 311 | "for j, (images, objects) in enumerate(images_queries):\n", 312 | " \n", 313 | " joint_image = all_preprocessing(images)\n", 314 | " \n", 315 | " joint_image = torch.stack(joint_image)[:,0]\n", 316 | " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n", 317 | " image_features = clip_model.encode_image(joint_image)\n", 318 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 319 | " \n", 320 | " prompts = [f'a photo of a {obj}'for obj in objects]\n", 321 | " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n", 322 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n", 323 | " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n", 324 | " sim = torch.softmax(logits, dim=-1).detach().cpu()\n", 325 | "\n", 326 | " for i, img in enumerate(joint_image):\n", 327 | " ax[2*j, i].axis('off')\n", 328 | " \n", 329 | " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n", 330 | " ax[2*j+ 1, i].grid(True)\n", 331 | " \n", 332 | " ax[2*j + 1, i].set_ylim(0,1)\n", 333 | " ax[2*j + 1, i].set_yticklabels([])\n", 334 | " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n", 335 | "# ax[1, i].set_xticklabels(objects, rotation=90)\n", 336 | " for k in range(len(sim[i])):\n", 337 | " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n", 338 | " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n", 339 | "\n", 340 | "plt.tight_layout()\n", 341 | "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')" 342 | ] 343 | } 344 | ], 345 | "metadata": { 346 | "kernelspec": { 347 | "display_name": "env2", 348 | "language": "python", 349 | "name": "env2" 350 | }, 351 | "language_info": { 352 | "codemirror_mode": { 353 | "name": "ipython", 354 | "version": 3 355 | }, 356 | "file_extension": ".py", 357 | "mimetype": "text/x-python", 358 | "name": "python", 359 | "nbconvert_exporter": "python", 360 | "pygments_lexer": "ipython3", 361 | "version": "3.8.8" 362 | } 363 | }, 364 | "nbformat": 4, 365 | "nbformat_minor": 4 366 | } 367 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timojl/clipseg/77c4e2dc853880df46413d4a870a725e961f2f73/__init__.py -------------------------------------------------------------------------------- /datasets/coco_wrapper.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from types import new_class 3 | import torch 4 | import numpy as np 5 | import os 6 | import json 7 | 8 | from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename 9 | from random import shuffle, seed as set_seed 10 | from PIL import Image 11 | 12 | from itertools import combinations 13 | from torchvision import transforms 14 | from torchvision.transforms.transforms import Resize 15 | 16 | from datasets.utils import blend_image_segmentation 17 | from general_utils import get_from_repository 18 | 19 | COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'} 20 | 21 | class COCOWrapper(object): 22 | 23 | def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0, 24 | with_class_label=False): 25 | super().__init__() 26 | 27 | self.mask = mask 28 | self.with_class_label = with_class_label 29 | self.negative_prob = negative_prob 30 | 31 | from third_party.hsnet.data.coco import DatasetCOCO 32 | 33 | get_from_repository('COCO-20i', ['COCO-20i.tar']) 34 | 35 | foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl') 36 | 37 | def build_img_metadata_classwise(self): 38 | with open(foldpath % (self.split, self.fold), 'rb') as f: 39 | img_metadata_classwise = pickle.load(f) 40 | return img_metadata_classwise 41 | 42 | 43 | DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise 44 | # DatasetCOCO.read_mask = read_mask 45 | 46 | mean = [0.485, 0.456, 0.406] 47 | std = [0.229, 0.224, 0.225] 48 | transform = transforms.Compose([ 49 | transforms.Resize((image_size, image_size)), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean, std) 52 | ]) 53 | 54 | self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False) 55 | 56 | self.all_classes = [self.coco.class_ids] 57 | self.coco.base_path = join(expanduser('~/datasets/COCO-20i')) 58 | 59 | def __len__(self): 60 | return len(self.coco) 61 | 62 | def __getitem__(self, i): 63 | sample = self.coco[i] 64 | 65 | label_name = COCO_CLASSES[int(sample['class_id'])] 66 | 67 | img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0] 68 | 69 | if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob: 70 | new_class_id = sample['class_id'] 71 | while new_class_id == sample['class_id']: 72 | sample2 = self.coco[torch.randint(0, len(self), (1,)).item()] 73 | new_class_id = sample2['class_id'] 74 | img_s = sample2['support_imgs'][0] 75 | seg_s = torch.zeros_like(seg_s) 76 | 77 | mask = self.mask 78 | if mask == 'separate': 79 | supp = (img_s, seg_s) 80 | elif mask == 'text_label': 81 | # DEPRECATED 82 | supp = [int(sample['class_id'])] 83 | elif mask == 'text': 84 | supp = [label_name] 85 | else: 86 | if mask.startswith('text_and_'): 87 | mask = mask[9:] 88 | label_add = [label_name] 89 | else: 90 | label_add = [] 91 | 92 | supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask) 93 | 94 | if self.with_class_label: 95 | label = (torch.zeros(0), sample['class_id'],) 96 | else: 97 | label = (torch.zeros(0), ) 98 | 99 | return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label -------------------------------------------------------------------------------- /datasets/pascal_classes.json: -------------------------------------------------------------------------------- 1 | [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}] -------------------------------------------------------------------------------- /datasets/pascal_zeroshot.py: -------------------------------------------------------------------------------- 1 | from os.path import expanduser 2 | import torch 3 | import json 4 | import torchvision 5 | from general_utils import get_from_repository 6 | from general_utils import log 7 | from torchvision import transforms 8 | 9 | PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], 10 | ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], 11 | ['chair.n.01', 'pot_plant.n.01']] 12 | 13 | 14 | class PascalZeroShot(object): 15 | 16 | def __init__(self, split, n_unseen, image_size=224) -> None: 17 | super().__init__() 18 | 19 | import sys 20 | sys.path.append('third_party/JoEm') 21 | from third_party.JoEm.data_loader.dataset import VOCSegmentation 22 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC 23 | 24 | self.pascal_classes = VOC 25 | self.image_size = image_size 26 | 27 | self.transform = transforms.Compose([ 28 | transforms.Resize((image_size, image_size)), 29 | ]) 30 | 31 | if split == 'train': 32 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), 33 | split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), 34 | ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) 35 | elif split == 'val': 36 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), 37 | split=split, transform=False, 38 | ignore_bg=False, ignore_unseen=False) 39 | 40 | self.unseen_idx = get_unseen_idx(n_unseen) 41 | 42 | def __len__(self): 43 | return len(self.voc) 44 | 45 | def __getitem__(self, i): 46 | 47 | sample = self.voc[i] 48 | label = sample['label'].long() 49 | all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] 50 | class_indices = [l for l in all_labels] 51 | class_names = [self.pascal_classes[l] for l in all_labels] 52 | 53 | image = self.transform(sample['image']) 54 | 55 | label = transforms.Resize((self.image_size, self.image_size), 56 | interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] 57 | 58 | return (image,), (label, ) 59 | 60 | 61 | -------------------------------------------------------------------------------- /datasets/pfe_dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import expanduser 2 | import torch 3 | import json 4 | from general_utils import get_from_repository 5 | from datasets.utils import blend_image_segmentation 6 | from general_utils import log 7 | 8 | PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))} 9 | 10 | 11 | class PFEPascalWrapper(object): 12 | 13 | def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None): 14 | import sys 15 | # sys.path.append(expanduser('~/projects/new_one_shot')) 16 | from third_party.PFENet.util.dataset import SemData 17 | 18 | get_from_repository('PascalVOC2012', ['Pascal5i.tar']) 19 | 20 | self.p_negative = p_negative 21 | self.size = size 22 | self.mode = mode 23 | self.image_size = image_size 24 | 25 | if label_support in {True, False}: 26 | log.warning('label_support argument is deprecated. Use mask instead.') 27 | #raise ValueError() 28 | 29 | self.mask = mask 30 | 31 | value_scale = 255 32 | mean = [0.485, 0.456, 0.406] 33 | mean = [item * value_scale for item in mean] 34 | std = [0.229, 0.224, 0.225] 35 | std = [item * value_scale for item in std] 36 | 37 | import third_party.PFENet.util.transform as transform 38 | 39 | if mode == 'val': 40 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt') 41 | 42 | data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else [] 43 | data_transform += [ 44 | transform.ToTensor(), 45 | transform.Normalize(mean=mean, std=std) 46 | ] 47 | 48 | 49 | elif mode == 'train': 50 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt') 51 | 52 | assert image_size != 'original' 53 | 54 | data_transform = [ 55 | transform.RandScale([0.9, 1.1]), 56 | transform.RandRotate([-10, 10], padding=mean, ignore_label=255), 57 | transform.RandomGaussianBlur(), 58 | transform.RandomHorizontalFlip(), 59 | transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255), 60 | transform.ToTensor(), 61 | transform.Normalize(mean=mean, std=std) 62 | ] 63 | 64 | data_transform = transform.Compose(data_transform) 65 | 66 | self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'), 67 | data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False) 68 | 69 | self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list 70 | 71 | # verify that subcls_list always has length 1 72 | # assert len(set([len(d[4]) for d in self.dataset])) == 1 73 | 74 | print('actual length', len(self.dataset.data_list)) 75 | 76 | def __len__(self): 77 | if self.mode == 'val': 78 | return len(self.dataset.data_list) 79 | else: 80 | return len(self.dataset.data_list) 81 | 82 | def __getitem__(self, index): 83 | if self.dataset.mode == 'train': 84 | image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)] 85 | elif self.dataset.mode == 'val': 86 | image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)] 87 | ori_label = torch.from_numpy(ori_label).unsqueeze(0) 88 | 89 | if self.image_size != 'original': 90 | longerside = max(ori_label.size(1), ori_label.size(2)) 91 | backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255 92 | backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label 93 | label = backmask.clone().long() 94 | else: 95 | label = label.unsqueeze(0) 96 | 97 | # assert label.shape == (473, 473) 98 | 99 | if self.p_negative > 0: 100 | if torch.rand(1).item() < self.p_negative: 101 | while True: 102 | idx = torch.randint(0, len(self.dataset.data_list), (1,)).item() 103 | _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx] 104 | if subcls_list[0] != subcls_list_tmp[0]: 105 | break 106 | 107 | s_x = s_x[0] 108 | s_y = (s_y == 1)[0] 109 | label_fg = (label == 1).float() 110 | val_mask = (label != 255).float() 111 | 112 | class_id = self.class_list[subcls_list[0]] 113 | 114 | label_name = PASCAL_CLASSES[class_id][0] 115 | label_add = () 116 | mask = self.mask 117 | 118 | if mask == 'text': 119 | support = ('a photo of a ' + label_name + '.',) 120 | elif mask == 'separate': 121 | support = (s_x, s_y) 122 | else: 123 | if mask.startswith('text_and_'): 124 | label_add = (label_name,) 125 | mask = mask[9:] 126 | 127 | support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],) 128 | 129 | return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0]) 130 | -------------------------------------------------------------------------------- /datasets/phrasecut.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import os 5 | 6 | from os.path import join, isdir, isfile, expanduser 7 | from PIL import Image 8 | 9 | from torchvision import transforms 10 | from torchvision.transforms.transforms import Resize 11 | 12 | from torch.nn import functional as nnf 13 | from general_utils import get_from_repository 14 | 15 | from skimage.draw import polygon2mask 16 | 17 | PASCAL_SYNSETS = ['person.n.01', 'bird.n.01', 'cat.n.01', 'cattle.n.01', 'dog.n.01', 'horse.n.01', 'sheep.n.01', 'aeroplane.n.01', 'bicycle.n.01', 18 | 'boat.n.01', 'bus.n.01', 'car.n.01', 'motorcycle.n.01', 'train.n.01', 'bottle.n.01', 'chair.n.01', 'kitchen_table.n.01', 19 | 'breakfast_table.n.01', 'trestle_table.n.01','pot_plant.n.01', 'sofa.n.01', 'television.n.03'] 20 | 21 | PASCAL_5I_SYNSETS_ORDERED = [ 22 | 'aeroplane.n.01', 'bicycle.n.01', 'bird.n.01', 'vessel.n.02', 'bottle.n.01', 'bus.n.01', 'car.n.01', 23 | 'cat.n.01', 'chair.n.01', 'cattle.n.01', 'table.n.02', 'dog.n.01', 'horse.n.01', 'motorcycle.n.01', 24 | 'person.n.01', 'pot_plant.n.01', 'sheep.n.01', 'sofa.n.01', 'train.n.01', 'television.n.03'] 25 | 26 | # Pascal-5i classes 27 | PASCAL_5I_CLASS_IDS = { 28 | 3: list(range(1, 16)), 29 | 2: list(range(1, 11)) + list(range(16, 21)), 30 | 1: list(range(1, 6)) + list(range(11, 21)), 31 | 0: list(range(6, 21)) 32 | } 33 | 34 | 35 | def traverse_lemmas(synset): 36 | """ list of all lemma names of hypernyms """ 37 | 38 | out = synset.lemma_names() 39 | for h in synset.hypernyms(): 40 | out += traverse_lemmas(h) 41 | 42 | return out 43 | 44 | 45 | def traverse_lemmas_hypo(synset, max_depth=None, depth=0): 46 | """ list of all lemma names of hypernyms """ 47 | 48 | if synset.name() == 'person.n.01': 49 | return ['person', 'human', 'man', 'woman', 'toddler', 'baby', 'body', 'child', 'infant'] 50 | 51 | out = [l.name() for l in synset.lemmas()] 52 | 53 | # print(' '*depth, synset.name()) 54 | 55 | if max_depth is not None and depth >= max_depth: 56 | return out 57 | 58 | # print(' '*(3*depth), out) 59 | for h in synset.hyponyms(): 60 | out += traverse_lemmas_hypo(h, max_depth, depth+1) 61 | 62 | return out 63 | 64 | 65 | 66 | def random_crop_slices(origin_size, target_size): 67 | """Gets slices of a random crop. """ 68 | assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}' 69 | 70 | offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high 71 | offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item() 72 | 73 | return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1]) 74 | 75 | 76 | def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None): 77 | 78 | best_crops = [] 79 | best_crop_not_ok = float('-inf'), None, None 80 | min_sum = 0 81 | 82 | seg = seg.astype('bool') 83 | 84 | if min_frac is not None: 85 | #min_sum = seg.sum() * min_frac 86 | min_sum = seg.shape[0] * seg.shape[1] * min_frac 87 | 88 | for iteration in range(iterations): 89 | sl_y, sl_x = random_crop_slices(seg.shape, image_size) 90 | seg_ = seg[sl_y, sl_x] 91 | sum_seg_ = seg_.sum() 92 | 93 | if sum_seg_ > min_sum: 94 | 95 | if best_of is None: 96 | return sl_y, sl_x, False 97 | else: 98 | best_crops += [(sum_seg_, sl_y, sl_x)] 99 | if len(best_crops) >= best_of: 100 | best_crops.sort(key=lambda x:x[0], reverse=True) 101 | sl_y, sl_x = best_crops[0][1:] 102 | 103 | return sl_y, sl_x, False 104 | 105 | else: 106 | if sum_seg_ > best_crop_not_ok[0]: 107 | best_crop_not_ok = sum_seg_, sl_y, sl_x 108 | 109 | else: 110 | # return best segmentation found 111 | return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,) 112 | 113 | 114 | class PhraseCut(object): 115 | 116 | def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True, 117 | min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None): 118 | super().__init__() 119 | 120 | self.negative_prob = negative_prob 121 | self.image_size = image_size 122 | self.with_visual = with_visual 123 | self.only_visual = only_visual 124 | self.phrase_form = '{}' 125 | self.mask = mask 126 | self.aug_crop = aug_crop 127 | 128 | if aug_color: 129 | self.aug_color = transforms.Compose([ 130 | transforms.ColorJitter(0.5, 0.5, 0.2, 0.05), 131 | ]) 132 | else: 133 | self.aug_color = None 134 | 135 | get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([ 136 | isdir(join(local_dir, 'VGPhraseCut_v0')), 137 | isdir(join(local_dir, 'VGPhraseCut_v0', 'images')), 138 | isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')), 139 | len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249} 140 | ])) 141 | 142 | from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader 143 | self.refvg_loader = RefVGLoader(split=split) 144 | 145 | # img_ids where the size in the annotations does not match actual size 146 | invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530, 147 | 150333, 286065, 285814, 498187, 285761, 498042]) 148 | 149 | mean = [0.485, 0.456, 0.406] 150 | std = [0.229, 0.224, 0.225] 151 | self.normalize = transforms.Normalize(mean, std) 152 | 153 | self.sample_ids = [(i, j) 154 | for i in self.refvg_loader.img_ids 155 | for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases'])) 156 | if i not in invalid_img_ids] 157 | 158 | 159 | # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']])) 160 | 161 | from nltk.stem import WordNetLemmatizer 162 | wnl = WordNetLemmatizer() 163 | 164 | # Filter by class (if remove_classes is set) 165 | if remove_classes is None: 166 | pass 167 | else: 168 | from nltk.corpus import wordnet 169 | 170 | print('remove pascal classes...') 171 | 172 | get_data = self.refvg_loader.get_img_ref_data # shortcut 173 | keep_sids = None 174 | 175 | if remove_classes[0] == 'pas5i': 176 | subset_id = remove_classes[1] 177 | 178 | avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]] 179 | 180 | 181 | elif remove_classes[0] == 'zs': 182 | stop = remove_classes[1] 183 | 184 | from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS 185 | 186 | avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set] 187 | print(avoid) 188 | 189 | elif remove_classes[0] == 'aff': 190 | # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02'] 191 | # all_lemmas = set(['drink', 'sit', 'ride']) 192 | avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting', 193 | 'ride', 'rides', 'riding', 194 | 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven', 195 | 'swim', 'swims', 'swimming', 196 | 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears'] 197 | keep_sids = [(i, j) for i, j in self.sample_ids if 198 | all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))] 199 | 200 | print('avoid classes:', avoid) 201 | 202 | 203 | if keep_sids is None: 204 | all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)] 205 | all_lemmas = list(set(all_lemmas)) 206 | all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas] 207 | all_lemmas = set(all_lemmas) 208 | 209 | # divide into multi word and single word 210 | all_lemmas_s = set(l for l in all_lemmas if ' ' not in l) 211 | all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s) 212 | 213 | # new3 214 | phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids] 215 | remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases) 216 | if any(l in phrase for l in all_lemmas_m) or 217 | len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0 218 | ) 219 | keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids] 220 | 221 | print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}') 222 | removed_ids = set(self.sample_ids) - set(keep_sids) 223 | 224 | print('Examples of removed', len(removed_ids)) 225 | for i, j in list(removed_ids)[:20]: 226 | print(i, get_data(i)['phrases'][j]) 227 | 228 | self.sample_ids = keep_sids 229 | 230 | from itertools import groupby 231 | samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j)) 232 | for i, j in self.sample_ids] 233 | samples_by_phrase = sorted(samples_by_phrase) 234 | samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0]) 235 | 236 | self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase} 237 | 238 | self.all_phrases = list(set(self.samples_by_phrase.keys())) 239 | 240 | 241 | if self.only_visual: 242 | assert self.with_visual 243 | self.sample_ids = [(i, j) for i, j in self.sample_ids 244 | if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1] 245 | 246 | # Filter by size (if min_size is set) 247 | sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids] 248 | image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids] 249 | #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes] 250 | self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)] 251 | 252 | if min_size: 253 | print('filter by size') 254 | 255 | self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size] 256 | 257 | self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/')) 258 | 259 | def __len__(self): 260 | return len(self.sample_ids) 261 | 262 | 263 | def load_sample(self, sample_i, j): 264 | 265 | img_ref_data = self.refvg_loader.get_img_ref_data(sample_i) 266 | 267 | polys_phrase0 = img_ref_data['gt_Polygons'][j] 268 | phrase = img_ref_data['phrases'][j] 269 | phrase = self.phrase_form.format(phrase) 270 | 271 | masks = [] 272 | for polys in polys_phrase0: 273 | for poly in polys: 274 | poly = [p[::-1] for p in poly] # swap x,y 275 | masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)] 276 | 277 | seg = np.stack(masks).max(0) 278 | img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg'))) 279 | 280 | min_shape = min(img.shape[:2]) 281 | 282 | if self.aug_crop: 283 | sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05) 284 | else: 285 | sly, slx = slice(0, None), slice(0, None) 286 | 287 | seg = seg[sly, slx] 288 | img = img[sly, slx] 289 | 290 | seg = seg.astype('uint8') 291 | seg = torch.from_numpy(seg).view(1, 1, *seg.shape) 292 | 293 | if img.ndim == 2: 294 | img = np.dstack([img] * 3) 295 | 296 | img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() 297 | 298 | seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0] 299 | img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0] 300 | 301 | # img = img.permute([2,0, 1]) 302 | img = img / 255.0 303 | 304 | if self.aug_color is not None: 305 | img = self.aug_color(img) 306 | 307 | img = self.normalize(img) 308 | 309 | 310 | 311 | return img, seg, phrase 312 | 313 | def __getitem__(self, i): 314 | 315 | sample_i, j = self.sample_ids[i] 316 | 317 | img, seg, phrase = self.load_sample(sample_i, j) 318 | 319 | if self.negative_prob > 0: 320 | if torch.rand((1,)).item() < self.negative_prob: 321 | 322 | new_phrase = None 323 | while new_phrase is None or new_phrase == phrase: 324 | idx = torch.randint(0, len(self.all_phrases), (1,)).item() 325 | new_phrase = self.all_phrases[idx] 326 | phrase = new_phrase 327 | seg = torch.zeros_like(seg) 328 | 329 | if self.with_visual: 330 | # find a corresponding visual image 331 | if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1: 332 | idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item() 333 | other_sample = self.samples_by_phrase[phrase][idx] 334 | #print(other_sample) 335 | img_s, seg_s, _ = self.load_sample(*other_sample) 336 | 337 | from datasets.utils import blend_image_segmentation 338 | 339 | if self.mask in {'separate', 'text_and_separate'}: 340 | # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:] 341 | add_phrase = [phrase] if self.mask == 'text_and_separate' else [] 342 | vis_s = add_phrase + [img_s, seg_s, True] 343 | else: 344 | if self.mask.startswith('text_and_'): 345 | mask_mode = self.mask[9:] 346 | label_add = [phrase] 347 | else: 348 | mask_mode = self.mask 349 | label_add = [] 350 | 351 | masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0]) 352 | vis_s = label_add + [masked_img_s, True] 353 | 354 | else: 355 | # phrase is unique 356 | vis_s = torch.zeros_like(img) 357 | 358 | if self.mask in {'separate', 'text_and_separate'}: 359 | add_phrase = [phrase] if self.mask == 'text_and_separate' else [] 360 | vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False] 361 | elif self.mask.startswith('text_and_'): 362 | vis_s = [phrase, vis_s, False] 363 | else: 364 | vis_s = [vis_s, False] 365 | else: 366 | assert self.mask == 'text' 367 | vis_s = [phrase] 368 | 369 | seg = seg.unsqueeze(0).float() 370 | 371 | data_x = (img,) + tuple(vis_s) 372 | 373 | return data_x, (seg, torch.zeros(0), i) 374 | 375 | 376 | class PhraseCutPlus(PhraseCut): 377 | 378 | def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None): 379 | super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size, 380 | remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask) 381 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def blend_image_segmentation(img, seg, mode, image_size=224): 7 | 8 | 9 | if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}: 10 | if isinstance(img, np.ndarray): 11 | img = torch.from_numpy(img) 12 | 13 | if isinstance(seg, np.ndarray): 14 | seg = torch.from_numpy(seg) 15 | 16 | if mode == 'overlay': 17 | out = img * seg 18 | out = [out.astype('float32')] 19 | elif mode == 'highlight': 20 | out = img * seg[None, :, :] * 0.85 + 0.15 * img 21 | out = [out.astype('float32')] 22 | elif mode == 'highlight2': 23 | img = img / 2 24 | out = (img+0.1) * seg[None, :, :] + 0.3 * img 25 | out = [out.astype('float32')] 26 | elif mode == 'blur_highlight': 27 | from evaluation_utils import img_preprocess 28 | out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01] 29 | elif mode == 'blur3_highlight': 30 | from evaluation_utils import img_preprocess 31 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01] 32 | elif mode == 'blur3_highlight01': 33 | from evaluation_utils import img_preprocess 34 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01] 35 | elif mode == 'blur_highlight_random': 36 | from evaluation_utils import img_preprocess 37 | out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01] 38 | elif mode == 'crop': 39 | from evaluation_utils import img_preprocess 40 | out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()] 41 | elif mode == 'crop_blur_highlight': 42 | from evaluation_utils import img_preprocess 43 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()] 44 | elif mode == 'crop_blur_highlight352': 45 | from evaluation_utils import img_preprocess 46 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()] 47 | elif mode == 'shape': 48 | out = [np.stack([seg[:, :]]*3).astype('float32')] 49 | elif mode == 'concat': 50 | out = [np.concatenate([img, seg[None, :, :]]).astype('float32')] 51 | elif mode == 'image_only': 52 | out = [img.astype('float32')] 53 | elif mode == 'image_black': 54 | out = [img.astype('float32')*0] 55 | elif mode is None: 56 | out = [img.astype('float32')] 57 | elif mode == 'separate': 58 | out = [img.astype('float32'), seg.astype('int64')] 59 | elif mode == 'separate_img_black': 60 | out = [img.astype('float32')*0, seg.astype('int64')] 61 | elif mode == 'separate_seg_ones': 62 | out = [img.astype('float32'), np.ones_like(seg).astype('int64')] 63 | elif mode == 'separate_both_black': 64 | out = [img.astype('float32')*0, seg.astype('int64')*0] 65 | else: 66 | raise ValueError(f'invalid mode: {mode}') 67 | 68 | return out -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: clipseg-environment 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - numpy 7 | - scipy 8 | - matplotlib-base 9 | - pip 10 | - pip: 11 | - --find-links https://download.pytorch.org/whl/torch_stable.html 12 | - torch==1.12.1+cpu 13 | - torchvision==0.13.1+cpu 14 | - opencv-python 15 | - git+https://github.com/openai/CLIP.git 16 | -------------------------------------------------------------------------------- /evaluation_utils.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | from general_utils import load_model 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import numpy as np 6 | 7 | def denorm(img): 8 | 9 | np_input = False 10 | if isinstance(img, np.ndarray): 11 | img = torch.from_numpy(img) 12 | np_input = True 13 | 14 | mean = torch.Tensor([0.485, 0.456, 0.406]) 15 | std = torch.Tensor([0.229, 0.224, 0.225]) 16 | 17 | img_denorm = (img*std[:,None,None]) + mean[:,None,None] 18 | 19 | if np_input: 20 | img_denorm = np.clip(img_denorm.numpy(), 0, 1) 21 | else: 22 | img_denorm = torch.clamp(img_denorm, 0, 1) 23 | 24 | return img_denorm 25 | 26 | 27 | def norm(img): 28 | mean = torch.Tensor([0.485, 0.456, 0.406]) 29 | std = torch.Tensor([0.229, 0.224, 0.225]) 30 | return (img - mean[:,None,None]) / std[:,None,None] 31 | 32 | 33 | def fast_iou_curve(p, g): 34 | 35 | g = g[p.sort().indices] 36 | p = torch.sigmoid(p.sort().values) 37 | 38 | scores = [] 39 | vals = np.linspace(0, 1, 50) 40 | 41 | for q in vals: 42 | 43 | n = int(len(g) * q) 44 | 45 | valid = torch.where(p > q)[0] 46 | if len(valid) > 0: 47 | n = int(valid[0]) 48 | else: 49 | n = len(g) 50 | 51 | fn = g[:n].sum() 52 | tn = n - fn 53 | tp = g[n:].sum() 54 | fp = len(g) - n - tp 55 | 56 | iou = tp / (tp + fn + fp) 57 | 58 | precision = tp / (tp + fp) 59 | recall = tp / (tp + fn) 60 | 61 | scores += [iou] 62 | 63 | return vals, scores 64 | 65 | 66 | def fast_rp_curve(p, g): 67 | 68 | g = g[p.sort().indices] 69 | p = torch.sigmoid(p.sort().values) 70 | 71 | precisions, recalls = [], [] 72 | vals = np.linspace(p.min(), p.max(), 250) 73 | 74 | for q in p[::100000]: 75 | 76 | n = int(len(g) * q) 77 | 78 | valid = torch.where(p > q)[0] 79 | if len(valid) > 0: 80 | n = int(valid[0]) 81 | else: 82 | n = len(g) 83 | 84 | fn = g[:n].sum() 85 | tn = n - fn 86 | tp = g[n:].sum() 87 | fp = len(g) - n - tp 88 | 89 | iou = tp / (tp + fn + fp) 90 | 91 | precision = tp / (tp + fp) 92 | recall = tp / (tp + fn) 93 | 94 | precisions += [precision] 95 | recalls += [recall] 96 | 97 | return recalls, precisions 98 | 99 | 100 | # Image processing 101 | 102 | def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2, 103 | brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224): 104 | import cv2 105 | 106 | rw = rect_width 107 | 108 | out = [] 109 | for img, mask in zip(batch[1], batch[2]): 110 | 111 | img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img) 112 | mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask) 113 | 114 | img *= brightness 115 | img_bl = img 116 | if blur > 0: # best 5 117 | img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1) 118 | 119 | if grayscale: 120 | img_bl = img_bl[1][None] 121 | 122 | #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl 123 | # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask) 124 | img_inp = img*mask + (bg_fac) * img_bl * (1-mask) 125 | 126 | if rect: 127 | _, bbox = crop_mask(img, mask, context=0.1) 128 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None] 129 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None] 130 | img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] 131 | img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] 132 | 133 | 134 | if center_context is not None: 135 | img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size) 136 | 137 | if colorize: 138 | img_gray = denorm(img) 139 | img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY) 140 | img_gray = torch.stack([torch.from_numpy(img_gray)]*3) 141 | img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask) 142 | img_inp = norm(img_inp) 143 | 144 | if outline: 145 | cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 146 | outline_img = np.zeros(mask.shape, dtype=np.uint8) 147 | cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255)) 148 | outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255. 149 | img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img) 150 | img_inp = norm(img_inp) 151 | 152 | out += [img_inp] 153 | 154 | return torch.stack(out) 155 | 156 | 157 | def object_crop(img, mask, context=0.0, square=False, image_size=224): 158 | img_crop, bbox = crop_mask(img, mask, context=context, square=square) 159 | img_crop = pad_to_square(img_crop, channel_dim=0) 160 | img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0) 161 | return img_crop 162 | 163 | 164 | def crop_mask(img, mask, context=0.0, square=False): 165 | 166 | assert img.shape[1:] == mask.shape 167 | 168 | bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()] 169 | bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()] 170 | bbox = [int(x) for x in bbox] 171 | 172 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 173 | 174 | # square mask 175 | if square: 176 | bbox[0] = int(max(0, bbox[0] - context * height)) 177 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) 178 | bbox[2] = int(max(0, bbox[2] - context * width)) 179 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) 180 | 181 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 182 | if height > width: 183 | bbox[2] = int(max(0, (bbox[2] - 0.5*height))) 184 | bbox[3] = bbox[2] + height 185 | else: 186 | bbox[0] = int(max(0, (bbox[0] - 0.5*width))) 187 | bbox[1] = bbox[0] + width 188 | else: 189 | bbox[0] = int(max(0, bbox[0] - context * height)) 190 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) 191 | bbox[2] = int(max(0, bbox[2] - context * width)) 192 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) 193 | 194 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 195 | img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]] 196 | return img_crop, bbox 197 | 198 | 199 | def pad_to_square(img, channel_dim=2, fill=0): 200 | """ 201 | 202 | 203 | add padding such that a squared image is returned """ 204 | 205 | from torchvision.transforms.functional import pad 206 | 207 | if channel_dim == 2: 208 | img = img.permute(2, 0, 1) 209 | elif channel_dim == 0: 210 | pass 211 | else: 212 | raise ValueError('invalid channel_dim') 213 | 214 | h, w = img.shape[1:] 215 | pady1 = pady2 = padx1 = padx2 = 0 216 | 217 | if h > w: 218 | padx1 = (h - w) // 2 219 | padx2 = h - w - padx1 220 | elif w > h: 221 | pady1 = (w - h) // 2 222 | pady2 = w - h - pady1 223 | 224 | img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant') 225 | 226 | if channel_dim == 2: 227 | img_padded = img_padded.permute(1, 2, 0) 228 | 229 | return img_padded 230 | 231 | 232 | # qualitative 233 | 234 | def split_sentence(inp, limit=9): 235 | t_new, current_len = [], 0 236 | for k, t in enumerate(inp.split(' ')): 237 | current_len += len(t) + 1 238 | t_new += [t+' '] 239 | # not last 240 | if current_len > limit and k != len(inp.split(' ')) - 1: 241 | current_len = 0 242 | t_new += ['\n'] 243 | 244 | t_new = ''.join(t_new) 245 | return t_new 246 | 247 | 248 | from matplotlib import pyplot as plt 249 | 250 | 251 | def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None): 252 | 253 | row_off = 0 if labels is None else 1 254 | _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2))) 255 | [a.axis('off') for a in ax.flatten()] 256 | 257 | if labels is not None: 258 | for j in range(len(labels)): 259 | t_new = split_sentence(labels[j], limit=6) 260 | ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale) 261 | 262 | 263 | for i in range(len(imgs)): 264 | ax[i + row_off,0].imshow(imgs[i]) 265 | for j in range(len(preds)): 266 | img = preds[j][i][0].detach().cpu().numpy() 267 | 268 | if gt_labels is not None and labels[j] == gt_labels[i]: 269 | print(j, labels[j], gt_labels[i]) 270 | edgecolor = 'red' 271 | if aps is not None: 272 | ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8) 273 | else: 274 | edgecolor = 'k' 275 | 276 | rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none", 277 | edgecolor=edgecolor, linewidth=3) 278 | ax[i + row_off,1 + j].add_patch(rect) 279 | 280 | if vmax is None: 281 | this_vmax = 1 282 | elif vmax == 'per_prompt': 283 | this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))]) 284 | elif vmax == 'per_image': 285 | this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))]) 286 | 287 | ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap) 288 | 289 | 290 | # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max()) 291 | plt.tight_layout() 292 | plt.subplots_adjust(wspace=0.05, hspace=0.05) -------------------------------------------------------------------------------- /example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timojl/clipseg/77c4e2dc853880df46413d4a870a725e961f2f73/example_image.jpg -------------------------------------------------------------------------------- /experiments/ablation.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 # <-########################################## 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut # <----------------- 20 | split_mode: pascal_test 21 | split: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | negative_prob: 0.2 25 | mix_text_max: 0.5 26 | 27 | # general 28 | mix: True # <----------------- 29 | prompt: shuffle+ 30 | norm_cond: True 31 | mix_text_min: 0.0 32 | with_visual: True 33 | 34 | # model 35 | version: 'ViT-B/16' 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False # <-########################################## 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | batch_size: 32 48 | sigmoid: True 49 | split: test 50 | label_support: True 51 | 52 | test_configuration: 53 | 54 | - 55 | name: pc 56 | metric: metrics.FixedIntervalMetrics 57 | test_dataset: phrasecut 58 | mask: text 59 | 60 | - 61 | name: pc-vis 62 | metric: metrics.FixedIntervalMetrics 63 | test_dataset: phrasecut 64 | mask: crop_blur_highlight352 65 | with_visual: True 66 | visual_only: True 67 | 68 | 69 | columns: [name, 70 | pc_fgiou_best, pc_miou_best, pc_fgiou_0.5, 71 | pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5, 72 | duration] 73 | 74 | 75 | individual_configurations: 76 | 77 | - {name: rd64-uni} 78 | - {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003} 79 | - {name: rd64-no-negatives, negative_prob: 0.0} 80 | - {name: rd64-neg0.5, negative_prob: 0.5} 81 | - {name: rd64-no-visual, with_visual: False, mix: False} 82 | - {name: rd16-uni, reduce_dim: 16} 83 | - {name: rd64-layer3, extract_layers: [3], depth: 1} 84 | - {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}} -------------------------------------------------------------------------------- /experiments/coco.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.coco_wrapper.COCOWrapper 20 | # split_mode: pascal_test 21 | split: train 22 | mask: text_and_blur3_highlight01 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | 28 | # general 29 | mix: True 30 | prompt: shuffle+ 31 | norm_cond: True 32 | mix_text_min: 0.0 33 | 34 | # model 35 | out: 1 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | # max_iterations: 10 48 | batch_size: 8 49 | sigmoid: True 50 | test_dataset: coco 51 | metric: metrics.FixedIntervalMetrics 52 | 53 | test_configuration: 54 | 55 | - 56 | name: coco_t 57 | mask: text 58 | 59 | - 60 | name: coco_h 61 | mask: blur3_highlight01 62 | 63 | - 64 | name: coco_h2 65 | mask: crop_blur_highlight352 66 | 67 | 68 | columns: [i, name, 69 | coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5, 70 | coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5, 71 | coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t, 72 | train_loss, duration, date 73 | ] 74 | 75 | individual_configurations: 76 | 77 | 78 | - {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 79 | - {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 80 | - {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 81 | - {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 82 | 83 | 84 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 85 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 86 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 87 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 88 | 89 | 90 | # ViT 91 | - {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 92 | - {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 93 | - {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 94 | - {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 95 | 96 | 97 | # BASELINE 98 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 99 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 100 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 101 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} -------------------------------------------------------------------------------- /experiments/pascal_0shot.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | # val_metric_class: metrics.BinaryIoU 8 | # use_val_metric: BIoU_fg 9 | 10 | trainer: experiment_setup.train_loop 11 | scorer: experiment_setup.score 12 | model: models.clipseg.CLIPDensePredT 13 | 14 | lr_scheduler: cosine 15 | T_max: 20000 16 | eta_min: 0.0001 17 | 18 | max_iterations: 20000 # <-########################################## 19 | val_interval: null 20 | 21 | # dataset 22 | dataset: datasets.phrasecut.PhraseCut # <----------------- 23 | split_mode: pascal_test 24 | split: train 25 | image_size: 352 26 | normalize: True 27 | pre_crop_image_size: [sample, 1, 1.5] 28 | aug: 1new 29 | 30 | # new, not 31 | with_visual: True 32 | 33 | # general 34 | mix: False # <----------------- 35 | prompt: shuffle+ 36 | norm_cond: True 37 | mix_text_min: 0.0 38 | 39 | # model 40 | out: 1 41 | extract_layers: [3, 7, 9] 42 | reduce_dim: 64 43 | depth: 3 44 | fix_shift: False # <-########################################## 45 | 46 | loss: torch.nn.functional.binary_cross_entropy_with_logits 47 | amp: True 48 | 49 | test_configuration_common: 50 | normalize: True 51 | image_size: 352 52 | batch_size: 32 53 | # max_iterations: 150 54 | 55 | test_configuration: 56 | test_dataset: pascal_zs 57 | 58 | columns: [name, pas_zs_seen, pas_zs_unseen, duration, date] 59 | 60 | - {name: rd64-uni-zs5, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, remove_classes: [zs, 5], negative_prob: 0.2, mix: True, mix_text_max: 0.5, mask: text_and_crop_blur_highlight352} 61 | - {name: rd64-uni-zs2, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, remove_classes: [zs, 2], negative_prob: 0.2, mix: True, mix_text_max: 0.5, mask: text_and_crop_blur_highlight352} 62 | -------------------------------------------------------------------------------- /experiments/pascal_1shot.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 # <-########################################## 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut 20 | split_mode: pascal_test 21 | mode: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | with_visual: True 28 | split: train 29 | 30 | # general 31 | mix: True 32 | prompt: shuffle+ 33 | norm_cond: True 34 | mix_text_min: 0.0 35 | 36 | # model 37 | out: 1 38 | version: 'ViT-B/16' 39 | extract_layers: [3, 7, 9] 40 | reduce_dim: 64 41 | depth: 3 42 | 43 | loss: torch.nn.functional.binary_cross_entropy_with_logits 44 | amp: True 45 | 46 | test_configuration_common: 47 | normalize: True 48 | image_size: 352 49 | metric: metrics.FixedIntervalMetrics 50 | batch_size: 1 51 | test_dataset: pascal 52 | sigmoid: True 53 | # max_iterations: 250 54 | 55 | test_configuration: 56 | 57 | - 58 | name: pas_t 59 | mask: text 60 | 61 | - 62 | name: pas_h 63 | mask: blur3_highlight01 64 | 65 | - 66 | name: pas_h2 67 | mask: crop_blur_highlight352 68 | 69 | 70 | columns: [name, 71 | pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct, 72 | pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct, 73 | pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t, 74 | train_loss, duration, date 75 | ] 76 | 77 | individual_configurations: 78 | 79 | - {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}} 80 | - {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}} 81 | - {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}} 82 | - {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}} 83 | 84 | 85 | - {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}} 86 | - {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}} 87 | - {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}} 88 | - {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}} 89 | 90 | 91 | # baseline 92 | - {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}} 93 | - {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}} 94 | - {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}} 95 | - {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}} 96 | 97 | # ViT 98 | - {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}} 99 | - {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}} 100 | - {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}} 101 | - {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}} 102 | -------------------------------------------------------------------------------- /experiments/phrasecut.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut # <----------------- 20 | split_mode: pascal_test 21 | split: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | 28 | # general 29 | mix: False # <----------------- 30 | prompt: shuffle+ 31 | norm_cond: True 32 | mix_text_min: 0.0 33 | 34 | # model 35 | out: 1 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | batch_size: 32 48 | # max_iterations: 5 49 | # max_iterations: 150 50 | 51 | test_configuration: 52 | 53 | - 54 | name: pc # old: phrasecut 55 | metric: metrics.FixedIntervalMetrics 56 | test_dataset: phrasecut 57 | split: test 58 | mask: text 59 | label_support: True 60 | sigmoid: True 61 | 62 | 63 | columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date] 64 | 65 | 66 | individual_configurations: 67 | 68 | # important ones 69 | 70 | 71 | - {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5} 72 | 73 | # this is almost the same training setting as the refined model except for transformer dropout of 0.1 (currently not implemented in the model) 74 | - {name: rd64-uni-refined, version: 'ViT-B/16', reduce_dim: 64, negative_prob: 0.2, complex_trans_conv: True, with_visual: True, mix: True, mix_text_max: 0.5, T_max: 50000, max_iterations: 50000} 75 | 76 | 77 | # this was accedentally trained using old mask 78 | - {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01} 79 | - {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False} 80 | # this was accedentally trained using old mask 81 | - {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01} 82 | 83 | - {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003} 84 | - {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001} 85 | -------------------------------------------------------------------------------- /general_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import inspect 3 | import torch 4 | import os 5 | import sys 6 | import yaml 7 | from shutil import copy, copytree 8 | from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename 9 | 10 | 11 | class Logger(object): 12 | 13 | def __getattr__(self, k): 14 | return print 15 | 16 | log = Logger() 17 | 18 | def training_config_from_cli_args(): 19 | experiment_name = sys.argv[1] 20 | experiment_id = int(sys.argv[2]) 21 | 22 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) 23 | 24 | config = yaml_config['configuration'] 25 | config = {**config, **yaml_config['individual_configurations'][experiment_id]} 26 | config = AttributeDict(config) 27 | return config 28 | 29 | 30 | def score_config_from_cli_args(): 31 | experiment_name = sys.argv[1] 32 | experiment_id = int(sys.argv[2]) 33 | 34 | 35 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) 36 | 37 | config = yaml_config['test_configuration_common'] 38 | 39 | if type(yaml_config['test_configuration']) == list: 40 | test_id = int(sys.argv[3]) 41 | config = {**config, **yaml_config['test_configuration'][test_id]} 42 | else: 43 | config = {**config, **yaml_config['test_configuration']} 44 | 45 | if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]: 46 | config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']} 47 | 48 | train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name'] 49 | 50 | config = AttributeDict(config) 51 | return config, train_checkpoint_id 52 | 53 | 54 | def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository', 55 | local_dir='~/datasets'): 56 | """ copies files from repository to local folder. 57 | 58 | repo_files: list of filenames or list of tuples [filename, target path] 59 | 60 | e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar']) 61 | will create a folder 'MyDataset' in local_dir, and extract the content of 62 | '/data/dataset1.tar' to /MyDataset/other/path. 63 | """ 64 | 65 | local_dir = realpath(join(expanduser(local_dir), local_name)) 66 | 67 | dataset_exists = True 68 | 69 | # check if folder is available 70 | if not isdir(local_dir): 71 | dataset_exists = False 72 | 73 | if integrity_check is not None: 74 | try: 75 | integrity_ok = integrity_check(local_dir) 76 | except BaseException: 77 | integrity_ok = False 78 | 79 | if integrity_ok: 80 | log.hint('Passed custom integrity check') 81 | else: 82 | log.hint('Custom integrity check failed') 83 | 84 | dataset_exists = dataset_exists and integrity_ok 85 | 86 | if not dataset_exists: 87 | 88 | repo_dir = realpath(expanduser(repo_dir)) 89 | 90 | for i, filename in enumerate(repo_files): 91 | 92 | if type(filename) == str: 93 | origin, target = filename, filename 94 | archive_target = join(local_dir, basename(origin)) 95 | extract_target = join(local_dir) 96 | else: 97 | origin, target = filename 98 | archive_target = join(local_dir, dirname(target), basename(origin)) 99 | extract_target = join(local_dir, dirname(target)) 100 | 101 | archive_origin = join(repo_dir, origin) 102 | 103 | log.hint(f'copy: {archive_origin} to {archive_target}') 104 | 105 | # make sure the path exists 106 | os.makedirs(dirname(archive_target), exist_ok=True) 107 | 108 | if os.path.isfile(archive_target): 109 | # only copy if size differs 110 | if os.path.getsize(archive_target) != os.path.getsize(archive_origin): 111 | log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}') 112 | copy(archive_origin, archive_target) 113 | else: 114 | copy(archive_origin, archive_target) 115 | 116 | extract_archive(archive_target, extract_target, noarchive_ok=True) 117 | 118 | # concurrent processes might have deleted the file 119 | if os.path.isfile(archive_target): 120 | os.remove(archive_target) 121 | 122 | 123 | def extract_archive(filename, target_folder=None, noarchive_ok=False): 124 | from subprocess import run, PIPE 125 | 126 | if filename.endswith('.tgz') or filename.endswith('.tar'): 127 | command = f'tar -xf {filename}' 128 | command += f' -C {target_folder}' if target_folder is not None else '' 129 | elif filename.endswith('.tar.gz'): 130 | command = f'tar -xzf {filename}' 131 | command += f' -C {target_folder}' if target_folder is not None else '' 132 | elif filename.endswith('zip'): 133 | command = f'unzip {filename}' 134 | command += f' -d {target_folder}' if target_folder is not None else '' 135 | else: 136 | if noarchive_ok: 137 | return 138 | else: 139 | raise ValueError(f'unsuppored file ending of {filename}') 140 | 141 | log.hint(command) 142 | result = run(command.split(), stdout=PIPE, stderr=PIPE) 143 | if result.returncode != 0: 144 | print(result.stdout, result.stderr) 145 | 146 | 147 | class AttributeDict(dict): 148 | """ 149 | An extended dictionary that allows access to elements as atttributes and counts 150 | these accesses. This way, we know if some attributes were never used. 151 | """ 152 | 153 | def __init__(self, *args, **kwargs): 154 | from collections import Counter 155 | super().__init__(*args, **kwargs) 156 | self.__dict__['counter'] = Counter() 157 | 158 | def __getitem__(self, k): 159 | self.__dict__['counter'][k] += 1 160 | return super().__getitem__(k) 161 | 162 | def __getattr__(self, k): 163 | self.__dict__['counter'][k] += 1 164 | return super().get(k) 165 | 166 | def __setattr__(self, k, v): 167 | return super().__setitem__(k, v) 168 | 169 | def __delattr__(self, k, v): 170 | return super().__delitem__(k, v) 171 | 172 | def unused_keys(self, exceptions=()): 173 | return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions] 174 | 175 | def assume_no_unused_keys(self, exceptions=()): 176 | if len(self.unused_keys(exceptions=exceptions)) > 0: 177 | log.warning('Unused keys:', self.unused_keys(exceptions=exceptions)) 178 | 179 | 180 | def get_attribute(name): 181 | import importlib 182 | 183 | if name is None: 184 | raise ValueError('The provided attribute is None') 185 | 186 | name_split = name.split('.') 187 | mod = importlib.import_module('.'.join(name_split[:-1])) 188 | return getattr(mod, name_split[-1]) 189 | 190 | 191 | 192 | def filter_args(input_args, default_args): 193 | 194 | updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()} 195 | used_args = {k: v for k, v in input_args.items() if k in default_args} 196 | unused_args = {k: v for k, v in input_args.items() if k not in default_args} 197 | 198 | return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args) 199 | 200 | 201 | def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False): 202 | 203 | config = json.load(open(join('logs', checkpoint_id, 'config.json'))) 204 | 205 | if model_args != 'from_config' and type(model_args) != dict: 206 | raise ValueError('model_args must either be "from_config" or a dictionary of values') 207 | 208 | model_cls = get_attribute(config['model']) 209 | 210 | # load model 211 | if model_args == 'from_config': 212 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 213 | 214 | model = model_cls(**model_args) 215 | 216 | if weights_file is None: 217 | weights_file = realpath(join('logs', checkpoint_id, 'weights.pth')) 218 | else: 219 | weights_file = realpath(join('logs', checkpoint_id, weights_file)) 220 | 221 | if isfile(weights_file): 222 | weights = torch.load(weights_file) 223 | for _, w in weights.items(): 224 | assert not torch.any(torch.isnan(w)), 'weights contain NaNs' 225 | model.load_state_dict(weights, strict=strict) 226 | else: 227 | raise FileNotFoundError(f'model checkpoint {weights_file} was not found') 228 | 229 | if with_config: 230 | return model, config 231 | 232 | return model 233 | 234 | 235 | class TrainingLogger(object): 236 | 237 | def __init__(self, model, log_dir, config=None, *args): 238 | super().__init__() 239 | self.model = model 240 | self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None 241 | 242 | os.makedirs('logs/', exist_ok=True) 243 | os.makedirs(self.base_path, exist_ok=True) 244 | 245 | if config is not None: 246 | json.dump(config, open(join(self.base_path, 'config.json'), 'w')) 247 | 248 | def iter(self, i, **kwargs): 249 | if i % 100 == 0 and 'loss' in kwargs: 250 | loss = kwargs['loss'] 251 | print(f'iteration {i}: loss {loss:.4f}') 252 | 253 | def save_weights(self, only_trainable=False, weight_file='weights.pth'): 254 | if self.model is None: 255 | raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.') 256 | 257 | weights_path = join(self.base_path, weight_file) 258 | 259 | weight_dict = self.model.state_dict() 260 | 261 | if only_trainable: 262 | weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad} 263 | 264 | torch.save(weight_dict, weights_path) 265 | log.info(f'Saved weights to {weights_path}') 266 | 267 | def __enter__(self): 268 | return self 269 | 270 | def __exit__(self, type, value, traceback): 271 | """ automatically stop processes if used in a context manager """ 272 | pass -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | from general_utils import log 3 | from collections import defaultdict 4 | import numpy as np 5 | 6 | import torch 7 | from torch.nn import functional as nnf 8 | 9 | 10 | class BaseMetric(object): 11 | 12 | def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True, 13 | eval_validation=True): 14 | self._names = tuple(metric_names) 15 | self._eval_intermediate = eval_intermediate 16 | self._eval_validation = eval_validation 17 | 18 | self._pred_range = pred_range 19 | self._pred_index = pred_index 20 | self._gt_index = gt_index 21 | 22 | self.predictions = [] 23 | self.ground_truths = [] 24 | 25 | def eval_intermediate(self): 26 | return self._eval_intermediate 27 | 28 | def eval_validation(self): 29 | return self._eval_validation 30 | 31 | def names(self): 32 | return self._names 33 | 34 | def add(self, predictions, ground_truth): 35 | raise NotImplementedError 36 | 37 | def value(self): 38 | raise NotImplementedError 39 | 40 | def scores(self): 41 | # similar to value but returns dict 42 | value = self.value() 43 | if type(value) == dict: 44 | return value 45 | else: 46 | assert type(value) in {list, tuple} 47 | return list(zip(self.names(), self.value())) 48 | 49 | def _get_pred_gt(self, predictions, ground_truth): 50 | pred = predictions[self._pred_index] 51 | gt = ground_truth[self._gt_index] 52 | 53 | if self._pred_range is not None: 54 | pred = pred[:, self._pred_range[0]: self._pred_range[1]] 55 | 56 | return pred, gt 57 | 58 | 59 | class FixedIntervalMetrics(BaseMetric): 60 | 61 | def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None, 62 | resize_pred=None, n_values=51, custom_threshold=None): 63 | 64 | 65 | super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh')) 66 | self.intersections = [] 67 | self.unions = [] 68 | # self.threshold = threshold 69 | self.sigmoid = sigmoid 70 | self.resize_to = resize_to 71 | self.resize_pred = resize_pred # resize prediction to match ground truth 72 | self.class_count = defaultdict(lambda: 0) 73 | self.per_class = defaultdict(lambda : [0,0]) 74 | self.ignore_mask = ignore_mask 75 | self.custom_threshold = custom_threshold 76 | 77 | self.scores_ap = [] 78 | self.scores_iou = [] 79 | self.gts, self.preds = [], [] 80 | self.classes = [] 81 | 82 | # [1:-1] ignores 0 and 1 83 | self.threshold_values = np.linspace(0, 1, n_values)[1:-1] 84 | 85 | self.metrics = dict(tp=[], fp=[], fn=[], tn=[]) 86 | 87 | def add(self, pred, gt): 88 | 89 | pred_batch = pred[0].cpu() 90 | 91 | if self.sigmoid: 92 | pred_batch = torch.sigmoid(pred_batch) 93 | 94 | gt_batch = gt[0].cpu() 95 | mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch)) 96 | cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch) 97 | 98 | if self.resize_to is not None: 99 | gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest') 100 | pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False) 101 | 102 | if isinstance(cls_batch, torch.Tensor): 103 | cls_batch = cls_batch.cpu().numpy().tolist() 104 | 105 | assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}' 106 | 107 | for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch): 108 | 109 | if self.resize_pred: 110 | predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True) 111 | 112 | p = predictions.flatten() 113 | g = ground_truth.flatten() 114 | 115 | assert len(p) == len(g) 116 | 117 | if mask is not None: 118 | m = mask.flatten().bool() 119 | p = p[m] 120 | g = g[m] 121 | 122 | p_sorted = p.sort() 123 | p = p_sorted.values 124 | g = g[p_sorted.indices] 125 | 126 | tps, fps, fns, tns = [], [], [], [] 127 | for thresh in self.threshold_values: 128 | 129 | valid = torch.where(p > thresh)[0] 130 | if len(valid) > 0: 131 | n = int(valid[0]) 132 | else: 133 | n = len(g) 134 | 135 | fn = int(g[:n].sum()) 136 | tp = int(g[n:].sum()) 137 | fns += [fn] 138 | tns += [n - fn] 139 | tps += [tp] 140 | fps += [len(g) - n - tp] 141 | 142 | self.metrics['tp'] += [tps] 143 | self.metrics['fp'] += [fps] 144 | self.metrics['fn'] += [fns] 145 | self.metrics['tn'] += [tns] 146 | 147 | self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls] 148 | 149 | def value(self): 150 | 151 | import time 152 | t_start = time.time() 153 | 154 | if set(self.classes) == set([None]): 155 | all_classes = None 156 | log.warning('classes were not provided, cannot compute mIoU') 157 | else: 158 | all_classes = set(int(c) for c in self.classes) 159 | # log.info(f'compute metrics for {len(all_classes)} classes') 160 | 161 | summed = {k: [sum([self.metrics[k][i][j] 162 | for i in range(len(self.metrics[k]))]) 163 | for j in range(len(self.threshold_values))] 164 | for k in self.metrics.keys()} 165 | 166 | if all_classes is not None: 167 | 168 | assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn']) 169 | # group by class 170 | metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes} 171 | for i in range(len(self.metrics['tp'])): 172 | for k in self.metrics.keys(): 173 | metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]] 174 | 175 | # sum over all instances within the classes 176 | summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()} 177 | 178 | 179 | # Compute average precision 180 | 181 | assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made' 182 | 183 | # only consider values where a prediction is made 184 | precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values)) 185 | if summed['tp'][j] + summed['fp'][j] > 0] 186 | recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values)) 187 | if summed['tp'][j] + summed['fp'][j] > 0] 188 | 189 | # remove duplicate recall-precision-pairs (and sort by recall value) 190 | recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0])) 191 | 192 | from scipy.integrate import simps 193 | ap = simps(precisions, recalls) 194 | 195 | # Compute best IoU 196 | fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))] 197 | 198 | biniou_scores = [ 199 | 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) + 200 | 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j])) 201 | for j in range(len(self.threshold_values)) 202 | ] 203 | 204 | index_0p5 = self.threshold_values.tolist().index(0.5) 205 | index_0p1 = self.threshold_values.tolist().index(0.1) 206 | index_0p2 = self.threshold_values.tolist().index(0.2) 207 | index_0p3 = self.threshold_values.tolist().index(0.3) 208 | 209 | if self.custom_threshold is not None: 210 | index_ct = self.threshold_values.tolist().index(self.custom_threshold) 211 | 212 | if all_classes is not None: 213 | # mean IoU 214 | mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j]) 215 | for c in all_classes]) 216 | for j in range(len(self.threshold_values))] 217 | 218 | mean_iou_dict = { 219 | 'miou_best': max(mean_ious) if all_classes is not None else None, 220 | 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None, 221 | 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None, 222 | 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None, 223 | 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None, 224 | 'miou_best_t': self.threshold_values[np.argmax(mean_ious)], 225 | 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None, 226 | 'mean_iou_scores': mean_ious, 227 | } 228 | 229 | print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s') 230 | 231 | return { 232 | 'ap': ap, 233 | 234 | # fgiou 235 | 'fgiou_best': max(fgiou_scores), 236 | 'fgiou_0.5': fgiou_scores[index_0p5], 237 | 'fgiou_0.1': fgiou_scores[index_0p1], 238 | 'fgiou_0.2': fgiou_scores[index_0p2], 239 | 'fgiou_0.3': fgiou_scores[index_0p3], 240 | 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)], 241 | 242 | # mean iou 243 | 244 | 245 | # biniou 246 | 'biniou_best': max(biniou_scores), 247 | 'biniou_0.5': biniou_scores[index_0p5], 248 | 'biniou_0.1': biniou_scores[index_0p1], 249 | 'biniou_0.2': biniou_scores[index_0p2], 250 | 'biniou_0.3': biniou_scores[index_0p3], 251 | 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)], 252 | 253 | # custom threshold 254 | 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None, 255 | 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None, 256 | 'ct': self.custom_threshold, 257 | 258 | # statistics 259 | 'fgiou_scores': fgiou_scores, 260 | 'biniou_scores': biniou_scores, 261 | 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))), 262 | 'summed_statistics': summed, 263 | 'summed_by_cls_statistics': summed_by_cls, 264 | 265 | **mean_iou_dict 266 | } 267 | 268 | # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh' 269 | 270 | # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls} 271 | 272 | -------------------------------------------------------------------------------- /models/clipseg.py: -------------------------------------------------------------------------------- 1 | import math 2 | from os.path import basename, dirname, join, isfile 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as nnf 6 | from torch.nn.modules.activation import ReLU 7 | 8 | 9 | def get_prompt_list(prompt): 10 | if prompt == 'plain': 11 | return ['{}'] 12 | elif prompt == 'fixed': 13 | return ['a photo of a {}.'] 14 | elif prompt == 'shuffle': 15 | return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] 16 | elif prompt == 'shuffle+': 17 | return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', 18 | 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', 19 | 'a bad photo of a {}.', 'a photo of the {}.'] 20 | else: 21 | raise ValueError('Invalid value for prompt') 22 | 23 | 24 | def forward_multihead_attention(x, b, with_aff=False, attn_mask=None): 25 | """ 26 | Simplified version of multihead attention (taken from torch source code but without tons of if clauses). 27 | The mlp and layer norm come from CLIP. 28 | x: input. 29 | b: multihead attention module. 30 | """ 31 | 32 | x_ = b.ln_1(x) 33 | q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1) 34 | tgt_len, bsz, embed_dim = q.size() 35 | 36 | head_dim = embed_dim // b.attn.num_heads 37 | scaling = float(head_dim) ** -0.5 38 | 39 | q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) 40 | k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) 41 | v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) 42 | 43 | q = q * scaling 44 | 45 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2 46 | if attn_mask is not None: 47 | 48 | 49 | attn_mask_type, attn_mask = attn_mask 50 | n_heads = attn_output_weights.size(0) // attn_mask.size(0) 51 | attn_mask = attn_mask.repeat(n_heads, 1) 52 | 53 | if attn_mask_type == 'cls_token': 54 | # the mask only affects similarities compared to the readout-token. 55 | attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...] 56 | # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0] 57 | 58 | if attn_mask_type == 'all': 59 | # print(attn_output_weights.shape, attn_mask[:, None].shape) 60 | attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None] 61 | 62 | 63 | attn_output_weights = torch.softmax(attn_output_weights, dim=-1) 64 | 65 | attn_output = torch.bmm(attn_output_weights, v) 66 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 67 | attn_output = b.attn.out_proj(attn_output) 68 | 69 | x = x + attn_output 70 | x = x + b.mlp(b.ln_2(x)) 71 | 72 | if with_aff: 73 | return x, attn_output_weights 74 | else: 75 | return x 76 | 77 | 78 | class CLIPDenseBase(nn.Module): 79 | 80 | def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens): 81 | super().__init__() 82 | 83 | import clip 84 | 85 | # prec = torch.FloatTensor 86 | self.clip_model, _ = clip.load(version, device='cpu', jit=False) 87 | self.model = self.clip_model.visual 88 | 89 | # if not None, scale conv weights such that we obtain n_tokens. 90 | self.n_tokens = n_tokens 91 | 92 | for p in self.clip_model.parameters(): 93 | p.requires_grad_(False) 94 | 95 | # conditional 96 | if reduce_cond is not None: 97 | self.reduce_cond = nn.Linear(512, reduce_cond) 98 | for p in self.reduce_cond.parameters(): 99 | p.requires_grad_(False) 100 | else: 101 | self.reduce_cond = None 102 | 103 | self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 104 | self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 105 | 106 | self.reduce = nn.Linear(768, reduce_dim) 107 | 108 | self.prompt_list = get_prompt_list(prompt) 109 | 110 | # precomputed prompts 111 | import pickle 112 | if isfile('precomputed_prompt_vectors.pickle'): 113 | precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) 114 | self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} 115 | else: 116 | self.precomputed_prompts = dict() 117 | 118 | def rescaled_pos_emb(self, new_size): 119 | assert len(new_size) == 2 120 | 121 | a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) 122 | b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T 123 | return torch.cat([self.model.positional_embedding[:1], b]) 124 | 125 | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): 126 | 127 | 128 | with torch.no_grad(): 129 | 130 | inp_size = x_inp.shape[2:] 131 | 132 | if self.n_tokens is not None: 133 | stride2 = x_inp.shape[2] // self.n_tokens 134 | conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True) 135 | x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation) 136 | else: 137 | x = self.model.conv1(x_inp) # shape = [*, width, grid, grid] 138 | 139 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 140 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 141 | 142 | x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 143 | 144 | standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197 145 | 146 | if x.shape[1] != standard_n_tokens: 147 | new_shape = int(math.sqrt(x.shape[1]-1)) 148 | x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:] 149 | else: 150 | x = x + self.model.positional_embedding.to(x.dtype) 151 | 152 | x = self.model.ln_pre(x) 153 | 154 | x = x.permute(1, 0, 2) # NLD -> LND 155 | 156 | activations, affinities = [], [] 157 | for i, res_block in enumerate(self.model.transformer.resblocks): 158 | 159 | if mask is not None: 160 | mask_layer, mask_type, mask_tensor = mask 161 | if mask_layer == i or mask_layer == 'all': 162 | # import ipdb; ipdb.set_trace() 163 | size = int(math.sqrt(x.shape[0] - 1)) 164 | 165 | attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size)) 166 | 167 | else: 168 | attn_mask = None 169 | else: 170 | attn_mask = None 171 | 172 | x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask) 173 | 174 | if i in extract_layers: 175 | affinities += [aff_per_head] 176 | 177 | #if self.n_tokens is not None: 178 | # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)] 179 | #else: 180 | activations += [x] 181 | 182 | if len(extract_layers) > 0 and i == max(extract_layers) and skip: 183 | print('early skip') 184 | break 185 | 186 | x = x.permute(1, 0, 2) # LND -> NLD 187 | x = self.model.ln_post(x[:, 0, :]) 188 | 189 | if self.model.proj is not None: 190 | x = x @ self.model.proj 191 | 192 | return x, activations, affinities 193 | 194 | def sample_prompts(self, words, prompt_list=None): 195 | 196 | prompt_list = prompt_list if prompt_list is not None else self.prompt_list 197 | 198 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 199 | prompts = [prompt_list[i] for i in prompt_indices] 200 | return [promt.format(w) for promt, w in zip(prompts, words)] 201 | 202 | def get_cond_vec(self, conditional, batch_size): 203 | # compute conditional from a single string 204 | if conditional is not None and type(conditional) == str: 205 | cond = self.compute_conditional(conditional) 206 | cond = cond.repeat(batch_size, 1) 207 | 208 | # compute conditional from string list/tuple 209 | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: 210 | assert len(conditional) == batch_size 211 | cond = self.compute_conditional(conditional) 212 | 213 | # use conditional directly 214 | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: 215 | cond = conditional 216 | 217 | # compute conditional from image 218 | elif conditional is not None and type(conditional) == torch.Tensor: 219 | with torch.no_grad(): 220 | cond, _, _ = self.visual_forward(conditional) 221 | else: 222 | raise ValueError('invalid conditional') 223 | return cond 224 | 225 | def compute_conditional(self, conditional): 226 | import clip 227 | 228 | dev = next(self.parameters()).device 229 | 230 | if type(conditional) in {list, tuple}: 231 | text_tokens = clip.tokenize(conditional).to(dev) 232 | cond = self.clip_model.encode_text(text_tokens) 233 | else: 234 | if conditional in self.precomputed_prompts: 235 | cond = self.precomputed_prompts[conditional].float().to(dev) 236 | else: 237 | text_tokens = clip.tokenize([conditional]).to(dev) 238 | cond = self.clip_model.encode_text(text_tokens)[0] 239 | 240 | if self.shift_vector is not None: 241 | return cond + self.shift_vector 242 | else: 243 | return cond 244 | 245 | 246 | def clip_load_untrained(version): 247 | assert version == 'ViT-B/16' 248 | from clip.model import CLIP 249 | from clip.clip import _MODELS, _download 250 | model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval() 251 | state_dict = model.state_dict() 252 | 253 | vision_width = state_dict["visual.conv1.weight"].shape[0] 254 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 255 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 256 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 257 | image_resolution = vision_patch_size * grid_size 258 | embed_dim = state_dict["text_projection"].shape[1] 259 | context_length = state_dict["positional_embedding"].shape[0] 260 | vocab_size = state_dict["token_embedding.weight"].shape[0] 261 | transformer_width = state_dict["ln_final.weight"].shape[0] 262 | transformer_heads = transformer_width // 64 263 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 264 | 265 | return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, 266 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) 267 | 268 | 269 | class CLIPDensePredT(CLIPDenseBase): 270 | 271 | def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', 272 | extra_blocks=0, reduce_cond=None, fix_shift=False, 273 | learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False, 274 | add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None, complex_trans_conv=False): 275 | 276 | super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) 277 | # device = 'cpu' 278 | 279 | self.extract_layers = extract_layers 280 | self.cond_layer = cond_layer 281 | self.limit_to_clip_only = limit_to_clip_only 282 | self.process_cond = None 283 | self.rev_activations = rev_activations 284 | 285 | depth = len(extract_layers) 286 | 287 | if add_calibration: 288 | self.calibration_conds = 1 289 | 290 | self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None 291 | 292 | self.add_activation1 = True 293 | 294 | self.version = version 295 | 296 | self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] 297 | 298 | if fix_shift: 299 | # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False) 300 | self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False) 301 | # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False) 302 | else: 303 | self.shift_vector = None 304 | 305 | if trans_conv is None: 306 | trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] 307 | else: 308 | # explicitly define transposed conv kernel size 309 | trans_conv_ks = (trans_conv, trans_conv) 310 | 311 | if not complex_trans_conv: 312 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 313 | else: 314 | assert trans_conv_ks[0] == trans_conv_ks[1] 315 | 316 | tp_kernels = (trans_conv_ks[0] // 4, trans_conv_ks[0] // 4) 317 | 318 | self.trans_conv = nn.Sequential( 319 | nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1), 320 | nn.ReLU(), 321 | nn.ConvTranspose2d(reduce_dim, reduce_dim // 2, kernel_size=tp_kernels[0], stride=tp_kernels[0]), 322 | nn.ReLU(), 323 | nn.ConvTranspose2d(reduce_dim // 2, 1, kernel_size=tp_kernels[1], stride=tp_kernels[1]), 324 | ) 325 | 326 | # self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 327 | 328 | assert len(self.extract_layers) == depth 329 | 330 | self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) 331 | self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) 332 | self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) 333 | 334 | # refinement and trans conv 335 | 336 | if learn_trans_conv_only: 337 | for p in self.parameters(): 338 | p.requires_grad_(False) 339 | 340 | for p in self.trans_conv.parameters(): 341 | p.requires_grad_(True) 342 | 343 | self.prompt_list = get_prompt_list(prompt) 344 | 345 | 346 | def forward(self, inp_image, conditional=None, return_features=False, mask=None): 347 | 348 | assert type(return_features) == bool 349 | 350 | inp_image = inp_image.to(self.model.positional_embedding.device) 351 | 352 | if mask is not None: 353 | raise ValueError('mask not supported') 354 | 355 | # x_inp = normalize(inp_image) 356 | x_inp = inp_image 357 | 358 | bs, dev = inp_image.shape[0], x_inp.device 359 | 360 | cond = self.get_cond_vec(conditional, bs) 361 | 362 | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) 363 | 364 | activation1 = activations[0] 365 | activations = activations[1:] 366 | 367 | _activations = activations[::-1] if not self.rev_activations else activations 368 | 369 | a = None 370 | for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)): 371 | 372 | if a is not None: 373 | a = reduce(activation) + a 374 | else: 375 | a = reduce(activation) 376 | 377 | if i == self.cond_layer: 378 | if self.reduce_cond is not None: 379 | cond = self.reduce_cond(cond) 380 | 381 | a = self.film_mul(cond) * a + self.film_add(cond) 382 | 383 | a = block(a) 384 | 385 | for block in self.extra_blocks: 386 | a = a + block(a) 387 | 388 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens 389 | 390 | size = int(math.sqrt(a.shape[2])) 391 | 392 | a = a.view(bs, a.shape[1], size, size) 393 | 394 | a = self.trans_conv(a) 395 | 396 | if self.n_tokens is not None: 397 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True) 398 | 399 | if self.upsample_proj is not None: 400 | a = self.upsample_proj(a) 401 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') 402 | 403 | if return_features: 404 | return a, visual_q, cond, [activation1] + activations 405 | else: 406 | return a, 407 | 408 | 409 | 410 | class CLIPDensePredTMasked(CLIPDensePredT): 411 | 412 | def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, 413 | prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False, 414 | refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None): 415 | 416 | super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim, 417 | n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond, 418 | fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only, 419 | limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration, 420 | n_tokens=n_tokens) 421 | 422 | def visual_forward_masked(self, img_s, seg_s): 423 | return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s)) 424 | 425 | def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False): 426 | 427 | if seg_s is None: 428 | cond = cond_or_img_s 429 | else: 430 | img_s = cond_or_img_s 431 | 432 | with torch.no_grad(): 433 | cond, _, _ = self.visual_forward_masked(img_s, seg_s) 434 | 435 | return super().forward(img_q, cond, return_features=return_features) 436 | 437 | 438 | 439 | class CLIPDenseBaseline(CLIPDenseBase): 440 | 441 | def __init__(self, version='ViT-B/32', cond_layer=0, 442 | extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed', 443 | reduce_cond=None, limit_to_clip_only=False, n_tokens=None): 444 | 445 | super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) 446 | device = 'cpu' 447 | 448 | # self.cond_layer = cond_layer 449 | self.extract_layer = extract_layer 450 | self.limit_to_clip_only = limit_to_clip_only 451 | self.shift_vector = None 452 | 453 | self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] 454 | 455 | assert reduce2_dim is not None 456 | 457 | self.reduce2 = nn.Sequential( 458 | nn.Linear(reduce_dim, reduce2_dim), 459 | nn.ReLU(), 460 | nn.Linear(reduce2_dim, reduce_dim) 461 | ) 462 | 463 | trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] 464 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 465 | 466 | 467 | def forward(self, inp_image, conditional=None, return_features=False): 468 | 469 | inp_image = inp_image.to(self.model.positional_embedding.device) 470 | 471 | # x_inp = normalize(inp_image) 472 | x_inp = inp_image 473 | 474 | bs, dev = inp_image.shape[0], x_inp.device 475 | 476 | cond = self.get_cond_vec(conditional, bs) 477 | 478 | visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer]) 479 | 480 | a = activations[0] 481 | a = self.reduce(a) 482 | a = self.film_mul(cond) * a + self.film_add(cond) 483 | 484 | if self.reduce2 is not None: 485 | a = self.reduce2(a) 486 | 487 | # the original model would execute a transformer block here 488 | 489 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens 490 | 491 | size = int(math.sqrt(a.shape[2])) 492 | 493 | a = a.view(bs, a.shape[1], size, size) 494 | a = self.trans_conv(a) 495 | 496 | if return_features: 497 | return a, visual_q, cond, activations 498 | else: 499 | return a, 500 | 501 | 502 | class CLIPSegMultiLabel(nn.Module): 503 | 504 | def __init__(self, model) -> None: 505 | super().__init__() 506 | 507 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC 508 | 509 | self.pascal_classes = VOC 510 | 511 | from models.clipseg import CLIPDensePredT 512 | from general_utils import load_model 513 | # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False) 514 | self.clipseg = load_model(model, strict=False) 515 | 516 | self.clipseg.eval() 517 | 518 | def forward(self, x): 519 | 520 | bs = x.shape[0] 521 | out = torch.ones(21, bs, 352, 352).to(x.device) * -10 522 | 523 | for class_id, class_name in enumerate(self.pascal_classes): 524 | 525 | fac = 3 if class_name == 'background' else 1 526 | 527 | with torch.no_grad(): 528 | pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac 529 | 530 | out[class_id] += pred 531 | 532 | 533 | out = out.permute(1, 0, 2, 3) 534 | 535 | return out 536 | 537 | # construct output tensor 538 | 539 | -------------------------------------------------------------------------------- /models/vitseg.py: -------------------------------------------------------------------------------- 1 | import math 2 | from posixpath import basename, dirname, join 3 | # import clip 4 | from clip.model import convert_weights 5 | import torch 6 | import json 7 | from torch import nn 8 | from torch.nn import functional as nnf 9 | from torch.nn.modules import activation 10 | from torch.nn.modules.activation import ReLU 11 | from torchvision import transforms 12 | 13 | normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 14 | 15 | from torchvision.models import ResNet 16 | 17 | 18 | def process_prompts(conditional, prompt_list, conditional_map): 19 | # DEPRECATED 20 | 21 | # randomly sample a synonym 22 | words = [conditional_map[int(i)] for i in conditional] 23 | words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] 24 | words = [w.replace('_', ' ') for w in words] 25 | 26 | if prompt_list is not None: 27 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 28 | prompts = [prompt_list[i] for i in prompt_indices] 29 | else: 30 | prompts = ['a photo of {}'] * (len(words)) 31 | 32 | return [promt.format(w) for promt, w in zip(prompts, words)] 33 | 34 | 35 | class VITDenseBase(nn.Module): 36 | 37 | def rescaled_pos_emb(self, new_size): 38 | assert len(new_size) == 2 39 | 40 | a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) 41 | b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T 42 | return torch.cat([self.model.positional_embedding[:1], b]) 43 | 44 | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): 45 | 46 | with torch.no_grad(): 47 | 48 | x_inp = nnf.interpolate(x_inp, (384, 384)) 49 | 50 | x = self.model.patch_embed(x_inp) 51 | cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 52 | if self.model.dist_token is None: 53 | x = torch.cat((cls_token, x), dim=1) 54 | else: 55 | x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 56 | x = self.model.pos_drop(x + self.model.pos_embed) 57 | 58 | activations = [] 59 | for i, block in enumerate(self.model.blocks): 60 | x = block(x) 61 | 62 | if i in extract_layers: 63 | # permute to be compatible with CLIP 64 | activations += [x.permute(1,0,2)] 65 | 66 | x = self.model.norm(x) 67 | x = self.model.head(self.model.pre_logits(x[:, 0])) 68 | 69 | # again for CLIP compatibility 70 | # x = x.permute(1, 0, 2) 71 | 72 | return x, activations, None 73 | 74 | def sample_prompts(self, words, prompt_list=None): 75 | 76 | prompt_list = prompt_list if prompt_list is not None else self.prompt_list 77 | 78 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 79 | prompts = [prompt_list[i] for i in prompt_indices] 80 | return [promt.format(w) for promt, w in zip(prompts, words)] 81 | 82 | def get_cond_vec(self, conditional, batch_size): 83 | # compute conditional from a single string 84 | if conditional is not None and type(conditional) == str: 85 | cond = self.compute_conditional(conditional) 86 | cond = cond.repeat(batch_size, 1) 87 | 88 | # compute conditional from string list/tuple 89 | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: 90 | assert len(conditional) == batch_size 91 | cond = self.compute_conditional(conditional) 92 | 93 | # use conditional directly 94 | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: 95 | cond = conditional 96 | 97 | # compute conditional from image 98 | elif conditional is not None and type(conditional) == torch.Tensor: 99 | with torch.no_grad(): 100 | cond, _, _ = self.visual_forward(conditional) 101 | else: 102 | raise ValueError('invalid conditional') 103 | return cond 104 | 105 | def compute_conditional(self, conditional): 106 | import clip 107 | 108 | dev = next(self.parameters()).device 109 | 110 | if type(conditional) in {list, tuple}: 111 | text_tokens = clip.tokenize(conditional).to(dev) 112 | cond = self.clip_model.encode_text(text_tokens) 113 | else: 114 | if conditional in self.precomputed_prompts: 115 | cond = self.precomputed_prompts[conditional].float().to(dev) 116 | else: 117 | text_tokens = clip.tokenize([conditional]).to(dev) 118 | cond = self.clip_model.encode_text(text_tokens)[0] 119 | 120 | return cond 121 | 122 | 123 | class VITDensePredT(VITDenseBase): 124 | 125 | def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', 126 | depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, 127 | learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, 128 | add_calibration=False, process_cond=None, not_pretrained=False): 129 | super().__init__() 130 | # device = 'cpu' 131 | 132 | self.extract_layers = extract_layers 133 | self.cond_layer = cond_layer 134 | self.limit_to_clip_only = limit_to_clip_only 135 | self.process_cond = None 136 | 137 | if add_calibration: 138 | self.calibration_conds = 1 139 | 140 | self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None 141 | 142 | self.add_activation1 = True 143 | 144 | import timm 145 | self.model = timm.create_model('vit_base_patch16_384', pretrained=True) 146 | self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) 147 | 148 | for p in self.model.parameters(): 149 | p.requires_grad_(False) 150 | 151 | import clip 152 | self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) 153 | # del self.clip_model.visual 154 | 155 | 156 | self.token_shape = (14, 14) 157 | 158 | # conditional 159 | if reduce_cond is not None: 160 | self.reduce_cond = nn.Linear(512, reduce_cond) 161 | for p in self.reduce_cond.parameters(): 162 | p.requires_grad_(False) 163 | else: 164 | self.reduce_cond = None 165 | 166 | # self.film = AVAILABLE_BLOCKS['film'](512, 128) 167 | self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 168 | self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 169 | 170 | # DEPRECATED 171 | # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} 172 | 173 | assert len(self.extract_layers) == depth 174 | 175 | self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) 176 | self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) 177 | self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) 178 | 179 | trans_conv_ks = (16, 16) 180 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 181 | 182 | # refinement and trans conv 183 | 184 | if learn_trans_conv_only: 185 | for p in self.parameters(): 186 | p.requires_grad_(False) 187 | 188 | for p in self.trans_conv.parameters(): 189 | p.requires_grad_(True) 190 | 191 | if prompt == 'fixed': 192 | self.prompt_list = ['a photo of a {}.'] 193 | elif prompt == 'shuffle': 194 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] 195 | elif prompt == 'shuffle+': 196 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', 197 | 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', 198 | 'a bad photo of a {}.', 'a photo of the {}.'] 199 | elif prompt == 'shuffle_clip': 200 | from models.clip_prompts import imagenet_templates 201 | self.prompt_list = imagenet_templates 202 | 203 | if process_cond is not None: 204 | if process_cond == 'clamp' or process_cond[0] == 'clamp': 205 | 206 | val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 207 | 208 | def clamp_vec(x): 209 | return torch.clamp(x, -val, val) 210 | 211 | self.process_cond = clamp_vec 212 | 213 | elif process_cond.endswith('.pth'): 214 | 215 | shift = torch.load(process_cond) 216 | def add_shift(x): 217 | return x + shift.to(x.device) 218 | 219 | self.process_cond = add_shift 220 | 221 | import pickle 222 | precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) 223 | self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} 224 | 225 | 226 | def forward(self, inp_image, conditional=None, return_features=False, mask=None): 227 | 228 | assert type(return_features) == bool 229 | 230 | # inp_image = inp_image.to(self.model.positional_embedding.device) 231 | 232 | if mask is not None: 233 | raise ValueError('mask not supported') 234 | 235 | # x_inp = normalize(inp_image) 236 | x_inp = inp_image 237 | 238 | bs, dev = inp_image.shape[0], x_inp.device 239 | 240 | inp_image_size = inp_image.shape[2:] 241 | 242 | cond = self.get_cond_vec(conditional, bs) 243 | 244 | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) 245 | 246 | activation1 = activations[0] 247 | activations = activations[1:] 248 | 249 | a = None 250 | for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): 251 | 252 | if a is not None: 253 | a = reduce(activation) + a 254 | else: 255 | a = reduce(activation) 256 | 257 | if i == self.cond_layer: 258 | if self.reduce_cond is not None: 259 | cond = self.reduce_cond(cond) 260 | 261 | a = self.film_mul(cond) * a + self.film_add(cond) 262 | 263 | a = block(a) 264 | 265 | for block in self.extra_blocks: 266 | a = a + block(a) 267 | 268 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens 269 | 270 | size = int(math.sqrt(a.shape[2])) 271 | 272 | a = a.view(bs, a.shape[1], size, size) 273 | 274 | if self.trans_conv is not None: 275 | a = self.trans_conv(a) 276 | 277 | if self.upsample_proj is not None: 278 | a = self.upsample_proj(a) 279 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') 280 | 281 | a = nnf.interpolate(a, inp_image_size) 282 | 283 | if return_features: 284 | return a, visual_q, cond, [activation1] + activations 285 | else: 286 | return a, 287 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timojl/clipseg/77c4e2dc853880df46413d4a870a725e961f2f73/overview.png -------------------------------------------------------------------------------- /sample_rd64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timojl/clipseg/77c4e2dc853880df46413d4a870a725e961f2f73/sample_rd64.png -------------------------------------------------------------------------------- /sample_rd64_refined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timojl/clipseg/77c4e2dc853880df46413d4a870a725e961f2f73/sample_rd64_refined.png -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | 3 | import torch 4 | import inspect 5 | import json 6 | import yaml 7 | import time 8 | import sys 9 | 10 | from general_utils import log 11 | 12 | import numpy as np 13 | from os.path import expanduser, join, isfile, realpath 14 | 15 | from torch.utils.data import DataLoader 16 | 17 | from metrics import FixedIntervalMetrics 18 | 19 | from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args 20 | 21 | 22 | DATASET_CACHE = dict() 23 | 24 | def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False): 25 | 26 | config = json.load(open(join('logs', checkpoint_id, 'config.json'))) 27 | 28 | if model_args != 'from_config' and type(model_args) != dict: 29 | raise ValueError('model_args must either be "from_config" or a dictionary of values') 30 | 31 | model_cls = get_attribute(config['model']) 32 | 33 | # load model 34 | if model_args == 'from_config': 35 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 36 | 37 | model = model_cls(**model_args) 38 | 39 | if weights_file is None: 40 | weights_file = realpath(join('logs', checkpoint_id, 'weights.pth')) 41 | else: 42 | weights_file = realpath(join('logs', checkpoint_id, weights_file)) 43 | 44 | if isfile(weights_file) and not ignore_weights: 45 | weights = torch.load(weights_file) 46 | for _, w in weights.items(): 47 | assert not torch.any(torch.isnan(w)), 'weights contain NaNs' 48 | model.load_state_dict(weights, strict=strict) 49 | else: 50 | if not ignore_weights: 51 | raise FileNotFoundError(f'model checkpoint {weights_file} was not found') 52 | 53 | if with_config: 54 | return model, config 55 | 56 | return model 57 | 58 | 59 | def compute_shift2(model, datasets, seed=123, repetitions=1): 60 | """ computes shift """ 61 | 62 | model.eval() 63 | model.cuda() 64 | 65 | import random 66 | random.seed(seed) 67 | 68 | preds, gts = [], [] 69 | for i_dataset, dataset in enumerate(datasets): 70 | 71 | loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False) 72 | 73 | max_iterations = int(repetitions * len(dataset.dataset.data_list)) 74 | 75 | with torch.no_grad(): 76 | 77 | i, losses = 0, [] 78 | for i_all, (data_x, data_y) in enumerate(loader): 79 | 80 | data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x] 81 | data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y] 82 | 83 | pred, = model(data_x[0], data_x[1], data_x[2]) 84 | preds += [pred.detach()] 85 | gts += [data_y] 86 | 87 | i += 1 88 | if max_iterations and i >= max_iterations: 89 | break 90 | 91 | from metrics import FixedIntervalMetrics 92 | n_values = 51 93 | thresholds = np.linspace(0, 1, n_values)[1:-1] 94 | metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values) 95 | 96 | for p, y in zip(preds, gts): 97 | metric.add(p.unsqueeze(1), y) 98 | 99 | best_idx = np.argmax(metric.value()['fgiou_scores']) 100 | best_thresh = thresholds[best_idx] 101 | 102 | return best_thresh 103 | 104 | 105 | def get_cached_pascal_pfe(split, config): 106 | from datasets.pfe_dataset import PFEPascalWrapper 107 | try: 108 | dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] 109 | except KeyError: 110 | dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support) 111 | DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset 112 | return dataset 113 | 114 | 115 | 116 | 117 | def main(): 118 | config, train_checkpoint_id = score_config_from_cli_args() 119 | 120 | metrics = score(config, train_checkpoint_id, None) 121 | 122 | for dataset in metrics.keys(): 123 | for k in metrics[dataset]: 124 | if type(metrics[dataset][k]) in {float, int}: 125 | print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}') 126 | 127 | 128 | def score(config, train_checkpoint_id, train_config): 129 | 130 | config = AttributeDict(config) 131 | 132 | print(config) 133 | 134 | # use training dataset and loss 135 | train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json'))) 136 | 137 | cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else '' 138 | 139 | 140 | model_cls = get_attribute(train_config['model']) 141 | 142 | _, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters) 143 | 144 | model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}} 145 | 146 | strict_models = {'ConditionBase4', 'PFENetWrapper'} 147 | model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args, 148 | weights_file=f'weights{cp_str}.pth', ) 149 | 150 | 151 | model.eval() 152 | model.cuda() 153 | 154 | metric_args = dict() 155 | 156 | if 'threshold' in config: 157 | if config.metric.split('.')[-1] == 'SkLearnMetrics': 158 | metric_args['threshold'] = config.threshold 159 | 160 | if 'resize_to' in config: 161 | metric_args['resize_to'] = config.resize_to 162 | 163 | if 'sigmoid' in config: 164 | metric_args['sigmoid'] = config.sigmoid 165 | 166 | if 'custom_threshold' in config: 167 | metric_args['custom_threshold'] = config.custom_threshold 168 | 169 | if config.test_dataset == 'pascal': 170 | 171 | loss_fn = get_attribute(train_config.loss) 172 | # assume that if no split is specified in train_config, test on all splits, 173 | 174 | if 'splits' in config: 175 | splits = config.splits 176 | else: 177 | if 'split' in train_config and type(train_config.split) == int: 178 | # unless train_config has a split set, in that case assume train mode in training 179 | splits = [train_config.split] 180 | assert train_config.mode == 'train' 181 | else: 182 | splits = [0,1,2,3] 183 | 184 | log.info('Test on these splits', splits) 185 | 186 | scores = dict() 187 | for split in splits: 188 | 189 | shift = config.shift if 'shift' in config else 0 190 | 191 | # automatic shift 192 | if shift == 'auto': 193 | shift_compute_t = time.time() 194 | shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac) 195 | log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s') 196 | 197 | dataset = get_cached_pascal_pfe(split, config) 198 | 199 | eval_start_t = time.time() 200 | 201 | loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False) 202 | 203 | assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1' 204 | 205 | metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args) 206 | 207 | with torch.no_grad(): 208 | 209 | i, losses = 0, [] 210 | for i_all, (data_x, data_y) in enumerate(loader): 211 | 212 | data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] 213 | data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] 214 | 215 | if config.mask == 'separate': # for old CondBase model 216 | pred, = model(data_x[0], data_x[1], data_x[2]) 217 | else: 218 | # assert config.mask in {'text', 'highlight'} 219 | pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) 220 | 221 | # loss = loss_fn(pred, data_y[0]) 222 | metric.add(pred.unsqueeze(1) + shift, data_y) 223 | 224 | # losses += [float(loss)] 225 | 226 | i += 1 227 | if config.max_iterations and i >= config.max_iterations: 228 | break 229 | 230 | #scores[split] = {m: s for m, s in zip(metric.names(), metric.value())} 231 | 232 | log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.') 233 | 234 | print(metric.value()['mean_iou_scores']) 235 | 236 | scores[split] = metric.scores() 237 | 238 | log.info(f'Completed split {split}') 239 | 240 | key_prefix = config['name'] if 'name' in config else 'pas' 241 | 242 | all_keys = set.intersection(*[set(v.keys()) for v in scores.values()]) 243 | 244 | valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())] 245 | 246 | return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}} 247 | 248 | 249 | if config.test_dataset == 'coco': 250 | from datasets.coco_wrapper import COCOWrapper 251 | 252 | coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask, 253 | with_class_label=True) 254 | 255 | log.info('Dataset length', len(coco_dataset)) 256 | loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False) 257 | 258 | metric = get_attribute(config.metric)(resize_pred=True, **metric_args) 259 | 260 | shift = config.shift if 'shift' in config else 0 261 | 262 | with torch.no_grad(): 263 | 264 | i, losses = 0, [] 265 | for i_all, (data_x, data_y) in enumerate(loader): 266 | data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] 267 | data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] 268 | 269 | if config.mask == 'separate': # for old CondBase model 270 | pred, = model(data_x[0], data_x[1], data_x[2]) 271 | else: 272 | # assert config.mask in {'text', 'highlight'} 273 | pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) 274 | 275 | metric.add([pred + shift], data_y) 276 | 277 | i += 1 278 | if config.max_iterations and i >= config.max_iterations: 279 | break 280 | 281 | key_prefix = config['name'] if 'name' in config else 'coco' 282 | return {key_prefix: metric.scores()} 283 | #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}} 284 | 285 | 286 | if config.test_dataset == 'phrasecut': 287 | from datasets.phrasecut import PhraseCut 288 | 289 | only_visual = config.only_visual is not None and config.only_visual 290 | with_visual = config.with_visual is not None and config.with_visual 291 | 292 | dataset = PhraseCut('test', 293 | image_size=train_config.image_size, 294 | mask=config.mask, 295 | with_visual=with_visual, only_visual=only_visual, aug_crop=False, 296 | aug_color=False) 297 | 298 | loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False) 299 | metric = get_attribute(config.metric)(resize_pred=True, **metric_args) 300 | 301 | shift = config.shift if 'shift' in config else 0 302 | 303 | 304 | with torch.no_grad(): 305 | 306 | i, losses = 0, [] 307 | for i_all, (data_x, data_y) in enumerate(loader): 308 | data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] 309 | data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] 310 | 311 | pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) 312 | metric.add([pred + shift], data_y) 313 | 314 | i += 1 315 | if config.max_iterations and i >= config.max_iterations: 316 | break 317 | 318 | key_prefix = config['name'] if 'name' in config else 'phrasecut' 319 | return {key_prefix: metric.scores()} 320 | #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}} 321 | 322 | if config.test_dataset == 'pascal_zs': 323 | from third_party.JoEm.model.metric import Evaluator 324 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC 325 | from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS 326 | 327 | from models.clipseg import CLIPSegMultiLabel 328 | 329 | n_unseen = train_config.remove_classes[1] 330 | 331 | pz = PascalZeroShot('val', n_unseen, image_size=352) 332 | m = CLIPSegMultiLabel(model=train_config.name).cuda() 333 | m.eval(); 334 | 335 | print(len(pz), n_unseen) 336 | print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set]) 337 | 338 | print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)]) 339 | print('seen', [VOC[i] for i in get_seen_idx(n_unseen)]) 340 | 341 | loader = DataLoader(pz, batch_size=8) 342 | evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen)) 343 | 344 | for i, (data_x, data_y) in enumerate(loader): 345 | pred = m(data_x[0].cuda()) 346 | evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy()) 347 | 348 | if config.max_iter is not None and i > config.max_iter: 349 | break 350 | 351 | scores = evaluator.Mean_Intersection_over_Union() 352 | key_prefix = config['name'] if 'name' in config else 'pas_zs' 353 | 354 | return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}} 355 | 356 | elif config.test_dataset in {'same_as_training', 'affordance'}: 357 | loss_fn = get_attribute(train_config.loss) 358 | 359 | metric_cls = get_attribute(config.metric) 360 | metric = metric_cls(**metric_args) 361 | 362 | if config.test_dataset == 'same_as_training': 363 | dataset_cls = get_attribute(train_config.dataset) 364 | elif config.test_dataset == 'affordance': 365 | dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance') 366 | dataset_name = 'aff' 367 | else: 368 | dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot') 369 | dataset_name = 'lvis' 370 | 371 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) 372 | 373 | dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation 374 | 375 | if model.__class__.__name__ == 'PFENetWrapper': 376 | dataset_args['image_size'] = config.image_size 377 | 378 | log.info('init dataset', str(dataset_cls)) 379 | dataset = dataset_cls(**dataset_args) 380 | 381 | log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}') 382 | 383 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle) 384 | 385 | # explicitly set prompts 386 | if config.prompt == 'plain': 387 | model.prompt_list = ['{}'] 388 | elif config.prompt == 'fixed': 389 | model.prompt_list = ['a photo of a {}.'] 390 | elif config.prompt == 'shuffle': 391 | model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] 392 | elif config.prompt == 'shuffle_clip': 393 | from models.clip_prompts import imagenet_templates 394 | model.prompt_list = imagenet_templates 395 | 396 | config.assume_no_unused_keys(exceptions=['max_iterations']) 397 | 398 | t_start = time.time() 399 | 400 | with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9) 401 | i, losses = 0, [] 402 | for data_x, data_y in data_loader: 403 | 404 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] 405 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] 406 | 407 | if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}: 408 | pred, = model(data_x[0], data_x[1], data_x[2]) 409 | visual_q = None 410 | else: 411 | pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True) 412 | 413 | loss = loss_fn(pred, data_y[0]) 414 | 415 | metric.add([pred], data_y) 416 | 417 | losses += [float(loss)] 418 | 419 | i += 1 420 | if config.max_iterations and i >= config.max_iterations: 421 | break 422 | 423 | # scores = {m: s for m, s in zip(metric.names(), metric.value())} 424 | scores = metric.scores() 425 | 426 | keys = set(scores.keys()) 427 | if dataset.negative_prob > 0 and 'mIoU' in keys: 428 | keys.remove('mIoU') 429 | 430 | name_mask = dataset.mask.replace('text_label', 'txt')[:3] 431 | name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob) 432 | 433 | score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}' 434 | 435 | scores = {score_name: {k: v for k,v in scores.items() if k in keys}} 436 | scores[score_name].update({'test_loss': np.mean(losses)}) 437 | 438 | log.info(f'Evaluation took {time.time() - t_start:.1f}s') 439 | 440 | return scores 441 | else: 442 | raise ValueError('invalid test dataset') 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | if __name__ == '__main__': 453 | main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r", encoding="utf-8") as readme_file: 4 | readme = readme_file.read() 5 | 6 | requirements = [ 7 | "numpy", 8 | "scipy", 9 | "matplotlib", 10 | "torch", 11 | "torchvision", 12 | "opencv-python", 13 | "CLIP @ git+https://github.com/openai/CLIP.git" 14 | ] 15 | 16 | setup( 17 | name='clipseg', 18 | packages=['clipseg'], 19 | package_dir={'clipseg': 'models'}, 20 | package_data={'clipseg': [ 21 | "../weights/*.pth", 22 | ]}, 23 | version='0.0.1', 24 | url='https://github.com/timojl/clipseg', 25 | python_requires='>=3.9', 26 | install_requires=requirements, 27 | description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".', 28 | long_description=readme, 29 | long_description_content_type="text/markdown", 30 | ) 31 | -------------------------------------------------------------------------------- /supplementary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timojl/clipseg/77c4e2dc853880df46413d4a870a725e961f2f73/supplementary.pdf -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | import json 4 | import yaml 5 | import math 6 | import os 7 | import sys 8 | 9 | from general_utils import log 10 | 11 | import numpy as np 12 | from functools import partial 13 | from os.path import expanduser, join, isfile, basename 14 | 15 | from torch.cuda.amp import autocast, GradScaler 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from contextlib import nullcontext 18 | from torch.utils.data import DataLoader 19 | 20 | from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args 21 | 22 | 23 | def cosine_warmup_lr(i, warmup=10, max_iter=90): 24 | """ Cosine LR with Warmup """ 25 | if i < warmup: 26 | return (i+1)/(warmup+1) 27 | else: 28 | return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup)))) 29 | 30 | 31 | def validate(model, dataset, config): 32 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False) 33 | 34 | metric_class, use_metric = config.val_metric_class, config.use_val_metric 35 | loss_fn = get_attribute(config.loss) 36 | 37 | model.eval() 38 | model.cuda() 39 | 40 | if metric_class is not None: 41 | metric = get_attribute(metric_class)() 42 | 43 | with torch.no_grad(): 44 | 45 | i, losses = 0, [] 46 | for data_x, data_y in data_loader: 47 | 48 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] 49 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] 50 | 51 | prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',)) 52 | pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True) 53 | 54 | if metric_class is not None: 55 | metric.add([pred], data_y) 56 | 57 | # pred = model(data_x[0], prompts) 58 | # loss = loss_fn(pred[0], data_y[0]) 59 | loss = loss_fn(pred, data_y[0]) 60 | losses += [float(loss)] 61 | 62 | i += 1 63 | 64 | if config.val_max_iterations is not None and i > config.val_max_iterations: 65 | break 66 | 67 | if use_metric is None: 68 | return np.mean(losses), {}, False 69 | else: 70 | metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {} 71 | return np.mean(losses), metric_scores, True 72 | 73 | 74 | def main(): 75 | 76 | config = training_config_from_cli_args() 77 | 78 | val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf') 79 | 80 | model_cls = get_attribute(config.model) 81 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 82 | model = model_cls(**model_args).cuda() 83 | 84 | dataset_cls = get_attribute(config.dataset) 85 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) 86 | 87 | dataset = dataset_cls(**dataset_args) 88 | 89 | log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})') 90 | 91 | if val_interval is not None: 92 | dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'} 93 | _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters) 94 | print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 95 | 96 | dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 97 | 98 | # optimizer 99 | opt_cls = get_attribute(config.optimizer) 100 | if config.optimize == 'torch.optim.SGD': 101 | opt_args = {'momentum': config.momentum if 'momentum' in config else 0} 102 | else: 103 | opt_args = {} 104 | opt = opt_cls(model.parameters(), lr=config.lr, **opt_args) 105 | 106 | if config.lr_scheduler == 'cosine': 107 | assert config.T_max is not None and config.eta_min is not None 108 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min) 109 | elif config.lr_scheduler == 'warmup_cosine': 110 | lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup)) 111 | else: 112 | lr_scheduler = None 113 | 114 | batch_size, max_iterations = config.batch_size, config.max_iterations 115 | 116 | loss_fn = get_attribute(config.loss) 117 | 118 | if config.amp: 119 | log.info('Using AMP') 120 | autocast_fn = autocast 121 | scaler = GradScaler() 122 | else: 123 | autocast_fn, scaler = nullcontext, None 124 | 125 | 126 | save_only_trainable = True 127 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4) 128 | 129 | # disable config when hyperparam. opt. to avoid writing logs. 130 | tracker_config = config if not config.hyperparameter_optimization else None 131 | 132 | with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger: 133 | 134 | i = 0 135 | while True: 136 | for data_x, data_y in data_loader: 137 | 138 | # between caption and output feature. 139 | # 1. Sample random captions 140 | # 2. Check alignment with CLIP 141 | 142 | # randomly mix text and visual support conditionals 143 | if config.mix: 144 | 145 | assert config.mask.startswith('text_and') 146 | 147 | with autocast_fn(): 148 | # data_x[1] = text label 149 | prompts = model.sample_prompts(data_x[1]) 150 | 151 | # model.clip_model() 152 | 153 | text_cond = model.compute_conditional(prompts) 154 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 155 | # when mask=='separate' 156 | visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda()) 157 | else: 158 | # data_x[2] = visual prompt 159 | visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda()) 160 | 161 | max_txt = config.mix_text_max if config.mix_text_max is not None else 1 162 | batch_size = text_cond.shape[0] 163 | 164 | # sample weights for each element in batch 165 | text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None] 166 | text_weights = text_weights.cuda() 167 | 168 | if dataset.__class__.__name__ == 'PhraseCut': 169 | # give full weight to text where support_image is invalid 170 | visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3] 171 | text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1) 172 | 173 | cond = text_cond * text_weights + visual_s_cond * (1 - text_weights) 174 | 175 | else: 176 | # no mix 177 | 178 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 179 | # compute conditional vector using CLIP masking 180 | with autocast_fn(): 181 | assert config.mask == 'separate' 182 | cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda()) 183 | else: 184 | cond = data_x[1] 185 | if isinstance(cond, torch.Tensor): 186 | cond = cond.cuda() 187 | 188 | with autocast_fn(): 189 | visual_q = None 190 | 191 | pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True) 192 | 193 | loss = loss_fn(pred, data_y[0].cuda()) 194 | 195 | if torch.isnan(loss) or torch.isinf(loss): 196 | # skip if loss is nan 197 | log.warning('Training stopped due to inf/nan loss.') 198 | sys.exit(-1) 199 | 200 | extra_loss = 0 201 | loss += extra_loss 202 | 203 | opt.zero_grad() 204 | 205 | if scaler is None: 206 | loss.backward() 207 | opt.step() 208 | else: 209 | scaler.scale(loss).backward() 210 | scaler.step(opt) 211 | scaler.update() 212 | 213 | if lr_scheduler is not None: 214 | lr_scheduler.step() 215 | if i % 2000 == 0: 216 | current_lr = [g['lr'] for g in opt.param_groups][0] 217 | log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)') 218 | 219 | logger.iter(i=i, loss=loss) 220 | i += 1 221 | 222 | if i >= max_iterations: 223 | 224 | if not isfile(join(logger.base_path, 'weights.pth')): 225 | # only write if no weights were already written 226 | logger.save_weights(only_trainable=save_only_trainable) 227 | 228 | sys.exit(0) 229 | 230 | 231 | if config.checkpoint_iterations is not None and i in config.checkpoint_iterations: 232 | logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth') 233 | 234 | 235 | if val_interval is not None and i % val_interval == val_interval - 1: 236 | 237 | val_loss, val_scores, maximize = validate(model, dataset_val, config) 238 | 239 | if len(val_scores) > 0: 240 | 241 | score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items()) 242 | 243 | if maximize and val_scores[config.use_val_metric] > best_val_score: 244 | logger.save_weights(only_trainable=save_only_trainable) 245 | best_val_score = val_scores[config.use_val_metric] 246 | 247 | elif not maximize and val_scores[config.use_val_metric] < best_val_score: 248 | logger.save_weights(only_trainable=save_only_trainable) 249 | best_val_score = val_scores[config.use_val_metric] 250 | 251 | else: 252 | score_str = '' 253 | # if no score is used, fall back to loss 254 | if val_loss < best_val_loss: 255 | logger.save_weights(only_trainable=save_only_trainable) 256 | best_val_loss = val_loss 257 | 258 | log.info(f'Validation loss: {val_loss}' + score_str) 259 | logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores) 260 | model.train() 261 | 262 | print('epoch complete') 263 | 264 | 265 | if __name__ == '__main__': 266 | main() --------------------------------------------------------------------------------