├── .github └── README.md ├── docs └── txt2mask │ ├── CHANGELOG.md │ └── README_LINKED.md ├── repositories └── clipseg │ ├── .gitattributes │ ├── .gitignore │ ├── LICENSE │ ├── Quickstart.ipynb │ ├── Readme.md │ ├── Tables.ipynb │ ├── Visual_Feature_Engineering.ipynb │ ├── datasets │ ├── coco_wrapper.py │ ├── pascal_classes.json │ ├── pascal_zeroshot.py │ ├── pfe_dataset.py │ ├── phrasecut.py │ └── utils.py │ ├── environment.yml │ ├── eval.py │ ├── evaluation_utils.py │ ├── example_image.jpg │ ├── experiments │ ├── ablation.yaml │ ├── coco.yaml │ ├── pascal_1shot.yaml │ └── phrasecut.yaml │ ├── general_utils.py │ ├── metrics.py │ ├── models │ ├── clipseg.py │ └── vitseg.py │ ├── overview.png │ ├── score.py │ ├── setup.py │ ├── supplementary.pdf │ └── training.py └── scripts └── txt2mask.py /.github/README.md: -------------------------------------------------------------------------------- 1 | ../docs/txt2mask/README_LINKED.md -------------------------------------------------------------------------------- /docs/txt2mask/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | ## 0.1.1 - 23 September 2022 5 | ### Changed 6 | - Download URL for weights replaced as the Github mirror was prone to hitting the download limit 7 | 8 | ## 0.1.0 - 22 September 2022 9 | ### Added 10 | - New "Brush mask mode" option that allows you to specify what to do with the normal mask, i.e. the mask created using the brush tool or uploaded image 11 | - Added debug mode 12 | 13 | ## 0.0.7 - 21 September 2022 14 | ### Changed 15 | - Fixed issues related to negative mask prompt with images that are not 512x512 16 | 17 | ## 0.0.6 - 21 September 2022 18 | ### Added 19 | - New "Negative mask prompt" field that allows you to subtract areas from your selection 20 | 21 | ### Changed 22 | - Increased mask padding cap from 128 to 500, the padding is added to the image's original dimensions which means you might need a ton of padding if you're working with a large image 23 | 24 | ## 0.0.5 - 20 September 2022 25 | ### Added 26 | - New "Mask Padding" option that allows you to easily expand the boundaries of the selected mask region 27 | 28 | ## 0.0.4 - 19 September 2022 29 | ### Added 30 | - Support for mask prompt delimiter, use `|` to specify multiple selctions which will be stacked additively to produce the final mask image 31 | 32 | ### Changed 33 | - Fixed a possible issue with the clipseg weight downloader 34 | - Fix for crash related to inpainting at full resolution 35 | 36 | ## 0.0.3 - 19 September 2022 37 | ### Added 38 | - Added option to show mask in output 39 | 40 | ## 0.0.2 - 18 September 2022 41 | ### Changed 42 | - Readme and Changelog moved to `docs/txt2mask` so as not to conflict with Automatic's repo 43 | 44 | ### Removed 45 | - Model weights are no longer included with the repo as they will not download properly through Github's website, instead the model files will be fetched by the script if they do not exist on your filesystem 46 | 47 | ## 0.0.1 - 17 September 2022 48 | ### Added 49 | - Initial release -------------------------------------------------------------------------------- /docs/txt2mask/README_LINKED.md: -------------------------------------------------------------------------------- 1 | > **Warning** 2 | > 3 | > This script has been superseded by my new extension, [Unprompted](https://github.com/ThereforeGames/unprompted), which has [the same powerful masking features](https://github.com/ThereforeGames/unprompted/blob/main/docs/MANUAL.md#txt2mask) and a lot more. I do not plan on updating the original script and I cannot guarantee that it will continue working in new versions of the A1111 WebUI. Thank you for understanding. 4 | 5 | # txt2mask for Stable Diffusion 6 | Automatically create masks for inpainting with Stable Diffusion using natural language. 7 | 8 | ## Introduction 9 | 10 | txt2mask is an addon for [AUTOMATIC1111's Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) that allows you to enter a text string in img2img mode which automatically creates an image mask. It is powered by [clipseg](https://github.com/timojl/clipseg). No more messing around with that tempermental brush tool. 😅 11 | 12 | This script is still under active development. 13 | 14 | ![image](https://user-images.githubusercontent.com/95403634/190878562-d020887c-ccb0-411c-ab37-38e2115552eb.png) 15 | 16 | ## Installation 17 | 18 | Simply clone or download this repo and place the files in the base directory of Automatic's web UI. 19 | 20 | ## Usage 21 | 22 | From the img2img screen, select txt2mask as your active script: 23 | 24 | ![image](https://user-images.githubusercontent.com/95403634/190878234-43134aff-0843-4caf-a0ea-146d6e1891dc.png) 25 | 26 | In the `Mask Prompt` field, enter the text to search for within your image. (In the case of the topmost screenshot, this value would be 'business suit' and the prompt box at the top of your UI would say 'sci-fi battle suit.') 27 | 28 | Adjust the `Mask Precision` field to increase or decrease the confidence of that which is masked. Lowering this value too much means it may select more than you intend. 29 | 30 | Press Generate. That's it! 31 | 32 | ## Advanced Features & Tips 33 | 34 | - The Mask Prompt allows you to search for multiple objects by using `|` as a delimiter. For example, if you enter `a face|a tree|a flower` then clipseg will process these three items independently and stack the resulting submasks into one final mask. This will likely yield a better result than had you searched for `a face and a tree and a flower`. 35 | - You can use the `Mask Padding` option to increase the boundaries of your selection. For example, if you enter `a red shirt` as your prompt but find that it's not quite selecting the whole shirt, and `Mask Precision` isn't helping, then padding may be a good way to address the issue. 36 | - Use the `Negative mask prompt` to subtract from areas selected by `Mask prompt`. For example, if your prompt is `a face` and the negative prompt is `eyes` then the resulting mask will select a face without selecting the eyes. 37 | - **(NEW)** You can combine your text mask with the brush tool or uploaded image mask using the `Brush mask mode` setting. Get the best of both worlds. 38 | - In general, less is more for masking: instead of trying to mask "a one-armed man doing a backflip off a barn" you will probably have more luck writing "a man." -------------------------------------------------------------------------------- /repositories/clipseg/.gitattributes: -------------------------------------------------------------------------------- 1 | weights/rd16-uni.pth filter=lfs diff=lfs merge=lfs -text 2 | weights/rd64-uni.pth filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /repositories/clipseg/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /repositories/clipseg/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | This license does not apply to the model weights. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /repositories/clipseg/Quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import requests\n", 11 | "\n", 12 | "from models.clipseg import CLIPDensePredT\n", 13 | "from PIL import Image\n", 14 | "from torchvision import transforms\n", 15 | "from matplotlib import pyplot as plt\n", 16 | "\n", 17 | "# load model\n", 18 | "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n", 19 | "model.eval();\n", 20 | "\n", 21 | "# non-strict, because we only stored decoder weights (not CLIP weights)\n", 22 | "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "Load and normalize `example_image.jpg`. You can also load through an URL." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# load and normalize image\n", 39 | "input_image = Image.open('example_image.jpg')\n", 40 | "\n", 41 | "# or load from URL...\n", 42 | "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n", 43 | "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n", 44 | "\n", 45 | "transform = transforms.Compose([\n", 46 | " transforms.ToTensor(),\n", 47 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 48 | " transforms.Resize((352, 352)),\n", 49 | "])\n", 50 | "img = transform(input_image).unsqueeze(0)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "Predict and visualize (this might take a few seconds if running without GPU support)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n", 67 | "\n", 68 | "# predict\n", 69 | "with torch.no_grad():\n", 70 | " preds = model(img.repeat(4,1,1,1), prompts)[0]\n", 71 | "\n", 72 | "# visualize prediction\n", 73 | "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n", 74 | "[a.axis('off') for a in ax.flatten()]\n", 75 | "ax[0].imshow(input_image)\n", 76 | "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n", 77 | "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];" 78 | ] 79 | } 80 | ], 81 | "metadata": { 82 | "interpreter": { 83 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" 84 | }, 85 | "kernelspec": { 86 | "display_name": "Python 3", 87 | "language": "python", 88 | "name": "python3" 89 | }, 90 | "language_info": { 91 | "codemirror_mode": { 92 | "name": "ipython", 93 | "version": 3 94 | }, 95 | "file_extension": ".py", 96 | "mimetype": "text/x-python", 97 | "name": "python", 98 | "nbconvert_exporter": "python", 99 | "pygments_lexer": "ipython3", 100 | "version": "3.8.10" 101 | } 102 | }, 103 | "nbformat": 4, 104 | "nbformat_minor": 4 105 | } 106 | -------------------------------------------------------------------------------- /repositories/clipseg/Readme.md: -------------------------------------------------------------------------------- 1 | # Image Segmentation Using Text and Image Prompts 2 | This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003). 3 | 4 | **The Paper has been accepted to CVPR 2022!** 5 | 6 | drawing 7 | 8 | The systems allows to create segmentation models without training based on: 9 | - An arbitrary text query 10 | - Or an image with a mask highlighting stuff or an object. 11 | 12 | ### Quick Start 13 | 14 | In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension. 15 | It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb) 16 | (please note that the VM does not use a GPU, thus inference takes a few seconds). 17 | 18 | 19 | ### Dependencies 20 | This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`). 21 | Additional dependencies are hidden for double blind review. 22 | 23 | 24 | ### Datasets 25 | 26 | * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset 27 | * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation 28 | * `PascalZeroShot`: Wrapper class for PascalZeroShot 29 | * `COCOWrapper`: Wrapper class for COCO. 30 | 31 | ### Models 32 | 33 | * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder. 34 | * `ViTDensePredT`: CLIPSeg model with transformer-based decoder. 35 | 36 | ### Third Party Dependencies 37 | For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder. 38 | ```bash 39 | git clone https://github.com/cvlab-yonsei/JoEm 40 | git clone https://github.com/Jia-Research-Lab/PFENet.git 41 | git clone https://github.com/ChenyunWu/PhraseCutDataset.git 42 | git clone https://github.com/juhongm999/hsnet.git 43 | ``` 44 | 45 | ### Weights 46 | 47 | The MIT license does not apply to these weights. 48 | 49 | - [CLIPSeg-D64](https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth) (4.1MB, without CLIP weights) 50 | - [CLIPSeg-D16](https://github.com/timojl/clipseg/raw/master/weights/rd16-uni.pth) (1.1MB, without CLIP weights) 51 | 52 | ### Training and Evaluation 53 | 54 | To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`. 55 | 56 | For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`. 57 | 58 | 59 | ### Usage of PFENet Wrappers 60 | 61 | In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder. 62 | `git clone https://github.com/Jia-Research-Lab/PFENet.git ` 63 | 64 | 65 | ### License 66 | 67 | The source code files in this repository (excluding model weights) are released under MIT license. 68 | 69 | ### Citation 70 | ``` 71 | @InProceedings{lueddecke22_cvpr, 72 | author = {L\"uddecke, Timo and Ecker, Alexander}, 73 | title = {Image Segmentation Using Text and Image Prompts}, 74 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 75 | month = {June}, 76 | year = {2022}, 77 | pages = {7086-7096} 78 | } 79 | 80 | ``` 81 | -------------------------------------------------------------------------------- /repositories/clipseg/Tables.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import clip\n", 13 | "from evaluation_utils import norm, denorm\n", 14 | "from general_utils import *\n", 15 | "from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "# PhraseCut" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n", 50 | "tab1 = pc[['name'] + cols]\n", 51 | "for k in cols:\n", 52 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 53 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", 54 | "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n", 55 | "print(tab1.to_latex(header=False, index=False))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "For 0.1 threshold" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n", 72 | "tab1 = pc[['name'] + cols]\n", 73 | "for k in cols:\n", 74 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 75 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", 76 | "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n", 77 | "print(tab1.to_latex(header=False, index=False))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "# One-shot" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "### Pascal" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 119 | "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n", 120 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 121 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 122 | "\n", 123 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 124 | "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n", 125 | "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 126 | "\n", 127 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", 128 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", 129 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "#### Pascal Zero-shot (in one-shot setting)\n", 137 | "\n", 138 | "Using the same setting as one-shot (hence different from the other zero-shot benchmark)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 148 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", 149 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 150 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 151 | "\n", 152 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 153 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", 154 | "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 155 | "\n", 156 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", 157 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", 158 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "# without fixed thresholds...\n", 168 | "\n", 169 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 170 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", 171 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 172 | "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 173 | "\n", 174 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 175 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", 176 | "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "### COCO" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n", 202 | "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n", 203 | "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n", 204 | "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n", 205 | "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 206 | "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n", 207 | "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "# Zero-shot" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "\n", 233 | "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n", 234 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n", 235 | "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n", 236 | "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "# Ablation" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n", 262 | "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n", 263 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 264 | "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "# Generalization" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "generalization = experiment('experiments/generalize.yaml').dataframe()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "print(\n", 317 | " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n", 318 | " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n", 319 | " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n", 320 | " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n", 321 | ")" 322 | ] 323 | } 324 | ], 325 | "metadata": { 326 | "interpreter": { 327 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" 328 | }, 329 | "kernelspec": { 330 | "display_name": "env2", 331 | "language": "python", 332 | "name": "env2" 333 | }, 334 | "language_info": { 335 | "codemirror_mode": { 336 | "name": "ipython", 337 | "version": 3 338 | }, 339 | "file_extension": ".py", 340 | "mimetype": "text/x-python", 341 | "name": "python", 342 | "nbconvert_exporter": "python", 343 | "pygments_lexer": "ipython3", 344 | "version": "3.8.8" 345 | } 346 | }, 347 | "nbformat": 4, 348 | "nbformat_minor": 4 349 | } 350 | -------------------------------------------------------------------------------- /repositories/clipseg/Visual_Feature_Engineering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Systematic" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "\n", 19 | "import clip\n", 20 | "from evaluation_utils import norm, denorm\n", 21 | "from general_utils import *\n", 22 | "from datasets.lvis_oneshot3 import LVIS_OneShot3\n", 23 | "\n", 24 | "clip_device = 'cuda'\n", 25 | "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n", 26 | "clip_model.eval();\n", 27 | "\n", 28 | "from models.clipseg import CLIPDensePredTMasked\n", 29 | "\n", 30 | "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n", 31 | "clip_mask_model.eval();" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n", 41 | " text_class_labels=True, image_size=352, min_area=0.1,\n", 42 | " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "plot_data(lvis)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "from collections import defaultdict\n", 61 | "import json\n", 62 | "\n", 63 | "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n", 64 | "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n", 65 | "\n", 66 | "objects_per_image = defaultdict(lambda : set())\n", 67 | "for ann in lvis_raw['annotations']:\n", 68 | " objects_per_image[ann['image_id']].add(ann['category_id'])\n", 69 | " \n", 70 | "for ann in lvis_val_raw['annotations']:\n", 71 | " objects_per_image[ann['image_id']].add(ann['category_id']) \n", 72 | " \n", 73 | "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n", 74 | "\n", 75 | "del lvis_raw, lvis_val_raw" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "#bs = 32\n", 85 | "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "from general_utils import get_batch\n", 95 | "from functools import partial\n", 96 | "from evaluation_utils import img_preprocess\n", 97 | "import torch\n", 98 | "\n", 99 | "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n", 100 | "\n", 101 | " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n", 102 | "\n", 103 | " all_prompts = []\n", 104 | " \n", 105 | " with torch.no_grad():\n", 106 | " valid_sims = []\n", 107 | " torch.manual_seed(571)\n", 108 | " \n", 109 | " if type(batches_or_dataset) == list:\n", 110 | " loader = batches_or_dataset # already loaded\n", 111 | " max_iter = float('inf')\n", 112 | " else:\n", 113 | " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n", 114 | " max_iter = 50\n", 115 | " \n", 116 | " global batch\n", 117 | " for i_batch, (batch, batch_y) in enumerate(loader):\n", 118 | " \n", 119 | " if i_batch >= max_iter: break\n", 120 | " \n", 121 | " processed_batch = process(batch)\n", 122 | " if type(processed_batch) == dict:\n", 123 | " \n", 124 | " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n", 125 | " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n", 126 | " else:\n", 127 | " processed_batch = process(batch).to(clip_device)\n", 128 | " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n", 129 | " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n", 130 | " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n", 131 | " \n", 132 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 133 | " bs = len(batch[0])\n", 134 | " for j in range(bs):\n", 135 | " \n", 136 | " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n", 137 | " support_image = basename(lvis.samples[c][sid])\n", 138 | " \n", 139 | " img_objs = [o for o in objects_per_image[int(support_image)]]\n", 140 | " img_objs = [o.replace('_', ' ') for o in img_objs]\n", 141 | " \n", 142 | " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n", 143 | " if o != batch_y[2][j]]\n", 144 | " \n", 145 | " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n", 146 | " all_prompts += [prompts]\n", 147 | " \n", 148 | " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n", 149 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n", 150 | "\n", 151 | " global logits\n", 152 | " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n", 153 | "\n", 154 | " global sim\n", 155 | " sim = torch.softmax(logits, dim=-1)\n", 156 | " \n", 157 | " valid_sims += [sim]\n", 158 | " \n", 159 | " #valid_sims = torch.stack(valid_sims)\n", 160 | " return valid_sims, all_prompts\n", 161 | " \n", 162 | "\n", 163 | "def new_img_preprocess(x):\n", 164 | " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n", 165 | " \n", 166 | "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n", 167 | "get_similarities(lvis, lambda x: x[1]);" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "preprocessing_functions = [\n", 177 | "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n", 178 | "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n", 179 | "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n", 180 | "# ['colorize object red', partial(img_preprocess, colorize=True)],\n", 181 | "# ['add red outline', partial(img_preprocess, outline=True)],\n", 182 | " \n", 183 | "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n", 184 | "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n", 185 | "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n", 186 | "# ['BG blur', partial(img_preprocess, blur=3)],\n", 187 | "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", 188 | " \n", 189 | "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n", 190 | "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n", 191 | " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n", 192 | " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", 193 | "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n", 194 | "]\n", 195 | "\n", 196 | "preprocessing_functions = preprocessing_functions\n", 197 | "\n", 198 | "base, base_p = get_similarities(lvis, lambda x: x[1])\n", 199 | "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "for j in range(1):\n", 218 | " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "from pandas import DataFrame\n", 228 | "tab = dict()\n", 229 | "for j, (name, _) in enumerate(preprocessing_functions):\n", 230 | " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n", 231 | " \n", 232 | " \n", 233 | "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) " 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "# Visual" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "from evaluation_utils import denorm, norm" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "def load_sample(filename, filename2):\n", 259 | " from os.path import join\n", 260 | " bp = expanduser('~/cloud/resources/sample_images')\n", 261 | " tf = transforms.Compose([\n", 262 | " transforms.ToTensor(),\n", 263 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 264 | " transforms.Resize(224),\n", 265 | " transforms.CenterCrop(224)\n", 266 | " ])\n", 267 | " tf2 = transforms.Compose([\n", 268 | " transforms.ToTensor(),\n", 269 | " transforms.Resize(224),\n", 270 | " transforms.CenterCrop(224)\n", 271 | " ])\n", 272 | " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n", 273 | " inp1[1] = inp1[1].unsqueeze(0)\n", 274 | " inp1[2] = inp1[2][:1] \n", 275 | " return inp1\n", 276 | "\n", 277 | "def all_preprocessing(inp1):\n", 278 | " return [\n", 279 | " img_preprocess(inp1),\n", 280 | " img_preprocess(inp1, colorize=True),\n", 281 | " img_preprocess(inp1, outline=True), \n", 282 | " img_preprocess(inp1, blur=3),\n", 283 | " img_preprocess(inp1, bg_fac=0.1),\n", 284 | " #img_preprocess(inp1, bg_fac=0.5),\n", 285 | " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n", 286 | " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n", 287 | " ]\n", 288 | "\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "from torchvision import transforms\n", 298 | "from PIL import Image\n", 299 | "from matplotlib import pyplot as plt\n", 300 | "from evaluation_utils import img_preprocess\n", 301 | "import clip\n", 302 | "\n", 303 | "images_queries = [\n", 304 | " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n", 305 | " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n", 306 | "]\n", 307 | "\n", 308 | "\n", 309 | "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n", 310 | "\n", 311 | "for j, (images, objects) in enumerate(images_queries):\n", 312 | " \n", 313 | " joint_image = all_preprocessing(images)\n", 314 | " \n", 315 | " joint_image = torch.stack(joint_image)[:,0]\n", 316 | " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n", 317 | " image_features = clip_model.encode_image(joint_image)\n", 318 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 319 | " \n", 320 | " prompts = [f'a photo of a {obj}'for obj in objects]\n", 321 | " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n", 322 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n", 323 | " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n", 324 | " sim = torch.softmax(logits, dim=-1).detach().cpu()\n", 325 | "\n", 326 | " for i, img in enumerate(joint_image):\n", 327 | " ax[2*j, i].axis('off')\n", 328 | " \n", 329 | " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n", 330 | " ax[2*j+ 1, i].grid(True)\n", 331 | " \n", 332 | " ax[2*j + 1, i].set_ylim(0,1)\n", 333 | " ax[2*j + 1, i].set_yticklabels([])\n", 334 | " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n", 335 | "# ax[1, i].set_xticklabels(objects, rotation=90)\n", 336 | " for k in range(len(sim[i])):\n", 337 | " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n", 338 | " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n", 339 | "\n", 340 | "plt.tight_layout()\n", 341 | "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')" 342 | ] 343 | } 344 | ], 345 | "metadata": { 346 | "kernelspec": { 347 | "display_name": "env2", 348 | "language": "python", 349 | "name": "env2" 350 | }, 351 | "language_info": { 352 | "codemirror_mode": { 353 | "name": "ipython", 354 | "version": 3 355 | }, 356 | "file_extension": ".py", 357 | "mimetype": "text/x-python", 358 | "name": "python", 359 | "nbconvert_exporter": "python", 360 | "pygments_lexer": "ipython3", 361 | "version": "3.8.8" 362 | } 363 | }, 364 | "nbformat": 4, 365 | "nbformat_minor": 4 366 | } 367 | -------------------------------------------------------------------------------- /repositories/clipseg/datasets/coco_wrapper.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from types import new_class 3 | import torch 4 | import numpy as np 5 | import os 6 | import json 7 | 8 | from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename 9 | from random import shuffle, seed as set_seed 10 | from PIL import Image 11 | 12 | from itertools import combinations 13 | from torchvision import transforms 14 | from torchvision.transforms.transforms import Resize 15 | 16 | from datasets.utils import blend_image_segmentation 17 | from general_utils import get_from_repository 18 | 19 | COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'} 20 | 21 | class COCOWrapper(object): 22 | 23 | def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0, 24 | with_class_label=False): 25 | super().__init__() 26 | 27 | self.mask = mask 28 | self.with_class_label = with_class_label 29 | self.negative_prob = negative_prob 30 | 31 | from third_party.hsnet.data.coco import DatasetCOCO 32 | 33 | get_from_repository('COCO-20i', ['COCO-20i.tar']) 34 | 35 | foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl') 36 | 37 | def build_img_metadata_classwise(self): 38 | with open(foldpath % (self.split, self.fold), 'rb') as f: 39 | img_metadata_classwise = pickle.load(f) 40 | return img_metadata_classwise 41 | 42 | 43 | DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise 44 | # DatasetCOCO.read_mask = read_mask 45 | 46 | mean = [0.485, 0.456, 0.406] 47 | std = [0.229, 0.224, 0.225] 48 | transform = transforms.Compose([ 49 | transforms.Resize((image_size, image_size)), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean, std) 52 | ]) 53 | 54 | self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False) 55 | 56 | self.all_classes = [self.coco.class_ids] 57 | self.coco.base_path = join(expanduser('~/datasets/COCO-20i')) 58 | 59 | def __len__(self): 60 | return len(self.coco) 61 | 62 | def __getitem__(self, i): 63 | sample = self.coco[i] 64 | 65 | label_name = COCO_CLASSES[int(sample['class_id'])] 66 | 67 | img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0] 68 | 69 | if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob: 70 | new_class_id = sample['class_id'] 71 | while new_class_id == sample['class_id']: 72 | sample2 = self.coco[torch.randint(0, len(self), (1,)).item()] 73 | new_class_id = sample2['class_id'] 74 | img_s = sample2['support_imgs'][0] 75 | seg_s = torch.zeros_like(seg_s) 76 | 77 | mask = self.mask 78 | if mask == 'separate': 79 | supp = (img_s, seg_s) 80 | elif mask == 'text_label': 81 | # DEPRECATED 82 | supp = [int(sample['class_id'])] 83 | elif mask == 'text': 84 | supp = [label_name] 85 | else: 86 | if mask.startswith('text_and_'): 87 | mask = mask[9:] 88 | label_add = [label_name] 89 | else: 90 | label_add = [] 91 | 92 | supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask) 93 | 94 | if self.with_class_label: 95 | label = (torch.zeros(0), sample['class_id'],) 96 | else: 97 | label = (torch.zeros(0), ) 98 | 99 | return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label -------------------------------------------------------------------------------- /repositories/clipseg/datasets/pascal_classes.json: -------------------------------------------------------------------------------- 1 | [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}] -------------------------------------------------------------------------------- /repositories/clipseg/datasets/pascal_zeroshot.py: -------------------------------------------------------------------------------- 1 | from os.path import expanduser 2 | import torch 3 | import json 4 | import torchvision 5 | from general_utils import get_from_repository 6 | from general_utils import log 7 | from torchvision import transforms 8 | 9 | PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], 10 | ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], 11 | ['chair.n.01', 'pot_plant.n.01']] 12 | 13 | 14 | class PascalZeroShot(object): 15 | 16 | def __init__(self, split, n_unseen, image_size=224) -> None: 17 | super().__init__() 18 | 19 | import sys 20 | sys.path.append('third_party/JoEm') 21 | from third_party.JoEm.data_loader.dataset import VOCSegmentation 22 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC 23 | 24 | self.pascal_classes = VOC 25 | self.image_size = image_size 26 | 27 | self.transform = transforms.Compose([ 28 | transforms.Resize((image_size, image_size)), 29 | ]) 30 | 31 | if split == 'train': 32 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), 33 | split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), 34 | ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) 35 | elif split == 'val': 36 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), 37 | split=split, transform=False, 38 | ignore_bg=False, ignore_unseen=False) 39 | 40 | self.unseen_idx = get_unseen_idx(n_unseen) 41 | 42 | def __len__(self): 43 | return len(self.voc) 44 | 45 | def __getitem__(self, i): 46 | 47 | sample = self.voc[i] 48 | label = sample['label'].long() 49 | all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] 50 | class_indices = [l for l in all_labels] 51 | class_names = [self.pascal_classes[l] for l in all_labels] 52 | 53 | image = self.transform(sample['image']) 54 | 55 | label = transforms.Resize((self.image_size, self.image_size), 56 | interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] 57 | 58 | return (image,), (label, ) 59 | 60 | 61 | -------------------------------------------------------------------------------- /repositories/clipseg/datasets/pfe_dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import expanduser 2 | import torch 3 | import json 4 | from general_utils import get_from_repository 5 | from datasets.lvis_oneshot3 import blend_image_segmentation 6 | from general_utils import log 7 | 8 | PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))} 9 | 10 | 11 | class PFEPascalWrapper(object): 12 | 13 | def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None): 14 | import sys 15 | # sys.path.append(expanduser('~/projects/new_one_shot')) 16 | from third_party.PFENet.util.dataset import SemData 17 | 18 | get_from_repository('PascalVOC2012', ['Pascal5i.tar']) 19 | 20 | self.p_negative = p_negative 21 | self.size = size 22 | self.mode = mode 23 | self.image_size = image_size 24 | 25 | if label_support in {True, False}: 26 | log.warning('label_support argument is deprecated. Use mask instead.') 27 | #raise ValueError() 28 | 29 | self.mask = mask 30 | 31 | value_scale = 255 32 | mean = [0.485, 0.456, 0.406] 33 | mean = [item * value_scale for item in mean] 34 | std = [0.229, 0.224, 0.225] 35 | std = [item * value_scale for item in std] 36 | 37 | import third_party.PFENet.util.transform as transform 38 | 39 | if mode == 'val': 40 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt') 41 | 42 | data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else [] 43 | data_transform += [ 44 | transform.ToTensor(), 45 | transform.Normalize(mean=mean, std=std) 46 | ] 47 | 48 | 49 | elif mode == 'train': 50 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt') 51 | 52 | assert image_size != 'original' 53 | 54 | data_transform = [ 55 | transform.RandScale([0.9, 1.1]), 56 | transform.RandRotate([-10, 10], padding=mean, ignore_label=255), 57 | transform.RandomGaussianBlur(), 58 | transform.RandomHorizontalFlip(), 59 | transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255), 60 | transform.ToTensor(), 61 | transform.Normalize(mean=mean, std=std) 62 | ] 63 | 64 | data_transform = transform.Compose(data_transform) 65 | 66 | self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'), 67 | data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False) 68 | 69 | self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list 70 | 71 | # verify that subcls_list always has length 1 72 | # assert len(set([len(d[4]) for d in self.dataset])) == 1 73 | 74 | print('actual length', len(self.dataset.data_list)) 75 | 76 | def __len__(self): 77 | if self.mode == 'val': 78 | return len(self.dataset.data_list) 79 | else: 80 | return len(self.dataset.data_list) 81 | 82 | def __getitem__(self, index): 83 | if self.dataset.mode == 'train': 84 | image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)] 85 | elif self.dataset.mode == 'val': 86 | image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)] 87 | ori_label = torch.from_numpy(ori_label).unsqueeze(0) 88 | 89 | if self.image_size != 'original': 90 | longerside = max(ori_label.size(1), ori_label.size(2)) 91 | backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255 92 | backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label 93 | label = backmask.clone().long() 94 | else: 95 | label = label.unsqueeze(0) 96 | 97 | # assert label.shape == (473, 473) 98 | 99 | if self.p_negative > 0: 100 | if torch.rand(1).item() < self.p_negative: 101 | while True: 102 | idx = torch.randint(0, len(self.dataset.data_list), (1,)).item() 103 | _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx] 104 | if subcls_list[0] != subcls_list_tmp[0]: 105 | break 106 | 107 | s_x = s_x[0] 108 | s_y = (s_y == 1)[0] 109 | label_fg = (label == 1).float() 110 | val_mask = (label != 255).float() 111 | 112 | class_id = self.class_list[subcls_list[0]] 113 | 114 | label_name = PASCAL_CLASSES[class_id][0] 115 | label_add = () 116 | mask = self.mask 117 | 118 | if mask == 'text': 119 | support = ('a photo of a ' + label_name + '.',) 120 | elif mask == 'separate': 121 | support = (s_x, s_y) 122 | else: 123 | if mask.startswith('text_and_'): 124 | label_add = (label_name,) 125 | mask = mask[9:] 126 | 127 | support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],) 128 | 129 | return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0]) 130 | -------------------------------------------------------------------------------- /repositories/clipseg/datasets/phrasecut.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import os 5 | 6 | from os.path import join, isdir, isfile, expanduser 7 | from PIL import Image 8 | 9 | from torchvision import transforms 10 | from torchvision.transforms.transforms import Resize 11 | 12 | from torch.nn import functional as nnf 13 | from general_utils import get_from_repository 14 | 15 | from skimage.draw import polygon2mask 16 | 17 | 18 | 19 | def random_crop_slices(origin_size, target_size): 20 | """Gets slices of a random crop. """ 21 | assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}' 22 | 23 | offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high 24 | offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item() 25 | 26 | return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1]) 27 | 28 | 29 | def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None): 30 | 31 | 32 | best_crops = [] 33 | best_crop_not_ok = float('-inf'), None, None 34 | min_sum = 0 35 | 36 | seg = seg.astype('bool') 37 | 38 | if min_frac is not None: 39 | #min_sum = seg.sum() * min_frac 40 | min_sum = seg.shape[0] * seg.shape[1] * min_frac 41 | 42 | for iteration in range(iterations): 43 | sl_y, sl_x = random_crop_slices(seg.shape, image_size) 44 | seg_ = seg[sl_y, sl_x] 45 | sum_seg_ = seg_.sum() 46 | 47 | if sum_seg_ > min_sum: 48 | 49 | if best_of is None: 50 | return sl_y, sl_x, False 51 | else: 52 | best_crops += [(sum_seg_, sl_y, sl_x)] 53 | if len(best_crops) >= best_of: 54 | best_crops.sort(key=lambda x:x[0], reverse=True) 55 | sl_y, sl_x = best_crops[0][1:] 56 | 57 | return sl_y, sl_x, False 58 | 59 | else: 60 | if sum_seg_ > best_crop_not_ok[0]: 61 | best_crop_not_ok = sum_seg_, sl_y, sl_x 62 | 63 | else: 64 | # return best segmentation found 65 | return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,) 66 | 67 | 68 | class PhraseCut(object): 69 | 70 | def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True, 71 | min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None): 72 | super().__init__() 73 | 74 | self.negative_prob = negative_prob 75 | self.image_size = image_size 76 | self.with_visual = with_visual 77 | self.only_visual = only_visual 78 | self.phrase_form = '{}' 79 | self.mask = mask 80 | self.aug_crop = aug_crop 81 | 82 | if aug_color: 83 | self.aug_color = transforms.Compose([ 84 | transforms.ColorJitter(0.5, 0.5, 0.2, 0.05), 85 | ]) 86 | else: 87 | self.aug_color = None 88 | 89 | get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([ 90 | isdir(join(local_dir, 'VGPhraseCut_v0')), 91 | isdir(join(local_dir, 'VGPhraseCut_v0', 'images')), 92 | isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')), 93 | len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249} 94 | ])) 95 | 96 | from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader 97 | self.refvg_loader = RefVGLoader(split=split) 98 | 99 | # img_ids where the size in the annotations does not match actual size 100 | invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530, 101 | 150333, 286065, 285814, 498187, 285761, 498042]) 102 | 103 | mean = [0.485, 0.456, 0.406] 104 | std = [0.229, 0.224, 0.225] 105 | self.normalize = transforms.Normalize(mean, std) 106 | 107 | self.sample_ids = [(i, j) 108 | for i in self.refvg_loader.img_ids 109 | for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases'])) 110 | if i not in invalid_img_ids] 111 | 112 | 113 | # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']])) 114 | 115 | from nltk.stem import WordNetLemmatizer 116 | wnl = WordNetLemmatizer() 117 | 118 | # Filter by class (if remove_classes is set) 119 | if remove_classes is None: 120 | pass 121 | else: 122 | from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo 123 | from nltk.corpus import wordnet 124 | 125 | print('remove pascal classes...') 126 | 127 | get_data = self.refvg_loader.get_img_ref_data # shortcut 128 | keep_sids = None 129 | 130 | if remove_classes[0] == 'pas5i': 131 | subset_id = remove_classes[1] 132 | from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS 133 | avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]] 134 | 135 | 136 | elif remove_classes[0] == 'zs': 137 | stop = remove_classes[1] 138 | 139 | from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS 140 | 141 | avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set] 142 | print(avoid) 143 | 144 | elif remove_classes[0] == 'aff': 145 | # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02'] 146 | # all_lemmas = set(['drink', 'sit', 'ride']) 147 | avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting', 148 | 'ride', 'rides', 'riding', 149 | 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven', 150 | 'swim', 'swims', 'swimming', 151 | 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears'] 152 | keep_sids = [(i, j) for i, j in self.sample_ids if 153 | all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))] 154 | 155 | print('avoid classes:', avoid) 156 | 157 | 158 | if keep_sids is None: 159 | all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)] 160 | all_lemmas = list(set(all_lemmas)) 161 | all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas] 162 | all_lemmas = set(all_lemmas) 163 | 164 | # divide into multi word and single word 165 | all_lemmas_s = set(l for l in all_lemmas if ' ' not in l) 166 | all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s) 167 | 168 | # new3 169 | phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids] 170 | remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases) 171 | if any(l in phrase for l in all_lemmas_m) or 172 | len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0 173 | ) 174 | keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids] 175 | 176 | print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}') 177 | removed_ids = set(self.sample_ids) - set(keep_sids) 178 | 179 | print('Examples of removed', len(removed_ids)) 180 | for i, j in list(removed_ids)[:20]: 181 | print(i, get_data(i)['phrases'][j]) 182 | 183 | self.sample_ids = keep_sids 184 | 185 | from itertools import groupby 186 | samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j)) 187 | for i, j in self.sample_ids] 188 | samples_by_phrase = sorted(samples_by_phrase) 189 | samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0]) 190 | 191 | self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase} 192 | 193 | self.all_phrases = list(set(self.samples_by_phrase.keys())) 194 | 195 | 196 | if self.only_visual: 197 | assert self.with_visual 198 | self.sample_ids = [(i, j) for i, j in self.sample_ids 199 | if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1] 200 | 201 | # Filter by size (if min_size is set) 202 | sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids] 203 | image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids] 204 | #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes] 205 | self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)] 206 | 207 | if min_size: 208 | print('filter by size') 209 | 210 | self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size] 211 | 212 | self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/')) 213 | 214 | def __len__(self): 215 | return len(self.sample_ids) 216 | 217 | 218 | def load_sample(self, sample_i, j): 219 | 220 | img_ref_data = self.refvg_loader.get_img_ref_data(sample_i) 221 | 222 | polys_phrase0 = img_ref_data['gt_Polygons'][j] 223 | phrase = img_ref_data['phrases'][j] 224 | phrase = self.phrase_form.format(phrase) 225 | 226 | masks = [] 227 | for polys in polys_phrase0: 228 | for poly in polys: 229 | poly = [p[::-1] for p in poly] # swap x,y 230 | masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)] 231 | 232 | seg = np.stack(masks).max(0) 233 | img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg'))) 234 | 235 | min_shape = min(img.shape[:2]) 236 | 237 | if self.aug_crop: 238 | sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05) 239 | else: 240 | sly, slx = slice(0, None), slice(0, None) 241 | 242 | seg = seg[sly, slx] 243 | img = img[sly, slx] 244 | 245 | seg = seg.astype('uint8') 246 | seg = torch.from_numpy(seg).view(1, 1, *seg.shape) 247 | 248 | if img.ndim == 2: 249 | img = np.dstack([img] * 3) 250 | 251 | img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() 252 | 253 | seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0] 254 | img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0] 255 | 256 | # img = img.permute([2,0, 1]) 257 | img = img / 255.0 258 | 259 | if self.aug_color is not None: 260 | img = self.aug_color(img) 261 | 262 | img = self.normalize(img) 263 | 264 | 265 | 266 | return img, seg, phrase 267 | 268 | def __getitem__(self, i): 269 | 270 | sample_i, j = self.sample_ids[i] 271 | 272 | img, seg, phrase = self.load_sample(sample_i, j) 273 | 274 | if self.negative_prob > 0: 275 | if torch.rand((1,)).item() < self.negative_prob: 276 | 277 | new_phrase = None 278 | while new_phrase is None or new_phrase == phrase: 279 | idx = torch.randint(0, len(self.all_phrases), (1,)).item() 280 | new_phrase = self.all_phrases[idx] 281 | phrase = new_phrase 282 | seg = torch.zeros_like(seg) 283 | 284 | if self.with_visual: 285 | # find a corresponding visual image 286 | if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1: 287 | idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item() 288 | other_sample = self.samples_by_phrase[phrase][idx] 289 | #print(other_sample) 290 | img_s, seg_s, _ = self.load_sample(*other_sample) 291 | 292 | from datasets.utils import blend_image_segmentation 293 | 294 | if self.mask in {'separate', 'text_and_separate'}: 295 | # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:] 296 | add_phrase = [phrase] if self.mask == 'text_and_separate' else [] 297 | vis_s = add_phrase + [img_s, seg_s, True] 298 | else: 299 | if self.mask.startswith('text_and_'): 300 | mask_mode = self.mask[9:] 301 | label_add = [phrase] 302 | else: 303 | mask_mode = self.mask 304 | label_add = [] 305 | 306 | masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0]) 307 | vis_s = label_add + [masked_img_s, True] 308 | 309 | else: 310 | # phrase is unique 311 | vis_s = torch.zeros_like(img) 312 | 313 | if self.mask in {'separate', 'text_and_separate'}: 314 | add_phrase = [phrase] if self.mask == 'text_and_separate' else [] 315 | vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False] 316 | elif self.mask.startswith('text_and_'): 317 | vis_s = [phrase, vis_s, False] 318 | else: 319 | vis_s = [vis_s, False] 320 | else: 321 | assert self.mask == 'text' 322 | vis_s = [phrase] 323 | 324 | seg = seg.unsqueeze(0).float() 325 | 326 | data_x = (img,) + tuple(vis_s) 327 | 328 | return data_x, (seg, torch.zeros(0), i) 329 | 330 | 331 | class PhraseCutPlus(PhraseCut): 332 | 333 | def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None): 334 | super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size, 335 | remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask) -------------------------------------------------------------------------------- /repositories/clipseg/datasets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def blend_image_segmentation(img, seg, mode, image_size=224): 7 | 8 | 9 | if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}: 10 | if isinstance(img, np.ndarray): 11 | img = torch.from_numpy(img) 12 | 13 | if isinstance(seg, np.ndarray): 14 | seg = torch.from_numpy(seg) 15 | 16 | if mode == 'overlay': 17 | out = img * seg 18 | out = [out.astype('float32')] 19 | elif mode == 'highlight': 20 | out = img * seg[None, :, :] * 0.85 + 0.15 * img 21 | out = [out.astype('float32')] 22 | elif mode == 'highlight2': 23 | img = img / 2 24 | out = (img+0.1) * seg[None, :, :] + 0.3 * img 25 | out = [out.astype('float32')] 26 | elif mode == 'blur_highlight': 27 | from evaluation_utils import img_preprocess 28 | out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01] 29 | elif mode == 'blur3_highlight': 30 | from evaluation_utils import img_preprocess 31 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01] 32 | elif mode == 'blur3_highlight01': 33 | from evaluation_utils import img_preprocess 34 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01] 35 | elif mode == 'blur_highlight_random': 36 | from evaluation_utils import img_preprocess 37 | out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01] 38 | elif mode == 'crop': 39 | from evaluation_utils import img_preprocess 40 | out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()] 41 | elif mode == 'crop_blur_highlight': 42 | from evaluation_utils import img_preprocess 43 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()] 44 | elif mode == 'crop_blur_highlight352': 45 | from evaluation_utils import img_preprocess 46 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()] 47 | elif mode == 'shape': 48 | out = [np.stack([seg[:, :]]*3).astype('float32')] 49 | elif mode == 'concat': 50 | out = [np.concatenate([img, seg[None, :, :]]).astype('float32')] 51 | elif mode == 'image_only': 52 | out = [img.astype('float32')] 53 | elif mode == 'image_black': 54 | out = [img.astype('float32')*0] 55 | elif mode is None: 56 | out = [img.astype('float32')] 57 | elif mode == 'separate': 58 | out = [img.astype('float32'), seg.astype('int64')] 59 | elif mode == 'separate_img_black': 60 | out = [img.astype('float32')*0, seg.astype('int64')] 61 | elif mode == 'separate_seg_ones': 62 | out = [img.astype('float32'), np.ones_like(seg).astype('int64')] 63 | elif mode == 'separate_both_black': 64 | out = [img.astype('float32')*0, seg.astype('int64')*0] 65 | else: 66 | raise ValueError(f'invalid mode: {mode}') 67 | 68 | return out -------------------------------------------------------------------------------- /repositories/clipseg/environment.yml: -------------------------------------------------------------------------------- 1 | name: clipseg-environment 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - numpy 7 | - scipy 8 | - matplotlib-base 9 | - pip 10 | - pip: 11 | - --find-links https://download.pytorch.org/whl/torch_stable.html 12 | - torch==1.10.0+cpu 13 | - torchvision==0.11.1+cpu 14 | - opencv-python 15 | - git+https://github.com/openai/CLIP.git 16 | -------------------------------------------------------------------------------- /repositories/clipseg/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import requests 3 | 4 | from models.clipseg import CLIPDensePredT 5 | from PIL import Image 6 | from torchvision import transforms 7 | from matplotlib import pyplot as plt 8 | 9 | import cv2 10 | 11 | # load model 12 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64) 13 | model.eval(); 14 | 15 | # non-strict, because we only stored decoder weights (not CLIP weights) 16 | model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cuda')), strict=False); 17 | 18 | # load and normalize image 19 | input_image = Image.open('example_image.jpg') 20 | 21 | # or load from URL... 22 | # image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg' 23 | # input_image = Image.open(requests.get(image_url, stream=True).raw) 24 | 25 | transform = transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 28 | transforms.Resize((512, 512)), 29 | ]) 30 | img = transform(input_image).unsqueeze(0) 31 | 32 | prompts = ['a spoon'] 33 | 34 | # predict 35 | with torch.no_grad(): 36 | preds = model(img.repeat(1,1,1,1), prompts)[0] 37 | 38 | for i in range(len(prompts)): 39 | filename = f"{i}.png" 40 | plt.imsave(filename,torch.sigmoid(preds[i][0])) 41 | 42 | # TODO: Figure out how to convert the plot above to numpy instead of re-loading image 43 | img = cv2.imread(filename) 44 | 45 | gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 46 | 47 | (thresh, bw_image) = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY) 48 | 49 | blur_image = cv2.GaussianBlur(bw_image, (19,19), 0) 50 | 51 | cv2.imwrite(f"{i}_bw.png",bw_image) 52 | cv2.imwrite(f"{i}_blur.png",blur_image) -------------------------------------------------------------------------------- /repositories/clipseg/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | from general_utils import load_model 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import numpy as np 6 | 7 | def denorm(img): 8 | 9 | np_input = False 10 | if isinstance(img, np.ndarray): 11 | img = torch.from_numpy(img) 12 | np_input = True 13 | 14 | mean = torch.Tensor([0.485, 0.456, 0.406]) 15 | std = torch.Tensor([0.229, 0.224, 0.225]) 16 | 17 | img_denorm = (img*std[:,None,None]) + mean[:,None,None] 18 | 19 | if np_input: 20 | img_denorm = np.clip(img_denorm.numpy(), 0, 1) 21 | else: 22 | img_denorm = torch.clamp(img_denorm, 0, 1) 23 | 24 | return img_denorm 25 | 26 | 27 | def norm(img): 28 | mean = torch.Tensor([0.485, 0.456, 0.406]) 29 | std = torch.Tensor([0.229, 0.224, 0.225]) 30 | return (img - mean[:,None,None]) / std[:,None,None] 31 | 32 | 33 | def fast_iou_curve(p, g): 34 | 35 | g = g[p.sort().indices] 36 | p = torch.sigmoid(p.sort().values) 37 | 38 | scores = [] 39 | vals = np.linspace(0, 1, 50) 40 | 41 | for q in vals: 42 | 43 | n = int(len(g) * q) 44 | 45 | valid = torch.where(p > q)[0] 46 | if len(valid) > 0: 47 | n = int(valid[0]) 48 | else: 49 | n = len(g) 50 | 51 | fn = g[:n].sum() 52 | tn = n - fn 53 | tp = g[n:].sum() 54 | fp = len(g) - n - tp 55 | 56 | iou = tp / (tp + fn + fp) 57 | 58 | precision = tp / (tp + fp) 59 | recall = tp / (tp + fn) 60 | 61 | scores += [iou] 62 | 63 | return vals, scores 64 | 65 | 66 | def fast_rp_curve(p, g): 67 | 68 | g = g[p.sort().indices] 69 | p = torch.sigmoid(p.sort().values) 70 | 71 | precisions, recalls = [], [] 72 | vals = np.linspace(p.min(), p.max(), 250) 73 | 74 | for q in p[::100000]: 75 | 76 | n = int(len(g) * q) 77 | 78 | valid = torch.where(p > q)[0] 79 | if len(valid) > 0: 80 | n = int(valid[0]) 81 | else: 82 | n = len(g) 83 | 84 | fn = g[:n].sum() 85 | tn = n - fn 86 | tp = g[n:].sum() 87 | fp = len(g) - n - tp 88 | 89 | iou = tp / (tp + fn + fp) 90 | 91 | precision = tp / (tp + fp) 92 | recall = tp / (tp + fn) 93 | 94 | precisions += [precision] 95 | recalls += [recall] 96 | 97 | return recalls, precisions 98 | 99 | 100 | # Image processing 101 | 102 | def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2, 103 | brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224): 104 | import cv2 105 | 106 | rw = rect_width 107 | 108 | out = [] 109 | for img, mask in zip(batch[1], batch[2]): 110 | 111 | img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img) 112 | mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask) 113 | 114 | img *= brightness 115 | img_bl = img 116 | if blur > 0: # best 5 117 | img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1) 118 | 119 | if grayscale: 120 | img_bl = img_bl[1][None] 121 | 122 | #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl 123 | # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask) 124 | img_inp = img*mask + (bg_fac) * img_bl * (1-mask) 125 | 126 | if rect: 127 | _, bbox = crop_mask(img, mask, context=0.1) 128 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None] 129 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None] 130 | img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] 131 | img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] 132 | 133 | 134 | if center_context is not None: 135 | img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size) 136 | 137 | if colorize: 138 | img_gray = denorm(img) 139 | img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY) 140 | img_gray = torch.stack([torch.from_numpy(img_gray)]*3) 141 | img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask) 142 | img_inp = norm(img_inp) 143 | 144 | if outline: 145 | cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 146 | outline_img = np.zeros(mask.shape, dtype=np.uint8) 147 | cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255)) 148 | outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255. 149 | img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img) 150 | img_inp = norm(img_inp) 151 | 152 | out += [img_inp] 153 | 154 | return torch.stack(out) 155 | 156 | 157 | def object_crop(img, mask, context=0.0, square=False, image_size=224): 158 | img_crop, bbox = crop_mask(img, mask, context=context, square=square) 159 | img_crop = pad_to_square(img_crop, channel_dim=0) 160 | img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0) 161 | return img_crop 162 | 163 | 164 | def crop_mask(img, mask, context=0.0, square=False): 165 | 166 | assert img.shape[1:] == mask.shape 167 | 168 | bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()] 169 | bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()] 170 | bbox = [int(x) for x in bbox] 171 | 172 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 173 | 174 | # square mask 175 | if square: 176 | bbox[0] = int(max(0, bbox[0] - context * height)) 177 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) 178 | bbox[2] = int(max(0, bbox[2] - context * width)) 179 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) 180 | 181 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 182 | if height > width: 183 | bbox[2] = int(max(0, (bbox[2] - 0.5*height))) 184 | bbox[3] = bbox[2] + height 185 | else: 186 | bbox[0] = int(max(0, (bbox[0] - 0.5*width))) 187 | bbox[1] = bbox[0] + width 188 | else: 189 | bbox[0] = int(max(0, bbox[0] - context * height)) 190 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) 191 | bbox[2] = int(max(0, bbox[2] - context * width)) 192 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) 193 | 194 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 195 | img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]] 196 | return img_crop, bbox 197 | 198 | 199 | def pad_to_square(img, channel_dim=2, fill=0): 200 | """ 201 | 202 | 203 | add padding such that a squared image is returned """ 204 | 205 | from torchvision.transforms.functional import pad 206 | 207 | if channel_dim == 2: 208 | img = img.permute(2, 0, 1) 209 | elif channel_dim == 0: 210 | pass 211 | else: 212 | raise ValueError('invalid channel_dim') 213 | 214 | h, w = img.shape[1:] 215 | pady1 = pady2 = padx1 = padx2 = 0 216 | 217 | if h > w: 218 | padx1 = (h - w) // 2 219 | padx2 = h - w - padx1 220 | elif w > h: 221 | pady1 = (w - h) // 2 222 | pady2 = w - h - pady1 223 | 224 | img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant') 225 | 226 | if channel_dim == 2: 227 | img_padded = img_padded.permute(1, 2, 0) 228 | 229 | return img_padded 230 | 231 | 232 | # qualitative 233 | 234 | def split_sentence(inp, limit=9): 235 | t_new, current_len = [], 0 236 | for k, t in enumerate(inp.split(' ')): 237 | current_len += len(t) + 1 238 | t_new += [t+' '] 239 | # not last 240 | if current_len > limit and k != len(inp.split(' ')) - 1: 241 | current_len = 0 242 | t_new += ['\n'] 243 | 244 | t_new = ''.join(t_new) 245 | return t_new 246 | 247 | 248 | from matplotlib import pyplot as plt 249 | 250 | 251 | def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None): 252 | 253 | row_off = 0 if labels is None else 1 254 | _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2))) 255 | [a.axis('off') for a in ax.flatten()] 256 | 257 | if labels is not None: 258 | for j in range(len(labels)): 259 | t_new = split_sentence(labels[j], limit=6) 260 | ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale) 261 | 262 | 263 | for i in range(len(imgs)): 264 | ax[i + row_off,0].imshow(imgs[i]) 265 | for j in range(len(preds)): 266 | img = preds[j][i][0].detach().cpu().numpy() 267 | 268 | if gt_labels is not None and labels[j] == gt_labels[i]: 269 | print(j, labels[j], gt_labels[i]) 270 | edgecolor = 'red' 271 | if aps is not None: 272 | ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8) 273 | else: 274 | edgecolor = 'k' 275 | 276 | rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none", 277 | edgecolor=edgecolor, linewidth=3) 278 | ax[i + row_off,1 + j].add_patch(rect) 279 | 280 | if vmax is None: 281 | this_vmax = 1 282 | elif vmax == 'per_prompt': 283 | this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))]) 284 | elif vmax == 'per_image': 285 | this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))]) 286 | 287 | ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap) 288 | 289 | 290 | # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max()) 291 | plt.tight_layout() 292 | plt.subplots_adjust(wspace=0.05, hspace=0.05) -------------------------------------------------------------------------------- /repositories/clipseg/example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThereforeGames/txt2mask/cebd3962cadc16b953048e9d473f0f36cdd8e5df/repositories/clipseg/example_image.jpg -------------------------------------------------------------------------------- /repositories/clipseg/experiments/ablation.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 # <-########################################## 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut # <----------------- 20 | split_mode: pascal_test 21 | split: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | negative_prob: 0.2 25 | mix_text_max: 0.5 26 | 27 | # general 28 | mix: True # <----------------- 29 | prompt: shuffle+ 30 | norm_cond: True 31 | mix_text_min: 0.0 32 | with_visual: True 33 | 34 | # model 35 | version: 'ViT-B/16' 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False # <-########################################## 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | batch_size: 32 48 | sigmoid: True 49 | split: test 50 | label_support: True 51 | 52 | test_configuration: 53 | 54 | - 55 | name: pc 56 | metric: metrics.FixedIntervalMetrics 57 | test_dataset: phrasecut 58 | mask: text 59 | 60 | - 61 | name: pc-vis 62 | metric: metrics.FixedIntervalMetrics 63 | test_dataset: phrasecut 64 | mask: crop_blur_highlight352 65 | with_visual: True 66 | visual_only: True 67 | 68 | 69 | columns: [name, 70 | pc_fgiou_best, pc_miou_best, pc_fgiou_0.5, 71 | pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5, 72 | duration] 73 | 74 | 75 | individual_configurations: 76 | 77 | - {name: rd64-uni} 78 | - {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003} 79 | - {name: rd64-no-negatives, negative_prob: 0.0} 80 | - {name: rd64-neg0.5, negative_prob: 0.5} 81 | - {name: rd64-no-visual, with_visual: False, mix: False} 82 | - {name: rd16-uni, reduce_dim: 16} 83 | - {name: rd64-layer3, extract_layers: [3], depth: 1} 84 | - {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}} -------------------------------------------------------------------------------- /repositories/clipseg/experiments/coco.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.coco_wrapper.COCOWrapper 20 | # split_mode: pascal_test 21 | split: train 22 | mask: text_and_blur3_highlight01 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | 28 | # general 29 | mix: True 30 | prompt: shuffle+ 31 | norm_cond: True 32 | mix_text_min: 0.0 33 | 34 | # model 35 | out: 1 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | # max_iterations: 10 48 | batch_size: 8 49 | sigmoid: True 50 | test_dataset: coco 51 | metric: metrics.FixedIntervalMetrics 52 | 53 | test_configuration: 54 | 55 | - 56 | name: coco_t 57 | mask: text 58 | 59 | - 60 | name: coco_h 61 | mask: blur3_highlight01 62 | 63 | - 64 | name: coco_h2 65 | mask: crop_blur_highlight352 66 | 67 | 68 | columns: [i, name, 69 | coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5, 70 | coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5, 71 | coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t, 72 | train_loss, duration, date 73 | ] 74 | 75 | individual_configurations: 76 | 77 | 78 | - {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 79 | - {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 80 | - {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 81 | - {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 82 | 83 | 84 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 85 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 86 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 87 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 88 | 89 | 90 | # ViT 91 | - {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 92 | - {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 93 | - {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 94 | - {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 95 | 96 | 97 | # BASELINE 98 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 99 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 100 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 101 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} -------------------------------------------------------------------------------- /repositories/clipseg/experiments/pascal_1shot.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 # <-########################################## 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut 20 | split_mode: pascal_test 21 | mode: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | with_visual: True 28 | split: train 29 | 30 | # general 31 | mix: True 32 | prompt: shuffle+ 33 | norm_cond: True 34 | mix_text_min: 0.0 35 | 36 | # model 37 | out: 1 38 | version: 'ViT-B/16' 39 | extract_layers: [3, 7, 9] 40 | reduce_dim: 64 41 | depth: 3 42 | 43 | loss: torch.nn.functional.binary_cross_entropy_with_logits 44 | amp: True 45 | 46 | test_configuration_common: 47 | normalize: True 48 | image_size: 352 49 | metric: metrics.FixedIntervalMetrics 50 | batch_size: 1 51 | test_dataset: pascal 52 | sigmoid: True 53 | # max_iterations: 250 54 | 55 | test_configuration: 56 | 57 | - 58 | name: pas_t 59 | mask: text 60 | 61 | - 62 | name: pas_h 63 | mask: blur3_highlight01 64 | 65 | - 66 | name: pas_h2 67 | mask: crop_blur_highlight352 68 | 69 | 70 | columns: [name, 71 | pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct, 72 | pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct, 73 | pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t, 74 | train_loss, duration, date 75 | ] 76 | 77 | individual_configurations: 78 | 79 | - {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}} 80 | - {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}} 81 | - {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}} 82 | - {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}} 83 | 84 | 85 | - {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}} 86 | - {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}} 87 | - {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}} 88 | - {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}} 89 | 90 | 91 | # baseline 92 | - {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}} 93 | - {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}} 94 | - {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}} 95 | - {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}} 96 | 97 | # ViT 98 | - {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}} 99 | - {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}} 100 | - {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}} 101 | - {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}} 102 | -------------------------------------------------------------------------------- /repositories/clipseg/experiments/phrasecut.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut # <----------------- 20 | split_mode: pascal_test 21 | split: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | 28 | # general 29 | mix: False # <----------------- 30 | prompt: shuffle+ 31 | norm_cond: True 32 | mix_text_min: 0.0 33 | 34 | # model 35 | out: 1 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | batch_size: 32 48 | # max_iterations: 5 49 | # max_iterations: 150 50 | 51 | test_configuration: 52 | 53 | - 54 | name: pc # old: phrasecut 55 | metric: metrics.FixedIntervalMetrics 56 | test_dataset: phrasecut 57 | split: test 58 | mask: text 59 | label_support: True 60 | sigmoid: True 61 | 62 | 63 | columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date] 64 | 65 | 66 | individual_configurations: 67 | 68 | # important ones 69 | 70 | 71 | - {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5} 72 | 73 | # this was accedentally trained using old mask 74 | - {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01} 75 | - {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False} 76 | # this was accedentally trained using old mask 77 | - {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01} 78 | 79 | - {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003} 80 | - {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001} -------------------------------------------------------------------------------- /repositories/clipseg/general_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import inspect 3 | import torch 4 | import os 5 | import sys 6 | import yaml 7 | from shutil import copy, copytree 8 | from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename 9 | 10 | 11 | class Logger(object): 12 | 13 | def __getattr__(self, k): 14 | return print 15 | 16 | log = Logger() 17 | 18 | def training_config_from_cli_args(): 19 | experiment_name = sys.argv[1] 20 | experiment_id = int(sys.argv[2]) 21 | 22 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) 23 | 24 | config = yaml_config['configuration'] 25 | config = {**config, **yaml_config['individual_configurations'][experiment_id]} 26 | config = AttributeDict(config) 27 | return config 28 | 29 | 30 | def score_config_from_cli_args(): 31 | experiment_name = sys.argv[1] 32 | experiment_id = int(sys.argv[2]) 33 | 34 | 35 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) 36 | 37 | config = yaml_config['test_configuration_common'] 38 | 39 | if type(yaml_config['test_configuration']) == list: 40 | test_id = int(sys.argv[3]) 41 | config = {**config, **yaml_config['test_configuration'][test_id]} 42 | else: 43 | config = {**config, **yaml_config['test_configuration']} 44 | 45 | if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]: 46 | config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']} 47 | 48 | train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name'] 49 | 50 | config = AttributeDict(config) 51 | return config, train_checkpoint_id 52 | 53 | 54 | def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository', 55 | local_dir='~/datasets'): 56 | """ copies files from repository to local folder. 57 | 58 | repo_files: list of filenames or list of tuples [filename, target path] 59 | 60 | e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar']) 61 | will create a folder 'MyDataset' in local_dir, and extract the content of 62 | '/data/dataset1.tar' to /MyDataset/other/path. 63 | """ 64 | 65 | local_dir = realpath(join(expanduser(local_dir), local_name)) 66 | 67 | dataset_exists = True 68 | 69 | # check if folder is available 70 | if not isdir(local_dir): 71 | dataset_exists = False 72 | 73 | if integrity_check is not None: 74 | try: 75 | integrity_ok = integrity_check(local_dir) 76 | except BaseException: 77 | integrity_ok = False 78 | 79 | if integrity_ok: 80 | log.hint('Passed custom integrity check') 81 | else: 82 | log.hint('Custom integrity check failed') 83 | 84 | dataset_exists = dataset_exists and integrity_ok 85 | 86 | if not dataset_exists: 87 | 88 | repo_dir = realpath(expanduser(repo_dir)) 89 | 90 | for i, filename in enumerate(repo_files): 91 | 92 | if type(filename) == str: 93 | origin, target = filename, filename 94 | archive_target = join(local_dir, basename(origin)) 95 | extract_target = join(local_dir) 96 | else: 97 | origin, target = filename 98 | archive_target = join(local_dir, dirname(target), basename(origin)) 99 | extract_target = join(local_dir, dirname(target)) 100 | 101 | archive_origin = join(repo_dir, origin) 102 | 103 | log.hint(f'copy: {archive_origin} to {archive_target}') 104 | 105 | # make sure the path exists 106 | os.makedirs(dirname(archive_target), exist_ok=True) 107 | 108 | if os.path.isfile(archive_target): 109 | # only copy if size differs 110 | if os.path.getsize(archive_target) != os.path.getsize(archive_origin): 111 | log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}') 112 | copy(archive_origin, archive_target) 113 | else: 114 | copy(archive_origin, archive_target) 115 | 116 | extract_archive(archive_target, extract_target, noarchive_ok=True) 117 | 118 | # concurrent processes might have deleted the file 119 | if os.path.isfile(archive_target): 120 | os.remove(archive_target) 121 | 122 | 123 | def extract_archive(filename, target_folder=None, noarchive_ok=False): 124 | from subprocess import run, PIPE 125 | 126 | if filename.endswith('.tgz') or filename.endswith('.tar'): 127 | command = f'tar -xf {filename}' 128 | command += f' -C {target_folder}' if target_folder is not None else '' 129 | elif filename.endswith('.tar.gz'): 130 | command = f'tar -xzf {filename}' 131 | command += f' -C {target_folder}' if target_folder is not None else '' 132 | elif filename.endswith('zip'): 133 | command = f'unzip {filename}' 134 | command += f' -d {target_folder}' if target_folder is not None else '' 135 | else: 136 | if noarchive_ok: 137 | return 138 | else: 139 | raise ValueError(f'unsuppored file ending of {filename}') 140 | 141 | log.hint(command) 142 | result = run(command.split(), stdout=PIPE, stderr=PIPE) 143 | if result.returncode != 0: 144 | print(result.stdout, result.stderr) 145 | 146 | 147 | class AttributeDict(dict): 148 | """ 149 | An extended dictionary that allows access to elements as atttributes and counts 150 | these accesses. This way, we know if some attributes were never used. 151 | """ 152 | 153 | def __init__(self, *args, **kwargs): 154 | from collections import Counter 155 | super().__init__(*args, **kwargs) 156 | self.__dict__['counter'] = Counter() 157 | 158 | def __getitem__(self, k): 159 | self.__dict__['counter'][k] += 1 160 | return super().__getitem__(k) 161 | 162 | def __getattr__(self, k): 163 | self.__dict__['counter'][k] += 1 164 | return super().get(k) 165 | 166 | def __setattr__(self, k, v): 167 | return super().__setitem__(k, v) 168 | 169 | def __delattr__(self, k, v): 170 | return super().__delitem__(k, v) 171 | 172 | def unused_keys(self, exceptions=()): 173 | return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions] 174 | 175 | def assume_no_unused_keys(self, exceptions=()): 176 | if len(self.unused_keys(exceptions=exceptions)) > 0: 177 | log.warning('Unused keys:', self.unused_keys(exceptions=exceptions)) 178 | 179 | 180 | def get_attribute(name): 181 | import importlib 182 | 183 | if name is None: 184 | raise ValueError('The provided attribute is None') 185 | 186 | name_split = name.split('.') 187 | mod = importlib.import_module('.'.join(name_split[:-1])) 188 | return getattr(mod, name_split[-1]) 189 | 190 | 191 | 192 | def filter_args(input_args, default_args): 193 | 194 | updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()} 195 | used_args = {k: v for k, v in input_args.items() if k in default_args} 196 | unused_args = {k: v for k, v in input_args.items() if k not in default_args} 197 | 198 | return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args) 199 | 200 | 201 | def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False): 202 | 203 | config = json.load(open(join('logs', checkpoint_id, 'config.json'))) 204 | 205 | if model_args != 'from_config' and type(model_args) != dict: 206 | raise ValueError('model_args must either be "from_config" or a dictionary of values') 207 | 208 | model_cls = get_attribute(config['model']) 209 | 210 | # load model 211 | if model_args == 'from_config': 212 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 213 | 214 | model = model_cls(**model_args) 215 | 216 | if weights_file is None: 217 | weights_file = realpath(join('logs', checkpoint_id, 'weights.pth')) 218 | else: 219 | weights_file = realpath(join('logs', checkpoint_id, weights_file)) 220 | 221 | if isfile(weights_file): 222 | weights = torch.load(weights_file) 223 | for _, w in weights.items(): 224 | assert not torch.any(torch.isnan(w)), 'weights contain NaNs' 225 | model.load_state_dict(weights, strict=strict) 226 | else: 227 | raise FileNotFoundError(f'model checkpoint {weights_file} was not found') 228 | 229 | if with_config: 230 | return model, config 231 | 232 | return model 233 | 234 | 235 | class TrainingLogger(object): 236 | 237 | def __init__(self, model, log_dir, config=None, *args): 238 | super().__init__() 239 | self.model = model 240 | self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None 241 | 242 | os.makedirs('logs/', exist_ok=True) 243 | os.makedirs(self.base_path, exist_ok=True) 244 | 245 | if config is not None: 246 | json.dump(config, open(join(self.base_path, 'config.json'), 'w')) 247 | 248 | def iter(self, i, **kwargs): 249 | if i % 100 == 0 and 'loss' in kwargs: 250 | loss = kwargs['loss'] 251 | print(f'iteration {i}: loss {loss:.4f}') 252 | 253 | def save_weights(self, only_trainable=False, weight_file='weights.pth'): 254 | if self.model is None: 255 | raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.') 256 | 257 | weights_path = join(self.base_path, weight_file) 258 | 259 | weight_dict = self.model.state_dict() 260 | 261 | if only_trainable: 262 | weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad} 263 | 264 | torch.save(weight_dict, weights_path) 265 | log.info(f'Saved weights to {weights_path}') 266 | 267 | def __enter__(self): 268 | return self 269 | 270 | def __exit__(self, type, value, traceback): 271 | """ automatically stop processes if used in a context manager """ 272 | pass -------------------------------------------------------------------------------- /repositories/clipseg/metrics.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | from general_utils import log 3 | from collections import defaultdict 4 | import numpy as np 5 | 6 | import torch 7 | from torch.nn import functional as nnf 8 | 9 | 10 | class BaseMetric(object): 11 | 12 | def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True, 13 | eval_validation=True): 14 | self._names = tuple(metric_names) 15 | self._eval_intermediate = eval_intermediate 16 | self._eval_validation = eval_validation 17 | 18 | self._pred_range = pred_range 19 | self._pred_index = pred_index 20 | self._gt_index = gt_index 21 | 22 | self.predictions = [] 23 | self.ground_truths = [] 24 | 25 | def eval_intermediate(self): 26 | return self._eval_intermediate 27 | 28 | def eval_validation(self): 29 | return self._eval_validation 30 | 31 | def names(self): 32 | return self._names 33 | 34 | def add(self, predictions, ground_truth): 35 | raise NotImplementedError 36 | 37 | def value(self): 38 | raise NotImplementedError 39 | 40 | def scores(self): 41 | # similar to value but returns dict 42 | value = self.value() 43 | if type(value) == dict: 44 | return value 45 | else: 46 | assert type(value) in {list, tuple} 47 | return list(zip(self.names(), self.value())) 48 | 49 | def _get_pred_gt(self, predictions, ground_truth): 50 | pred = predictions[self._pred_index] 51 | gt = ground_truth[self._gt_index] 52 | 53 | if self._pred_range is not None: 54 | pred = pred[:, self._pred_range[0]: self._pred_range[1]] 55 | 56 | return pred, gt 57 | 58 | 59 | class FixedIntervalMetrics(BaseMetric): 60 | 61 | def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None, 62 | resize_pred=None, n_values=51, custom_threshold=None): 63 | 64 | 65 | super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh')) 66 | self.intersections = [] 67 | self.unions = [] 68 | # self.threshold = threshold 69 | self.sigmoid = sigmoid 70 | self.resize_to = resize_to 71 | self.resize_pred = resize_pred # resize prediction to match ground truth 72 | self.class_count = defaultdict(lambda: 0) 73 | self.per_class = defaultdict(lambda : [0,0]) 74 | self.ignore_mask = ignore_mask 75 | self.custom_threshold = custom_threshold 76 | 77 | self.scores_ap = [] 78 | self.scores_iou = [] 79 | self.gts, self.preds = [], [] 80 | self.classes = [] 81 | 82 | # [1:-1] ignores 0 and 1 83 | self.threshold_values = np.linspace(0, 1, n_values)[1:-1] 84 | 85 | self.metrics = dict(tp=[], fp=[], fn=[], tn=[]) 86 | 87 | def add(self, pred, gt): 88 | 89 | pred_batch = pred[0].cpu() 90 | 91 | if self.sigmoid: 92 | pred_batch = torch.sigmoid(pred_batch) 93 | 94 | gt_batch = gt[0].cpu() 95 | mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch)) 96 | cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch) 97 | 98 | if self.resize_to is not None: 99 | gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest') 100 | pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False) 101 | 102 | if isinstance(cls_batch, torch.Tensor): 103 | cls_batch = cls_batch.cpu().numpy().tolist() 104 | 105 | assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}' 106 | 107 | for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch): 108 | 109 | if self.resize_pred: 110 | predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True) 111 | 112 | p = predictions.flatten() 113 | g = ground_truth.flatten() 114 | 115 | assert len(p) == len(g) 116 | 117 | if mask is not None: 118 | m = mask.flatten().bool() 119 | p = p[m] 120 | g = g[m] 121 | 122 | p_sorted = p.sort() 123 | p = p_sorted.values 124 | g = g[p_sorted.indices] 125 | 126 | tps, fps, fns, tns = [], [], [], [] 127 | for thresh in self.threshold_values: 128 | 129 | valid = torch.where(p > thresh)[0] 130 | if len(valid) > 0: 131 | n = int(valid[0]) 132 | else: 133 | n = len(g) 134 | 135 | fn = int(g[:n].sum()) 136 | tp = int(g[n:].sum()) 137 | fns += [fn] 138 | tns += [n - fn] 139 | tps += [tp] 140 | fps += [len(g) - n - tp] 141 | 142 | self.metrics['tp'] += [tps] 143 | self.metrics['fp'] += [fps] 144 | self.metrics['fn'] += [fns] 145 | self.metrics['tn'] += [tns] 146 | 147 | self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls] 148 | 149 | def value(self): 150 | 151 | import time 152 | t_start = time.time() 153 | 154 | if set(self.classes) == set([None]): 155 | all_classes = None 156 | log.warning('classes were not provided, cannot compute mIoU') 157 | else: 158 | all_classes = set(int(c) for c in self.classes) 159 | # log.info(f'compute metrics for {len(all_classes)} classes') 160 | 161 | summed = {k: [sum([self.metrics[k][i][j] 162 | for i in range(len(self.metrics[k]))]) 163 | for j in range(len(self.threshold_values))] 164 | for k in self.metrics.keys()} 165 | 166 | if all_classes is not None: 167 | 168 | assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn']) 169 | # group by class 170 | metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes} 171 | for i in range(len(self.metrics['tp'])): 172 | for k in self.metrics.keys(): 173 | metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]] 174 | 175 | # sum over all instances within the classes 176 | summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()} 177 | 178 | 179 | # Compute average precision 180 | 181 | assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made' 182 | 183 | # only consider values where a prediction is made 184 | precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values)) 185 | if summed['tp'][j] + summed['fp'][j] > 0] 186 | recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values)) 187 | if summed['tp'][j] + summed['fp'][j] > 0] 188 | 189 | # remove duplicate recall-precision-pairs (and sort by recall value) 190 | recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0])) 191 | 192 | from scipy.integrate import simps 193 | ap = simps(precisions, recalls) 194 | 195 | # Compute best IoU 196 | fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))] 197 | 198 | biniou_scores = [ 199 | 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) + 200 | 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j])) 201 | for j in range(len(self.threshold_values)) 202 | ] 203 | 204 | index_0p5 = self.threshold_values.tolist().index(0.5) 205 | index_0p1 = self.threshold_values.tolist().index(0.1) 206 | index_0p2 = self.threshold_values.tolist().index(0.2) 207 | index_0p3 = self.threshold_values.tolist().index(0.3) 208 | 209 | if self.custom_threshold is not None: 210 | index_ct = self.threshold_values.tolist().index(self.custom_threshold) 211 | 212 | if all_classes is not None: 213 | # mean IoU 214 | mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j]) 215 | for c in all_classes]) 216 | for j in range(len(self.threshold_values))] 217 | 218 | mean_iou_dict = { 219 | 'miou_best': max(mean_ious) if all_classes is not None else None, 220 | 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None, 221 | 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None, 222 | 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None, 223 | 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None, 224 | 'miou_best_t': self.threshold_values[np.argmax(mean_ious)], 225 | 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None, 226 | 'mean_iou_scores': mean_ious, 227 | } 228 | 229 | print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s') 230 | 231 | return { 232 | 'ap': ap, 233 | 234 | # fgiou 235 | 'fgiou_best': max(fgiou_scores), 236 | 'fgiou_0.5': fgiou_scores[index_0p5], 237 | 'fgiou_0.1': fgiou_scores[index_0p1], 238 | 'fgiou_0.2': fgiou_scores[index_0p2], 239 | 'fgiou_0.3': fgiou_scores[index_0p3], 240 | 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)], 241 | 242 | # mean iou 243 | 244 | 245 | # biniou 246 | 'biniou_best': max(biniou_scores), 247 | 'biniou_0.5': biniou_scores[index_0p5], 248 | 'biniou_0.1': biniou_scores[index_0p1], 249 | 'biniou_0.2': biniou_scores[index_0p2], 250 | 'biniou_0.3': biniou_scores[index_0p3], 251 | 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)], 252 | 253 | # custom threshold 254 | 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None, 255 | 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None, 256 | 'ct': self.custom_threshold, 257 | 258 | # statistics 259 | 'fgiou_scores': fgiou_scores, 260 | 'biniou_scores': biniou_scores, 261 | 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))), 262 | 'summed_statistics': summed, 263 | 'summed_by_cls_statistics': summed_by_cls, 264 | 265 | **mean_iou_dict 266 | } 267 | 268 | # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh' 269 | 270 | # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls} 271 | 272 | -------------------------------------------------------------------------------- /repositories/clipseg/models/clipseg.py: -------------------------------------------------------------------------------- 1 | import math 2 | from os.path import basename, dirname, join, isfile 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as nnf 6 | from torch.nn.modules.activation import ReLU 7 | 8 | 9 | def precompute_clip_vectors(): 10 | 11 | from trails.initialization import init_dataset 12 | lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True, 13 | reduce_factor=None, add_bar=False, negative_prob=0.5) 14 | 15 | all_names = list(lvis.category_names.values()) 16 | 17 | import clip 18 | from models.clip_prompts import imagenet_templates 19 | clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0] 20 | prompt_vectors = {} 21 | for name in all_names[:100]: 22 | with torch.no_grad(): 23 | conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates] 24 | text_tokens = clip.tokenize(conditionals).cuda() 25 | cond = clip_model.encode_text(text_tokens).cpu() 26 | 27 | for cond, vec in zip(conditionals, cond): 28 | prompt_vectors[cond] = vec.cpu() 29 | 30 | import pickle 31 | 32 | pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb')) 33 | 34 | 35 | def get_prompt_list(prompt): 36 | if prompt == 'plain': 37 | return ['{}'] 38 | elif prompt == 'fixed': 39 | return ['a photo of a {}.'] 40 | elif prompt == 'shuffle': 41 | return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] 42 | elif prompt == 'shuffle+': 43 | return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', 44 | 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', 45 | 'a bad photo of a {}.', 'a photo of the {}.'] 46 | elif prompt == 'shuffle_clip': 47 | from models.clip_prompts import imagenet_templates 48 | return imagenet_templates 49 | else: 50 | raise ValueError('Invalid value for prompt') 51 | 52 | 53 | def forward_multihead_attention(x, b, with_aff=False, attn_mask=None): 54 | """ 55 | Simplified version of multihead attention (taken from torch source code but without tons of if clauses). 56 | The mlp and layer norm come from CLIP. 57 | x: input. 58 | b: multihead attention module. 59 | """ 60 | 61 | x_ = b.ln_1(x) 62 | q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1) 63 | tgt_len, bsz, embed_dim = q.size() 64 | 65 | head_dim = embed_dim // b.attn.num_heads 66 | scaling = float(head_dim) ** -0.5 67 | 68 | q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) 69 | k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) 70 | v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) 71 | 72 | q = q * scaling 73 | 74 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2 75 | if attn_mask is not None: 76 | 77 | 78 | attn_mask_type, attn_mask = attn_mask 79 | n_heads = attn_output_weights.size(0) // attn_mask.size(0) 80 | attn_mask = attn_mask.repeat(n_heads, 1) 81 | 82 | if attn_mask_type == 'cls_token': 83 | # the mask only affects similarities compared to the readout-token. 84 | attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...] 85 | # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0] 86 | 87 | if attn_mask_type == 'all': 88 | # print(attn_output_weights.shape, attn_mask[:, None].shape) 89 | attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None] 90 | 91 | 92 | attn_output_weights = torch.softmax(attn_output_weights, dim=-1) 93 | 94 | attn_output = torch.bmm(attn_output_weights, v) 95 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 96 | attn_output = b.attn.out_proj(attn_output) 97 | 98 | x = x + attn_output 99 | x = x + b.mlp(b.ln_2(x)) 100 | 101 | if with_aff: 102 | return x, attn_output_weights 103 | else: 104 | return x 105 | 106 | 107 | class CLIPDenseBase(nn.Module): 108 | 109 | def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens): 110 | super().__init__() 111 | 112 | import clip 113 | 114 | # prec = torch.FloatTensor 115 | self.clip_model, _ = clip.load(version, device='cpu', jit=False) 116 | self.model = self.clip_model.visual 117 | 118 | # if not None, scale conv weights such that we obtain n_tokens. 119 | self.n_tokens = n_tokens 120 | 121 | for p in self.clip_model.parameters(): 122 | p.requires_grad_(False) 123 | 124 | # conditional 125 | if reduce_cond is not None: 126 | self.reduce_cond = nn.Linear(512, reduce_cond) 127 | for p in self.reduce_cond.parameters(): 128 | p.requires_grad_(False) 129 | else: 130 | self.reduce_cond = None 131 | 132 | self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 133 | self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 134 | 135 | self.reduce = nn.Linear(768, reduce_dim) 136 | 137 | self.prompt_list = get_prompt_list(prompt) 138 | 139 | # precomputed prompts 140 | import pickle 141 | if isfile('precomputed_prompt_vectors.pickle'): 142 | precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) 143 | self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} 144 | else: 145 | self.precomputed_prompts = dict() 146 | 147 | def rescaled_pos_emb(self, new_size): 148 | assert len(new_size) == 2 149 | 150 | a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) 151 | b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T 152 | return torch.cat([self.model.positional_embedding[:1], b]) 153 | 154 | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): 155 | 156 | 157 | with torch.no_grad(): 158 | 159 | inp_size = x_inp.shape[2:] 160 | 161 | if self.n_tokens is not None: 162 | stride2 = x_inp.shape[2] // self.n_tokens 163 | conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True) 164 | x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation) 165 | else: 166 | x = self.model.conv1(x_inp) # shape = [*, width, grid, grid] 167 | 168 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 169 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 170 | 171 | x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 172 | 173 | standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197 174 | 175 | if x.shape[1] != standard_n_tokens: 176 | new_shape = int(math.sqrt(x.shape[1]-1)) 177 | x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:] 178 | else: 179 | x = x + self.model.positional_embedding.to(x.dtype) 180 | 181 | x = self.model.ln_pre(x) 182 | 183 | x = x.permute(1, 0, 2) # NLD -> LND 184 | 185 | activations, affinities = [], [] 186 | for i, res_block in enumerate(self.model.transformer.resblocks): 187 | 188 | if mask is not None: 189 | mask_layer, mask_type, mask_tensor = mask 190 | if mask_layer == i or mask_layer == 'all': 191 | # import ipdb; ipdb.set_trace() 192 | size = int(math.sqrt(x.shape[0] - 1)) 193 | 194 | attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size)) 195 | 196 | else: 197 | attn_mask = None 198 | else: 199 | attn_mask = None 200 | 201 | x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask) 202 | 203 | if i in extract_layers: 204 | affinities += [aff_per_head] 205 | 206 | #if self.n_tokens is not None: 207 | # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)] 208 | #else: 209 | activations += [x] 210 | 211 | if len(extract_layers) > 0 and i == max(extract_layers) and skip: 212 | print('early skip') 213 | break 214 | 215 | x = x.permute(1, 0, 2) # LND -> NLD 216 | x = self.model.ln_post(x[:, 0, :]) 217 | 218 | if self.model.proj is not None: 219 | x = x @ self.model.proj 220 | 221 | return x, activations, affinities 222 | 223 | def sample_prompts(self, words, prompt_list=None): 224 | 225 | prompt_list = prompt_list if prompt_list is not None else self.prompt_list 226 | 227 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 228 | prompts = [prompt_list[i] for i in prompt_indices] 229 | return [promt.format(w) for promt, w in zip(prompts, words)] 230 | 231 | def get_cond_vec(self, conditional, batch_size): 232 | # compute conditional from a single string 233 | if conditional is not None and type(conditional) == str: 234 | cond = self.compute_conditional(conditional) 235 | cond = cond.repeat(batch_size, 1) 236 | 237 | # compute conditional from string list/tuple 238 | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: 239 | assert len(conditional) == batch_size 240 | cond = self.compute_conditional(conditional) 241 | 242 | # use conditional directly 243 | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: 244 | cond = conditional 245 | 246 | # compute conditional from image 247 | elif conditional is not None and type(conditional) == torch.Tensor: 248 | with torch.no_grad(): 249 | cond, _, _ = self.visual_forward(conditional) 250 | else: 251 | raise ValueError('invalid conditional') 252 | return cond 253 | 254 | def compute_conditional(self, conditional): 255 | import clip 256 | 257 | dev = next(self.parameters()).device 258 | 259 | if type(conditional) in {list, tuple}: 260 | text_tokens = clip.tokenize(conditional).to(dev) 261 | cond = self.clip_model.encode_text(text_tokens) 262 | else: 263 | if conditional in self.precomputed_prompts: 264 | cond = self.precomputed_prompts[conditional].float().to(dev) 265 | else: 266 | text_tokens = clip.tokenize([conditional]).to(dev) 267 | cond = self.clip_model.encode_text(text_tokens)[0] 268 | 269 | if self.shift_vector is not None: 270 | return cond + self.shift_vector 271 | else: 272 | return cond 273 | 274 | 275 | def clip_load_untrained(version): 276 | assert version == 'ViT-B/16' 277 | from clip.model import CLIP 278 | from clip.clip import _MODELS, _download 279 | model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval() 280 | state_dict = model.state_dict() 281 | 282 | vision_width = state_dict["visual.conv1.weight"].shape[0] 283 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 284 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 285 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 286 | image_resolution = vision_patch_size * grid_size 287 | embed_dim = state_dict["text_projection"].shape[1] 288 | context_length = state_dict["positional_embedding"].shape[0] 289 | vocab_size = state_dict["token_embedding.weight"].shape[0] 290 | transformer_width = state_dict["ln_final.weight"].shape[0] 291 | transformer_heads = transformer_width // 64 292 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 293 | 294 | return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, 295 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) 296 | 297 | 298 | class CLIPDensePredT(CLIPDenseBase): 299 | 300 | def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', 301 | extra_blocks=0, reduce_cond=None, fix_shift=False, 302 | learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False, 303 | add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None): 304 | 305 | super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) 306 | # device = 'cpu' 307 | 308 | self.extract_layers = extract_layers 309 | self.cond_layer = cond_layer 310 | self.limit_to_clip_only = limit_to_clip_only 311 | self.process_cond = None 312 | self.rev_activations = rev_activations 313 | 314 | depth = len(extract_layers) 315 | 316 | if add_calibration: 317 | self.calibration_conds = 1 318 | 319 | self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None 320 | 321 | self.add_activation1 = True 322 | 323 | self.version = version 324 | 325 | self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] 326 | 327 | if fix_shift: 328 | # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False) 329 | self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False) 330 | # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False) 331 | else: 332 | self.shift_vector = None 333 | 334 | if trans_conv is None: 335 | trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] 336 | else: 337 | # explicitly define transposed conv kernel size 338 | trans_conv_ks = (trans_conv, trans_conv) 339 | 340 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 341 | 342 | assert len(self.extract_layers) == depth 343 | 344 | self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) 345 | self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) 346 | self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) 347 | 348 | # refinement and trans conv 349 | 350 | if learn_trans_conv_only: 351 | for p in self.parameters(): 352 | p.requires_grad_(False) 353 | 354 | for p in self.trans_conv.parameters(): 355 | p.requires_grad_(True) 356 | 357 | self.prompt_list = get_prompt_list(prompt) 358 | 359 | 360 | def forward(self, inp_image, conditional=None, return_features=False, mask=None): 361 | 362 | assert type(return_features) == bool 363 | 364 | inp_image = inp_image.to(self.model.positional_embedding.device) 365 | 366 | if mask is not None: 367 | raise ValueError('mask not supported') 368 | 369 | # x_inp = normalize(inp_image) 370 | x_inp = inp_image 371 | 372 | bs, dev = inp_image.shape[0], x_inp.device 373 | 374 | cond = self.get_cond_vec(conditional, bs) 375 | 376 | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) 377 | 378 | activation1 = activations[0] 379 | activations = activations[1:] 380 | 381 | _activations = activations[::-1] if not self.rev_activations else activations 382 | 383 | a = None 384 | for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)): 385 | 386 | if a is not None: 387 | a = reduce(activation) + a 388 | else: 389 | a = reduce(activation) 390 | 391 | if i == self.cond_layer: 392 | if self.reduce_cond is not None: 393 | cond = self.reduce_cond(cond) 394 | 395 | a = self.film_mul(cond) * a + self.film_add(cond) 396 | 397 | a = block(a) 398 | 399 | for block in self.extra_blocks: 400 | a = a + block(a) 401 | 402 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens 403 | 404 | size = int(math.sqrt(a.shape[2])) 405 | 406 | a = a.view(bs, a.shape[1], size, size) 407 | 408 | a = self.trans_conv(a) 409 | 410 | if self.n_tokens is not None: 411 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True) 412 | 413 | if self.upsample_proj is not None: 414 | a = self.upsample_proj(a) 415 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') 416 | 417 | if return_features: 418 | return a, visual_q, cond, [activation1] + activations 419 | else: 420 | return a, 421 | 422 | 423 | 424 | class CLIPDensePredTMasked(CLIPDensePredT): 425 | 426 | def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, 427 | prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False, 428 | refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None): 429 | 430 | super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim, 431 | n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond, 432 | fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only, 433 | limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration, 434 | n_tokens=n_tokens) 435 | 436 | def visual_forward_masked(self, img_s, seg_s): 437 | return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s)) 438 | 439 | def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False): 440 | 441 | if seg_s is None: 442 | cond = cond_or_img_s 443 | else: 444 | img_s = cond_or_img_s 445 | 446 | with torch.no_grad(): 447 | cond, _, _ = self.visual_forward_masked(img_s, seg_s) 448 | 449 | return super().forward(img_q, cond, return_features=return_features) 450 | 451 | 452 | 453 | class CLIPDenseBaseline(CLIPDenseBase): 454 | 455 | def __init__(self, version='ViT-B/32', cond_layer=0, 456 | extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed', 457 | reduce_cond=None, limit_to_clip_only=False, n_tokens=None): 458 | 459 | super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) 460 | device = 'cpu' 461 | 462 | # self.cond_layer = cond_layer 463 | self.extract_layer = extract_layer 464 | self.limit_to_clip_only = limit_to_clip_only 465 | self.shift_vector = None 466 | 467 | self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] 468 | 469 | assert reduce2_dim is not None 470 | 471 | self.reduce2 = nn.Sequential( 472 | nn.Linear(reduce_dim, reduce2_dim), 473 | nn.ReLU(), 474 | nn.Linear(reduce2_dim, reduce_dim) 475 | ) 476 | 477 | trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] 478 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 479 | 480 | 481 | def forward(self, inp_image, conditional=None, return_features=False): 482 | 483 | inp_image = inp_image.to(self.model.positional_embedding.device) 484 | 485 | # x_inp = normalize(inp_image) 486 | x_inp = inp_image 487 | 488 | bs, dev = inp_image.shape[0], x_inp.device 489 | 490 | cond = self.get_cond_vec(conditional, bs) 491 | 492 | visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer]) 493 | 494 | a = activations[0] 495 | a = self.reduce(a) 496 | a = self.film_mul(cond) * a + self.film_add(cond) 497 | 498 | if self.reduce2 is not None: 499 | a = self.reduce2(a) 500 | 501 | # the original model would execute a transformer block here 502 | 503 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens 504 | 505 | size = int(math.sqrt(a.shape[2])) 506 | 507 | a = a.view(bs, a.shape[1], size, size) 508 | a = self.trans_conv(a) 509 | 510 | if return_features: 511 | return a, visual_q, cond, activations 512 | else: 513 | return a, 514 | 515 | 516 | class CLIPSegMultiLabel(nn.Module): 517 | 518 | def __init__(self, model) -> None: 519 | super().__init__() 520 | 521 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC 522 | 523 | self.pascal_classes = VOC 524 | 525 | from models.clipseg import CLIPDensePredT 526 | from general_utils import load_model 527 | # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False) 528 | self.clipseg = load_model(model, strict=False) 529 | 530 | self.clipseg.eval() 531 | 532 | def forward(self, x): 533 | 534 | bs = x.shape[0] 535 | out = torch.ones(21, bs, 352, 352).to(x.device) * -10 536 | 537 | for class_id, class_name in enumerate(self.pascal_classes): 538 | 539 | fac = 3 if class_name == 'background' else 1 540 | 541 | with torch.no_grad(): 542 | pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac 543 | 544 | out[class_id] += pred 545 | 546 | 547 | out = out.permute(1, 0, 2, 3) 548 | 549 | return out 550 | 551 | # construct output tensor 552 | -------------------------------------------------------------------------------- /repositories/clipseg/models/vitseg.py: -------------------------------------------------------------------------------- 1 | import math 2 | from posixpath import basename, dirname, join 3 | # import clip 4 | from clip.model import convert_weights 5 | import torch 6 | import json 7 | from torch import nn 8 | from torch.nn import functional as nnf 9 | from torch.nn.modules import activation 10 | from torch.nn.modules.activation import ReLU 11 | from torchvision import transforms 12 | 13 | normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 14 | 15 | from torchvision.models import ResNet 16 | 17 | 18 | def process_prompts(conditional, prompt_list, conditional_map): 19 | # DEPRECATED 20 | 21 | # randomly sample a synonym 22 | words = [conditional_map[int(i)] for i in conditional] 23 | words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] 24 | words = [w.replace('_', ' ') for w in words] 25 | 26 | if prompt_list is not None: 27 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 28 | prompts = [prompt_list[i] for i in prompt_indices] 29 | else: 30 | prompts = ['a photo of {}'] * (len(words)) 31 | 32 | return [promt.format(w) for promt, w in zip(prompts, words)] 33 | 34 | 35 | class VITDenseBase(nn.Module): 36 | 37 | def rescaled_pos_emb(self, new_size): 38 | assert len(new_size) == 2 39 | 40 | a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) 41 | b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T 42 | return torch.cat([self.model.positional_embedding[:1], b]) 43 | 44 | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): 45 | 46 | with torch.no_grad(): 47 | 48 | x_inp = nnf.interpolate(x_inp, (384, 384)) 49 | 50 | x = self.model.patch_embed(x_inp) 51 | cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 52 | if self.model.dist_token is None: 53 | x = torch.cat((cls_token, x), dim=1) 54 | else: 55 | x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 56 | x = self.model.pos_drop(x + self.model.pos_embed) 57 | 58 | activations = [] 59 | for i, block in enumerate(self.model.blocks): 60 | x = block(x) 61 | 62 | if i in extract_layers: 63 | # permute to be compatible with CLIP 64 | activations += [x.permute(1,0,2)] 65 | 66 | x = self.model.norm(x) 67 | x = self.model.head(self.model.pre_logits(x[:, 0])) 68 | 69 | # again for CLIP compatibility 70 | # x = x.permute(1, 0, 2) 71 | 72 | return x, activations, None 73 | 74 | def sample_prompts(self, words, prompt_list=None): 75 | 76 | prompt_list = prompt_list if prompt_list is not None else self.prompt_list 77 | 78 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 79 | prompts = [prompt_list[i] for i in prompt_indices] 80 | return [promt.format(w) for promt, w in zip(prompts, words)] 81 | 82 | def get_cond_vec(self, conditional, batch_size): 83 | # compute conditional from a single string 84 | if conditional is not None and type(conditional) == str: 85 | cond = self.compute_conditional(conditional) 86 | cond = cond.repeat(batch_size, 1) 87 | 88 | # compute conditional from string list/tuple 89 | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: 90 | assert len(conditional) == batch_size 91 | cond = self.compute_conditional(conditional) 92 | 93 | # use conditional directly 94 | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: 95 | cond = conditional 96 | 97 | # compute conditional from image 98 | elif conditional is not None and type(conditional) == torch.Tensor: 99 | with torch.no_grad(): 100 | cond, _, _ = self.visual_forward(conditional) 101 | else: 102 | raise ValueError('invalid conditional') 103 | return cond 104 | 105 | def compute_conditional(self, conditional): 106 | import clip 107 | 108 | dev = next(self.parameters()).device 109 | 110 | if type(conditional) in {list, tuple}: 111 | text_tokens = clip.tokenize(conditional).to(dev) 112 | cond = self.clip_model.encode_text(text_tokens) 113 | else: 114 | if conditional in self.precomputed_prompts: 115 | cond = self.precomputed_prompts[conditional].float().to(dev) 116 | else: 117 | text_tokens = clip.tokenize([conditional]).to(dev) 118 | cond = self.clip_model.encode_text(text_tokens)[0] 119 | 120 | return cond 121 | 122 | 123 | class VITDensePredT(VITDenseBase): 124 | 125 | def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', 126 | depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, 127 | learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, 128 | add_calibration=False, process_cond=None, not_pretrained=False): 129 | super().__init__() 130 | # device = 'cpu' 131 | 132 | self.extract_layers = extract_layers 133 | self.cond_layer = cond_layer 134 | self.limit_to_clip_only = limit_to_clip_only 135 | self.process_cond = None 136 | 137 | if add_calibration: 138 | self.calibration_conds = 1 139 | 140 | self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None 141 | 142 | self.add_activation1 = True 143 | 144 | import timm 145 | self.model = timm.create_model('vit_base_patch16_384', pretrained=True) 146 | self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) 147 | 148 | for p in self.model.parameters(): 149 | p.requires_grad_(False) 150 | 151 | import clip 152 | self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) 153 | # del self.clip_model.visual 154 | 155 | 156 | self.token_shape = (14, 14) 157 | 158 | # conditional 159 | if reduce_cond is not None: 160 | self.reduce_cond = nn.Linear(512, reduce_cond) 161 | for p in self.reduce_cond.parameters(): 162 | p.requires_grad_(False) 163 | else: 164 | self.reduce_cond = None 165 | 166 | # self.film = AVAILABLE_BLOCKS['film'](512, 128) 167 | self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 168 | self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 169 | 170 | # DEPRECATED 171 | # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} 172 | 173 | assert len(self.extract_layers) == depth 174 | 175 | self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) 176 | self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) 177 | self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) 178 | 179 | trans_conv_ks = (16, 16) 180 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 181 | 182 | # refinement and trans conv 183 | 184 | if learn_trans_conv_only: 185 | for p in self.parameters(): 186 | p.requires_grad_(False) 187 | 188 | for p in self.trans_conv.parameters(): 189 | p.requires_grad_(True) 190 | 191 | if prompt == 'fixed': 192 | self.prompt_list = ['a photo of a {}.'] 193 | elif prompt == 'shuffle': 194 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] 195 | elif prompt == 'shuffle+': 196 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', 197 | 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', 198 | 'a bad photo of a {}.', 'a photo of the {}.'] 199 | elif prompt == 'shuffle_clip': 200 | from models.clip_prompts import imagenet_templates 201 | self.prompt_list = imagenet_templates 202 | 203 | if process_cond is not None: 204 | if process_cond == 'clamp' or process_cond[0] == 'clamp': 205 | 206 | val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 207 | 208 | def clamp_vec(x): 209 | return torch.clamp(x, -val, val) 210 | 211 | self.process_cond = clamp_vec 212 | 213 | elif process_cond.endswith('.pth'): 214 | 215 | shift = torch.load(process_cond) 216 | def add_shift(x): 217 | return x + shift.to(x.device) 218 | 219 | self.process_cond = add_shift 220 | 221 | import pickle 222 | precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) 223 | self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} 224 | 225 | 226 | def forward(self, inp_image, conditional=None, return_features=False, mask=None): 227 | 228 | assert type(return_features) == bool 229 | 230 | # inp_image = inp_image.to(self.model.positional_embedding.device) 231 | 232 | if mask is not None: 233 | raise ValueError('mask not supported') 234 | 235 | # x_inp = normalize(inp_image) 236 | x_inp = inp_image 237 | 238 | bs, dev = inp_image.shape[0], x_inp.device 239 | 240 | inp_image_size = inp_image.shape[2:] 241 | 242 | cond = self.get_cond_vec(conditional, bs) 243 | 244 | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) 245 | 246 | activation1 = activations[0] 247 | activations = activations[1:] 248 | 249 | a = None 250 | for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): 251 | 252 | if a is not None: 253 | a = reduce(activation) + a 254 | else: 255 | a = reduce(activation) 256 | 257 | if i == self.cond_layer: 258 | if self.reduce_cond is not None: 259 | cond = self.reduce_cond(cond) 260 | 261 | a = self.film_mul(cond) * a + self.film_add(cond) 262 | 263 | a = block(a) 264 | 265 | for block in self.extra_blocks: 266 | a = a + block(a) 267 | 268 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens 269 | 270 | size = int(math.sqrt(a.shape[2])) 271 | 272 | a = a.view(bs, a.shape[1], size, size) 273 | 274 | if self.trans_conv is not None: 275 | a = self.trans_conv(a) 276 | 277 | if self.upsample_proj is not None: 278 | a = self.upsample_proj(a) 279 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') 280 | 281 | a = nnf.interpolate(a, inp_image_size) 282 | 283 | if return_features: 284 | return a, visual_q, cond, [activation1] + activations 285 | else: 286 | return a, 287 | -------------------------------------------------------------------------------- /repositories/clipseg/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThereforeGames/txt2mask/cebd3962cadc16b953048e9d473f0f36cdd8e5df/repositories/clipseg/overview.png -------------------------------------------------------------------------------- /repositories/clipseg/score.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | 3 | import torch 4 | import inspect 5 | import json 6 | import yaml 7 | import time 8 | import sys 9 | 10 | from general_utils import log 11 | 12 | import numpy as np 13 | from os.path import expanduser, join, isfile, realpath 14 | 15 | from torch.utils.data import DataLoader 16 | 17 | from metrics import FixedIntervalMetrics 18 | 19 | from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args 20 | 21 | 22 | DATASET_CACHE = dict() 23 | 24 | def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False): 25 | 26 | config = json.load(open(join('logs', checkpoint_id, 'config.json'))) 27 | 28 | if model_args != 'from_config' and type(model_args) != dict: 29 | raise ValueError('model_args must either be "from_config" or a dictionary of values') 30 | 31 | model_cls = get_attribute(config['model']) 32 | 33 | # load model 34 | if model_args == 'from_config': 35 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 36 | 37 | model = model_cls(**model_args) 38 | 39 | if weights_file is None: 40 | weights_file = realpath(join('logs', checkpoint_id, 'weights.pth')) 41 | else: 42 | weights_file = realpath(join('logs', checkpoint_id, weights_file)) 43 | 44 | if isfile(weights_file) and not ignore_weights: 45 | weights = torch.load(weights_file) 46 | for _, w in weights.items(): 47 | assert not torch.any(torch.isnan(w)), 'weights contain NaNs' 48 | model.load_state_dict(weights, strict=strict) 49 | else: 50 | if not ignore_weights: 51 | raise FileNotFoundError(f'model checkpoint {weights_file} was not found') 52 | 53 | if with_config: 54 | return model, config 55 | 56 | return model 57 | 58 | 59 | def compute_shift2(model, datasets, seed=123, repetitions=1): 60 | """ computes shift """ 61 | 62 | model.eval() 63 | model.cuda() 64 | 65 | import random 66 | random.seed(seed) 67 | 68 | preds, gts = [], [] 69 | for i_dataset, dataset in enumerate(datasets): 70 | 71 | loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False) 72 | 73 | max_iterations = int(repetitions * len(dataset.dataset.data_list)) 74 | 75 | with torch.no_grad(): 76 | 77 | i, losses = 0, [] 78 | for i_all, (data_x, data_y) in enumerate(loader): 79 | 80 | data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x] 81 | data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y] 82 | 83 | pred, = model(data_x[0], data_x[1], data_x[2]) 84 | preds += [pred.detach()] 85 | gts += [data_y] 86 | 87 | i += 1 88 | if max_iterations and i >= max_iterations: 89 | break 90 | 91 | from metrics import FixedIntervalMetrics 92 | n_values = 51 93 | thresholds = np.linspace(0, 1, n_values)[1:-1] 94 | metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values) 95 | 96 | for p, y in zip(preds, gts): 97 | metric.add(p.unsqueeze(1), y) 98 | 99 | best_idx = np.argmax(metric.value()['fgiou_scores']) 100 | best_thresh = thresholds[best_idx] 101 | 102 | return best_thresh 103 | 104 | 105 | def get_cached_pascal_pfe(split, config): 106 | from datasets.pfe_dataset import PFEPascalWrapper 107 | try: 108 | dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] 109 | except KeyError: 110 | dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support) 111 | DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset 112 | return dataset 113 | 114 | 115 | 116 | 117 | def main(): 118 | config, train_checkpoint_id = score_config_from_cli_args() 119 | 120 | metrics = score(config, train_checkpoint_id, None) 121 | 122 | for dataset in metrics.keys(): 123 | for k in metrics[dataset]: 124 | if type(metrics[dataset][k]) in {float, int}: 125 | print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}') 126 | 127 | 128 | def score(config, train_checkpoint_id, train_config): 129 | 130 | config = AttributeDict(config) 131 | 132 | print(config) 133 | 134 | # use training dataset and loss 135 | train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json'))) 136 | 137 | cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else '' 138 | 139 | 140 | model_cls = get_attribute(train_config['model']) 141 | 142 | _, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters) 143 | 144 | model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}} 145 | 146 | strict_models = {'ConditionBase4', 'PFENetWrapper'} 147 | model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args, 148 | weights_file=f'weights{cp_str}.pth', ) 149 | 150 | 151 | model.eval() 152 | model.cuda() 153 | 154 | metric_args = dict() 155 | 156 | if 'threshold' in config: 157 | if config.metric.split('.')[-1] == 'SkLearnMetrics': 158 | metric_args['threshold'] = config.threshold 159 | 160 | if 'resize_to' in config: 161 | metric_args['resize_to'] = config.resize_to 162 | 163 | if 'sigmoid' in config: 164 | metric_args['sigmoid'] = config.sigmoid 165 | 166 | if 'custom_threshold' in config: 167 | metric_args['custom_threshold'] = config.custom_threshold 168 | 169 | if config.test_dataset == 'pascal': 170 | 171 | loss_fn = get_attribute(train_config.loss) 172 | # assume that if no split is specified in train_config, test on all splits, 173 | 174 | if 'splits' in config: 175 | splits = config.splits 176 | else: 177 | if 'split' in train_config and type(train_config.split) == int: 178 | # unless train_config has a split set, in that case assume train mode in training 179 | splits = [train_config.split] 180 | assert train_config.mode == 'train' 181 | else: 182 | splits = [0,1,2,3] 183 | 184 | log.info('Test on these splits', splits) 185 | 186 | scores = dict() 187 | for split in splits: 188 | 189 | shift = config.shift if 'shift' in config else 0 190 | 191 | # automatic shift 192 | if shift == 'auto': 193 | shift_compute_t = time.time() 194 | shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac) 195 | log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s') 196 | 197 | dataset = get_cached_pascal_pfe(split, config) 198 | 199 | eval_start_t = time.time() 200 | 201 | loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False) 202 | 203 | assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1' 204 | 205 | metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args) 206 | 207 | with torch.no_grad(): 208 | 209 | i, losses = 0, [] 210 | for i_all, (data_x, data_y) in enumerate(loader): 211 | 212 | data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] 213 | data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] 214 | 215 | if config.mask == 'separate': # for old CondBase model 216 | pred, = model(data_x[0], data_x[1], data_x[2]) 217 | else: 218 | # assert config.mask in {'text', 'highlight'} 219 | pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) 220 | 221 | # loss = loss_fn(pred, data_y[0]) 222 | metric.add(pred.unsqueeze(1) + shift, data_y) 223 | 224 | # losses += [float(loss)] 225 | 226 | i += 1 227 | if config.max_iterations and i >= config.max_iterations: 228 | break 229 | 230 | #scores[split] = {m: s for m, s in zip(metric.names(), metric.value())} 231 | 232 | log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.') 233 | 234 | print(metric.value()['mean_iou_scores']) 235 | 236 | scores[split] = metric.scores() 237 | 238 | log.info(f'Completed split {split}') 239 | 240 | key_prefix = config['name'] if 'name' in config else 'pas' 241 | 242 | all_keys = set.intersection(*[set(v.keys()) for v in scores.values()]) 243 | 244 | valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())] 245 | 246 | return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}} 247 | 248 | 249 | if config.test_dataset == 'coco': 250 | from datasets.coco_wrapper import COCOWrapper 251 | 252 | coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask, 253 | with_class_label=True) 254 | 255 | log.info('Dataset length', len(coco_dataset)) 256 | loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False) 257 | 258 | metric = get_attribute(config.metric)(resize_pred=True, **metric_args) 259 | 260 | shift = config.shift if 'shift' in config else 0 261 | 262 | with torch.no_grad(): 263 | 264 | i, losses = 0, [] 265 | for i_all, (data_x, data_y) in enumerate(loader): 266 | data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] 267 | data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] 268 | 269 | if config.mask == 'separate': # for old CondBase model 270 | pred, = model(data_x[0], data_x[1], data_x[2]) 271 | else: 272 | # assert config.mask in {'text', 'highlight'} 273 | pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) 274 | 275 | metric.add([pred + shift], data_y) 276 | 277 | i += 1 278 | if config.max_iterations and i >= config.max_iterations: 279 | break 280 | 281 | key_prefix = config['name'] if 'name' in config else 'coco' 282 | return {key_prefix: metric.scores()} 283 | #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}} 284 | 285 | 286 | if config.test_dataset == 'phrasecut': 287 | from datasets.phrasecut import PhraseCut 288 | 289 | only_visual = config.only_visual is not None and config.only_visual 290 | with_visual = config.with_visual is not None and config.with_visual 291 | 292 | dataset = PhraseCut('test', 293 | image_size=train_config.image_size, 294 | mask=config.mask, 295 | with_visual=with_visual, only_visual=only_visual, aug_crop=False, 296 | aug_color=False) 297 | 298 | loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False) 299 | metric = get_attribute(config.metric)(resize_pred=True, **metric_args) 300 | 301 | shift = config.shift if 'shift' in config else 0 302 | 303 | 304 | with torch.no_grad(): 305 | 306 | i, losses = 0, [] 307 | for i_all, (data_x, data_y) in enumerate(loader): 308 | data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] 309 | data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] 310 | 311 | pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) 312 | metric.add([pred + shift], data_y) 313 | 314 | i += 1 315 | if config.max_iterations and i >= config.max_iterations: 316 | break 317 | 318 | key_prefix = config['name'] if 'name' in config else 'phrasecut' 319 | return {key_prefix: metric.scores()} 320 | #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}} 321 | 322 | if config.test_dataset == 'pascal_zs': 323 | from third_party.JoEm.model.metric import Evaluator 324 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC 325 | from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS 326 | 327 | from models.clipseg import CLIPSegMultiLabel 328 | 329 | n_unseen = train_config.remove_classes[1] 330 | 331 | pz = PascalZeroShot('val', n_unseen, image_size=352) 332 | m = CLIPSegMultiLabel(model=train_config.name).cuda() 333 | m.eval(); 334 | 335 | print(len(pz), n_unseen) 336 | print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set]) 337 | 338 | print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)]) 339 | print('seen', [VOC[i] for i in get_seen_idx(n_unseen)]) 340 | 341 | loader = DataLoader(pz, batch_size=8) 342 | evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen)) 343 | 344 | for i, (data_x, data_y) in enumerate(loader): 345 | pred = m(data_x[0].cuda()) 346 | evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy()) 347 | 348 | if config.max_iter is not None and i > config.max_iter: 349 | break 350 | 351 | scores = evaluator.Mean_Intersection_over_Union() 352 | key_prefix = config['name'] if 'name' in config else 'pas_zs' 353 | 354 | return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}} 355 | 356 | elif config.test_dataset in {'same_as_training', 'affordance'}: 357 | loss_fn = get_attribute(train_config.loss) 358 | 359 | metric_cls = get_attribute(config.metric) 360 | metric = metric_cls(**metric_args) 361 | 362 | if config.test_dataset == 'same_as_training': 363 | dataset_cls = get_attribute(train_config.dataset) 364 | elif config.test_dataset == 'affordance': 365 | dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance') 366 | dataset_name = 'aff' 367 | else: 368 | dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot') 369 | dataset_name = 'lvis' 370 | 371 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) 372 | 373 | dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation 374 | 375 | if model.__class__.__name__ == 'PFENetWrapper': 376 | dataset_args['image_size'] = config.image_size 377 | 378 | log.info('init dataset', str(dataset_cls)) 379 | dataset = dataset_cls(**dataset_args) 380 | 381 | log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}') 382 | 383 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle) 384 | 385 | # explicitly set prompts 386 | if config.prompt == 'plain': 387 | model.prompt_list = ['{}'] 388 | elif config.prompt == 'fixed': 389 | model.prompt_list = ['a photo of a {}.'] 390 | elif config.prompt == 'shuffle': 391 | model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] 392 | elif config.prompt == 'shuffle_clip': 393 | from models.clip_prompts import imagenet_templates 394 | model.prompt_list = imagenet_templates 395 | 396 | config.assume_no_unused_keys(exceptions=['max_iterations']) 397 | 398 | t_start = time.time() 399 | 400 | with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9) 401 | i, losses = 0, [] 402 | for data_x, data_y in data_loader: 403 | 404 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] 405 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] 406 | 407 | if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}: 408 | pred, = model(data_x[0], data_x[1], data_x[2]) 409 | visual_q = None 410 | else: 411 | pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True) 412 | 413 | loss = loss_fn(pred, data_y[0]) 414 | 415 | metric.add([pred], data_y) 416 | 417 | losses += [float(loss)] 418 | 419 | i += 1 420 | if config.max_iterations and i >= config.max_iterations: 421 | break 422 | 423 | # scores = {m: s for m, s in zip(metric.names(), metric.value())} 424 | scores = metric.scores() 425 | 426 | keys = set(scores.keys()) 427 | if dataset.negative_prob > 0 and 'mIoU' in keys: 428 | keys.remove('mIoU') 429 | 430 | name_mask = dataset.mask.replace('text_label', 'txt')[:3] 431 | name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob) 432 | 433 | score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}' 434 | 435 | scores = {score_name: {k: v for k,v in scores.items() if k in keys}} 436 | scores[score_name].update({'test_loss': np.mean(losses)}) 437 | 438 | log.info(f'Evaluation took {time.time() - t_start:.1f}s') 439 | 440 | return scores 441 | else: 442 | raise ValueError('invalid test dataset') 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | if __name__ == '__main__': 453 | main() -------------------------------------------------------------------------------- /repositories/clipseg/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r", encoding="utf-8") as readme_file: 4 | readme = readme_file.read() 5 | 6 | requirements = [ 7 | "numpy", 8 | "scipy", 9 | "matplotlib", 10 | "torch", 11 | "torchvision", 12 | "opencv-python", 13 | "CLIP @ git+https://github.com/openai/CLIP.git" 14 | ] 15 | 16 | setup( 17 | name='clipseg', 18 | packages=['clipseg'], 19 | package_dir={'clipseg': 'models'}, 20 | package_data={'clipseg': [ 21 | "../weights/*.pth", 22 | ]}, 23 | version='0.0.1', 24 | url='https://github.com/timojl/clipseg', 25 | python_requires='>=3.9', 26 | install_requires=requirements, 27 | description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".', 28 | long_description=readme, 29 | long_description_content_type="text/markdown", 30 | ) 31 | -------------------------------------------------------------------------------- /repositories/clipseg/supplementary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThereforeGames/txt2mask/cebd3962cadc16b953048e9d473f0f36cdd8e5df/repositories/clipseg/supplementary.pdf -------------------------------------------------------------------------------- /repositories/clipseg/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | import json 4 | import yaml 5 | import math 6 | import os 7 | import sys 8 | 9 | from general_utils import log 10 | 11 | import numpy as np 12 | from functools import partial 13 | from os.path import expanduser, join, isfile, basename 14 | 15 | from torch.cuda.amp import autocast, GradScaler 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from contextlib import nullcontext 18 | from torch.utils.data import DataLoader 19 | 20 | from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args 21 | 22 | 23 | def cosine_warmup_lr(i, warmup=10, max_iter=90): 24 | """ Cosine LR with Warmup """ 25 | if i < warmup: 26 | return (i+1)/(warmup+1) 27 | else: 28 | return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup)))) 29 | 30 | 31 | def validate(model, dataset, config): 32 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False) 33 | 34 | metric_class, use_metric = config.val_metric_class, config.use_val_metric 35 | loss_fn = get_attribute(config.loss) 36 | 37 | model.eval() 38 | model.cuda() 39 | 40 | if metric_class is not None: 41 | metric = get_attribute(metric_class)() 42 | 43 | with torch.no_grad(): 44 | 45 | i, losses = 0, [] 46 | for data_x, data_y in data_loader: 47 | 48 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] 49 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] 50 | 51 | prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',)) 52 | pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True) 53 | 54 | if metric_class is not None: 55 | metric.add([pred], data_y) 56 | 57 | # pred = model(data_x[0], prompts) 58 | # loss = loss_fn(pred[0], data_y[0]) 59 | loss = loss_fn(pred, data_y[0]) 60 | losses += [float(loss)] 61 | 62 | i += 1 63 | 64 | if config.val_max_iterations is not None and i > config.val_max_iterations: 65 | break 66 | 67 | if use_metric is None: 68 | return np.mean(losses), {}, False 69 | else: 70 | metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {} 71 | return np.mean(losses), metric_scores, True 72 | 73 | 74 | def main(): 75 | 76 | config = training_config_from_cli_args() 77 | 78 | val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf') 79 | 80 | model_cls = get_attribute(config.model) 81 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 82 | model = model_cls(**model_args).cuda() 83 | 84 | dataset_cls = get_attribute(config.dataset) 85 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) 86 | 87 | dataset = dataset_cls(**dataset_args) 88 | 89 | log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})') 90 | 91 | if val_interval is not None: 92 | dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'} 93 | _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters) 94 | print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 95 | 96 | dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 97 | 98 | # optimizer 99 | opt_cls = get_attribute(config.optimizer) 100 | if config.optimize == 'torch.optim.SGD': 101 | opt_args = {'momentum': config.momentum if 'momentum' in config else 0} 102 | else: 103 | opt_args = {} 104 | opt = opt_cls(model.parameters(), lr=config.lr, **opt_args) 105 | 106 | if config.lr_scheduler == 'cosine': 107 | assert config.T_max is not None and config.eta_min is not None 108 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min) 109 | elif config.lr_scheduler == 'warmup_cosine': 110 | lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup)) 111 | else: 112 | lr_scheduler = None 113 | 114 | batch_size, max_iterations = config.batch_size, config.max_iterations 115 | 116 | loss_fn = get_attribute(config.loss) 117 | 118 | if config.amp: 119 | log.info('Using AMP') 120 | autocast_fn = autocast 121 | scaler = GradScaler() 122 | else: 123 | autocast_fn, scaler = nullcontext, None 124 | 125 | 126 | save_only_trainable = True 127 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4) 128 | 129 | # disable config when hyperparam. opt. to avoid writing logs. 130 | tracker_config = config if not config.hyperparameter_optimization else None 131 | 132 | with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger: 133 | 134 | i = 0 135 | while True: 136 | for data_x, data_y in data_loader: 137 | 138 | # between caption and output feature. 139 | # 1. Sample random captions 140 | # 2. Check alignment with CLIP 141 | 142 | # randomly mix text and visual support conditionals 143 | if config.mix: 144 | 145 | assert config.mask.startswith('text_and') 146 | 147 | with autocast_fn(): 148 | # data_x[1] = text label 149 | prompts = model.sample_prompts(data_x[1]) 150 | 151 | # model.clip_model() 152 | 153 | text_cond = model.compute_conditional(prompts) 154 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 155 | # when mask=='separate' 156 | visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda()) 157 | else: 158 | # data_x[2] = visual prompt 159 | visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda()) 160 | 161 | max_txt = config.mix_text_max if config.mix_text_max is not None else 1 162 | batch_size = text_cond.shape[0] 163 | 164 | # sample weights for each element in batch 165 | text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None] 166 | text_weights = text_weights.cuda() 167 | 168 | if dataset.__class__.__name__ == 'PhraseCut': 169 | # give full weight to text where support_image is invalid 170 | visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3] 171 | text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1) 172 | 173 | cond = text_cond * text_weights + visual_s_cond * (1 - text_weights) 174 | 175 | else: 176 | # no mix 177 | 178 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 179 | # compute conditional vector using CLIP masking 180 | with autocast_fn(): 181 | assert config.mask == 'separate' 182 | cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda()) 183 | else: 184 | cond = data_x[1] 185 | if isinstance(cond, torch.Tensor): 186 | cond = cond.cuda() 187 | 188 | with autocast_fn(): 189 | visual_q = None 190 | 191 | pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True) 192 | 193 | loss = loss_fn(pred, data_y[0].cuda()) 194 | 195 | if torch.isnan(loss) or torch.isinf(loss): 196 | # skip if loss is nan 197 | log.warning('Training stopped due to inf/nan loss.') 198 | sys.exit(-1) 199 | 200 | extra_loss = 0 201 | loss += extra_loss 202 | 203 | opt.zero_grad() 204 | 205 | if scaler is None: 206 | loss.backward() 207 | opt.step() 208 | else: 209 | scaler.scale(loss).backward() 210 | scaler.step(opt) 211 | scaler.update() 212 | 213 | if lr_scheduler is not None: 214 | lr_scheduler.step() 215 | if i % 2000 == 0: 216 | current_lr = [g['lr'] for g in opt.param_groups][0] 217 | log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)') 218 | 219 | logger.iter(i=i, loss=loss) 220 | i += 1 221 | 222 | if i >= max_iterations: 223 | 224 | if not isfile(join(logger.base_path, 'weights.pth')): 225 | # only write if no weights were already written 226 | logger.save_weights(only_trainable=save_only_trainable) 227 | 228 | sys.exit(0) 229 | 230 | 231 | if config.checkpoint_iterations is not None and i in config.checkpoint_iterations: 232 | logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth') 233 | 234 | 235 | if val_interval is not None and i % val_interval == val_interval - 1: 236 | 237 | val_loss, val_scores, maximize = validate(model, dataset_val, config) 238 | 239 | if len(val_scores) > 0: 240 | 241 | score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items()) 242 | 243 | if maximize and val_scores[config.use_val_metric] > best_val_score: 244 | logger.save_weights(only_trainable=save_only_trainable) 245 | best_val_score = val_scores[config.use_val_metric] 246 | 247 | elif not maximize and val_scores[config.use_val_metric] < best_val_score: 248 | logger.save_weights(only_trainable=save_only_trainable) 249 | best_val_score = val_scores[config.use_val_metric] 250 | 251 | else: 252 | score_str = '' 253 | # if no score is used, fall back to loss 254 | if val_loss < best_val_loss: 255 | logger.save_weights(only_trainable=save_only_trainable) 256 | best_val_loss = val_loss 257 | 258 | log.info(f'Validation loss: {val_loss}' + score_str) 259 | logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores) 260 | model.train() 261 | 262 | print('epoch complete') 263 | 264 | 265 | if __name__ == '__main__': 266 | main() -------------------------------------------------------------------------------- /scripts/txt2mask.py: -------------------------------------------------------------------------------- 1 | # Author: Therefore Games 2 | # https://github.com/ThereforeGames/txt2img2img 3 | 4 | import modules.scripts as scripts 5 | import gradio as gr 6 | 7 | from modules import processing, images, shared, sd_samplers 8 | from modules.processing import process_images, Processed 9 | from modules.shared import opts, cmd_opts, state, Options 10 | 11 | import torch 12 | import cv2 13 | import requests 14 | import os.path 15 | 16 | from repositories.clipseg.models.clipseg import CLIPDensePredT 17 | from PIL import ImageChops, Image, ImageOps 18 | from torchvision import transforms 19 | from matplotlib import pyplot as plt 20 | import numpy 21 | 22 | debug = False 23 | 24 | class Script(scripts.Script): 25 | def title(self): 26 | return "txt2mask v0.1.1" 27 | 28 | def show(self, is_img2img): 29 | return is_img2img 30 | 31 | def ui(self, is_img2img): 32 | if not is_img2img: 33 | return None 34 | 35 | mask_prompt = gr.Textbox(label="Mask prompt", lines=1) 36 | negative_mask_prompt = gr.Textbox(label="Negative mask prompt", lines=1) 37 | mask_precision = gr.Slider(label="Mask precision", minimum=0.0, maximum=255.0, step=1.0, value=100.0) 38 | mask_padding = gr.Slider(label="Mask padding", minimum=0.0, maximum=500.0, step=1.0, value=0.0) 39 | brush_mask_mode = gr.Radio(label="Brush mask mode", choices=['discard','add','subtract'], value='discard', type="index", visible=False) 40 | mask_output = gr.Checkbox(label="Show mask in output?",value=True) 41 | 42 | plug = gr.HTML(label="plug",value='

