├── README.md ├── checkpoint_handler.py ├── clipseg ├── LICENSE ├── Quickstart.ipynb ├── README.md ├── Tables.ipynb ├── Visual_Feature_Engineering.ipynb ├── clip_masking_lvis_image_ids.yml ├── datasets │ ├── coco_wrapper.py │ ├── pascal_classes.json │ ├── pascal_zeroshot.py │ ├── pfe_dataset.py │ ├── phrasecut.py │ └── utils.py ├── environment.yml ├── evaluation_utils.py ├── example_image.jpg ├── experiments │ ├── ablation.yaml │ ├── coco.yaml │ ├── pascal_0shot.yaml │ ├── pascal_1shot.yaml │ └── phrasecut.yaml ├── general_utils.py ├── metrics.py ├── models │ ├── clipseg.py │ └── vitseg.py ├── mycode.py ├── overview.png ├── sample_rd64.png ├── sample_rd64_refined.png ├── score.py ├── setup.py ├── training.py └── weights │ ├── rd16-uni.pth │ ├── rd64-uni-refined.pth │ └── rd64-uni.pth ├── constants.py ├── data └── person_1 │ └── person_1.jpg ├── environment ├── environment.yaml └── requirements.txt ├── inference_text.txt ├── input_configs ├── inference.yaml └── train.yaml ├── models ├── __init__.py ├── clip_prior.py ├── clip_text_embedding.py ├── clip_text_encoder.py ├── mapper.py ├── positional_encoding.py └── xti_attention_processor.py ├── prompt_manager.py ├── scripts ├── __init__.py ├── inference.py ├── seg.py └── train.py ├── sd_pipeline_call.py ├── training ├── __init__.py ├── coach_2.py ├── config.py ├── dataset.py ├── logger.py └── validate.py └── utils ├── __init__.py ├── types.py └── vis_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # PersonaMagic (AAAI 2025) 2 | 3 | 🚧 Work in Progress - Feature not finished yet 🚧 4 | 5 | -------------------------------------------------------------------------------- /checkpoint_handler.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple 3 | 4 | import pyrallis 5 | import torch 6 | from accelerate import Accelerator 7 | from torch import nn 8 | from transformers import CLIPTokenizer 9 | 10 | from models.clip_text_encoder import PersonaCLIPTextModel 11 | from models.mapper import Mapper 12 | from models.positional_encoding import BasicEncoder, TimePositionalEncoding 13 | from training.config import RunConfig 14 | 15 | 16 | class CheckpointHandler: 17 | 18 | def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path): 19 | self.cfg = cfg 20 | self.placeholder_token_string = placeholder_token_string 21 | self.placeholder_token_id = placeholder_token_id 22 | self.save_root = save_root 23 | 24 | def save_model(self, text_encoder: PersonaCLIPTextModel, 25 | accelerator: Accelerator, 26 | embeds_save_name: str, 27 | mapper_save_name: str): 28 | self.save_learned_embeds(text_encoder, accelerator, embeds_save_name) 29 | self.save_mapper(text_encoder, mapper_save_name) 30 | 31 | def save_learned_embeds(self, text_encoder: PersonaCLIPTextModel, accelerator: Accelerator, save_name: str): 32 | """ 33 | Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference 34 | to take the place of our placeholder token. 35 | """ 36 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id] 37 | learned_embeds = learned_embeds.detach().cpu() 38 | learned_embeds_dict = {} 39 | for i in range(len(self.placeholder_token_string)): 40 | learned_embeds_dict[self.placeholder_token_string[i]] = learned_embeds[i] 41 | torch.save(learned_embeds_dict, self.save_root / save_name) 42 | 43 | def save_mapper(self, text_encoder: PersonaCLIPTextModel, save_name: str): 44 | """ Save the mapper and config to be used at inference. """ 45 | cfg_ = RunConfig(**self.cfg.__dict__.copy()) 46 | state_dict = { 47 | "state_dict": text_encoder.text_model.embeddings.mapper.state_dict(), 48 | "cfg": pyrallis.encode(cfg_), 49 | "encoder": text_encoder.text_model.embeddings.mapper.encoder 50 | } 51 | torch.save(state_dict, self.save_root / save_name) 52 | 53 | 54 | @staticmethod 55 | def load_my_mapper(mapper_path: Path) -> Tuple[RunConfig, Mapper]: 56 | mapper_ckpt = torch.load(mapper_path, map_location="cpu") 57 | cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg']) 58 | neti_mapper = Mapper(output_dim=768, 59 | norm_scale=cfg.model.target_norm, 60 | use_positional_encoding=cfg.model.use_positional_encoding, 61 | num_pe_time_anchors=cfg.model.num_pe_time_anchors, 62 | sigma_t=cfg.model.sigma_t, 63 | output_bypass=cfg.model.output_bypass, 64 | token_num=4) 65 | neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True) 66 | encoder = mapper_ckpt['encoder'] 67 | if isinstance(encoder, TimePositionalEncoding): 68 | encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda()) 69 | elif isinstance(encoder, BasicEncoder): 70 | encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda() 71 | encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda() 72 | neti_mapper.encoder = encoder.cuda() 73 | neti_mapper.cuda() 74 | neti_mapper.eval() 75 | return cfg, neti_mapper 76 | 77 | @staticmethod 78 | def load_learned_embed_in_clip(learned_embeds_path: Path, 79 | text_encoder: PersonaCLIPTextModel, 80 | tokenizer: CLIPTokenizer) -> Tuple[str, int]: 81 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") 82 | 83 | # separate token and the embeds 84 | trained_tokens = list(loaded_learned_embeds.keys()) 85 | embeds = list(loaded_learned_embeds.values()) 86 | 87 | # cast to dtype of text_encoder 88 | dtype = text_encoder.get_input_embeddings().weight.dtype 89 | embeds = [e.to(dtype) for e in embeds] 90 | 91 | # add the tokens in tokenizer 92 | num_added_tokens = tokenizer.add_tokens(trained_tokens) 93 | if num_added_tokens == 0: 94 | raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. " 95 | f"Please pass a different `token` that is not already in the tokenizer.") 96 | 97 | # resize the token embeddings 98 | text_encoder.resize_token_embeddings(len(tokenizer)) 99 | 100 | # get the id for the token and assign the embeds 101 | placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens] 102 | 103 | for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)): 104 | text_encoder.get_input_embeddings().weight.data[token_id] = embed 105 | 106 | # assert len(trained_tokens) == 1, "Only one placeholder token is supported" 107 | # placeholder_token = trained_tokens[0] 108 | # placeholder_token_id = placeholder_token_ids[0] 109 | placeholder_token = trained_tokens 110 | placeholder_token_id = placeholder_token_ids 111 | return placeholder_token, placeholder_token_id 112 | -------------------------------------------------------------------------------- /clipseg/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | This license does not apply to the model weights. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /clipseg/Quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import requests\n", 11 | "\n", 12 | "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n", 13 | "! unzip -d weights -j weights.zip\n", 14 | "from models.clipseg import CLIPDensePredT\n", 15 | "from PIL import Image\n", 16 | "from torchvision import transforms\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "\n", 19 | "# load model\n", 20 | "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n", 21 | "model.eval();\n", 22 | "\n", 23 | "# non-strict, because we only stored decoder weights (not CLIP weights)\n", 24 | "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "Load and normalize `example_image.jpg`. You can also load through an URL." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# load and normalize image\n", 41 | "input_image = Image.open('example_image.jpg')\n", 42 | "\n", 43 | "# or load from URL...\n", 44 | "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n", 45 | "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n", 46 | "\n", 47 | "transform = transforms.Compose([\n", 48 | " transforms.ToTensor(),\n", 49 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 50 | " transforms.Resize((352, 352)),\n", 51 | "])\n", 52 | "img = transform(input_image).unsqueeze(0)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "Predict and visualize (this might take a few seconds if running without GPU support)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n", 69 | "\n", 70 | "# predict\n", 71 | "with torch.no_grad():\n", 72 | " preds = model(img.repeat(4,1,1,1), prompts)[0]\n", 73 | "\n", 74 | "# visualize prediction\n", 75 | "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n", 76 | "[a.axis('off') for a in ax.flatten()]\n", 77 | "ax[0].imshow(input_image)\n", 78 | "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n", 79 | "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];" 80 | ] 81 | } 82 | ], 83 | "metadata": { 84 | "interpreter": { 85 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" 86 | }, 87 | "kernelspec": { 88 | "display_name": "Python 3", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 3 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython3", 102 | "version": "3.8.10" 103 | } 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 4 107 | } 108 | -------------------------------------------------------------------------------- /clipseg/README.md: -------------------------------------------------------------------------------- 1 | # Image Segmentation Using Text and Image Prompts 2 | This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003). 3 | 4 | **November 2022:** CLIPSeg has been integrated into the [HuggingFace Transformers library](https://huggingface.co/docs/transformers/main/en/model_doc/clipseg). Thank you, [NielsRogge](https://github.com/NielsRogge)! 5 | **September 2022:** We released new weights for fine-grained predictions (see below for details). 6 | **March 2022:** The Paper has been accepted to CVPR 2022! 7 | 8 | 9 | drawing 10 | 11 | The systems allows to create segmentation models without training based on: 12 | - An arbitrary text query 13 | - Or an image with a mask highlighting stuff or an object. 14 | 15 | ### Quick Start 16 | 17 | In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension. 18 | It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb) 19 | (please note that the VM does not use a GPU, thus inference takes a few seconds). 20 | 21 | 22 | ### Dependencies 23 | This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`). 24 | Additional dependencies are hidden for double blind review. 25 | 26 | 27 | ### Datasets 28 | 29 | * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset 30 | * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation 31 | * `PascalZeroShot`: Wrapper class for PascalZeroShot 32 | * `COCOWrapper`: Wrapper class for COCO. 33 | 34 | ### Models 35 | 36 | * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder. 37 | * `ViTDensePredT`: CLIPSeg model with transformer-based decoder. 38 | 39 | ### Third Party Dependencies 40 | For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder. 41 | ```bash 42 | git clone https://github.com/cvlab-yonsei/JoEm 43 | git clone https://github.com/Jia-Research-Lab/PFENet.git 44 | git clone https://github.com/ChenyunWu/PhraseCutDataset.git 45 | git clone https://github.com/juhongm999/hsnet.git 46 | ``` 47 | 48 | ### Weights 49 | 50 | The MIT license does not apply to these weights. 51 | 52 | We provide three model weights, for D=64 (2x, ~4MB each) and D=16 (~1MB). 53 | ``` 54 | wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip 55 | unzip -d weights -j weights.zip 56 | ``` 57 | 58 | #### New Fine-grained Weights 59 | We introduced a more complex module for transforming tokens into predictions that allow for more refined predictions (in contrast to the square-like predictions of other weights). Corresponding weights are available in the weight download above called `rd64-uni-refined.pth`. 60 | They can be loaded by: 61 | ```python 62 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True) 63 | model.load_state_dict(torch.load('weights/rd64-uni-refined.pth'), strict=False) 64 | ``` 65 | 66 | See below for a direct comparison of the new fine-grained weights (top) and the old weights (below). 67 | drawing 68 | drawing 69 | 70 | 71 | 72 | ### Training and Evaluation 73 | 74 | To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`. 75 | 76 | For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`. 77 | 78 | 79 | ### Usage of PFENet Wrappers 80 | 81 | In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder. 82 | `git clone https://github.com/Jia-Research-Lab/PFENet.git ` 83 | 84 | 85 | ### License 86 | 87 | The source code files in this repository (excluding model weights) are released under MIT license. 88 | 89 | ### Citation 90 | ``` 91 | @InProceedings{lueddecke22_cvpr, 92 | author = {L\"uddecke, Timo and Ecker, Alexander}, 93 | title = {Image Segmentation Using Text and Image Prompts}, 94 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 95 | month = {June}, 96 | year = {2022}, 97 | pages = {7086-7096} 98 | } 99 | 100 | ``` 101 | -------------------------------------------------------------------------------- /clipseg/Tables.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import clip\n", 13 | "from evaluation_utils import norm, denorm\n", 14 | "from general_utils import *\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# PhraseCut" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n", 49 | "tab1 = pc[['name'] + cols]\n", 50 | "for k in cols:\n", 51 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 52 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", 53 | "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n", 54 | "print(tab1.to_latex(header=False, index=False))" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "For 0.1 threshold" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n", 71 | "tab1 = pc[['name'] + cols]\n", 72 | "for k in cols:\n", 73 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 74 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", 75 | "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n", 76 | "print(tab1.to_latex(header=False, index=False))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "# One-shot" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### Pascal" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 118 | "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n", 119 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 120 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 121 | "\n", 122 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 123 | "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n", 124 | "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 125 | "\n", 126 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", 127 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", 128 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "#### Pascal Zero-shot (in one-shot setting)\n", 136 | "\n", 137 | "Using the same setting as one-shot (hence different from the other zero-shot benchmark)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 147 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", 148 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 149 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 150 | "\n", 151 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 152 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", 153 | "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 154 | "\n", 155 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", 156 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", 157 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "# without fixed thresholds...\n", 167 | "\n", 168 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", 169 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", 170 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", 171 | "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 172 | "\n", 173 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", 174 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", 175 | "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "### COCO" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n", 201 | "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n", 202 | "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n", 203 | "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n", 204 | "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", 205 | "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n", 206 | "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "# Zero-shot" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "\n", 232 | "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n", 233 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n", 234 | "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n", 235 | "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "# Ablation" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n", 261 | "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n", 262 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", 263 | "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "# Generalization" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "generalization = experiment('experiments/generalize.yaml').dataframe()" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "print(\n", 316 | " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n", 317 | " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n", 318 | " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n", 319 | " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n", 320 | ")" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "interpreter": { 326 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" 327 | }, 328 | "kernelspec": { 329 | "display_name": "env2", 330 | "language": "python", 331 | "name": "env2" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.8.8" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 4 348 | } 349 | -------------------------------------------------------------------------------- /clipseg/Visual_Feature_Engineering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Systematic" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "\n", 19 | "import clip\n", 20 | "from evaluation_utils import norm, denorm\n", 21 | "from general_utils import *\n", 22 | "from datasets.lvis_oneshot3 import LVIS_OneShot3\n", 23 | "\n", 24 | "clip_device = 'cuda'\n", 25 | "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n", 26 | "clip_model.eval();\n", 27 | "\n", 28 | "from models.clipseg import CLIPDensePredTMasked\n", 29 | "\n", 30 | "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n", 31 | "clip_mask_model.eval();" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n", 41 | " text_class_labels=True, image_size=352, min_area=0.1,\n", 42 | " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "plot_data(lvis)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "from collections import defaultdict\n", 61 | "import json\n", 62 | "\n", 63 | "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n", 64 | "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n", 65 | "\n", 66 | "objects_per_image = defaultdict(lambda : set())\n", 67 | "for ann in lvis_raw['annotations']:\n", 68 | " objects_per_image[ann['image_id']].add(ann['category_id'])\n", 69 | " \n", 70 | "for ann in lvis_val_raw['annotations']:\n", 71 | " objects_per_image[ann['image_id']].add(ann['category_id']) \n", 72 | " \n", 73 | "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n", 74 | "\n", 75 | "del lvis_raw, lvis_val_raw" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "#bs = 32\n", 85 | "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "from general_utils import get_batch\n", 95 | "from functools import partial\n", 96 | "from evaluation_utils import img_preprocess\n", 97 | "import torch\n", 98 | "\n", 99 | "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n", 100 | "\n", 101 | " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n", 102 | "\n", 103 | " all_prompts = []\n", 104 | " \n", 105 | " with torch.no_grad():\n", 106 | " valid_sims = []\n", 107 | " torch.manual_seed(571)\n", 108 | " \n", 109 | " if type(batches_or_dataset) == list:\n", 110 | " loader = batches_or_dataset # already loaded\n", 111 | " max_iter = float('inf')\n", 112 | " else:\n", 113 | " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n", 114 | " max_iter = 50\n", 115 | " \n", 116 | " global batch\n", 117 | " for i_batch, (batch, batch_y) in enumerate(loader):\n", 118 | " \n", 119 | " if i_batch >= max_iter: break\n", 120 | " \n", 121 | " processed_batch = process(batch)\n", 122 | " if type(processed_batch) == dict:\n", 123 | " \n", 124 | " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n", 125 | " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n", 126 | " else:\n", 127 | " processed_batch = process(batch).to(clip_device)\n", 128 | " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n", 129 | " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n", 130 | " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n", 131 | " \n", 132 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 133 | " bs = len(batch[0])\n", 134 | " for j in range(bs):\n", 135 | " \n", 136 | " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n", 137 | " support_image = basename(lvis.samples[c][sid])\n", 138 | " \n", 139 | " img_objs = [o for o in objects_per_image[int(support_image)]]\n", 140 | " img_objs = [o.replace('_', ' ') for o in img_objs]\n", 141 | " \n", 142 | " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n", 143 | " if o != batch_y[2][j]]\n", 144 | " \n", 145 | " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n", 146 | " all_prompts += [prompts]\n", 147 | " \n", 148 | " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n", 149 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n", 150 | "\n", 151 | " global logits\n", 152 | " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n", 153 | "\n", 154 | " global sim\n", 155 | " sim = torch.softmax(logits, dim=-1)\n", 156 | " \n", 157 | " valid_sims += [sim]\n", 158 | " \n", 159 | " #valid_sims = torch.stack(valid_sims)\n", 160 | " return valid_sims, all_prompts\n", 161 | " \n", 162 | "\n", 163 | "def new_img_preprocess(x):\n", 164 | " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n", 165 | " \n", 166 | "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n", 167 | "get_similarities(lvis, lambda x: x[1]);" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "preprocessing_functions = [\n", 177 | "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n", 178 | "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n", 179 | "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n", 180 | "# ['colorize object red', partial(img_preprocess, colorize=True)],\n", 181 | "# ['add red outline', partial(img_preprocess, outline=True)],\n", 182 | " \n", 183 | "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n", 184 | "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n", 185 | "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n", 186 | "# ['BG blur', partial(img_preprocess, blur=3)],\n", 187 | "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", 188 | " \n", 189 | "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n", 190 | "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n", 191 | " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n", 192 | " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", 193 | "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n", 194 | "]\n", 195 | "\n", 196 | "preprocessing_functions = preprocessing_functions\n", 197 | "\n", 198 | "base, base_p = get_similarities(lvis, lambda x: x[1])\n", 199 | "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "for j in range(1):\n", 218 | " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "from pandas import DataFrame\n", 228 | "tab = dict()\n", 229 | "for j, (name, _) in enumerate(preprocessing_functions):\n", 230 | " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n", 231 | " \n", 232 | " \n", 233 | "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) " 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "# Visual" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "from evaluation_utils import denorm, norm" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "def load_sample(filename, filename2):\n", 259 | " from os.path import join\n", 260 | " bp = expanduser('~/cloud/resources/sample_images')\n", 261 | " tf = transforms.Compose([\n", 262 | " transforms.ToTensor(),\n", 263 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 264 | " transforms.Resize(224),\n", 265 | " transforms.CenterCrop(224)\n", 266 | " ])\n", 267 | " tf2 = transforms.Compose([\n", 268 | " transforms.ToTensor(),\n", 269 | " transforms.Resize(224),\n", 270 | " transforms.CenterCrop(224)\n", 271 | " ])\n", 272 | " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n", 273 | " inp1[1] = inp1[1].unsqueeze(0)\n", 274 | " inp1[2] = inp1[2][:1] \n", 275 | " return inp1\n", 276 | "\n", 277 | "def all_preprocessing(inp1):\n", 278 | " return [\n", 279 | " img_preprocess(inp1),\n", 280 | " img_preprocess(inp1, colorize=True),\n", 281 | " img_preprocess(inp1, outline=True), \n", 282 | " img_preprocess(inp1, blur=3),\n", 283 | " img_preprocess(inp1, bg_fac=0.1),\n", 284 | " #img_preprocess(inp1, bg_fac=0.5),\n", 285 | " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n", 286 | " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n", 287 | " ]\n", 288 | "\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "from torchvision import transforms\n", 298 | "from PIL import Image\n", 299 | "from matplotlib import pyplot as plt\n", 300 | "from evaluation_utils import img_preprocess\n", 301 | "import clip\n", 302 | "\n", 303 | "images_queries = [\n", 304 | " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n", 305 | " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n", 306 | "]\n", 307 | "\n", 308 | "\n", 309 | "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n", 310 | "\n", 311 | "for j, (images, objects) in enumerate(images_queries):\n", 312 | " \n", 313 | " joint_image = all_preprocessing(images)\n", 314 | " \n", 315 | " joint_image = torch.stack(joint_image)[:,0]\n", 316 | " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n", 317 | " image_features = clip_model.encode_image(joint_image)\n", 318 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 319 | " \n", 320 | " prompts = [f'a photo of a {obj}'for obj in objects]\n", 321 | " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n", 322 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n", 323 | " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n", 324 | " sim = torch.softmax(logits, dim=-1).detach().cpu()\n", 325 | "\n", 326 | " for i, img in enumerate(joint_image):\n", 327 | " ax[2*j, i].axis('off')\n", 328 | " \n", 329 | " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n", 330 | " ax[2*j+ 1, i].grid(True)\n", 331 | " \n", 332 | " ax[2*j + 1, i].set_ylim(0,1)\n", 333 | " ax[2*j + 1, i].set_yticklabels([])\n", 334 | " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n", 335 | "# ax[1, i].set_xticklabels(objects, rotation=90)\n", 336 | " for k in range(len(sim[i])):\n", 337 | " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n", 338 | " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n", 339 | "\n", 340 | "plt.tight_layout()\n", 341 | "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')" 342 | ] 343 | } 344 | ], 345 | "metadata": { 346 | "kernelspec": { 347 | "display_name": "env2", 348 | "language": "python", 349 | "name": "env2" 350 | }, 351 | "language_info": { 352 | "codemirror_mode": { 353 | "name": "ipython", 354 | "version": 3 355 | }, 356 | "file_extension": ".py", 357 | "mimetype": "text/x-python", 358 | "name": "python", 359 | "nbconvert_exporter": "python", 360 | "pygments_lexer": "ipython3", 361 | "version": "3.8.8" 362 | } 363 | }, 364 | "nbformat": 4, 365 | "nbformat_minor": 4 366 | } 367 | -------------------------------------------------------------------------------- /clipseg/datasets/coco_wrapper.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from types import new_class 3 | import torch 4 | import numpy as np 5 | import os 6 | import json 7 | 8 | from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename 9 | from random import shuffle, seed as set_seed 10 | from PIL import Image 11 | 12 | from itertools import combinations 13 | from torchvision import transforms 14 | from torchvision.transforms.transforms import Resize 15 | 16 | from datasets.utils import blend_image_segmentation 17 | from general_utils import get_from_repository 18 | 19 | COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'} 20 | 21 | class COCOWrapper(object): 22 | 23 | def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0, 24 | with_class_label=False): 25 | super().__init__() 26 | 27 | self.mask = mask 28 | self.with_class_label = with_class_label 29 | self.negative_prob = negative_prob 30 | 31 | from third_party.hsnet.data.coco import DatasetCOCO 32 | 33 | get_from_repository('COCO-20i', ['COCO-20i.tar']) 34 | 35 | foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl') 36 | 37 | def build_img_metadata_classwise(self): 38 | with open(foldpath % (self.split, self.fold), 'rb') as f: 39 | img_metadata_classwise = pickle.load(f) 40 | return img_metadata_classwise 41 | 42 | 43 | DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise 44 | # DatasetCOCO.read_mask = read_mask 45 | 46 | mean = [0.485, 0.456, 0.406] 47 | std = [0.229, 0.224, 0.225] 48 | transform = transforms.Compose([ 49 | transforms.Resize((image_size, image_size)), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean, std) 52 | ]) 53 | 54 | self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False) 55 | 56 | self.all_classes = [self.coco.class_ids] 57 | self.coco.base_path = join(expanduser('~/datasets/COCO-20i')) 58 | 59 | def __len__(self): 60 | return len(self.coco) 61 | 62 | def __getitem__(self, i): 63 | sample = self.coco[i] 64 | 65 | label_name = COCO_CLASSES[int(sample['class_id'])] 66 | 67 | img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0] 68 | 69 | if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob: 70 | new_class_id = sample['class_id'] 71 | while new_class_id == sample['class_id']: 72 | sample2 = self.coco[torch.randint(0, len(self), (1,)).item()] 73 | new_class_id = sample2['class_id'] 74 | img_s = sample2['support_imgs'][0] 75 | seg_s = torch.zeros_like(seg_s) 76 | 77 | mask = self.mask 78 | if mask == 'separate': 79 | supp = (img_s, seg_s) 80 | elif mask == 'text_label': 81 | # DEPRECATED 82 | supp = [int(sample['class_id'])] 83 | elif mask == 'text': 84 | supp = [label_name] 85 | else: 86 | if mask.startswith('text_and_'): 87 | mask = mask[9:] 88 | label_add = [label_name] 89 | else: 90 | label_add = [] 91 | 92 | supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask) 93 | 94 | if self.with_class_label: 95 | label = (torch.zeros(0), sample['class_id'],) 96 | else: 97 | label = (torch.zeros(0), ) 98 | 99 | return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label -------------------------------------------------------------------------------- /clipseg/datasets/pascal_classes.json: -------------------------------------------------------------------------------- 1 | [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}] -------------------------------------------------------------------------------- /clipseg/datasets/pascal_zeroshot.py: -------------------------------------------------------------------------------- 1 | from os.path import expanduser 2 | import torch 3 | import json 4 | import torchvision 5 | from general_utils import get_from_repository 6 | from general_utils import log 7 | from torchvision import transforms 8 | 9 | PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], 10 | ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], 11 | ['chair.n.01', 'pot_plant.n.01']] 12 | 13 | 14 | class PascalZeroShot(object): 15 | 16 | def __init__(self, split, n_unseen, image_size=224) -> None: 17 | super().__init__() 18 | 19 | import sys 20 | sys.path.append('third_party/JoEm') 21 | from third_party.JoEm.data_loader.dataset import VOCSegmentation 22 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC 23 | 24 | self.pascal_classes = VOC 25 | self.image_size = image_size 26 | 27 | self.transform = transforms.Compose([ 28 | transforms.Resize((image_size, image_size)), 29 | ]) 30 | 31 | if split == 'train': 32 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), 33 | split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), 34 | ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) 35 | elif split == 'val': 36 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), 37 | split=split, transform=False, 38 | ignore_bg=False, ignore_unseen=False) 39 | 40 | self.unseen_idx = get_unseen_idx(n_unseen) 41 | 42 | def __len__(self): 43 | return len(self.voc) 44 | 45 | def __getitem__(self, i): 46 | 47 | sample = self.voc[i] 48 | label = sample['label'].long() 49 | all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] 50 | class_indices = [l for l in all_labels] 51 | class_names = [self.pascal_classes[l] for l in all_labels] 52 | 53 | image = self.transform(sample['image']) 54 | 55 | label = transforms.Resize((self.image_size, self.image_size), 56 | interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] 57 | 58 | return (image,), (label, ) 59 | 60 | 61 | -------------------------------------------------------------------------------- /clipseg/datasets/pfe_dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import expanduser 2 | import torch 3 | import json 4 | from general_utils import get_from_repository 5 | from datasets.utils import blend_image_segmentation 6 | from general_utils import log 7 | 8 | PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))} 9 | 10 | 11 | class PFEPascalWrapper(object): 12 | 13 | def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None): 14 | import sys 15 | # sys.path.append(expanduser('~/projects/new_one_shot')) 16 | from third_party.PFENet.util.dataset import SemData 17 | 18 | get_from_repository('PascalVOC2012', ['Pascal5i.tar']) 19 | 20 | self.p_negative = p_negative 21 | self.size = size 22 | self.mode = mode 23 | self.image_size = image_size 24 | 25 | if label_support in {True, False}: 26 | log.warning('label_support argument is deprecated. Use mask instead.') 27 | #raise ValueError() 28 | 29 | self.mask = mask 30 | 31 | value_scale = 255 32 | mean = [0.485, 0.456, 0.406] 33 | mean = [item * value_scale for item in mean] 34 | std = [0.229, 0.224, 0.225] 35 | std = [item * value_scale for item in std] 36 | 37 | import third_party.PFENet.util.transform as transform 38 | 39 | if mode == 'val': 40 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt') 41 | 42 | data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else [] 43 | data_transform += [ 44 | transform.ToTensor(), 45 | transform.Normalize(mean=mean, std=std) 46 | ] 47 | 48 | 49 | elif mode == 'train': 50 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt') 51 | 52 | assert image_size != 'original' 53 | 54 | data_transform = [ 55 | transform.RandScale([0.9, 1.1]), 56 | transform.RandRotate([-10, 10], padding=mean, ignore_label=255), 57 | transform.RandomGaussianBlur(), 58 | transform.RandomHorizontalFlip(), 59 | transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255), 60 | transform.ToTensor(), 61 | transform.Normalize(mean=mean, std=std) 62 | ] 63 | 64 | data_transform = transform.Compose(data_transform) 65 | 66 | self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'), 67 | data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False) 68 | 69 | self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list 70 | 71 | # verify that subcls_list always has length 1 72 | # assert len(set([len(d[4]) for d in self.dataset])) == 1 73 | 74 | print('actual length', len(self.dataset.data_list)) 75 | 76 | def __len__(self): 77 | if self.mode == 'val': 78 | return len(self.dataset.data_list) 79 | else: 80 | return len(self.dataset.data_list) 81 | 82 | def __getitem__(self, index): 83 | if self.dataset.mode == 'train': 84 | image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)] 85 | elif self.dataset.mode == 'val': 86 | image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)] 87 | ori_label = torch.from_numpy(ori_label).unsqueeze(0) 88 | 89 | if self.image_size != 'original': 90 | longerside = max(ori_label.size(1), ori_label.size(2)) 91 | backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255 92 | backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label 93 | label = backmask.clone().long() 94 | else: 95 | label = label.unsqueeze(0) 96 | 97 | # assert label.shape == (473, 473) 98 | 99 | if self.p_negative > 0: 100 | if torch.rand(1).item() < self.p_negative: 101 | while True: 102 | idx = torch.randint(0, len(self.dataset.data_list), (1,)).item() 103 | _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx] 104 | if subcls_list[0] != subcls_list_tmp[0]: 105 | break 106 | 107 | s_x = s_x[0] 108 | s_y = (s_y == 1)[0] 109 | label_fg = (label == 1).float() 110 | val_mask = (label != 255).float() 111 | 112 | class_id = self.class_list[subcls_list[0]] 113 | 114 | label_name = PASCAL_CLASSES[class_id][0] 115 | label_add = () 116 | mask = self.mask 117 | 118 | if mask == 'text': 119 | support = ('a photo of a ' + label_name + '.',) 120 | elif mask == 'separate': 121 | support = (s_x, s_y) 122 | else: 123 | if mask.startswith('text_and_'): 124 | label_add = (label_name,) 125 | mask = mask[9:] 126 | 127 | support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],) 128 | 129 | return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0]) 130 | -------------------------------------------------------------------------------- /clipseg/datasets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def blend_image_segmentation(img, seg, mode, image_size=224): 7 | 8 | 9 | if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}: 10 | if isinstance(img, np.ndarray): 11 | img = torch.from_numpy(img) 12 | 13 | if isinstance(seg, np.ndarray): 14 | seg = torch.from_numpy(seg) 15 | 16 | if mode == 'overlay': 17 | out = img * seg 18 | out = [out.astype('float32')] 19 | elif mode == 'highlight': 20 | out = img * seg[None, :, :] * 0.85 + 0.15 * img 21 | out = [out.astype('float32')] 22 | elif mode == 'highlight2': 23 | img = img / 2 24 | out = (img+0.1) * seg[None, :, :] + 0.3 * img 25 | out = [out.astype('float32')] 26 | elif mode == 'blur_highlight': 27 | from evaluation_utils import img_preprocess 28 | out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01] 29 | elif mode == 'blur3_highlight': 30 | from evaluation_utils import img_preprocess 31 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01] 32 | elif mode == 'blur3_highlight01': 33 | from evaluation_utils import img_preprocess 34 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01] 35 | elif mode == 'blur_highlight_random': 36 | from evaluation_utils import img_preprocess 37 | out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01] 38 | elif mode == 'crop': 39 | from evaluation_utils import img_preprocess 40 | out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()] 41 | elif mode == 'crop_blur_highlight': 42 | from evaluation_utils import img_preprocess 43 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()] 44 | elif mode == 'crop_blur_highlight352': 45 | from evaluation_utils import img_preprocess 46 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()] 47 | elif mode == 'shape': 48 | out = [np.stack([seg[:, :]]*3).astype('float32')] 49 | elif mode == 'concat': 50 | out = [np.concatenate([img, seg[None, :, :]]).astype('float32')] 51 | elif mode == 'image_only': 52 | out = [img.astype('float32')] 53 | elif mode == 'image_black': 54 | out = [img.astype('float32')*0] 55 | elif mode is None: 56 | out = [img.astype('float32')] 57 | elif mode == 'separate': 58 | out = [img.astype('float32'), seg.astype('int64')] 59 | elif mode == 'separate_img_black': 60 | out = [img.astype('float32')*0, seg.astype('int64')] 61 | elif mode == 'separate_seg_ones': 62 | out = [img.astype('float32'), np.ones_like(seg).astype('int64')] 63 | elif mode == 'separate_both_black': 64 | out = [img.astype('float32')*0, seg.astype('int64')*0] 65 | else: 66 | raise ValueError(f'invalid mode: {mode}') 67 | 68 | return out -------------------------------------------------------------------------------- /clipseg/environment.yml: -------------------------------------------------------------------------------- 1 | name: clipseg-environment 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - numpy 7 | - scipy 8 | - matplotlib-base 9 | - pip 10 | - pip: 11 | - --find-links https://download.pytorch.org/whl/torch_stable.html 12 | - torch==1.12.1+cpu 13 | - torchvision==0.13.1+cpu 14 | - opencv-python 15 | - git+https://github.com/openai/CLIP.git 16 | -------------------------------------------------------------------------------- /clipseg/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | from general_utils import load_model 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import numpy as np 6 | 7 | def denorm(img): 8 | 9 | np_input = False 10 | if isinstance(img, np.ndarray): 11 | img = torch.from_numpy(img) 12 | np_input = True 13 | 14 | mean = torch.Tensor([0.485, 0.456, 0.406]) 15 | std = torch.Tensor([0.229, 0.224, 0.225]) 16 | 17 | img_denorm = (img*std[:,None,None]) + mean[:,None,None] 18 | 19 | if np_input: 20 | img_denorm = np.clip(img_denorm.numpy(), 0, 1) 21 | else: 22 | img_denorm = torch.clamp(img_denorm, 0, 1) 23 | 24 | return img_denorm 25 | 26 | 27 | def norm(img): 28 | mean = torch.Tensor([0.485, 0.456, 0.406]) 29 | std = torch.Tensor([0.229, 0.224, 0.225]) 30 | return (img - mean[:,None,None]) / std[:,None,None] 31 | 32 | 33 | def fast_iou_curve(p, g): 34 | 35 | g = g[p.sort().indices] 36 | p = torch.sigmoid(p.sort().values) 37 | 38 | scores = [] 39 | vals = np.linspace(0, 1, 50) 40 | 41 | for q in vals: 42 | 43 | n = int(len(g) * q) 44 | 45 | valid = torch.where(p > q)[0] 46 | if len(valid) > 0: 47 | n = int(valid[0]) 48 | else: 49 | n = len(g) 50 | 51 | fn = g[:n].sum() 52 | tn = n - fn 53 | tp = g[n:].sum() 54 | fp = len(g) - n - tp 55 | 56 | iou = tp / (tp + fn + fp) 57 | 58 | precision = tp / (tp + fp) 59 | recall = tp / (tp + fn) 60 | 61 | scores += [iou] 62 | 63 | return vals, scores 64 | 65 | 66 | def fast_rp_curve(p, g): 67 | 68 | g = g[p.sort().indices] 69 | p = torch.sigmoid(p.sort().values) 70 | 71 | precisions, recalls = [], [] 72 | vals = np.linspace(p.min(), p.max(), 250) 73 | 74 | for q in p[::100000]: 75 | 76 | n = int(len(g) * q) 77 | 78 | valid = torch.where(p > q)[0] 79 | if len(valid) > 0: 80 | n = int(valid[0]) 81 | else: 82 | n = len(g) 83 | 84 | fn = g[:n].sum() 85 | tn = n - fn 86 | tp = g[n:].sum() 87 | fp = len(g) - n - tp 88 | 89 | iou = tp / (tp + fn + fp) 90 | 91 | precision = tp / (tp + fp) 92 | recall = tp / (tp + fn) 93 | 94 | precisions += [precision] 95 | recalls += [recall] 96 | 97 | return recalls, precisions 98 | 99 | 100 | # Image processing 101 | 102 | def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2, 103 | brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224): 104 | import cv2 105 | 106 | rw = rect_width 107 | 108 | out = [] 109 | for img, mask in zip(batch[1], batch[2]): 110 | 111 | img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img) 112 | mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask) 113 | 114 | img *= brightness 115 | img_bl = img 116 | if blur > 0: # best 5 117 | img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1) 118 | 119 | if grayscale: 120 | img_bl = img_bl[1][None] 121 | 122 | #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl 123 | # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask) 124 | img_inp = img*mask + (bg_fac) * img_bl * (1-mask) 125 | 126 | if rect: 127 | _, bbox = crop_mask(img, mask, context=0.1) 128 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None] 129 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None] 130 | img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] 131 | img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] 132 | 133 | 134 | if center_context is not None: 135 | img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size) 136 | 137 | if colorize: 138 | img_gray = denorm(img) 139 | img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY) 140 | img_gray = torch.stack([torch.from_numpy(img_gray)]*3) 141 | img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask) 142 | img_inp = norm(img_inp) 143 | 144 | if outline: 145 | cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 146 | outline_img = np.zeros(mask.shape, dtype=np.uint8) 147 | cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255)) 148 | outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255. 149 | img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img) 150 | img_inp = norm(img_inp) 151 | 152 | out += [img_inp] 153 | 154 | return torch.stack(out) 155 | 156 | 157 | def object_crop(img, mask, context=0.0, square=False, image_size=224): 158 | img_crop, bbox = crop_mask(img, mask, context=context, square=square) 159 | img_crop = pad_to_square(img_crop, channel_dim=0) 160 | img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0) 161 | return img_crop 162 | 163 | 164 | def crop_mask(img, mask, context=0.0, square=False): 165 | 166 | assert img.shape[1:] == mask.shape 167 | 168 | bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()] 169 | bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()] 170 | bbox = [int(x) for x in bbox] 171 | 172 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 173 | 174 | # square mask 175 | if square: 176 | bbox[0] = int(max(0, bbox[0] - context * height)) 177 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) 178 | bbox[2] = int(max(0, bbox[2] - context * width)) 179 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) 180 | 181 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 182 | if height > width: 183 | bbox[2] = int(max(0, (bbox[2] - 0.5*height))) 184 | bbox[3] = bbox[2] + height 185 | else: 186 | bbox[0] = int(max(0, (bbox[0] - 0.5*width))) 187 | bbox[1] = bbox[0] + width 188 | else: 189 | bbox[0] = int(max(0, bbox[0] - context * height)) 190 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) 191 | bbox[2] = int(max(0, bbox[2] - context * width)) 192 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) 193 | 194 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) 195 | img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]] 196 | return img_crop, bbox 197 | 198 | 199 | def pad_to_square(img, channel_dim=2, fill=0): 200 | """ 201 | 202 | 203 | add padding such that a squared image is returned """ 204 | 205 | from torchvision.transforms.functional import pad 206 | 207 | if channel_dim == 2: 208 | img = img.permute(2, 0, 1) 209 | elif channel_dim == 0: 210 | pass 211 | else: 212 | raise ValueError('invalid channel_dim') 213 | 214 | h, w = img.shape[1:] 215 | pady1 = pady2 = padx1 = padx2 = 0 216 | 217 | if h > w: 218 | padx1 = (h - w) // 2 219 | padx2 = h - w - padx1 220 | elif w > h: 221 | pady1 = (w - h) // 2 222 | pady2 = w - h - pady1 223 | 224 | img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant') 225 | 226 | if channel_dim == 2: 227 | img_padded = img_padded.permute(1, 2, 0) 228 | 229 | return img_padded 230 | 231 | 232 | # qualitative 233 | 234 | def split_sentence(inp, limit=9): 235 | t_new, current_len = [], 0 236 | for k, t in enumerate(inp.split(' ')): 237 | current_len += len(t) + 1 238 | t_new += [t+' '] 239 | # not last 240 | if current_len > limit and k != len(inp.split(' ')) - 1: 241 | current_len = 0 242 | t_new += ['\n'] 243 | 244 | t_new = ''.join(t_new) 245 | return t_new 246 | 247 | 248 | from matplotlib import pyplot as plt 249 | 250 | 251 | def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None): 252 | 253 | row_off = 0 if labels is None else 1 254 | _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2))) 255 | [a.axis('off') for a in ax.flatten()] 256 | 257 | if labels is not None: 258 | for j in range(len(labels)): 259 | t_new = split_sentence(labels[j], limit=6) 260 | ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale) 261 | 262 | 263 | for i in range(len(imgs)): 264 | ax[i + row_off,0].imshow(imgs[i]) 265 | for j in range(len(preds)): 266 | img = preds[j][i][0].detach().cpu().numpy() 267 | 268 | if gt_labels is not None and labels[j] == gt_labels[i]: 269 | print(j, labels[j], gt_labels[i]) 270 | edgecolor = 'red' 271 | if aps is not None: 272 | ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8) 273 | else: 274 | edgecolor = 'k' 275 | 276 | rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none", 277 | edgecolor=edgecolor, linewidth=3) 278 | ax[i + row_off,1 + j].add_patch(rect) 279 | 280 | if vmax is None: 281 | this_vmax = 1 282 | elif vmax == 'per_prompt': 283 | this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))]) 284 | elif vmax == 'per_image': 285 | this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))]) 286 | 287 | ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap) 288 | 289 | 290 | # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max()) 291 | plt.tight_layout() 292 | plt.subplots_adjust(wspace=0.05, hspace=0.05) -------------------------------------------------------------------------------- /clipseg/example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/example_image.jpg -------------------------------------------------------------------------------- /clipseg/experiments/ablation.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 # <-########################################## 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut # <----------------- 20 | split_mode: pascal_test 21 | split: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | negative_prob: 0.2 25 | mix_text_max: 0.5 26 | 27 | # general 28 | mix: True # <----------------- 29 | prompt: shuffle+ 30 | norm_cond: True 31 | mix_text_min: 0.0 32 | with_visual: True 33 | 34 | # model 35 | version: 'ViT-B/16' 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False # <-########################################## 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | batch_size: 32 48 | sigmoid: True 49 | split: test 50 | label_support: True 51 | 52 | test_configuration: 53 | 54 | - 55 | name: pc 56 | metric: metrics.FixedIntervalMetrics 57 | test_dataset: phrasecut 58 | mask: text 59 | 60 | - 61 | name: pc-vis 62 | metric: metrics.FixedIntervalMetrics 63 | test_dataset: phrasecut 64 | mask: crop_blur_highlight352 65 | with_visual: True 66 | visual_only: True 67 | 68 | 69 | columns: [name, 70 | pc_fgiou_best, pc_miou_best, pc_fgiou_0.5, 71 | pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5, 72 | duration] 73 | 74 | 75 | individual_configurations: 76 | 77 | - {name: rd64-uni} 78 | - {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003} 79 | - {name: rd64-no-negatives, negative_prob: 0.0} 80 | - {name: rd64-neg0.5, negative_prob: 0.5} 81 | - {name: rd64-no-visual, with_visual: False, mix: False} 82 | - {name: rd16-uni, reduce_dim: 16} 83 | - {name: rd64-layer3, extract_layers: [3], depth: 1} 84 | - {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}} -------------------------------------------------------------------------------- /clipseg/experiments/coco.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.coco_wrapper.COCOWrapper 20 | # split_mode: pascal_test 21 | split: train 22 | mask: text_and_blur3_highlight01 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | 28 | # general 29 | mix: True 30 | prompt: shuffle+ 31 | norm_cond: True 32 | mix_text_min: 0.0 33 | 34 | # model 35 | out: 1 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | # max_iterations: 10 48 | batch_size: 8 49 | sigmoid: True 50 | test_dataset: coco 51 | metric: metrics.FixedIntervalMetrics 52 | 53 | test_configuration: 54 | 55 | - 56 | name: coco_t 57 | mask: text 58 | 59 | - 60 | name: coco_h 61 | mask: blur3_highlight01 62 | 63 | - 64 | name: coco_h2 65 | mask: crop_blur_highlight352 66 | 67 | 68 | columns: [i, name, 69 | coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5, 70 | coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5, 71 | coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t, 72 | train_loss, duration, date 73 | ] 74 | 75 | individual_configurations: 76 | 77 | 78 | - {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 79 | - {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 80 | - {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 81 | - {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 82 | 83 | 84 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 85 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 86 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 87 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 88 | 89 | 90 | # ViT 91 | - {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 92 | - {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 93 | - {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 94 | - {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} 95 | 96 | 97 | # BASELINE 98 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 99 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 100 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} 101 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} -------------------------------------------------------------------------------- /clipseg/experiments/pascal_0shot.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | # val_metric_class: metrics.BinaryIoU 8 | # use_val_metric: BIoU_fg 9 | 10 | trainer: experiment_setup.train_loop 11 | scorer: experiment_setup.score 12 | model: models.clipseg.CLIPDensePredT 13 | 14 | lr_scheduler: cosine 15 | T_max: 20000 16 | eta_min: 0.0001 17 | 18 | max_iterations: 20000 # <-########################################## 19 | val_interval: null 20 | 21 | # dataset 22 | dataset: datasets.phrasecut.PhraseCut # <----------------- 23 | split_mode: pascal_test 24 | split: train 25 | image_size: 352 26 | normalize: True 27 | pre_crop_image_size: [sample, 1, 1.5] 28 | aug: 1new 29 | 30 | # new, not 31 | with_visual: True 32 | 33 | # general 34 | mix: False # <----------------- 35 | prompt: shuffle+ 36 | norm_cond: True 37 | mix_text_min: 0.0 38 | 39 | # model 40 | out: 1 41 | extract_layers: [3, 7, 9] 42 | reduce_dim: 64 43 | depth: 3 44 | fix_shift: False # <-########################################## 45 | 46 | loss: torch.nn.functional.binary_cross_entropy_with_logits 47 | amp: True 48 | 49 | test_configuration_common: 50 | normalize: True 51 | image_size: 352 52 | batch_size: 32 53 | # max_iterations: 150 54 | 55 | test_configuration: 56 | test_dataset: pascal_zs 57 | 58 | columns: [name, pas_zs_seen, pas_zs_unseen, duration, date] 59 | 60 | - {name: rd64-uni-zs5, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, remove_classes: [zs, 5], negative_prob: 0.2, mix: True, mix_text_max: 0.5, mask: text_and_crop_blur_highlight352} 61 | - {name: rd64-uni-zs2, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, remove_classes: [zs, 2], negative_prob: 0.2, mix: True, mix_text_max: 0.5, mask: text_and_crop_blur_highlight352} 62 | -------------------------------------------------------------------------------- /clipseg/experiments/pascal_1shot.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 # <-########################################## 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut 20 | split_mode: pascal_test 21 | mode: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | with_visual: True 28 | split: train 29 | 30 | # general 31 | mix: True 32 | prompt: shuffle+ 33 | norm_cond: True 34 | mix_text_min: 0.0 35 | 36 | # model 37 | out: 1 38 | version: 'ViT-B/16' 39 | extract_layers: [3, 7, 9] 40 | reduce_dim: 64 41 | depth: 3 42 | 43 | loss: torch.nn.functional.binary_cross_entropy_with_logits 44 | amp: True 45 | 46 | test_configuration_common: 47 | normalize: True 48 | image_size: 352 49 | metric: metrics.FixedIntervalMetrics 50 | batch_size: 1 51 | test_dataset: pascal 52 | sigmoid: True 53 | # max_iterations: 250 54 | 55 | test_configuration: 56 | 57 | - 58 | name: pas_t 59 | mask: text 60 | 61 | - 62 | name: pas_h 63 | mask: blur3_highlight01 64 | 65 | - 66 | name: pas_h2 67 | mask: crop_blur_highlight352 68 | 69 | 70 | columns: [name, 71 | pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct, 72 | pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct, 73 | pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t, 74 | train_loss, duration, date 75 | ] 76 | 77 | individual_configurations: 78 | 79 | - {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}} 80 | - {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}} 81 | - {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}} 82 | - {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}} 83 | 84 | 85 | - {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}} 86 | - {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}} 87 | - {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}} 88 | - {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}} 89 | 90 | 91 | # baseline 92 | - {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}} 93 | - {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}} 94 | - {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}} 95 | - {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}} 96 | 97 | # ViT 98 | - {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}} 99 | - {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}} 100 | - {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}} 101 | - {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}} 102 | -------------------------------------------------------------------------------- /clipseg/experiments/phrasecut.yaml: -------------------------------------------------------------------------------- 1 | configuration: 2 | batch_size: 64 3 | optimizer: torch.optim.AdamW 4 | 5 | lr: 0.001 6 | 7 | trainer: experiment_setup.train_loop 8 | scorer: experiment_setup.score 9 | model: models.clipseg.CLIPDensePredT 10 | 11 | lr_scheduler: cosine 12 | T_max: 20000 13 | eta_min: 0.0001 14 | 15 | max_iterations: 20000 16 | val_interval: null 17 | 18 | # dataset 19 | dataset: datasets.phrasecut.PhraseCut # <----------------- 20 | split_mode: pascal_test 21 | split: train 22 | mask: text_and_crop_blur_highlight352 23 | image_size: 352 24 | normalize: True 25 | pre_crop_image_size: [sample, 1, 1.5] 26 | aug: 1new 27 | 28 | # general 29 | mix: False # <----------------- 30 | prompt: shuffle+ 31 | norm_cond: True 32 | mix_text_min: 0.0 33 | 34 | # model 35 | out: 1 36 | extract_layers: [3, 7, 9] 37 | reduce_dim: 64 38 | depth: 3 39 | fix_shift: False 40 | 41 | loss: torch.nn.functional.binary_cross_entropy_with_logits 42 | amp: True 43 | 44 | test_configuration_common: 45 | normalize: True 46 | image_size: 352 47 | batch_size: 32 48 | # max_iterations: 5 49 | # max_iterations: 150 50 | 51 | test_configuration: 52 | 53 | - 54 | name: pc # old: phrasecut 55 | metric: metrics.FixedIntervalMetrics 56 | test_dataset: phrasecut 57 | split: test 58 | mask: text 59 | label_support: True 60 | sigmoid: True 61 | 62 | 63 | columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date] 64 | 65 | 66 | individual_configurations: 67 | 68 | # important ones 69 | 70 | 71 | - {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5} 72 | 73 | # this is almost the same training setting as the refined model except for transformer dropout of 0.1 (currently not implemented in the model) 74 | - {name: rd64-uni-refined, version: 'ViT-B/16', reduce_dim: 64, negative_prob: 0.2, complex_trans_conv: True, with_visual: True, mix: True, mix_text_max: 0.5, T_max: 50000, max_iterations: 50000} 75 | 76 | 77 | # this was accedentally trained using old mask 78 | - {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01} 79 | - {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False} 80 | # this was accedentally trained using old mask 81 | - {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01} 82 | 83 | - {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003} 84 | - {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001} 85 | -------------------------------------------------------------------------------- /clipseg/general_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import inspect 3 | import torch 4 | import os 5 | import sys 6 | import yaml 7 | from shutil import copy, copytree 8 | from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename 9 | 10 | 11 | class Logger(object): 12 | 13 | def __getattr__(self, k): 14 | return print 15 | 16 | log = Logger() 17 | 18 | def training_config_from_cli_args(): 19 | experiment_name = sys.argv[1] 20 | experiment_id = int(sys.argv[2]) 21 | 22 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) 23 | 24 | config = yaml_config['configuration'] 25 | config = {**config, **yaml_config['individual_configurations'][experiment_id]} 26 | config = AttributeDict(config) 27 | return config 28 | 29 | 30 | def score_config_from_cli_args(): 31 | experiment_name = sys.argv[1] 32 | experiment_id = int(sys.argv[2]) 33 | 34 | 35 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) 36 | 37 | config = yaml_config['test_configuration_common'] 38 | 39 | if type(yaml_config['test_configuration']) == list: 40 | test_id = int(sys.argv[3]) 41 | config = {**config, **yaml_config['test_configuration'][test_id]} 42 | else: 43 | config = {**config, **yaml_config['test_configuration']} 44 | 45 | if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]: 46 | config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']} 47 | 48 | train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name'] 49 | 50 | config = AttributeDict(config) 51 | return config, train_checkpoint_id 52 | 53 | 54 | def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository', 55 | local_dir='~/datasets'): 56 | """ copies files from repository to local folder. 57 | 58 | repo_files: list of filenames or list of tuples [filename, target path] 59 | 60 | e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar']) 61 | will create a folder 'MyDataset' in local_dir, and extract the content of 62 | '/data/dataset1.tar' to /MyDataset/other/path. 63 | """ 64 | 65 | local_dir = realpath(join(expanduser(local_dir), local_name)) 66 | 67 | dataset_exists = True 68 | 69 | # check if folder is available 70 | if not isdir(local_dir): 71 | dataset_exists = False 72 | 73 | if integrity_check is not None: 74 | try: 75 | integrity_ok = integrity_check(local_dir) 76 | except BaseException: 77 | integrity_ok = False 78 | 79 | if integrity_ok: 80 | log.hint('Passed custom integrity check') 81 | else: 82 | log.hint('Custom integrity check failed') 83 | 84 | dataset_exists = dataset_exists and integrity_ok 85 | 86 | if not dataset_exists: 87 | 88 | repo_dir = realpath(expanduser(repo_dir)) 89 | 90 | for i, filename in enumerate(repo_files): 91 | 92 | if type(filename) == str: 93 | origin, target = filename, filename 94 | archive_target = join(local_dir, basename(origin)) 95 | extract_target = join(local_dir) 96 | else: 97 | origin, target = filename 98 | archive_target = join(local_dir, dirname(target), basename(origin)) 99 | extract_target = join(local_dir, dirname(target)) 100 | 101 | archive_origin = join(repo_dir, origin) 102 | 103 | log.hint(f'copy: {archive_origin} to {archive_target}') 104 | 105 | # make sure the path exists 106 | os.makedirs(dirname(archive_target), exist_ok=True) 107 | 108 | if os.path.isfile(archive_target): 109 | # only copy if size differs 110 | if os.path.getsize(archive_target) != os.path.getsize(archive_origin): 111 | log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}') 112 | copy(archive_origin, archive_target) 113 | else: 114 | copy(archive_origin, archive_target) 115 | 116 | extract_archive(archive_target, extract_target, noarchive_ok=True) 117 | 118 | # concurrent processes might have deleted the file 119 | if os.path.isfile(archive_target): 120 | os.remove(archive_target) 121 | 122 | 123 | def extract_archive(filename, target_folder=None, noarchive_ok=False): 124 | from subprocess import run, PIPE 125 | 126 | if filename.endswith('.tgz') or filename.endswith('.tar'): 127 | command = f'tar -xf {filename}' 128 | command += f' -C {target_folder}' if target_folder is not None else '' 129 | elif filename.endswith('.tar.gz'): 130 | command = f'tar -xzf {filename}' 131 | command += f' -C {target_folder}' if target_folder is not None else '' 132 | elif filename.endswith('zip'): 133 | command = f'unzip {filename}' 134 | command += f' -d {target_folder}' if target_folder is not None else '' 135 | else: 136 | if noarchive_ok: 137 | return 138 | else: 139 | raise ValueError(f'unsuppored file ending of {filename}') 140 | 141 | log.hint(command) 142 | result = run(command.split(), stdout=PIPE, stderr=PIPE) 143 | if result.returncode != 0: 144 | print(result.stdout, result.stderr) 145 | 146 | 147 | class AttributeDict(dict): 148 | """ 149 | An extended dictionary that allows access to elements as atttributes and counts 150 | these accesses. This way, we know if some attributes were never used. 151 | """ 152 | 153 | def __init__(self, *args, **kwargs): 154 | from collections import Counter 155 | super().__init__(*args, **kwargs) 156 | self.__dict__['counter'] = Counter() 157 | 158 | def __getitem__(self, k): 159 | self.__dict__['counter'][k] += 1 160 | return super().__getitem__(k) 161 | 162 | def __getattr__(self, k): 163 | self.__dict__['counter'][k] += 1 164 | return super().get(k) 165 | 166 | def __setattr__(self, k, v): 167 | return super().__setitem__(k, v) 168 | 169 | def __delattr__(self, k, v): 170 | return super().__delitem__(k, v) 171 | 172 | def unused_keys(self, exceptions=()): 173 | return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions] 174 | 175 | def assume_no_unused_keys(self, exceptions=()): 176 | if len(self.unused_keys(exceptions=exceptions)) > 0: 177 | log.warning('Unused keys:', self.unused_keys(exceptions=exceptions)) 178 | 179 | 180 | def get_attribute(name): 181 | import importlib 182 | 183 | if name is None: 184 | raise ValueError('The provided attribute is None') 185 | 186 | name_split = name.split('.') 187 | mod = importlib.import_module('.'.join(name_split[:-1])) 188 | return getattr(mod, name_split[-1]) 189 | 190 | 191 | 192 | def filter_args(input_args, default_args): 193 | 194 | updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()} 195 | used_args = {k: v for k, v in input_args.items() if k in default_args} 196 | unused_args = {k: v for k, v in input_args.items() if k not in default_args} 197 | 198 | return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args) 199 | 200 | 201 | def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False): 202 | 203 | config = json.load(open(join('logs', checkpoint_id, 'config.json'))) 204 | 205 | if model_args != 'from_config' and type(model_args) != dict: 206 | raise ValueError('model_args must either be "from_config" or a dictionary of values') 207 | 208 | model_cls = get_attribute(config['model']) 209 | 210 | # load model 211 | if model_args == 'from_config': 212 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 213 | 214 | model = model_cls(**model_args) 215 | 216 | if weights_file is None: 217 | weights_file = realpath(join('logs', checkpoint_id, 'weights.pth')) 218 | else: 219 | weights_file = realpath(join('logs', checkpoint_id, weights_file)) 220 | 221 | if isfile(weights_file): 222 | weights = torch.load(weights_file) 223 | for _, w in weights.items(): 224 | assert not torch.any(torch.isnan(w)), 'weights contain NaNs' 225 | model.load_state_dict(weights, strict=strict) 226 | else: 227 | raise FileNotFoundError(f'model checkpoint {weights_file} was not found') 228 | 229 | if with_config: 230 | return model, config 231 | 232 | return model 233 | 234 | 235 | class TrainingLogger(object): 236 | 237 | def __init__(self, model, log_dir, config=None, *args): 238 | super().__init__() 239 | self.model = model 240 | self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None 241 | 242 | os.makedirs('logs/', exist_ok=True) 243 | os.makedirs(self.base_path, exist_ok=True) 244 | 245 | if config is not None: 246 | json.dump(config, open(join(self.base_path, 'config.json'), 'w')) 247 | 248 | def iter(self, i, **kwargs): 249 | if i % 100 == 0 and 'loss' in kwargs: 250 | loss = kwargs['loss'] 251 | print(f'iteration {i}: loss {loss:.4f}') 252 | 253 | def save_weights(self, only_trainable=False, weight_file='weights.pth'): 254 | if self.model is None: 255 | raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.') 256 | 257 | weights_path = join(self.base_path, weight_file) 258 | 259 | weight_dict = self.model.state_dict() 260 | 261 | if only_trainable: 262 | weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad} 263 | 264 | torch.save(weight_dict, weights_path) 265 | log.info(f'Saved weights to {weights_path}') 266 | 267 | def __enter__(self): 268 | return self 269 | 270 | def __exit__(self, type, value, traceback): 271 | """ automatically stop processes if used in a context manager """ 272 | pass -------------------------------------------------------------------------------- /clipseg/metrics.py: -------------------------------------------------------------------------------- 1 | from torch.functional import Tensor 2 | from general_utils import log 3 | from collections import defaultdict 4 | import numpy as np 5 | 6 | import torch 7 | from torch.nn import functional as nnf 8 | 9 | 10 | class BaseMetric(object): 11 | 12 | def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True, 13 | eval_validation=True): 14 | self._names = tuple(metric_names) 15 | self._eval_intermediate = eval_intermediate 16 | self._eval_validation = eval_validation 17 | 18 | self._pred_range = pred_range 19 | self._pred_index = pred_index 20 | self._gt_index = gt_index 21 | 22 | self.predictions = [] 23 | self.ground_truths = [] 24 | 25 | def eval_intermediate(self): 26 | return self._eval_intermediate 27 | 28 | def eval_validation(self): 29 | return self._eval_validation 30 | 31 | def names(self): 32 | return self._names 33 | 34 | def add(self, predictions, ground_truth): 35 | raise NotImplementedError 36 | 37 | def value(self): 38 | raise NotImplementedError 39 | 40 | def scores(self): 41 | # similar to value but returns dict 42 | value = self.value() 43 | if type(value) == dict: 44 | return value 45 | else: 46 | assert type(value) in {list, tuple} 47 | return list(zip(self.names(), self.value())) 48 | 49 | def _get_pred_gt(self, predictions, ground_truth): 50 | pred = predictions[self._pred_index] 51 | gt = ground_truth[self._gt_index] 52 | 53 | if self._pred_range is not None: 54 | pred = pred[:, self._pred_range[0]: self._pred_range[1]] 55 | 56 | return pred, gt 57 | 58 | 59 | class FixedIntervalMetrics(BaseMetric): 60 | 61 | def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None, 62 | resize_pred=None, n_values=51, custom_threshold=None): 63 | 64 | 65 | super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh')) 66 | self.intersections = [] 67 | self.unions = [] 68 | # self.threshold = threshold 69 | self.sigmoid = sigmoid 70 | self.resize_to = resize_to 71 | self.resize_pred = resize_pred # resize prediction to match ground truth 72 | self.class_count = defaultdict(lambda: 0) 73 | self.per_class = defaultdict(lambda : [0,0]) 74 | self.ignore_mask = ignore_mask 75 | self.custom_threshold = custom_threshold 76 | 77 | self.scores_ap = [] 78 | self.scores_iou = [] 79 | self.gts, self.preds = [], [] 80 | self.classes = [] 81 | 82 | # [1:-1] ignores 0 and 1 83 | self.threshold_values = np.linspace(0, 1, n_values)[1:-1] 84 | 85 | self.metrics = dict(tp=[], fp=[], fn=[], tn=[]) 86 | 87 | def add(self, pred, gt): 88 | 89 | pred_batch = pred[0].cpu() 90 | 91 | if self.sigmoid: 92 | pred_batch = torch.sigmoid(pred_batch) 93 | 94 | gt_batch = gt[0].cpu() 95 | mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch)) 96 | cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch) 97 | 98 | if self.resize_to is not None: 99 | gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest') 100 | pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False) 101 | 102 | if isinstance(cls_batch, torch.Tensor): 103 | cls_batch = cls_batch.cpu().numpy().tolist() 104 | 105 | assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}' 106 | 107 | for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch): 108 | 109 | if self.resize_pred: 110 | predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True) 111 | 112 | p = predictions.flatten() 113 | g = ground_truth.flatten() 114 | 115 | assert len(p) == len(g) 116 | 117 | if mask is not None: 118 | m = mask.flatten().bool() 119 | p = p[m] 120 | g = g[m] 121 | 122 | p_sorted = p.sort() 123 | p = p_sorted.values 124 | g = g[p_sorted.indices] 125 | 126 | tps, fps, fns, tns = [], [], [], [] 127 | for thresh in self.threshold_values: 128 | 129 | valid = torch.where(p > thresh)[0] 130 | if len(valid) > 0: 131 | n = int(valid[0]) 132 | else: 133 | n = len(g) 134 | 135 | fn = int(g[:n].sum()) 136 | tp = int(g[n:].sum()) 137 | fns += [fn] 138 | tns += [n - fn] 139 | tps += [tp] 140 | fps += [len(g) - n - tp] 141 | 142 | self.metrics['tp'] += [tps] 143 | self.metrics['fp'] += [fps] 144 | self.metrics['fn'] += [fns] 145 | self.metrics['tn'] += [tns] 146 | 147 | self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls] 148 | 149 | def value(self): 150 | 151 | import time 152 | t_start = time.time() 153 | 154 | if set(self.classes) == set([None]): 155 | all_classes = None 156 | log.warning('classes were not provided, cannot compute mIoU') 157 | else: 158 | all_classes = set(int(c) for c in self.classes) 159 | # log.info(f'compute metrics for {len(all_classes)} classes') 160 | 161 | summed = {k: [sum([self.metrics[k][i][j] 162 | for i in range(len(self.metrics[k]))]) 163 | for j in range(len(self.threshold_values))] 164 | for k in self.metrics.keys()} 165 | 166 | if all_classes is not None: 167 | 168 | assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn']) 169 | # group by class 170 | metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes} 171 | for i in range(len(self.metrics['tp'])): 172 | for k in self.metrics.keys(): 173 | metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]] 174 | 175 | # sum over all instances within the classes 176 | summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()} 177 | 178 | 179 | # Compute average precision 180 | 181 | assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made' 182 | 183 | # only consider values where a prediction is made 184 | precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values)) 185 | if summed['tp'][j] + summed['fp'][j] > 0] 186 | recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values)) 187 | if summed['tp'][j] + summed['fp'][j] > 0] 188 | 189 | # remove duplicate recall-precision-pairs (and sort by recall value) 190 | recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0])) 191 | 192 | from scipy.integrate import simps 193 | ap = simps(precisions, recalls) 194 | 195 | # Compute best IoU 196 | fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))] 197 | 198 | biniou_scores = [ 199 | 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) + 200 | 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j])) 201 | for j in range(len(self.threshold_values)) 202 | ] 203 | 204 | index_0p5 = self.threshold_values.tolist().index(0.5) 205 | index_0p1 = self.threshold_values.tolist().index(0.1) 206 | index_0p2 = self.threshold_values.tolist().index(0.2) 207 | index_0p3 = self.threshold_values.tolist().index(0.3) 208 | 209 | if self.custom_threshold is not None: 210 | index_ct = self.threshold_values.tolist().index(self.custom_threshold) 211 | 212 | if all_classes is not None: 213 | # mean IoU 214 | mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j]) 215 | for c in all_classes]) 216 | for j in range(len(self.threshold_values))] 217 | 218 | mean_iou_dict = { 219 | 'miou_best': max(mean_ious) if all_classes is not None else None, 220 | 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None, 221 | 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None, 222 | 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None, 223 | 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None, 224 | 'miou_best_t': self.threshold_values[np.argmax(mean_ious)], 225 | 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None, 226 | 'mean_iou_scores': mean_ious, 227 | } 228 | 229 | print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s') 230 | 231 | return { 232 | 'ap': ap, 233 | 234 | # fgiou 235 | 'fgiou_best': max(fgiou_scores), 236 | 'fgiou_0.5': fgiou_scores[index_0p5], 237 | 'fgiou_0.1': fgiou_scores[index_0p1], 238 | 'fgiou_0.2': fgiou_scores[index_0p2], 239 | 'fgiou_0.3': fgiou_scores[index_0p3], 240 | 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)], 241 | 242 | # mean iou 243 | 244 | 245 | # biniou 246 | 'biniou_best': max(biniou_scores), 247 | 'biniou_0.5': biniou_scores[index_0p5], 248 | 'biniou_0.1': biniou_scores[index_0p1], 249 | 'biniou_0.2': biniou_scores[index_0p2], 250 | 'biniou_0.3': biniou_scores[index_0p3], 251 | 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)], 252 | 253 | # custom threshold 254 | 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None, 255 | 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None, 256 | 'ct': self.custom_threshold, 257 | 258 | # statistics 259 | 'fgiou_scores': fgiou_scores, 260 | 'biniou_scores': biniou_scores, 261 | 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))), 262 | 'summed_statistics': summed, 263 | 'summed_by_cls_statistics': summed_by_cls, 264 | 265 | **mean_iou_dict 266 | } 267 | 268 | # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh' 269 | 270 | # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls} 271 | 272 | -------------------------------------------------------------------------------- /clipseg/models/vitseg.py: -------------------------------------------------------------------------------- 1 | import math 2 | from posixpath import basename, dirname, join 3 | # import clip 4 | from clip.model import convert_weights 5 | import torch 6 | import json 7 | from torch import nn 8 | from torch.nn import functional as nnf 9 | from torch.nn.modules import activation 10 | from torch.nn.modules.activation import ReLU 11 | from torchvision import transforms 12 | 13 | normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 14 | 15 | from torchvision.models import ResNet 16 | 17 | 18 | def process_prompts(conditional, prompt_list, conditional_map): 19 | # DEPRECATED 20 | 21 | # randomly sample a synonym 22 | words = [conditional_map[int(i)] for i in conditional] 23 | words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] 24 | words = [w.replace('_', ' ') for w in words] 25 | 26 | if prompt_list is not None: 27 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 28 | prompts = [prompt_list[i] for i in prompt_indices] 29 | else: 30 | prompts = ['a photo of {}'] * (len(words)) 31 | 32 | return [promt.format(w) for promt, w in zip(prompts, words)] 33 | 34 | 35 | class VITDenseBase(nn.Module): 36 | 37 | def rescaled_pos_emb(self, new_size): 38 | assert len(new_size) == 2 39 | 40 | a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) 41 | b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T 42 | return torch.cat([self.model.positional_embedding[:1], b]) 43 | 44 | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): 45 | 46 | with torch.no_grad(): 47 | 48 | x_inp = nnf.interpolate(x_inp, (384, 384)) 49 | 50 | x = self.model.patch_embed(x_inp) 51 | cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 52 | if self.model.dist_token is None: 53 | x = torch.cat((cls_token, x), dim=1) 54 | else: 55 | x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 56 | x = self.model.pos_drop(x + self.model.pos_embed) 57 | 58 | activations = [] 59 | for i, block in enumerate(self.model.blocks): 60 | x = block(x) 61 | 62 | if i in extract_layers: 63 | # permute to be compatible with CLIP 64 | activations += [x.permute(1,0,2)] 65 | 66 | x = self.model.norm(x) 67 | x = self.model.head(self.model.pre_logits(x[:, 0])) 68 | 69 | # again for CLIP compatibility 70 | # x = x.permute(1, 0, 2) 71 | 72 | return x, activations, None 73 | 74 | def sample_prompts(self, words, prompt_list=None): 75 | 76 | prompt_list = prompt_list if prompt_list is not None else self.prompt_list 77 | 78 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) 79 | prompts = [prompt_list[i] for i in prompt_indices] 80 | return [promt.format(w) for promt, w in zip(prompts, words)] 81 | 82 | def get_cond_vec(self, conditional, batch_size): 83 | # compute conditional from a single string 84 | if conditional is not None and type(conditional) == str: 85 | cond = self.compute_conditional(conditional) 86 | cond = cond.repeat(batch_size, 1) 87 | 88 | # compute conditional from string list/tuple 89 | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: 90 | assert len(conditional) == batch_size 91 | cond = self.compute_conditional(conditional) 92 | 93 | # use conditional directly 94 | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: 95 | cond = conditional 96 | 97 | # compute conditional from image 98 | elif conditional is not None and type(conditional) == torch.Tensor: 99 | with torch.no_grad(): 100 | cond, _, _ = self.visual_forward(conditional) 101 | else: 102 | raise ValueError('invalid conditional') 103 | return cond 104 | 105 | def compute_conditional(self, conditional): 106 | import clip 107 | 108 | dev = next(self.parameters()).device 109 | 110 | if type(conditional) in {list, tuple}: 111 | text_tokens = clip.tokenize(conditional).to(dev) 112 | cond = self.clip_model.encode_text(text_tokens) 113 | else: 114 | if conditional in self.precomputed_prompts: 115 | cond = self.precomputed_prompts[conditional].float().to(dev) 116 | else: 117 | text_tokens = clip.tokenize([conditional]).to(dev) 118 | cond = self.clip_model.encode_text(text_tokens)[0] 119 | 120 | return cond 121 | 122 | 123 | class VITDensePredT(VITDenseBase): 124 | 125 | def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', 126 | depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, 127 | learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, 128 | add_calibration=False, process_cond=None, not_pretrained=False): 129 | super().__init__() 130 | # device = 'cpu' 131 | 132 | self.extract_layers = extract_layers 133 | self.cond_layer = cond_layer 134 | self.limit_to_clip_only = limit_to_clip_only 135 | self.process_cond = None 136 | 137 | if add_calibration: 138 | self.calibration_conds = 1 139 | 140 | self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None 141 | 142 | self.add_activation1 = True 143 | 144 | import timm 145 | self.model = timm.create_model('vit_base_patch16_384', pretrained=True) 146 | self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) 147 | 148 | for p in self.model.parameters(): 149 | p.requires_grad_(False) 150 | 151 | import clip 152 | self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) 153 | # del self.clip_model.visual 154 | 155 | 156 | self.token_shape = (14, 14) 157 | 158 | # conditional 159 | if reduce_cond is not None: 160 | self.reduce_cond = nn.Linear(512, reduce_cond) 161 | for p in self.reduce_cond.parameters(): 162 | p.requires_grad_(False) 163 | else: 164 | self.reduce_cond = None 165 | 166 | # self.film = AVAILABLE_BLOCKS['film'](512, 128) 167 | self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 168 | self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) 169 | 170 | # DEPRECATED 171 | # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} 172 | 173 | assert len(self.extract_layers) == depth 174 | 175 | self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) 176 | self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) 177 | self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) 178 | 179 | trans_conv_ks = (16, 16) 180 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) 181 | 182 | # refinement and trans conv 183 | 184 | if learn_trans_conv_only: 185 | for p in self.parameters(): 186 | p.requires_grad_(False) 187 | 188 | for p in self.trans_conv.parameters(): 189 | p.requires_grad_(True) 190 | 191 | if prompt == 'fixed': 192 | self.prompt_list = ['a photo of a {}.'] 193 | elif prompt == 'shuffle': 194 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] 195 | elif prompt == 'shuffle+': 196 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', 197 | 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', 198 | 'a bad photo of a {}.', 'a photo of the {}.'] 199 | elif prompt == 'shuffle_clip': 200 | from models.clip_prompts import imagenet_templates 201 | self.prompt_list = imagenet_templates 202 | 203 | if process_cond is not None: 204 | if process_cond == 'clamp' or process_cond[0] == 'clamp': 205 | 206 | val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 207 | 208 | def clamp_vec(x): 209 | return torch.clamp(x, -val, val) 210 | 211 | self.process_cond = clamp_vec 212 | 213 | elif process_cond.endswith('.pth'): 214 | 215 | shift = torch.load(process_cond) 216 | def add_shift(x): 217 | return x + shift.to(x.device) 218 | 219 | self.process_cond = add_shift 220 | 221 | import pickle 222 | precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) 223 | self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} 224 | 225 | 226 | def forward(self, inp_image, conditional=None, return_features=False, mask=None): 227 | 228 | assert type(return_features) == bool 229 | 230 | # inp_image = inp_image.to(self.model.positional_embedding.device) 231 | 232 | if mask is not None: 233 | raise ValueError('mask not supported') 234 | 235 | # x_inp = normalize(inp_image) 236 | x_inp = inp_image 237 | 238 | bs, dev = inp_image.shape[0], x_inp.device 239 | 240 | inp_image_size = inp_image.shape[2:] 241 | 242 | cond = self.get_cond_vec(conditional, bs) 243 | 244 | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) 245 | 246 | activation1 = activations[0] 247 | activations = activations[1:] 248 | 249 | a = None 250 | for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): 251 | 252 | if a is not None: 253 | a = reduce(activation) + a 254 | else: 255 | a = reduce(activation) 256 | 257 | if i == self.cond_layer: 258 | if self.reduce_cond is not None: 259 | cond = self.reduce_cond(cond) 260 | 261 | a = self.film_mul(cond) * a + self.film_add(cond) 262 | 263 | a = block(a) 264 | 265 | for block in self.extra_blocks: 266 | a = a + block(a) 267 | 268 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens 269 | 270 | size = int(math.sqrt(a.shape[2])) 271 | 272 | a = a.view(bs, a.shape[1], size, size) 273 | 274 | if self.trans_conv is not None: 275 | a = self.trans_conv(a) 276 | 277 | if self.upsample_proj is not None: 278 | a = self.upsample_proj(a) 279 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') 280 | 281 | a = nnf.interpolate(a, inp_image_size) 282 | 283 | if return_features: 284 | return a, visual_q, cond, [activation1] + activations 285 | else: 286 | return a, 287 | -------------------------------------------------------------------------------- /clipseg/mycode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | import json 4 | import yaml 5 | import math 6 | import os 7 | import sys 8 | 9 | from general_utils import log 10 | 11 | import numpy as np 12 | from functools import partial 13 | from os.path import expanduser, join, isfile, basename 14 | 15 | from torch.cuda.amp import autocast, GradScaler 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from contextlib import nullcontext 18 | from torch.utils.data import DataLoader 19 | 20 | from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args 21 | 22 | 23 | def cosine_warmup_lr(i, warmup=10, max_iter=90): 24 | """ Cosine LR with Warmup """ 25 | if i < warmup: 26 | return (i+1)/(warmup+1) 27 | else: 28 | return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup)))) 29 | 30 | 31 | def validate(model, dataset, config): 32 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False) 33 | 34 | metric_class, use_metric = config.val_metric_class, config.use_val_metric 35 | loss_fn = get_attribute(config.loss) 36 | 37 | model.eval() 38 | model.cuda() 39 | 40 | if metric_class is not None: 41 | metric = get_attribute(metric_class)() 42 | 43 | with torch.no_grad(): 44 | 45 | i, losses = 0, [] 46 | for data_x, data_y in data_loader: 47 | 48 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] 49 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] 50 | 51 | prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',)) 52 | pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True) 53 | 54 | if metric_class is not None: 55 | metric.add([pred], data_y) 56 | 57 | # pred = model(data_x[0], prompts) 58 | # loss = loss_fn(pred[0], data_y[0]) 59 | loss = loss_fn(pred, data_y[0]) 60 | losses += [float(loss)] 61 | 62 | i += 1 63 | 64 | if config.val_max_iterations is not None and i > config.val_max_iterations: 65 | break 66 | 67 | if use_metric is None: 68 | return np.mean(losses), {}, False 69 | else: 70 | metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {} 71 | return np.mean(losses), metric_scores, True 72 | 73 | 74 | def main(): 75 | 76 | config = training_config_from_cli_args() 77 | 78 | val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf') 79 | 80 | model_cls = get_attribute(config.model) 81 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 82 | model = model_cls(**model_args).cuda() 83 | 84 | dataset_cls = get_attribute(config.dataset) 85 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) 86 | 87 | dataset = dataset_cls(**dataset_args) 88 | 89 | log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})') 90 | 91 | if val_interval is not None: 92 | dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'} 93 | _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters) 94 | print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 95 | 96 | dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 97 | 98 | # optimizer 99 | opt_cls = get_attribute(config.optimizer) 100 | if config.optimize == 'torch.optim.SGD': 101 | opt_args = {'momentum': config.momentum if 'momentum' in config else 0} 102 | else: 103 | opt_args = {} 104 | opt = opt_cls(model.parameters(), lr=config.lr, **opt_args) 105 | 106 | if config.lr_scheduler == 'cosine': 107 | assert config.T_max is not None and config.eta_min is not None 108 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min) 109 | elif config.lr_scheduler == 'warmup_cosine': 110 | lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup)) 111 | else: 112 | lr_scheduler = None 113 | 114 | batch_size, max_iterations = config.batch_size, config.max_iterations 115 | 116 | loss_fn = get_attribute(config.loss) 117 | 118 | if config.amp: 119 | log.info('Using AMP') 120 | autocast_fn = autocast 121 | scaler = GradScaler() 122 | else: 123 | autocast_fn, scaler = nullcontext, None 124 | 125 | 126 | save_only_trainable = True 127 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4) 128 | 129 | # disable config when hyperparam. opt. to avoid writing logs. 130 | tracker_config = config if not config.hyperparameter_optimization else None 131 | 132 | with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger: 133 | 134 | i = 0 135 | while True: 136 | for data_x, data_y in data_loader: 137 | 138 | # between caption and output feature. 139 | # 1. Sample random captions 140 | # 2. Check alignment with CLIP 141 | 142 | # randomly mix text and visual support conditionals 143 | if config.mix: 144 | 145 | assert config.mask.startswith('text_and') 146 | 147 | with autocast_fn(): 148 | # data_x[1] = text label 149 | prompts = model.sample_prompts(data_x[1]) 150 | 151 | # model.clip_model() 152 | 153 | text_cond = model.compute_conditional(prompts) 154 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 155 | # when mask=='separate' 156 | visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda()) 157 | else: 158 | # data_x[2] = visual prompt 159 | visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda()) 160 | 161 | max_txt = config.mix_text_max if config.mix_text_max is not None else 1 162 | batch_size = text_cond.shape[0] 163 | 164 | # sample weights for each element in batch 165 | text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None] 166 | text_weights = text_weights.cuda() 167 | 168 | if dataset.__class__.__name__ == 'PhraseCut': 169 | # give full weight to text where support_image is invalid 170 | visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3] 171 | text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1) 172 | 173 | cond = text_cond * text_weights + visual_s_cond * (1 - text_weights) 174 | 175 | else: 176 | # no mix 177 | 178 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 179 | # compute conditional vector using CLIP masking 180 | with autocast_fn(): 181 | assert config.mask == 'separate' 182 | cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda()) 183 | else: 184 | cond = data_x[1] 185 | if isinstance(cond, torch.Tensor): 186 | cond = cond.cuda() 187 | 188 | with autocast_fn(): 189 | visual_q = None 190 | 191 | pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True) 192 | 193 | loss = loss_fn(pred, data_y[0].cuda()) 194 | 195 | if torch.isnan(loss) or torch.isinf(loss): 196 | # skip if loss is nan 197 | log.warning('Training stopped due to inf/nan loss.') 198 | sys.exit(-1) 199 | 200 | extra_loss = 0 201 | loss += extra_loss 202 | 203 | opt.zero_grad() 204 | 205 | if scaler is None: 206 | loss.backward() 207 | opt.step() 208 | else: 209 | scaler.scale(loss).backward() 210 | scaler.step(opt) 211 | scaler.update() 212 | 213 | if lr_scheduler is not None: 214 | lr_scheduler.step() 215 | if i % 2000 == 0: 216 | current_lr = [g['lr'] for g in opt.param_groups][0] 217 | log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)') 218 | 219 | logger.iter(i=i, loss=loss) 220 | i += 1 221 | 222 | if i >= max_iterations: 223 | 224 | if not isfile(join(logger.base_path, 'weights.pth')): 225 | # only write if no weights were already written 226 | logger.save_weights(only_trainable=save_only_trainable) 227 | 228 | sys.exit(0) 229 | 230 | 231 | if config.checkpoint_iterations is not None and i in config.checkpoint_iterations: 232 | logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth') 233 | 234 | 235 | if val_interval is not None and i % val_interval == val_interval - 1: 236 | 237 | val_loss, val_scores, maximize = validate(model, dataset_val, config) 238 | 239 | if len(val_scores) > 0: 240 | 241 | score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items()) 242 | 243 | if maximize and val_scores[config.use_val_metric] > best_val_score: 244 | logger.save_weights(only_trainable=save_only_trainable) 245 | best_val_score = val_scores[config.use_val_metric] 246 | 247 | elif not maximize and val_scores[config.use_val_metric] < best_val_score: 248 | logger.save_weights(only_trainable=save_only_trainable) 249 | best_val_score = val_scores[config.use_val_metric] 250 | 251 | else: 252 | score_str = '' 253 | # if no score is used, fall back to loss 254 | if val_loss < best_val_loss: 255 | logger.save_weights(only_trainable=save_only_trainable) 256 | best_val_loss = val_loss 257 | 258 | log.info(f'Validation loss: {val_loss}' + score_str) 259 | logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores) 260 | model.train() 261 | 262 | print('epoch complete') 263 | 264 | 265 | if __name__ == '__main__': 266 | main() -------------------------------------------------------------------------------- /clipseg/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/overview.png -------------------------------------------------------------------------------- /clipseg/sample_rd64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/sample_rd64.png -------------------------------------------------------------------------------- /clipseg/sample_rd64_refined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/sample_rd64_refined.png -------------------------------------------------------------------------------- /clipseg/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r", encoding="utf-8") as readme_file: 4 | readme = readme_file.read() 5 | 6 | requirements = [ 7 | "numpy", 8 | "scipy", 9 | "matplotlib", 10 | "torch", 11 | "torchvision", 12 | "opencv-python", 13 | "CLIP @ git+https://github.com/openai/CLIP.git" 14 | ] 15 | 16 | setup( 17 | name='clipseg', 18 | packages=['clipseg'], 19 | package_dir={'clipseg': 'models'}, 20 | package_data={'clipseg': [ 21 | "../weights/*.pth", 22 | ]}, 23 | version='0.0.1', 24 | url='https://github.com/timojl/clipseg', 25 | python_requires='>=3.9', 26 | install_requires=requirements, 27 | description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".', 28 | long_description=readme, 29 | long_description_content_type="text/markdown", 30 | ) 31 | -------------------------------------------------------------------------------- /clipseg/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | import json 4 | import yaml 5 | import math 6 | import os 7 | import sys 8 | 9 | from general_utils import log 10 | 11 | import numpy as np 12 | from functools import partial 13 | from os.path import expanduser, join, isfile, basename 14 | 15 | from torch.cuda.amp import autocast, GradScaler 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from contextlib import nullcontext 18 | from torch.utils.data import DataLoader 19 | 20 | from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args 21 | 22 | 23 | def cosine_warmup_lr(i, warmup=10, max_iter=90): 24 | """ Cosine LR with Warmup """ 25 | if i < warmup: 26 | return (i+1)/(warmup+1) 27 | else: 28 | return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup)))) 29 | 30 | 31 | def validate(model, dataset, config): 32 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False) 33 | 34 | metric_class, use_metric = config.val_metric_class, config.use_val_metric 35 | loss_fn = get_attribute(config.loss) 36 | 37 | model.eval() 38 | model.cuda() 39 | 40 | if metric_class is not None: 41 | metric = get_attribute(metric_class)() 42 | 43 | with torch.no_grad(): 44 | 45 | i, losses = 0, [] 46 | for data_x, data_y in data_loader: 47 | 48 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] 49 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] 50 | 51 | prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',)) 52 | pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True) 53 | 54 | if metric_class is not None: 55 | metric.add([pred], data_y) 56 | 57 | # pred = model(data_x[0], prompts) 58 | # loss = loss_fn(pred[0], data_y[0]) 59 | loss = loss_fn(pred, data_y[0]) 60 | losses += [float(loss)] 61 | 62 | i += 1 63 | 64 | if config.val_max_iterations is not None and i > config.val_max_iterations: 65 | break 66 | 67 | if use_metric is None: 68 | return np.mean(losses), {}, False 69 | else: 70 | metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {} 71 | return np.mean(losses), metric_scores, True 72 | 73 | 74 | def main(): 75 | 76 | config = training_config_from_cli_args() 77 | 78 | val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf') 79 | 80 | model_cls = get_attribute(config.model) 81 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) 82 | model = model_cls(**model_args).cuda() 83 | 84 | dataset_cls = get_attribute(config.dataset) 85 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) 86 | 87 | dataset = dataset_cls(**dataset_args) 88 | 89 | log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})') 90 | 91 | if val_interval is not None: 92 | dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'} 93 | _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters) 94 | print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 95 | 96 | dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) 97 | 98 | # optimizer 99 | opt_cls = get_attribute(config.optimizer) 100 | if config.optimize == 'torch.optim.SGD': 101 | opt_args = {'momentum': config.momentum if 'momentum' in config else 0} 102 | else: 103 | opt_args = {} 104 | opt = opt_cls(model.parameters(), lr=config.lr, **opt_args) 105 | 106 | if config.lr_scheduler == 'cosine': 107 | assert config.T_max is not None and config.eta_min is not None 108 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min) 109 | elif config.lr_scheduler == 'warmup_cosine': 110 | lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup)) 111 | else: 112 | lr_scheduler = None 113 | 114 | batch_size, max_iterations = config.batch_size, config.max_iterations 115 | 116 | loss_fn = get_attribute(config.loss) 117 | 118 | if config.amp: 119 | log.info('Using AMP') 120 | autocast_fn = autocast 121 | scaler = GradScaler() 122 | else: 123 | autocast_fn, scaler = nullcontext, None 124 | 125 | 126 | save_only_trainable = True 127 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4) 128 | 129 | # disable config when hyperparam. opt. to avoid writing logs. 130 | tracker_config = config if not config.hyperparameter_optimization else None 131 | 132 | with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger: 133 | 134 | i = 0 135 | while True: 136 | for data_x, data_y in data_loader: 137 | 138 | # between caption and output feature. 139 | # 1. Sample random captions 140 | # 2. Check alignment with CLIP 141 | 142 | # randomly mix text and visual support conditionals 143 | if config.mix: 144 | 145 | assert config.mask.startswith('text_and') 146 | 147 | with autocast_fn(): 148 | # data_x[1] = text label 149 | prompts = model.sample_prompts(data_x[1]) 150 | 151 | # model.clip_model() 152 | 153 | text_cond = model.compute_conditional(prompts) 154 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 155 | # when mask=='separate' 156 | visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda()) 157 | else: 158 | # data_x[2] = visual prompt 159 | visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda()) 160 | 161 | max_txt = config.mix_text_max if config.mix_text_max is not None else 1 162 | batch_size = text_cond.shape[0] 163 | 164 | # sample weights for each element in batch 165 | text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None] 166 | text_weights = text_weights.cuda() 167 | 168 | if dataset.__class__.__name__ == 'PhraseCut': 169 | # give full weight to text where support_image is invalid 170 | visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3] 171 | text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1) 172 | 173 | cond = text_cond * text_weights + visual_s_cond * (1 - text_weights) 174 | 175 | else: 176 | # no mix 177 | 178 | if model.__class__.__name__ == 'CLIPDensePredTMasked': 179 | # compute conditional vector using CLIP masking 180 | with autocast_fn(): 181 | assert config.mask == 'separate' 182 | cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda()) 183 | else: 184 | cond = data_x[1] 185 | if isinstance(cond, torch.Tensor): 186 | cond = cond.cuda() 187 | 188 | with autocast_fn(): 189 | visual_q = None 190 | 191 | pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True) 192 | 193 | loss = loss_fn(pred, data_y[0].cuda()) 194 | 195 | if torch.isnan(loss) or torch.isinf(loss): 196 | # skip if loss is nan 197 | log.warning('Training stopped due to inf/nan loss.') 198 | sys.exit(-1) 199 | 200 | extra_loss = 0 201 | loss += extra_loss 202 | 203 | opt.zero_grad() 204 | 205 | if scaler is None: 206 | loss.backward() 207 | opt.step() 208 | else: 209 | scaler.scale(loss).backward() 210 | scaler.step(opt) 211 | scaler.update() 212 | 213 | if lr_scheduler is not None: 214 | lr_scheduler.step() 215 | if i % 2000 == 0: 216 | current_lr = [g['lr'] for g in opt.param_groups][0] 217 | log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)') 218 | 219 | logger.iter(i=i, loss=loss) 220 | i += 1 221 | 222 | if i >= max_iterations: 223 | 224 | if not isfile(join(logger.base_path, 'weights.pth')): 225 | # only write if no weights were already written 226 | logger.save_weights(only_trainable=save_only_trainable) 227 | 228 | sys.exit(0) 229 | 230 | 231 | if config.checkpoint_iterations is not None and i in config.checkpoint_iterations: 232 | logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth') 233 | 234 | 235 | if val_interval is not None and i % val_interval == val_interval - 1: 236 | 237 | val_loss, val_scores, maximize = validate(model, dataset_val, config) 238 | 239 | if len(val_scores) > 0: 240 | 241 | score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items()) 242 | 243 | if maximize and val_scores[config.use_val_metric] > best_val_score: 244 | logger.save_weights(only_trainable=save_only_trainable) 245 | best_val_score = val_scores[config.use_val_metric] 246 | 247 | elif not maximize and val_scores[config.use_val_metric] < best_val_score: 248 | logger.save_weights(only_trainable=save_only_trainable) 249 | best_val_score = val_scores[config.use_val_metric] 250 | 251 | else: 252 | score_str = '' 253 | # if no score is used, fall back to loss 254 | if val_loss < best_val_loss: 255 | logger.save_weights(only_trainable=save_only_trainable) 256 | best_val_loss = val_loss 257 | 258 | log.info(f'Validation loss: {val_loss}' + score_str) 259 | logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores) 260 | model.train() 261 | 262 | print('epoch complete') 263 | 264 | 265 | if __name__ == '__main__': 266 | main() -------------------------------------------------------------------------------- /clipseg/weights/rd16-uni.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/weights/rd16-uni.pth -------------------------------------------------------------------------------- /clipseg/weights/rd64-uni-refined.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/weights/rd64-uni-refined.pth -------------------------------------------------------------------------------- /clipseg/weights/rd64-uni.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/weights/rd64-uni.pth -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | UNET_LAYERS = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 2 | 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] 3 | 4 | SD_INFERENCE_TIMESTEPS = [999, 979, 959, 939, 919, 899, 879, 859, 839, 819, 799, 779, 759, 739, 719, 699, 679, 659, 5 | 639, 619, 599, 579, 559, 539, 519, 500, 480, 460, 440, 420, 400, 380, 360, 340, 320, 300, 6 | 280, 260, 240, 220, 200, 180, 160, 140, 120, 100, 80, 60, 40, 20] 7 | 8 | # PROMPTS = [ 9 | # "A photo of {} in the jungle", 10 | # "A photo of {} on a beach", 11 | # "A photo of {} in Times Square", 12 | # "A photo of {} in the moon", 13 | # "A painting of {} in the style of Monet", 14 | # "Oil painting of {}", 15 | # "A Marc Chagall painting of {}", 16 | # "A manga drawing of {}", 17 | # 'A watercolor painting of {}', 18 | # "A statue of {}", 19 | # "App icon of {}", 20 | # "A sand sculpture of {}", 21 | # "Colorful graffiti of {}", 22 | # "{} wearing a Santa hat", 23 | # "{} in a construction outfit", 24 | # "{} on the cover of Time magazine", 25 | # "A movie poster of {}", 26 | # "{} is wearing a doctoral gown and a doctoral cap", 27 | # "{} in an astronaut suit", 28 | # "A photo of {} wearing a life jacket", 29 | # "{} as a jedi master, inside hogwarts with a mystical scepter", 30 | # "{} wearing a sombrero", 31 | # "{} dressed like a wizard, reading a grimoire", 32 | # "{} as a character in the Sherlock Holmes' movie", 33 | # "{} buckled in a seat on a plane", 34 | # "A selfie of {} in Times Square", 35 | # "An oil painting of {}", 36 | # "A photo of {} as a cowboy", 37 | # "{} in a chef's outfit, cooking in a kitchen", 38 | # "{} selfie standing under the pink blossoms of a cherry tree", 39 | # ] 40 | 41 | PROMPTS = [ 42 | "A {} is giving a speech on the podium", 43 | "A {} is walking on the red carpet at the evening gala", 44 | "A {} is holding a Halloween pumpkin lantern", 45 | "A {} is reading a book", 46 | "A {} wearing winter camo military gear in the snow", 47 | "A {} sitting in a hammock", 48 | "A photo of {} as an knight in armor", 49 | "A {} as a police officer", 50 | "A photo of {} in the lab", 51 | "A {} sitting a leather sofa", 52 | "A {} is reading newspaper, sitting in a busy train", 53 | "A {} in a jungle", 54 | "A {} is playing the piano", 55 | "A {} is doing a skateboard", 56 | "A {} is riding a motorcycle, road background", 57 | "{} is working on a beautiful design project, creating design projects, in a beautiful workspace", 58 | "{} driving a racing car", 59 | "{} floats on a raft along the street of a flooded city", 60 | "{} baking pizza in front of a wood-fired oven", 61 | "{} plays bass guitar on stage", 62 | ] 63 | 64 | VALIDATION_PROMPTS = [ 65 | "A photo of a {}", 66 | "A photo of a {} on a beach", 67 | "App icon of {}", 68 | "A painting of {} in the style of Monet", 69 | ] 70 | 71 | IMAGENET_TEMPLATES_SMALL = [ 72 | "a photo of a {}", 73 | "a rendering of a {}", 74 | "a cropped photo of the {}", 75 | "the photo of a {}", 76 | "a photo of a clean {}", 77 | "a photo of a dirty {}", 78 | "a dark photo of the {}", 79 | "a photo of my {}", 80 | "a photo of the cool {}", 81 | "a close-up photo of a {}", 82 | "a bright photo of the {}", 83 | "a cropped photo of a {}", 84 | "a photo of the {}", 85 | "a good photo of the {}", 86 | "a photo of one {}", 87 | "a close-up photo of the {}", 88 | "a rendition of the {}", 89 | "a photo of the clean {}", 90 | "a rendition of a {}", 91 | "a photo of a nice {}", 92 | "a good photo of a {}", 93 | "a photo of the nice {}", 94 | "a photo of the small {}", 95 | "a photo of the weird {}", 96 | "a photo of the large {}", 97 | "a photo of a cool {}", 98 | "a photo of a small {}", 99 | ] 100 | 101 | IMAGENET_STYLE_TEMPLATES_SMALL = [ 102 | "a painting in the style of {}", 103 | "a rendering in the style of {}", 104 | "a cropped painting in the style of {}", 105 | "the painting in the style of {}", 106 | "a clean painting in the style of {}", 107 | "a dirty painting in the style of {}", 108 | "a dark painting in the style of {}", 109 | "a picture in the style of {}", 110 | "a cool painting in the style of {}", 111 | "a close-up painting in the style of {}", 112 | "a bright painting in the style of {}", 113 | "a cropped painting in the style of {}", 114 | "a good painting in the style of {}", 115 | "a close-up painting in the style of {}", 116 | "a rendition in the style of {}", 117 | "a nice painting in the style of {}", 118 | "a small painting in the style of {}", 119 | "a weird painting in the style of {}", 120 | "a large painting in the style of {}", 121 | ] 122 | -------------------------------------------------------------------------------- /data/person_1/person_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/data/person_1/person_1.jpg -------------------------------------------------------------------------------- /environment/environment.yaml: -------------------------------------------------------------------------------- 1 | name: persona 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - invisible-watermark 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit>=0.73.1 24 | - einops==0.3.0 25 | - torch-fidelity==0.3.0 26 | - torchmetrics==0.6.0 27 | - kornia==0.6 28 | - -r requirements.txt 29 | -------------------------------------------------------------------------------- /environment/requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | diffusers 3 | opencv-python==4.1.2.30 4 | pudb==2019.2 5 | invisible-watermark 6 | imageio==2.9.0 7 | imageio-ffmpeg==0.4.2 8 | pytorch-lightning==1.4.2 9 | omegaconf==2.1.1 10 | test-tube>=0.7.5 11 | streamlit>=0.73.1 12 | einops==0.3.0 13 | torch-fidelity==0.3.0 14 | torchmetrics==0.6.0 15 | kornia==0.6 16 | ftfy 17 | opencv-python 18 | ipywidgets 19 | matplotlib 20 | jupyter 21 | pyrallis==0.3.1 22 | loguru==0.7.0 23 | diffusers==0.14.0 24 | transformers==4.27.4 25 | accelerate==0.18.0 26 | -------------------------------------------------------------------------------- /inference_text.txt: -------------------------------------------------------------------------------- 1 | {} wearing a Santa hat 2 | A photo of {} as a cowboy 3 | {} in an astronaut suit -------------------------------------------------------------------------------- /input_configs/inference.yaml: -------------------------------------------------------------------------------- 1 | input_dir: ./logs/person_1 2 | prompts_file_path: ./inference_text.txt 3 | clip_ckpt_path: "/path/to/clip/ckpt/" 4 | iteration: 1000 5 | seeds: [ 233, 234, 235, 236 ] 6 | torch_dtype: fp16 7 | inference_dir: ./results/ 8 | super_category_token: face 9 | image_path: ./data/person_1/person_1.jpg -------------------------------------------------------------------------------- /input_configs/train.yaml: -------------------------------------------------------------------------------- 1 | log: 2 | exp_name: person_1 3 | exp_dir: ./log/person_1 4 | save_steps: 1000 5 | data: 6 | train_data_dir: ./data/person_1 7 | placeholder_token: 8 | super_category_token: face 9 | dataloader_num_workers: 2 10 | model: 11 | pretrained_model_name_or_path: /path/to/diffusion/dreamlike-artdreamlike-photoreal-2.0/ 12 | clip_ckpt_path: "/path/to/clip/ckpt/" 13 | normalize_mapper_output: True 14 | use_positional_encoding: True 15 | num_pe_time_anchors: 200 16 | output_bypass: True 17 | eval: 18 | validation_steps: 2000 19 | optim: 20 | max_train_steps: 1000 21 | learning_rate: 5e-5 22 | train_batch_size: 1 23 | gradient_accumulation_steps: 2 24 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/models/__init__.py -------------------------------------------------------------------------------- /models/clip_prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import clip 5 | import numpy as np 6 | import kornia 7 | import os 8 | 9 | class MultiCLIP(torch.nn.Module): 10 | def __init__(self, clip_ckpt_path, device="cpu"): 11 | super().__init__() 12 | model_32, _ = clip.load(os.path.join(clip_ckpt_path,"ViT-B-32.pt"), device=device) 13 | # model_16, _ = clip.load(os.path.join(clip_ckpt_path,"ViT-B-16.pt"), device=device) 14 | # model_101, _ = clip.load(os.path.join(clip_ckpt_path,"RN101.pt"), device=device) 15 | self.model_32 = model_32 16 | # self.model_16 = model_16 17 | # self.model_101 = model_101 18 | 19 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 20 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 21 | 22 | def preprocess(self, x): 23 | # normalize to [0,1] 24 | x = kornia.geometry.resize(x, (224, 224), 25 | interpolation='bicubic',align_corners=True, 26 | antialias=False) 27 | x = (x + 1.) / 2. 28 | x = kornia.enhance.normalize(x, self.mean, self.std) 29 | return x 30 | 31 | def encode_image(self, image, dtype): 32 | with torch.no_grad(): 33 | image = self.preprocess(image) 34 | vectors = [self.model_32.encode_image(image.to(dtype))] 35 | return torch.cat(vectors, dim=-1).to(dtype) 36 | -------------------------------------------------------------------------------- /models/clip_text_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from transformers import CLIPTextConfig 6 | from typing import Union 7 | from models.mapper import Mapper 8 | from utils.types import Batch 9 | 10 | 11 | class PersonaCLIPTextEmbeddings(nn.Module): 12 | """ Modification of CLIPTextEmbedding to allow for the use of a Mapper to overwrite the concept token. """ 13 | 14 | def __init__(self, config: CLIPTextConfig): 15 | super().__init__() 16 | embed_dim = config.hidden_size 17 | self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) 18 | self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) 19 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 20 | 21 | def set_mapper(self, mapper: Mapper): 22 | self.mapper = mapper 23 | 24 | def forward(self, input_ids: Optional[torch.LongTensor] = None, 25 | position_ids: Optional[torch.LongTensor] = None, 26 | inputs_embeds: Optional[torch.FloatTensor] = None, 27 | batch: Optional[Batch] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 28 | 29 | if batch is not None: 30 | input_ids = batch.input_ids 31 | 32 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] 33 | 34 | if position_ids is None: 35 | position_ids = self.position_ids[:, :seq_length] 36 | 37 | if inputs_embeds is None: 38 | inputs_embeds = self.token_embedding(input_ids) 39 | 40 | bypass_outputs = None 41 | if batch is not None: 42 | mapper_outputs = self.mapper(input=batch.mapper_input) 43 | 44 | mapper_outputs = mapper_outputs.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) 45 | if self.mapper.output_bypass: 46 | bypass_outputs = mapper_outputs[:, mapper_outputs.shape[1] // 2:] 47 | mapper_outputs = mapper_outputs[:, :mapper_outputs.shape[1] // 2] 48 | 49 | # Overwrite the index of the placeholder token with the mapper output for each entry in the batch 50 | 51 | for i in range(len(batch.placeholder_token_id)): 52 | learnable_idxs = (input_ids == batch.placeholder_token_id[i]).nonzero(as_tuple=True)[1] 53 | inputs_embeds[torch.arange(input_ids.shape[0]), learnable_idxs] = mapper_outputs[i] 54 | 55 | position_embeddings = self.position_embedding(position_ids) 56 | embeddings = inputs_embeds + position_embeddings 57 | 58 | return embeddings, bypass_outputs 59 | -------------------------------------------------------------------------------- /models/clip_text_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.utils.checkpoint 5 | from torch import nn 6 | from transformers.modeling_outputs import BaseModelOutputWithPooling 7 | from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPEncoder 8 | from transformers.models.clip.modeling_clip import CLIPTextTransformer, _expand_mask 9 | import torch.nn.functional as F 10 | from models.clip_text_embedding import PersonaCLIPTextEmbeddings 11 | from utils.types import Batch 12 | import numpy as np 13 | from PIL import Image 14 | 15 | class PersonaCLIPTextModel(CLIPTextModel): 16 | def __init__(self, config: CLIPTextConfig): 17 | super().__init__(config) 18 | self.text_model = PersonaCLIPTextTransformer(config) 19 | self.post_init() 20 | 21 | def forward(self, input_ids: Optional[torch.Tensor] = None, 22 | attention_mask: Optional[torch.Tensor] = None, 23 | position_ids: Optional[torch.Tensor] = None, 24 | output_attentions: Optional[bool] = None, 25 | output_hidden_states: Optional[bool] = None, 26 | return_dict: Optional[bool] = None, 27 | batch: Optional[Batch] = None) -> Union[Tuple, BaseModelOutputWithPooling]: 28 | return self.text_model.forward( 29 | batch=batch, 30 | input_ids=input_ids, 31 | attention_mask=attention_mask, 32 | position_ids=position_ids, 33 | output_attentions=output_attentions, 34 | output_hidden_states=output_hidden_states, 35 | return_dict=return_dict, 36 | ) 37 | 38 | 39 | class PersonaCLIPTextTransformer(CLIPTextTransformer): 40 | def __init__(self, config: CLIPTextConfig): 41 | super().__init__(config=config) 42 | self.config = config 43 | embed_dim = config.hidden_size 44 | self.embeddings = PersonaCLIPTextEmbeddings(config) 45 | self.encoder = CLIPEncoder(config) 46 | self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 47 | 48 | def forward(self, input_ids: Optional[torch.Tensor] = None, 49 | attention_mask: Optional[torch.Tensor] = None, 50 | position_ids: Optional[torch.Tensor] = None, 51 | output_attentions: Optional[bool] = None, 52 | output_hidden_states: Optional[bool] = None, 53 | return_dict: Optional[bool] = None, 54 | batch: Optional[Batch] = None) -> Union[Tuple, BaseModelOutputWithPooling]: 55 | 56 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 57 | output_hidden_states = ( 58 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 59 | ) 60 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 61 | 62 | bypass_output = None 63 | 64 | if input_ids is not None: 65 | input_shape = input_ids.size() 66 | input_ids = input_ids.view(-1, input_shape[-1]) 67 | hidden_states, _ = self.embeddings(input_ids=input_ids, position_ids=position_ids) 68 | 69 | elif batch is not None: 70 | input_shape = batch.input_ids.size() 71 | batch.input_ids = batch.input_ids.view(-1, input_shape[-1]) 72 | hidden_states, bypass_output = self.embeddings(batch=batch, position_ids=position_ids) 73 | 74 | else: 75 | raise ValueError("You have to specify either batch or input_ids!") 76 | 77 | bsz, seq_len = input_shape 78 | # CLIP's text model uses causal mask, prepare it here. 79 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 80 | causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( 81 | hidden_states.device 82 | ) 83 | 84 | # expand attention_mask 85 | if attention_mask is not None: 86 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 87 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype) 88 | 89 | output_attentions = True 90 | 91 | encoder_outputs = self.encoder( 92 | inputs_embeds=hidden_states, 93 | attention_mask=attention_mask, 94 | causal_attention_mask=causal_attention_mask, 95 | output_attentions=output_attentions, 96 | output_hidden_states=output_hidden_states, 97 | return_dict=return_dict, 98 | ) 99 | 100 | last_hidden_state = encoder_outputs[0] 101 | last_hidden_state_with_bypass = last_hidden_state.clone() 102 | 103 | if bypass_output is not None: 104 | for i in range(len(batch.placeholder_token_id)): 105 | learnable_idxs = (batch.input_ids == batch.placeholder_token_id[i]).nonzero(as_tuple=True)[1] 106 | existing_state = last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs] 107 | 108 | new_state = F.normalize(bypass_output[i] + existing_state, dim=-1) * existing_state.norm(dim=1, keepdim=True) 109 | new_state = new_state.to(dtype=hidden_states.dtype) 110 | last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs] = new_state 111 | 112 | last_hidden_state = self.final_layer_norm(last_hidden_state) 113 | last_hidden_state_with_bypass = self.final_layer_norm(last_hidden_state_with_bypass) 114 | 115 | if input_ids is not None: 116 | pooled_output = last_hidden_state[ 117 | torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) 118 | ] 119 | pooled_output_with_bypass = last_hidden_state_with_bypass[ 120 | torch.arange(last_hidden_state_with_bypass.shape[0]), input_ids.to(torch.int).argmax(dim=-1) 121 | ] 122 | elif batch is not None: 123 | pooled_output = last_hidden_state[ 124 | torch.arange(last_hidden_state.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1) 125 | ] 126 | pooled_output_with_bypass = last_hidden_state_with_bypass[ 127 | torch.arange(last_hidden_state_with_bypass.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1) 128 | ] 129 | else: 130 | raise ValueError("You have to specify either batch or input_ids!") 131 | 132 | if bypass_output is not None: 133 | return BaseModelOutputWithPooling( 134 | last_hidden_state=last_hidden_state, 135 | pooler_output=pooled_output, 136 | hidden_states=encoder_outputs.hidden_states, 137 | attentions=encoder_outputs.attentions, 138 | ), BaseModelOutputWithPooling( 139 | last_hidden_state=last_hidden_state_with_bypass, 140 | pooler_output=pooled_output_with_bypass, 141 | hidden_states=encoder_outputs.hidden_states, 142 | attentions=encoder_outputs.attentions, 143 | ) 144 | else: 145 | return BaseModelOutputWithPooling( 146 | last_hidden_state=last_hidden_state, 147 | pooler_output=pooled_output, 148 | hidden_states=encoder_outputs.hidden_states, 149 | attentions=encoder_outputs.attentions, 150 | ), None 151 | -------------------------------------------------------------------------------- /models/mapper.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from models.positional_encoding import BasicEncoder, TimePositionalEncoding 9 | from utils.types import Mapper_input 10 | 11 | class Mapper(nn.Module): 12 | def __init__(self, output_dim: int = 768, 13 | norm_scale: Optional[torch.Tensor] = None, 14 | use_positional_encoding: bool = True, 15 | num_pe_time_anchors: int = 200, 16 | sigma_t: float = 1.0, 17 | output_bypass: bool = True, 18 | token_num = 1): 19 | super().__init__() 20 | self.norm_scale = norm_scale 21 | self.output_bypass = output_bypass 22 | self.token_num = token_num 23 | 24 | 25 | self.use_positional_encoding = use_positional_encoding 26 | if self.use_positional_encoding: 27 | self.encoder = TimePositionalEncoding(sigma_t=sigma_t).cuda() 28 | self.input_dim = num_pe_time_anchors 29 | else: 30 | self.encoder = BasicEncoder().cuda() 31 | self.input_dim = 2 32 | 33 | self.input_layer = self.set_input_layer(num_time_anchors=num_pe_time_anchors) 34 | 35 | self.timestep_proj = nn.Sequential( 36 | nn.Linear(self.input_dim, 128), 37 | nn.LayerNorm(128), 38 | nn.LeakyReLU(), 39 | nn.Linear(128, 128), 40 | nn.LayerNorm(128), 41 | nn.LeakyReLU()) 42 | if self.output_bypass: 43 | self.image_proj = nn.Sequential( 44 | nn.Linear(512 + 128, 128), 45 | nn.LayerNorm(128), 46 | nn.LeakyReLU(), 47 | nn.Linear(128, 128), 48 | nn.LayerNorm(128), 49 | nn.LeakyReLU()) 50 | self.image_output_layer = nn.Sequential(nn.Linear(128, token_num * output_dim)) 51 | 52 | self.output_layer = nn.Sequential(nn.Linear(128, token_num * output_dim)) 53 | 54 | 55 | def set_input_layer(self, num_time_anchors: int) -> nn.Module: 56 | if self.use_positional_encoding: 57 | input_layer = nn.Linear(self.encoder.num_w, self.input_dim) 58 | input_layer.weight.data = self.encoder.init_layer(num_time_anchors) 59 | else: 60 | input_layer = nn.Identity() 61 | return input_layer 62 | 63 | def get_time_embedding(self, timestep: torch.Tensor) -> torch.Tensor: 64 | time_embedding = self.encoder.encode(timestep) 65 | time_embedding = self.input_layer(time_embedding) 66 | return time_embedding 67 | 68 | def forward(self, input: Mapper_input) -> torch.Tensor: 69 | timestep = input.timesteps.float() 70 | word_embedding = input.word_embedding 71 | image_embedding = input.image_embedding 72 | time_embedding = self.get_time_embedding(timestep) 73 | embedding = self.timestep_proj(time_embedding) 74 | 75 | if self.output_bypass: 76 | bypass = torch.cat([embedding, image_embedding],dim=-1) 77 | bypass = self.image_proj(bypass) 78 | bypass = self.image_output_layer(bypass) 79 | if self.training and random.random() < 0.5: 80 | for idx in torch.arange(bypass.shape[0]): 81 | bypass[idx][0:] = 0 82 | bypass = bypass.view(-1,768) 83 | 84 | embedding = self.output_layer(embedding) 85 | embedding = embedding.view(-1,768) 86 | embedding = F.normalize(embedding + word_embedding, dim=-1) * self.norm_scale 87 | if self.output_bypass: 88 | embedding = torch.cat([embedding, bypass],dim=-1) 89 | return embedding 90 | -------------------------------------------------------------------------------- /models/positional_encoding.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import math 3 | import torch 4 | from torch import nn 5 | 6 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 7 | """ 8 | Create sinusoidal timestep embeddings. 9 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 10 | These may be fractional. 11 | :param dim: the dimension of the output. 12 | :param max_period: controls the minimum frequency of the embeddings. 13 | :return: an [N x dim] Tensor of positional embeddings. 14 | """ 15 | if not repeat_only: 16 | half = dim // 2 17 | freqs = torch.exp( 18 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 19 | ).to(device=timesteps.device) 20 | args = timesteps[:, None].float() * freqs[None] 21 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 22 | if dim % 2: 23 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 24 | else: 25 | embedding = repeat(timesteps, 'b -> b d', d=dim) 26 | return embedding 27 | 28 | class BasicEncoder(nn.Module): 29 | """ Simply normalizes the given timestep and unet layer to be between -1 and 1. """ 30 | 31 | def __init__(self, num_denoising_timesteps: int = 1000, num_unet_layers: int = 16): 32 | super().__init__() 33 | self.normalized_timesteps = (torch.arange(num_denoising_timesteps) / (num_denoising_timesteps - 1)) * 2 - 1 34 | self.normalized_unet_layers = (torch.arange(num_unet_layers) / (num_unet_layers - 1)) * 2 - 1 35 | self.normalized_timesteps = nn.Parameter(self.normalized_timesteps).cuda() 36 | self.normalized_unet_layers = nn.Parameter(self.normalized_unet_layers).cuda() 37 | 38 | def encode(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor: 39 | normalized_input = torch.stack([self.normalized_timesteps[timestep.long()], 40 | self.normalized_unet_layers[unet_layer.long()]]).T 41 | return normalized_input 42 | 43 | 44 | class TimePositionalEncoding(nn.Module): 45 | 46 | def __init__(self, sigma_t: float, num_w: int = 1024): 47 | super().__init__() 48 | self.sigma_t = sigma_t 49 | self.num_w = num_w 50 | self.w = torch.randn((num_w, 1)) 51 | self.w[:, 0] *= sigma_t 52 | self.w = nn.Parameter(self.w).cuda() 53 | 54 | def encode(self, t: torch.Tensor): 55 | """ Maps the given time and layer input into a 2048-dimensional vector. """ 56 | if type(t) == int or t.ndim == 0: 57 | x = torch.tensor([t]).float() 58 | else: 59 | x = t.unsqueeze(0) 60 | x = x.cuda() 61 | v_norm = timestep_embedding(x, 2048).squeeze(0) 62 | return v_norm 63 | 64 | def init_layer(self, num_time_anchors: int) -> torch.Tensor: 65 | """ Computes the weights for the positional encoding layer of size 200x2048.""" 66 | anchor_vectors = [] 67 | for t_anchor in range(400, 800, 400 // num_time_anchors): 68 | anchor_vectors.append(self.encode(t_anchor).float()) 69 | A = torch.stack(anchor_vectors) 70 | return A -------------------------------------------------------------------------------- /models/xti_attention_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | from diffusers.models.cross_attention import CrossAttention 4 | import numpy as np 5 | from PIL import Image 6 | 7 | class XTIAttenProc: 8 | 9 | def __call__(self, attn: CrossAttention, 10 | hidden_states: torch.Tensor, 11 | encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None, 12 | attention_mask: Optional[torch.Tensor] = None): 13 | _ehs_bypass = None 14 | if encoder_hidden_states is not None: 15 | if isinstance(encoder_hidden_states, dict): 16 | this_idx = encoder_hidden_states["this_idx"] 17 | _ehs = encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"] 18 | if f"CONTEXT_TENSOR_BYPASS_{this_idx}" in encoder_hidden_states: 19 | _ehs_bypass = encoder_hidden_states[f"CONTEXT_TENSOR_BYPASS_{this_idx}"] 20 | encoder_hidden_states["this_idx"] += 1 21 | encoder_hidden_states["this_idx"] %= 16 22 | else: 23 | _ehs = encoder_hidden_states 24 | else: 25 | _ehs = None 26 | 27 | batch_size, sequence_length, _ = (hidden_states.shape if _ehs is None else _ehs.shape) 28 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 29 | query = attn.to_q(hidden_states) 30 | 31 | if _ehs is None: 32 | _ehs = hidden_states 33 | elif attn.cross_attention_norm: 34 | _ehs = attn.norm_cross(_ehs) 35 | _ehs_bypass = attn.norm_cross(_ehs_bypass) 36 | 37 | key = attn.to_k(_ehs) 38 | if _ehs_bypass is not None: 39 | value = attn.to_v(_ehs_bypass) 40 | else: 41 | value = attn.to_v(_ehs) 42 | 43 | query = attn.head_to_batch_dim(query) 44 | key = attn.head_to_batch_dim(key) 45 | value = attn.head_to_batch_dim(value) 46 | 47 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 48 | hidden_states = torch.bmm(attention_probs, value) 49 | hidden_states = attn.batch_to_head_dim(hidden_states) 50 | 51 | # linear proj 52 | hidden_states = attn.to_out[0](hidden_states) 53 | # dropout 54 | hidden_states = attn.to_out[1](hidden_states) 55 | 56 | return hidden_states 57 | 58 | 59 | class MyAttenProc: 60 | def __call__(self, attn: CrossAttention, 61 | hidden_states: torch.Tensor, 62 | encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None, 63 | attention_mask: Optional[torch.Tensor] = None): 64 | 65 | _ehs_bypass = None 66 | if encoder_hidden_states is not None: 67 | if isinstance(encoder_hidden_states, dict): 68 | _ehs = encoder_hidden_states[f"CONTEXT_TENSOR"] 69 | if f"CONTEXT_TENSOR_BYPASS" in encoder_hidden_states: 70 | _ehs_bypass = encoder_hidden_states[f"CONTEXT_TENSOR_BYPASS"] 71 | else: 72 | _ehs = encoder_hidden_states 73 | else: 74 | _ehs = None 75 | 76 | batch_size, sequence_length, _ = (hidden_states.shape if _ehs is None else _ehs.shape) 77 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 78 | query = attn.to_q(hidden_states) 79 | 80 | if _ehs is None: 81 | _ehs = hidden_states 82 | elif attn.cross_attention_norm: 83 | _ehs = attn.norm_cross(_ehs) 84 | _ehs_bypass = attn.norm_cross(_ehs_bypass) 85 | 86 | key = attn.to_k(_ehs) 87 | if _ehs_bypass is not None: 88 | value = attn.to_v(_ehs_bypass) 89 | else: 90 | value = attn.to_v(_ehs) 91 | 92 | query = attn.head_to_batch_dim(query) 93 | key = attn.head_to_batch_dim(key) 94 | value = attn.head_to_batch_dim(value) 95 | 96 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 97 | hidden_states = torch.bmm(attention_probs, value) 98 | hidden_states = attn.batch_to_head_dim(hidden_states) 99 | 100 | # linear proj 101 | hidden_states = attn.to_out[0](hidden_states) 102 | # dropout 103 | hidden_states = attn.to_out[1](hidden_states) 104 | 105 | return hidden_states -------------------------------------------------------------------------------- /prompt_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Dict, Any 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from transformers import CLIPTokenizer 6 | 7 | import constants 8 | from models.clip_text_encoder import PersonaCLIPTextModel 9 | from utils.types import Mapper_input, Batch 10 | 11 | 12 | class PromptManager: 13 | """ Class for computing all time and space embeddings for a given prompt. """ 14 | def __init__(self, tokenizer: CLIPTokenizer, 15 | text_encoder: PersonaCLIPTextModel, 16 | timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS, 17 | unet_layers: List[str] = constants.UNET_LAYERS, 18 | placeholder_token_id: Optional[List] = None, 19 | placeholder_token: Optional[List] = None, 20 | torch_dtype: torch.dtype = torch.float32): 21 | self.tokenizer = tokenizer 22 | self.text_encoder = text_encoder 23 | self.timesteps = timesteps 24 | self.unet_layers = unet_layers 25 | self.placeholder_token = placeholder_token 26 | self.placeholder_token_id = placeholder_token_id 27 | self.dtype = torch_dtype 28 | 29 | def my_embed_prompt(self, text: str, 30 | word_embedding: torch.Tensor, 31 | image_embedding: torch.Tensor, 32 | num_images_per_prompt: int = 1, 33 | super_category_token: str = 'face') -> List[Dict[str, Any]]: 34 | 35 | constant_text = text.format(super_category_token) 36 | constant_ids = self.tokenizer( 37 | constant_text, 38 | padding="max_length", 39 | max_length=self.tokenizer.model_max_length, 40 | return_tensors="pt", 41 | ).input_ids 42 | 43 | dynamic_text = text.format(' '.join(self.placeholder_token)) 44 | dynamic_ids = self.tokenizer( 45 | dynamic_text, 46 | padding="max_length", 47 | max_length=self.tokenizer.model_max_length, 48 | return_tensors="pt", 49 | ).input_ids 50 | 51 | # print(dynamic_text, dynamic_ids) 52 | # Compute embeddings for each timestep and each U-Net layer 53 | print(f"Computing embeddings over {len(self.timesteps)} timesteps.") 54 | 55 | hidden_states_per_timestep = [] 56 | for timestep in tqdm(self.timesteps): 57 | _hs = {}.copy() 58 | if timestep > 800: 59 | ids = constant_ids 60 | elif 800 >= timestep >=400: 61 | ids = dynamic_ids 62 | else: 63 | _hs = hidden_states_per_timestep[-1] 64 | hidden_states_per_timestep.append(_hs) 65 | continue 66 | 67 | mapper_input = Mapper_input(timesteps=timestep.unsqueeze(0), 68 | word_embedding=word_embedding.unsqueeze(0), 69 | image_embedding=image_embedding.unsqueeze(0)) 70 | batch = Batch( 71 | input_ids=ids.to(device=self.text_encoder.device), 72 | placeholder_token_id=self.placeholder_token_id, 73 | mapper_input=mapper_input) 74 | layer_hidden_state, layer_hidden_state_bypass = self.text_encoder(batch=batch) 75 | layer_hidden_state = layer_hidden_state[0].to(dtype=self.dtype) 76 | _hs[f"CONTEXT_TENSOR"] = layer_hidden_state.repeat(num_images_per_prompt, 1, 1) 77 | if layer_hidden_state_bypass is not None: 78 | layer_hidden_state_bypass = layer_hidden_state_bypass[0].to(dtype=self.dtype) 79 | _hs[f"CONTEXT_TENSOR_BYPASS"] = layer_hidden_state_bypass.repeat(num_images_per_prompt, 1, 1) 80 | # _hs['timestep'] = timestep 81 | hidden_states_per_timestep.append(_hs) 82 | return hidden_states_per_timestep -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass, field 3 | from pathlib import Path 4 | from typing import Optional, List, Tuple, Union 5 | 6 | import numpy as np 7 | import pyrallis 8 | import torch 9 | import PIL 10 | from PIL import Image 11 | from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline 12 | from transformers import CLIPTokenizer 13 | 14 | sys.path.append(".") 15 | sys.path.append("..") 16 | 17 | import constants 18 | from models.clip_text_encoder import PersonaCLIPTextModel 19 | from models.mapper import Mapper 20 | from prompt_manager import PromptManager 21 | from sd_pipeline_call import sd_pipeline_call 22 | from models.xti_attention_processor import XTIAttenProc, MyAttenProc 23 | from checkpoint_handler import CheckpointHandler 24 | from utils import vis_utils 25 | from models.clip_prior import MultiCLIP 26 | 27 | import time 28 | @dataclass 29 | class InferenceConfig: 30 | # Specifies which checkpoint iteration we want to load 31 | iteration: Optional[int] = None 32 | # The input directory containing the saved models and embeddings 33 | input_dir: Optional[Path] = None 34 | # Where the save the inference results to 35 | inference_dir: Optional[Path] = None 36 | # Specific path to the mapper you want to load, overrides `input_dir` 37 | mapper_checkpoint_path: Optional[Path] = None 38 | # Specific path to the embeddings you want to load, overrides `input_dir` 39 | learned_embeds_path: Optional[Path] = None 40 | # List of prompts to run inference on 41 | prompts: Optional[List[str]] = None 42 | # Text file containing a prompts to run inference on (one prompt per line), overrides `prompts` 43 | prompts_file_path: Optional[Path] = None 44 | # List of random seeds to run on 45 | seeds: List[int] = field(default_factory=lambda: [42]) 46 | # If you want to run with dropout at inference time, this specifies the truncation indices for applying dropout. 47 | # None indicates that no dropout will be performed. If a list of indices is provided, will run all indices. 48 | truncation_idxs: Optional[Union[int, List[int]]] = None 49 | # Whether to run with torch.float16 or torch.float32 50 | torch_dtype: str = "fp16" 51 | clip_ckpt_path: Optional[Path] = "/path/to/clip/ckpt/" 52 | super_category_token: str = "face" 53 | image_path: Optional[str] = None 54 | 55 | def __post_init__(self): 56 | assert bool(self.prompts) != bool(self.prompts_file_path), \ 57 | "You must provide either prompts or prompts_file_path, but not both!" 58 | self._set_prompts() 59 | self._set_input_paths() 60 | self.inference_dir.mkdir(exist_ok=True, parents=True) 61 | if type(self.truncation_idxs) == int: 62 | self.truncation_idxs = [self.truncation_idxs] 63 | self.torch_dtype = torch.float16 if self.torch_dtype == "fp16" else torch.float32 64 | 65 | def _set_input_paths(self): 66 | if self.inference_dir is None: 67 | assert self.input_dir is not None, "You must pass an input_dir if you do not specify inference_dir" 68 | self.inference_dir = self.input_dir / f"inference_{self.iteration}" 69 | if self.mapper_checkpoint_path is None: 70 | assert self.input_dir is not None, "You must pass an input_dir if you do not specify mapper_checkpoint_path" 71 | self.mapper_checkpoint_path = self.input_dir / f"mapper-steps-{self.iteration}.pt" 72 | if self.learned_embeds_path is None: 73 | assert self.input_dir is not None, "You must pass an input_dir if you do not specify learned_embeds_path" 74 | self.learned_embeds_path = self.input_dir / f"learned_embeds-steps-{self.iteration}.bin" 75 | 76 | def _set_prompts(self): 77 | if self.prompts_file_path is not None: 78 | assert self.prompts_file_path.exists(), f"Prompts file {self.prompts_file_path} does not exist!" 79 | self.prompts = self.prompts_file_path.read_text().splitlines() 80 | 81 | 82 | @pyrallis.wrap() 83 | def main(infer_cfg: InferenceConfig): 84 | train_cfg, mapper = CheckpointHandler.load_my_mapper(infer_cfg.mapper_checkpoint_path) 85 | pipeline, placeholder_token, placeholder_token_id = load_stable_diffusion_model( 86 | pretrained_model_name_or_path=train_cfg.model.pretrained_model_name_or_path, 87 | mapper=mapper, 88 | learned_embeds_path=infer_cfg.learned_embeds_path, 89 | torch_dtype=infer_cfg.torch_dtype 90 | ) 91 | clip = MultiCLIP(clip_ckpt_path=infer_cfg.clip_ckpt_path).to(pipeline.device, dtype=infer_cfg.torch_dtype) 92 | prompt_manager = PromptManager(tokenizer=pipeline.tokenizer, 93 | text_encoder=pipeline.text_encoder, 94 | timesteps=pipeline.scheduler.timesteps, 95 | unet_layers=constants.UNET_LAYERS, 96 | placeholder_token=placeholder_token, 97 | placeholder_token_id=placeholder_token_id, 98 | torch_dtype=infer_cfg.torch_dtype) 99 | 100 | with torch.autocast("cuda"): 101 | with torch.no_grad(): 102 | token_embeds = pipeline.text_encoder.get_input_embeddings().weight.data 103 | super_category_token_id = pipeline.tokenizer.encode(infer_cfg.super_category_token, add_special_tokens=False)[0] 104 | word_embedding = token_embeds[super_category_token_id].clone().detach() 105 | image = read_image(infer_cfg.image_path).unsqueeze(0).to(pipeline.device) 106 | image_embedding = clip.encode_image(image=image, dtype=infer_cfg.torch_dtype).detach()[0] 107 | 108 | 109 | for prompt in infer_cfg.prompts: 110 | output_path = infer_cfg.inference_dir 111 | output_path.mkdir(exist_ok=True, parents=True) 112 | prompt_image = run_inference(prompt=prompt, 113 | pipeline=pipeline, 114 | prompt_manager=prompt_manager, 115 | word_embedding=word_embedding, 116 | image_embedding=image_embedding, 117 | seeds=infer_cfg.seeds, 118 | output_path=output_path, 119 | num_images_per_prompt=1, 120 | super_category_token=infer_cfg.super_category_token) 121 | 122 | 123 | def run_inference(prompt: str, 124 | pipeline: StableDiffusionPipeline, 125 | prompt_manager: PromptManager, 126 | word_embedding: torch.Tensor, 127 | image_embedding: torch.Tensor, 128 | seeds: List[int], 129 | output_path: Optional[Path] = None, 130 | num_images_per_prompt: int = 1, 131 | super_category_token: str = 'face') -> Image.Image: 132 | with torch.autocast("cuda"): 133 | with torch.no_grad(): 134 | prompt_embeds = prompt_manager.my_embed_prompt(prompt, 135 | word_embedding=word_embedding, 136 | image_embedding=image_embedding, 137 | num_images_per_prompt=num_images_per_prompt, 138 | super_category_token=super_category_token) 139 | joined_images = [] 140 | 141 | for seed in seeds: 142 | generator = torch.Generator(device='cuda').manual_seed(seed) 143 | images = sd_pipeline_call(pipeline, 144 | prompt_embeds=prompt_embeds, 145 | generator=generator, 146 | num_images_per_prompt=num_images_per_prompt).images 147 | seed_image = Image.fromarray(np.concatenate(images, axis=1)).convert("RGB") 148 | if output_path is not None: 149 | tmp_prompt = prompt.format(super_category_token).replace(' ', '_') 150 | image_name = output_path.name 151 | save_name = f'{tmp_prompt}_{image_name}_{seed}.jpg' 152 | seed_image.save(output_path / save_name) 153 | joined_images.append(seed_image) 154 | joined_image = vis_utils.get_image_grid(joined_images) 155 | 156 | 157 | return joined_image 158 | 159 | 160 | def load_stable_diffusion_model(pretrained_model_name_or_path: str, 161 | learned_embeds_path: Path, 162 | mapper: Optional[Mapper] = None, 163 | num_denoising_steps: int = 50, 164 | torch_dtype: torch.dtype = torch.float16) -> Tuple[StableDiffusionPipeline, str, int]: 165 | tokenizer = CLIPTokenizer.from_pretrained( 166 | pretrained_model_name_or_path, subfolder="tokenizer") 167 | text_encoder = PersonaCLIPTextModel.from_pretrained( 168 | pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, 169 | ) 170 | if mapper is not None: 171 | text_encoder.text_model.embeddings.set_mapper(mapper) 172 | placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip( 173 | learned_embeds_path=learned_embeds_path, 174 | text_encoder=text_encoder, 175 | tokenizer=tokenizer 176 | ) 177 | pipeline = StableDiffusionPipeline.from_pretrained( 178 | pretrained_model_name_or_path, 179 | torch_dtype=torch_dtype, 180 | text_encoder=text_encoder, 181 | tokenizer=tokenizer 182 | ).to("cuda") 183 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 184 | pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device) 185 | # pipeline.unet.set_attn_processor(XTIAttenProc()) 186 | pipeline.unet.set_attn_processor(MyAttenProc()) 187 | return pipeline, placeholder_token, placeholder_token_id 188 | 189 | def read_image(image_path): 190 | image = Image.open(image_path) 191 | if not image.mode == "RGB": 192 | image = image.convert("RGB") 193 | img = np.array(image).astype(np.uint8) 194 | 195 | image = Image.fromarray(img) 196 | image = image.resize((512, 512), resample=PIL.Image.BICUBIC) 197 | 198 | image = np.array(image).astype(np.uint8) 199 | image = (image / 127.5 - 1.0).astype(np.float32) 200 | 201 | image = torch.from_numpy(image).permute(2, 0, 1) 202 | return image 203 | 204 | if __name__ == '__main__': 205 | main() 206 | -------------------------------------------------------------------------------- /scripts/seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | import argparse 5 | 6 | import numpy as np 7 | import torch 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | from PIL import Image 11 | 12 | from clipseg.models.clipseg import CLIPDensePredT 13 | 14 | PROMPT_TEMPLATE = { 15 | 'a photo of a {}': 4, 16 | 'a good photo of a {}': 5, 17 | 'the photo of a {}': 4, 18 | 'a good photo of the {}': 5, 19 | 'image of a {}': 3, 20 | 'image of the {}': 3, 21 | 'A photograph of {}': 3, 22 | 'A {} shown in a photo': 1, 23 | 'A photo of {}': 3, 24 | } 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--image_dir', type=str, required=True) 29 | parser.add_argument('--super_class', type=str, required=True) 30 | parser.add_argument('--model_weight', type=str, default='./clipseg/weights/rd64-uni-refined.pth') 31 | parser.add_argument('--device', type=str, default='cuda') 32 | args = parser.parse_args() 33 | 34 | image_dir = args.image_dir 35 | super_class = args.super_class 36 | model_weight_path = args.model_weight 37 | device = args.device 38 | 39 | files = [f for f in os.listdir(image_dir) if f.split('.')[-1].lower() in ['png', 'jpg', 'jpeg']] 40 | image_paths = [os.path.join(image_dir, f) for f in files] 41 | mask_save_paths = [os.path.join(image_dir, 'mask', f) for f in files] 42 | 43 | os.makedirs(os.path.join(image_dir, 'mask')) 44 | 45 | prompts = [prompt.format(super_class) for prompt in PROMPT_TEMPLATE] 46 | 47 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True) 48 | model.load_state_dict(torch.load(model_weight_path, map_location='cpu'), strict=False) 49 | model = model.eval().to(device) 50 | transform = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 53 | transforms.Resize((352, 352)), 54 | ]) 55 | 56 | with torch.no_grad(): 57 | for img_p, save_p in tqdm(zip(image_paths, mask_save_paths)): 58 | img = Image.open(img_p).convert('RGB') 59 | img = transform(img).unsqueeze(0).to(device) 60 | preds = model(img.repeat(len(prompts), 1, 1, 1), prompts)[0] 61 | preds = torch.sigmoid(preds) 62 | pred = preds.mean(dim=0)[0] 63 | pred = Image.fromarray((pred.cpu().numpy() * 255).astype(np.uint8)) 64 | pred.save(save_p) -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import pyrallis 4 | from diffusers.utils import check_min_version 5 | 6 | sys.path.append(".") 7 | sys.path.append("..") 8 | 9 | from training.coach_2 import Coach 10 | from training.config import RunConfig 11 | 12 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 13 | check_min_version("0.14.0") 14 | 15 | 16 | @pyrallis.wrap() 17 | def main(cfg: RunConfig): 18 | prepare_directories(cfg=cfg) 19 | coach = Coach(cfg) 20 | coach.train() 21 | 22 | 23 | def prepare_directories(cfg: RunConfig): 24 | cfg.log.exp_dir.mkdir(parents=True, exist_ok=True) 25 | cfg.log.logging_dir = cfg.log.exp_dir / cfg.log.logging_dir 26 | cfg.log.logging_dir.mkdir(parents=True, exist_ok=True) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /sd_pipeline_call.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | import tqdm 3 | import torch 4 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline 5 | 6 | @torch.no_grad() 7 | def sd_pipeline_call( 8 | pipeline: StableDiffusionPipeline, 9 | prompt_embeds: torch.FloatTensor, 10 | height: Optional[int] = None, 11 | width: Optional[int] = None, 12 | num_inference_steps: int = 50, 13 | guidance_scale: float = 8.0, 14 | negative_prompt: Optional[Union[str, List[str]]] = None, 15 | num_images_per_prompt: Optional[int] = 1, 16 | eta: float = 0.0, 17 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 18 | latents: Optional[torch.FloatTensor] = None, 19 | output_type: Optional[str] = "pil", 20 | return_dict: bool = True, 21 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 22 | callback_steps: int = 1, 23 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 24 | noise_latents: Optional[torch.FloatTensor] = None): 25 | """ Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument.""" 26 | 27 | # 0. Default height and width to unet 28 | height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor 29 | width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor 30 | 31 | # 2. Define call parameters 32 | batch_size = 1 33 | device = pipeline._execution_device 34 | 35 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 36 | neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt) 37 | negative_prompt_embeds, _ = pipeline.text_encoder( 38 | input_ids=neg_prompt.input_ids.to(device), 39 | attention_mask=None, 40 | ) 41 | negative_prompt_embeds = negative_prompt_embeds[0] 42 | 43 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 44 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 45 | # corresponds to doing no classifier free guidance. 46 | do_classifier_free_guidance = guidance_scale > 1.0 47 | 48 | # 4. Prepare timesteps 49 | pipeline.scheduler.set_timesteps(num_inference_steps, device=device) 50 | timesteps = pipeline.scheduler.timesteps 51 | 52 | # 5. Prepare latent variables 53 | if noise_latents is not None: 54 | latents = noise_latents 55 | else: 56 | num_channels_latents = pipeline.unet.in_channels 57 | latents = pipeline.prepare_latents( 58 | batch_size * num_images_per_prompt, 59 | num_channels_latents, 60 | height, 61 | width, 62 | pipeline.text_encoder.dtype, 63 | device, 64 | generator, 65 | latents, 66 | ) 67 | 68 | # 6. Prepare extra step kwargs. 69 | extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) 70 | 71 | # 7. Denoising loop 72 | num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order 73 | with pipeline.progress_bar(total=num_inference_steps) as progress_bar: 74 | for i, t in enumerate(timesteps): 75 | 76 | if do_classifier_free_guidance: 77 | latent_model_input = latents 78 | latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) 79 | 80 | # predict the noise residual 81 | noise_pred_uncond = pipeline.unet( 82 | latent_model_input, 83 | t, 84 | encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1), 85 | cross_attention_kwargs=cross_attention_kwargs, 86 | ).sample 87 | 88 | embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds 89 | 90 | noise_pred_text = pipeline.unet( 91 | latent_model_input, 92 | t, 93 | encoder_hidden_states=embed, 94 | cross_attention_kwargs=cross_attention_kwargs, 95 | ).sample 96 | 97 | # perform guidance 98 | if do_classifier_free_guidance: 99 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 100 | 101 | # compute the previous noisy sample x_t -> x_t-1 102 | latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 103 | 104 | # call the callback, if provided 105 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): 106 | progress_bar.update() 107 | if callback is not None and i % callback_steps == 0: 108 | callback(i, t, latents) 109 | 110 | if output_type == "latent": 111 | image = latents 112 | has_nsfw_concept = None 113 | elif output_type == "pil": 114 | # 8. Post-processing 115 | image = pipeline.decode_latents(latents) 116 | # 9. Run safety checker 117 | # image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype) 118 | has_nsfw_concept = None 119 | # 10. Convert to PIL 120 | image = pipeline.numpy_to_pil(image) 121 | else: 122 | # 8. Post-processing 123 | image = pipeline.decode_latents(latents) 124 | # 9. Run safety checker 125 | image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype) 126 | 127 | 128 | # Offload last model to CPU 129 | if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None: 130 | pipeline.final_offload_hook.offload() 131 | 132 | if not return_dict: 133 | return image, has_nsfw_concept 134 | 135 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 136 | 137 | 138 | def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline, 139 | negative_prompt: Optional[Union[str, List[str]]] = None): 140 | if negative_prompt is None: 141 | negative_prompt = "" 142 | uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 143 | uncond_input = pipeline.tokenizer( 144 | uncond_tokens, 145 | padding="max_length", 146 | max_length=pipeline.tokenizer.model_max_length, 147 | truncation=True, 148 | return_tensors="pt", 149 | ) 150 | return uncond_input 151 | 152 | @torch.no_grad() 153 | def invert( 154 | pipeline: StableDiffusionPipeline, 155 | prompt_embeds: torch.FloatTensor, 156 | start_latents: torch.FloatTensor, 157 | num_inference_steps: int = 50, 158 | guidance_scale: float = 8.0, 159 | negative_prompt: Optional[Union[str, List[str]]] = None, 160 | num_images_per_prompt: Optional[int] = 1): 161 | 162 | batch_size = 1 163 | device = pipeline._execution_device 164 | 165 | neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt) 166 | negative_prompt_embeds, _ = pipeline.text_encoder( 167 | input_ids=neg_prompt.input_ids.to(device), 168 | attention_mask=None, 169 | ) 170 | negative_prompt_embeds = negative_prompt_embeds[0] 171 | 172 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 173 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 174 | # corresponds to doing no classifier free guidance. 175 | do_classifier_free_guidance = guidance_scale > 1.0 176 | 177 | pipeline.scheduler.set_timesteps(num_inference_steps, device=device) 178 | timesteps = reversed(pipeline.scheduler.timesteps) 179 | 180 | latents = start_latents.clone() 181 | intermediate_latents = [] 182 | 183 | for i, t in enumerate(timesteps): 184 | if i >= num_inference_steps - 1: continue 185 | 186 | t = timesteps[i] 187 | 188 | # predict the noise residual 189 | latent_model_input = latents 190 | latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) 191 | 192 | embed = prompt_embeds[0] if type(prompt_embeds) == list else prompt_embeds 193 | 194 | noise_pred_text = pipeline.unet( 195 | latent_model_input, 196 | t, 197 | encoder_hidden_states=embed, 198 | ).sample 199 | 200 | if do_classifier_free_guidance: 201 | noise_pred_uncond = pipeline.unet( 202 | latent_model_input, 203 | t, 204 | encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1), 205 | ).sample 206 | 207 | noise_pred_text = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 208 | 209 | current_t = max(0, t.item() - (1000//num_inference_steps))#t 210 | next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1 211 | alpha_t = pipeline.scheduler.alphas_cumprod[current_t] 212 | alpha_t_next = pipeline.scheduler.alphas_cumprod[next_t] 213 | 214 | latents = (latents - (1-alpha_t).sqrt()*noise_pred_text)*(alpha_t_next.sqrt()/alpha_t.sqrt()) + (1-alpha_t_next).sqrt()*noise_pred_text 215 | intermediate_latents.append(latents) 216 | 217 | return torch.cat(intermediate_latents) 218 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/training/__init__.py -------------------------------------------------------------------------------- /training/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, Dict 4 | 5 | from constants import VALIDATION_PROMPTS 6 | 7 | 8 | @dataclass 9 | class LogConfig: 10 | """ Parameters for logging and saving """ 11 | # Name of experiment. This will be the name of the output folder 12 | exp_name: str 13 | # The output directory where the model predictions and checkpoints will be written 14 | exp_dir: Path = Path("./logs") 15 | # Save interval 16 | save_steps: int = 250 17 | # [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to 18 | # `output_dir/runs/**CURRENT_DATETIME_HOSTNAME` 19 | logging_dir: Path = Path("logs") 20 | # The integration to report the results to. Supported platforms are "tensorboard" ' 21 | # (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 22 | report_to: str = "tensorboard" 23 | # Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` 24 | checkpoints_total_limit: Optional[int] = None 25 | 26 | 27 | @dataclass 28 | class DataConfig: 29 | """ Parameters for data """ 30 | # A folder containing the training data 31 | train_data_dir: Path 32 | # A token to use as a placeholder for the concept 33 | placeholder_token: str 34 | # Super category token to use for normalizing the mapper output 35 | super_category_token: Optional[str] = "object" 36 | # Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process 37 | dataloader_num_workers: int = 8 38 | # Choose between 'object' and 'style' - used for selecting the prompts for training 39 | learnable_property: str = "object" 40 | # How many times to repeat the training data 41 | repeats: int = 100 42 | # The resolution for input images, all the images in the train/validation dataset will be resized to this resolution 43 | resolution: int = 512 44 | # Whether to center crop images before resizing to resolution 45 | center_crop: bool = False 46 | 47 | 48 | @dataclass 49 | class ModelConfig: 50 | """ Parameters for defining all models """ 51 | # Path to pretrained model or model identifier from huggingface.co/models 52 | pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4" 53 | # Path to pretrained clip model 54 | clip_ckpt_path: str = "/path/to/clip/ckpt" 55 | # Whether to use our Nested Dropout technique 56 | use_nested_dropout: bool = True 57 | # Probability to apply nested dropout during training 58 | nested_dropout_prob: float = 0.5 59 | # Whether to normalize the norm of the mapper's output vector 60 | normalize_mapper_output: bool = True 61 | # Target norm for the mapper's output vector 62 | target_norm: Optional[float] = None 63 | # Whether to use positional encoding over the input to the mapper 64 | use_positional_encoding: bool = True 65 | # Sigmas used for computing positional encoding 66 | sigma_t: float = 1.0 67 | # Number of time anchors for computing our positional encodings 68 | num_pe_time_anchors: int = 10 69 | # Whether to output the textual bypass vector 70 | output_bypass: bool = True 71 | # Revision of pretrained model identifier from huggingface.co/models 72 | revision: Optional[str] = None 73 | # Whether training should be resumed from a previous checkpoint. 74 | mapper_checkpoint_path: Optional[Path] = None 75 | 76 | 77 | @dataclass 78 | class EvalConfig: 79 | """ Parameters for validation """ 80 | # A list of prompts that will be used during validation to verify that the model is learning 81 | validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS) 82 | # Number of images that should be generated during validation with `validation_prompt` 83 | num_validation_images: int = 4 84 | # Seeds to use for generating the validation images 85 | validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456]) 86 | # Run validation every X steps. 87 | validation_steps: int = 100 88 | # Number of denoising steps 89 | num_denoising_steps: int = 50 90 | 91 | def __post_init__(self): 92 | if self.validation_seeds is None: 93 | self.validation_seeds = list(range(self.num_validation_images)) 94 | assert len(self.validation_seeds) == self.num_validation_images, \ 95 | "Length of validation_seeds should equal num_validation_images" 96 | 97 | @dataclass 98 | class OptimConfig: 99 | """ Parameters for the optimization process """ 100 | # Total number of training steps to perform. 101 | max_train_steps: Optional[int] = 1_000 102 | # Learning rate 103 | learning_rate: float = 1e-3 104 | # Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size 105 | scale_lr: bool = True 106 | # Batch size (per device) for the training dataloader 107 | train_batch_size: int = 2 108 | # Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass 109 | gradient_checkpointing: bool = False 110 | # Number of updates steps to accumulate before performing a backward/update pass 111 | gradient_accumulation_steps: int = 4 112 | # A seed for reproducible training 113 | seed: Optional[int] = None 114 | # The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", 115 | # "constant", "constant_with_warmup"] 116 | lr_scheduler: str = "constant" 117 | # Number of steps for the warmup in the lr scheduler 118 | lr_warmup_steps: int = 0 119 | # The beta1 parameter for the Adam optimizer 120 | adam_beta1: float = 0.9 121 | # The beta2 parameter for the Adam optimizer 122 | adam_beta2: float = 0.999 123 | # Weight decay to use 124 | adam_weight_decay: float = 1e-2 125 | # Epsilon value for the Adam optimizer 126 | adam_epsilon: float = 1e-08 127 | # Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10. 128 | # and an Nvidia Ampere GPU. 129 | mixed_precision: str = "no" 130 | # Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see 131 | # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 132 | allow_tf32: bool = False 133 | 134 | 135 | @dataclass 136 | class RunConfig: 137 | """ The main configuration for the coach trainer """ 138 | log: LogConfig = field(default_factory=LogConfig) 139 | data: DataConfig = field(default_factory=DataConfig) 140 | model: ModelConfig = field(default_factory=ModelConfig) 141 | eval: EvalConfig = field(default_factory=EvalConfig) 142 | optim: OptimConfig = field(default_factory=OptimConfig) 143 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Dict, Any 4 | import PIL 5 | import numpy as np 6 | import torch 7 | import torch.utils.checkpoint 8 | from PIL import Image 9 | from packaging import version 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | from transformers import CLIPTokenizer 13 | from constants import IMAGENET_STYLE_TEMPLATES_SMALL, IMAGENET_TEMPLATES_SMALL, PROMPTS 14 | 15 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 16 | PIL_INTERPOLATION = { 17 | "linear": PIL.Image.Resampling.BILINEAR, 18 | "bilinear": PIL.Image.Resampling.BILINEAR, 19 | "bicubic": PIL.Image.Resampling.BICUBIC, 20 | "lanczos": PIL.Image.Resampling.LANCZOS, 21 | "nearest": PIL.Image.Resampling.NEAREST, 22 | } 23 | else: 24 | PIL_INTERPOLATION = { 25 | "linear": PIL.Image.LINEAR, 26 | "bilinear": PIL.Image.BILINEAR, 27 | "bicubic": PIL.Image.BICUBIC, 28 | "lanczos": PIL.Image.LANCZOS, 29 | "nearest": PIL.Image.NEAREST, 30 | } 31 | 32 | 33 | class TextualInversionDataset(Dataset): 34 | 35 | def __init__(self, data_root: Path, 36 | tokenizer: CLIPTokenizer, 37 | learnable_property: str = "object", # [object, style] 38 | size: int = 512, 39 | repeats: int = 100, 40 | interpolation: str = "bicubic", 41 | flip_p: float = 0.5, 42 | set: str = "train", 43 | super_category_token: str = "face", 44 | placeholder_token: list = [], 45 | center_crop: bool = False): 46 | self.data_root = data_root 47 | self.tokenizer = tokenizer 48 | self.learnable_property = learnable_property 49 | self.size = size 50 | self.super_category_token = super_category_token 51 | self.placeholder_token = placeholder_token 52 | self.center_crop = center_crop 53 | self.flip_p = flip_p 54 | 55 | if "." in str(self.data_root) and str(self.data_root).split(".")[-1].lower() in ["jpg", "jpeg", "png"]: 56 | self.image_paths = [self.data_root] 57 | else: 58 | self.image_paths = list(self.data_root.glob("*.*")) 59 | 60 | self.num_images = len(self.image_paths) 61 | self._length = self.num_images 62 | 63 | print(f"Running on {self.num_images} images") 64 | 65 | if set == "train": 66 | self._length = self.num_images * repeats 67 | 68 | self.interpolation = { 69 | "linear": PIL_INTERPOLATION["linear"], 70 | "bilinear": PIL_INTERPOLATION["bilinear"], 71 | "bicubic": PIL_INTERPOLATION["bicubic"], 72 | "lanczos": PIL_INTERPOLATION["lanczos"], 73 | }[interpolation] 74 | 75 | self.templates = IMAGENET_STYLE_TEMPLATES_SMALL if learnable_property == "style" else IMAGENET_TEMPLATES_SMALL 76 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 77 | 78 | self.templates_2 = PROMPTS 79 | 80 | def __len__(self) -> int: 81 | return self._length 82 | 83 | def __getitem__(self, i: int) -> Dict[str, Any]: 84 | image_path = self.image_paths[i % self.num_images] 85 | image = Image.open(image_path) 86 | 87 | if not image.mode == "RGB": 88 | image = image.convert("RGB") 89 | 90 | example = dict() 91 | 92 | image_name = image_path.name 93 | pre_path = image_path.parent.joinpath('mask') 94 | if pre_path.exists(): 95 | mask_path = pre_path.joinpath(image_name) 96 | org_mask = Image.open(mask_path).convert("L") 97 | mask = org_mask.resize((self.size//8, self.size//8), Image.NEAREST) 98 | mask = np.array(mask) / 255 99 | else: 100 | mask = np.ones((self.size//8, self.size//8)) 101 | mask = mask[np.newaxis, ...] 102 | mask = torch.from_numpy(mask) 103 | example['mask'] = mask 104 | 105 | text = random.choice(self.templates) 106 | example['text'] = text.format(' '.join(self.placeholder_token)) 107 | example["input_ids"] = self.tokenizer( 108 | example['text'], 109 | padding="max_length", 110 | truncation=True, 111 | max_length=self.tokenizer.model_max_length, 112 | return_tensors="pt", 113 | ).input_ids[0] 114 | 115 | example["input_ids_length"] = len(self.tokenizer(example['text'])['input_ids'])-2 116 | 117 | open_text = random.choice(self.templates_2).format(' '.join(self.placeholder_token)) 118 | example["open_input_ids"] = self.tokenizer( 119 | open_text, 120 | padding="max_length", 121 | truncation=True, 122 | max_length=self.tokenizer.model_max_length, 123 | return_tensors="pt", 124 | ).input_ids[0] 125 | example["open_input_ids_length"] = len(self.tokenizer(open_text)['input_ids'])-2 126 | 127 | # default to score-sde preprocessing 128 | img = np.array(image).astype(np.uint8) 129 | 130 | if self.center_crop: 131 | crop = min(img.shape[0], img.shape[1]) 132 | h, w = img.shape[0], img.shape[1] 133 | img = img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2] 134 | 135 | image = Image.fromarray(img) 136 | image = image.resize((self.size, self.size), resample=self.interpolation) 137 | 138 | # image = self.flip_transform(image) 139 | image = np.array(image).astype(np.uint8) 140 | image = (image / 127.5 - 1.0).astype(np.float32) 141 | 142 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 143 | return example 144 | -------------------------------------------------------------------------------- /training/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pyrallis 4 | from diffusers.utils import is_wandb_available 5 | from loguru import logger 6 | 7 | from training.config import RunConfig 8 | 9 | 10 | class CoachLogger: 11 | 12 | def __init__(self, cfg: RunConfig): 13 | self.cfg = cfg 14 | self.step = 0 15 | self.configure_loguru() 16 | self.log_config() 17 | self.validate_wandb() 18 | 19 | def configure_loguru(self): 20 | logger.remove() 21 | format = '{time:YYYY-MM-DD HH:mm:ss} {message}' 22 | logger.add(sys.stdout, colorize=True, format=format) 23 | logger.add(self.cfg.log.logging_dir / 'log.txt', colorize=False, format=format) 24 | 25 | def log_config(self): 26 | with (self.cfg.log.exp_dir / 'config.yaml').open('w') as f: 27 | pyrallis.dump(self.cfg, f) 28 | self.log_message('\n' + pyrallis.dump(self.cfg)) 29 | 30 | def validate_wandb(self): 31 | if self.cfg.log.report_to == "wandb": 32 | if not is_wandb_available(): 33 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 34 | 35 | @staticmethod 36 | def log_message(msg: str): 37 | logger.info(msg) 38 | 39 | def log_start_of_training(self, total_batch_size: int, num_samples: int): 40 | self.log_message("***** Running training *****") 41 | self.log_message(f" Num examples = {num_samples}") 42 | self.log_message(f" Instantaneous batch size per device = {self.cfg.optim.train_batch_size}") 43 | self.log_message(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 44 | self.log_message(f" Gradient Accumulation steps = {self.cfg.optim.gradient_accumulation_steps}") 45 | self.log_message(f" Total optimization steps = {self.cfg.optim.max_train_steps}") 46 | 47 | def update_step(self, step: int): 48 | self.step = step 49 | -------------------------------------------------------------------------------- /training/validate.py: -------------------------------------------------------------------------------- 1 | from email.mime import image 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from accelerate import Accelerator 8 | from accelerate.utils import set_seed 9 | from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, AutoencoderKL 10 | from diffusers.utils import is_wandb_available 11 | from tqdm import tqdm 12 | from transformers import CLIPTokenizer 13 | 14 | from training.config import RunConfig 15 | from models.clip_text_encoder import PersonaCLIPTextModel 16 | from models.xti_attention_processor import XTIAttenProc, MyAttenProc 17 | from prompt_manager import PromptManager 18 | from sd_pipeline_call import sd_pipeline_call 19 | 20 | if is_wandb_available(): 21 | import wandb 22 | 23 | 24 | class ValidationHandler: 25 | 26 | def __init__(self, cfg: RunConfig, placeholder_token_id: int, weights_dtype: torch.dtype): 27 | self.cfg = cfg 28 | self.placeholder_token_id = placeholder_token_id 29 | self.weight_dtype = weights_dtype 30 | 31 | def infer(self, accelerator: Accelerator, 32 | tokenizer: CLIPTokenizer, 33 | text_encoder: PersonaCLIPTextModel, 34 | unet: UNet2DConditionModel, vae: AutoencoderKL, 35 | prompts: List[str], 36 | num_images_per_prompt: int, 37 | seeds: List[int], 38 | step: int, 39 | word_embedding: torch.Tensor, 40 | image_embedding: torch.Tensor): 41 | """ Runs inference during our training scheme. """ 42 | pipeline = self.load_stable_diffusion_model(accelerator, tokenizer, text_encoder, unet, vae) 43 | prompt_manager = PromptManager(tokenizer=pipeline.tokenizer, 44 | text_encoder=pipeline.text_encoder, 45 | timesteps=pipeline.scheduler.timesteps, 46 | placeholder_token=self.cfg.data.placeholder_token, 47 | placeholder_token_id=self.placeholder_token_id) 48 | joined_images = [] 49 | for prompt in prompts: 50 | images = self.infer_on_prompt(pipeline=pipeline, 51 | prompt_manager=prompt_manager, 52 | prompt=prompt, 53 | num_images_per_prompt=num_images_per_prompt, 54 | seeds=seeds, 55 | word_embedding=word_embedding, 56 | image_embedding=image_embedding) 57 | prompt_image = Image.fromarray(np.concatenate(images, axis=1)) 58 | joined_images.append(prompt_image) 59 | final_image = Image.fromarray(np.concatenate(joined_images, axis=0)) 60 | final_image.save(self.cfg.log.exp_dir / f"val-image-{step}.png") 61 | self.log_with_accelerator(accelerator, joined_images, step=step) 62 | del pipeline 63 | torch.cuda.empty_cache() 64 | text_encoder.text_model.embeddings.mapper.train() 65 | if self.cfg.optim.seed is not None: 66 | set_seed(self.cfg.optim.seed) 67 | return final_image 68 | 69 | def infer_on_prompt(self, pipeline: StableDiffusionPipeline, 70 | prompt_manager: PromptManager, 71 | prompt: str, 72 | seeds: List[int], 73 | word_embedding: torch.Tensor, 74 | image_embedding: torch.Tensor, 75 | num_images_per_prompt: int = 1) -> List[Image.Image]: 76 | prompt_embeds = self.compute_embeddings(prompt_manager=prompt_manager, prompt=prompt, word_embedding=word_embedding, image_embedding=image_embedding) 77 | all_images = [] 78 | for idx in tqdm(range(num_images_per_prompt)): 79 | generator = torch.Generator(device='cuda').manual_seed(seeds[idx]) 80 | images = sd_pipeline_call(pipeline, 81 | prompt_embeds=prompt_embeds, 82 | generator=generator, 83 | num_images_per_prompt=1).images 84 | all_images.extend(images) 85 | return all_images 86 | 87 | @staticmethod 88 | def compute_embeddings(prompt_manager: PromptManager, prompt: str, word_embedding: torch.Tensor, image_embedding: torch.Tensor) -> torch.Tensor: 89 | with torch.autocast("cuda"): 90 | with torch.no_grad(): 91 | prompt_embeds = prompt_manager.my_embed_prompt(prompt, word_embedding, image_embedding) 92 | return prompt_embeds 93 | 94 | def load_stable_diffusion_model(self, accelerator: Accelerator, 95 | tokenizer: CLIPTokenizer, 96 | text_encoder: PersonaCLIPTextModel, 97 | unet: UNet2DConditionModel, 98 | vae: AutoencoderKL) -> StableDiffusionPipeline: 99 | """ Loads SD model given the current text encoder and our mapper. """ 100 | pipeline = StableDiffusionPipeline.from_pretrained(self.cfg.model.pretrained_model_name_or_path, 101 | text_encoder=accelerator.unwrap_model(text_encoder), 102 | tokenizer=tokenizer, 103 | unet=unet, 104 | vae=vae, 105 | torch_dtype=self.weight_dtype) 106 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 107 | pipeline = pipeline.to(accelerator.device) 108 | pipeline.set_progress_bar_config(disable=True) 109 | pipeline.scheduler.set_timesteps(self.cfg.eval.num_denoising_steps, device=pipeline.device) 110 | # pipeline.unet.set_attn_processor(XTIAttenProc()) 111 | pipeline.unet.set_attn_processor(MyAttenProc()) 112 | text_encoder.text_model.embeddings.mapper.eval() 113 | return pipeline 114 | 115 | def log_with_accelerator(self, accelerator: Accelerator, images: List[Image.Image], step: int): 116 | for tracker in accelerator.trackers: 117 | if tracker.name == "tensorboard": 118 | np_images = np.stack([np.asarray(img) for img in images]) 119 | tracker.writer.add_images("validation", np_images, step, dataformats="NHWC") 120 | if tracker.name == "wandb": 121 | tracker.log({"validation": [wandb.Image(image, caption=f"{i}: {self.cfg.eval.validation_prompts[i]}") 122 | for i, image in enumerate(images)]}) 123 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/utils/__init__.py -------------------------------------------------------------------------------- /utils/types.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | @dataclass 9 | class Mapper_input: 10 | timesteps: torch.Tensor 11 | word_embedding: torch.Tensor 12 | image_embedding: torch.Tensor 13 | 14 | @dataclass 15 | class Batch: 16 | input_ids: torch.Tensor 17 | placeholder_token_id: int 18 | mapper_input: Mapper_input 19 | 20 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | from PIL import Image 5 | 6 | 7 | def get_image_grid(images: List[Image.Image]) -> Image: 8 | num_images = len(images) 9 | cols = int(math.ceil(math.sqrt(num_images))) 10 | rows = int(math.ceil(num_images / cols)) 11 | width, height = images[0].size 12 | grid_image = Image.new('RGB', (cols * width, rows * height)) 13 | for i, img in enumerate(images): 14 | x = i % cols 15 | y = i // cols 16 | grid_image.paste(img, (x * width, y * height)) 17 | return grid_image 18 | --------------------------------------------------------------------------------