If you like my work, please consider showing your support on Patreon. Thank you! ❤

') 43 | 44 | return [mask_prompt,negative_mask_prompt, mask_precision, mask_padding, brush_mask_mode, mask_output, plug] 45 | 46 | def run(self, p, mask_prompt, negative_mask_prompt, mask_precision, mask_padding, brush_mask_mode, mask_output, plug): 47 | def download_file(filename, url): 48 | with open(filename, 'wb') as fout: 49 | response = requests.get(url, stream=True) 50 | response.raise_for_status() 51 | # Write response data to file 52 | for block in response.iter_content(4096): 53 | fout.write(block) 54 | def pil_to_cv2(img): 55 | return (cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)) 56 | def gray_to_pil(img): 57 | return (Image.fromarray(cv2.cvtColor(img,cv2.COLOR_GRAY2RGBA))) 58 | 59 | def center_crop(img,new_width,new_height): 60 | width, height = img.size # Get dimensions 61 | 62 | left = (width - new_width)/2 63 | top = (height - new_height)/2 64 | right = (width + new_width)/2 65 | bottom = (height + new_height)/2 66 | 67 | # Crop the center of the image 68 | return(img.crop((left, top, right, bottom))) 69 | 70 | def overlay_mask_part(img_a,img_b,mode): 71 | if (mode == 0): 72 | img_a = ImageChops.darker(img_a, img_b) 73 | else: img_a = ImageChops.lighter(img_a, img_b) 74 | return(img_a) 75 | 76 | def process_mask_parts(these_preds,these_prompt_parts,mode,final_img = None): 77 | for i in range(these_prompt_parts): 78 | filename = f"mask_{mode}_{i}.png" 79 | plt.imsave(filename,torch.sigmoid(these_preds[i][0])) 80 | 81 | # TODO: Figure out how to convert the plot above to numpy instead of re-loading image 82 | img = cv2.imread(filename) 83 | gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 84 | (thresh, bw_image) = cv2.threshold(gray_image, mask_precision, 255, cv2.THRESH_BINARY) 85 | 86 | if (mode == 0): bw_image = numpy.invert(bw_image) 87 | 88 | if (debug): 89 | print(f"bw_image: {bw_image}") 90 | print(f"final_img: {final_img}") 91 | 92 | # overlay mask parts 93 | bw_image = gray_to_pil(bw_image) 94 | if (i > 0 or final_img is not None): 95 | bw_image = overlay_mask_part(bw_image,final_img,mode) 96 | 97 | # For debugging only: 98 | if (debug): bw_image.save(f"processed_{filename}") 99 | 100 | final_img = bw_image 101 | 102 | return(final_img) 103 | 104 | def get_mask(): 105 | # load model 106 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64) 107 | model.eval(); 108 | model_dir = "./repositories/clipseg/weights" 109 | os.makedirs(model_dir, exist_ok=True) 110 | d64_file = f"{model_dir}/rd64-uni.pth" 111 | d16_file = f"{model_dir}/rd16-uni.pth" 112 | delimiter_string = "|" 113 | 114 | # Download model weights if we don't have them yet 115 | if not os.path.exists(d64_file): 116 | print("Downloading clipseg model weights...") 117 | download_file(d64_file,"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files=rd64-uni.pth") 118 | download_file(d16_file,"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files=rd16-uni.pth") 119 | # Mirror: 120 | # https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth 121 | # https://github.com/timojl/clipseg/raw/master/weights/rd16-uni.pth 122 | 123 | # non-strict, because we only stored decoder weights (not CLIP weights) 124 | model.load_state_dict(torch.load(d64_file, map_location=torch.device('cuda')), strict=False); 125 | 126 | transform = transforms.Compose([ 127 | transforms.ToTensor(), 128 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 129 | transforms.Resize((512, 512)), 130 | ]) 131 | img = transform(p.init_images[0]).unsqueeze(0) 132 | 133 | prompts = mask_prompt.split(delimiter_string) 134 | prompt_parts = len(prompts) 135 | negative_prompts = negative_mask_prompt.split(delimiter_string) 136 | negative_prompt_parts = len(negative_prompts) 137 | 138 | # predict 139 | with torch.no_grad(): 140 | preds = model(img.repeat(prompt_parts,1,1,1), prompts)[0] 141 | negative_preds = model(img.repeat(negative_prompt_parts,1,1,1), negative_prompts)[0] 142 | 143 | #tests 144 | if (debug): 145 | print("Check initial mask vars before processing...") 146 | print(f"p.image_mask: {p.image_mask}") 147 | print(f"p.latent_mask: {p.latent_mask}") 148 | print(f"p.mask_for_overlay: {p.mask_for_overlay}") 149 | 150 | if (brush_mask_mode == 1 and p.image_mask is not None): 151 | final_img = p.image_mask.convert("RGBA") 152 | else: final_img = None 153 | 154 | # process masking 155 | final_img = process_mask_parts(preds,prompt_parts,1,final_img) 156 | 157 | # process negative masking 158 | if (brush_mask_mode == 2 and p.image_mask is not None): 159 | p.image_mask = ImageOps.invert(p.image_mask) 160 | p.image_mask = p.image_mask.convert("RGBA") 161 | final_img = overlay_mask_part(final_img,p.image_mask,0) 162 | if (negative_mask_prompt): final_img = process_mask_parts(negative_preds,negative_prompt_parts,0,final_img) 163 | 164 | # Increase mask size with padding 165 | if (mask_padding > 0): 166 | aspect_ratio = p.init_images[0].width / p.init_images[0].height 167 | new_width = p.init_images[0].width+mask_padding*2 168 | new_height = round(new_width / aspect_ratio) 169 | final_img = final_img.resize((new_width,new_height)) 170 | final_img = center_crop(final_img,p.init_images[0].width,p.init_images[0].height) 171 | 172 | return (final_img) 173 | 174 | 175 | # Set up processor parameters correctly 176 | p.mode = 1 177 | p.mask_mode = 1 178 | p.image_mask = get_mask().resize((p.init_images[0].width,p.init_images[0].height)) 179 | p.mask_for_overlay = p.image_mask 180 | p.latent_mask = None # fixes inpainting full resolution 181 | 182 | 183 | processed = processing.process_images(p) 184 | 185 | if (mask_output): 186 | processed.images.append(p.image_mask) 187 | 188 | return processed --------------------------------------------------------------------------------