├── README.md
├── checkpoint_handler.py
├── clipseg
├── LICENSE
├── Quickstart.ipynb
├── README.md
├── Tables.ipynb
├── Visual_Feature_Engineering.ipynb
├── clip_masking_lvis_image_ids.yml
├── datasets
│ ├── coco_wrapper.py
│ ├── pascal_classes.json
│ ├── pascal_zeroshot.py
│ ├── pfe_dataset.py
│ ├── phrasecut.py
│ └── utils.py
├── environment.yml
├── evaluation_utils.py
├── example_image.jpg
├── experiments
│ ├── ablation.yaml
│ ├── coco.yaml
│ ├── pascal_0shot.yaml
│ ├── pascal_1shot.yaml
│ └── phrasecut.yaml
├── general_utils.py
├── metrics.py
├── models
│ ├── clipseg.py
│ └── vitseg.py
├── mycode.py
├── overview.png
├── sample_rd64.png
├── sample_rd64_refined.png
├── score.py
├── setup.py
├── training.py
└── weights
│ ├── rd16-uni.pth
│ ├── rd64-uni-refined.pth
│ └── rd64-uni.pth
├── constants.py
├── data
└── person_1
│ └── person_1.jpg
├── environment
├── environment.yaml
└── requirements.txt
├── inference_text.txt
├── input_configs
├── inference.yaml
└── train.yaml
├── models
├── __init__.py
├── clip_prior.py
├── clip_text_embedding.py
├── clip_text_encoder.py
├── mapper.py
├── positional_encoding.py
└── xti_attention_processor.py
├── prompt_manager.py
├── scripts
├── __init__.py
├── inference.py
├── seg.py
└── train.py
├── sd_pipeline_call.py
├── training
├── __init__.py
├── coach_2.py
├── config.py
├── dataset.py
├── logger.py
└── validate.py
└── utils
├── __init__.py
├── types.py
└── vis_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # PersonaMagic (AAAI 2025)
2 |
3 | 🚧 Work in Progress - Feature not finished yet 🚧
4 |
5 |
--------------------------------------------------------------------------------
/checkpoint_handler.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Tuple
3 |
4 | import pyrallis
5 | import torch
6 | from accelerate import Accelerator
7 | from torch import nn
8 | from transformers import CLIPTokenizer
9 |
10 | from models.clip_text_encoder import PersonaCLIPTextModel
11 | from models.mapper import Mapper
12 | from models.positional_encoding import BasicEncoder, TimePositionalEncoding
13 | from training.config import RunConfig
14 |
15 |
16 | class CheckpointHandler:
17 |
18 | def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path):
19 | self.cfg = cfg
20 | self.placeholder_token_string = placeholder_token_string
21 | self.placeholder_token_id = placeholder_token_id
22 | self.save_root = save_root
23 |
24 | def save_model(self, text_encoder: PersonaCLIPTextModel,
25 | accelerator: Accelerator,
26 | embeds_save_name: str,
27 | mapper_save_name: str):
28 | self.save_learned_embeds(text_encoder, accelerator, embeds_save_name)
29 | self.save_mapper(text_encoder, mapper_save_name)
30 |
31 | def save_learned_embeds(self, text_encoder: PersonaCLIPTextModel, accelerator: Accelerator, save_name: str):
32 | """
33 | Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference
34 | to take the place of our placeholder token.
35 | """
36 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id]
37 | learned_embeds = learned_embeds.detach().cpu()
38 | learned_embeds_dict = {}
39 | for i in range(len(self.placeholder_token_string)):
40 | learned_embeds_dict[self.placeholder_token_string[i]] = learned_embeds[i]
41 | torch.save(learned_embeds_dict, self.save_root / save_name)
42 |
43 | def save_mapper(self, text_encoder: PersonaCLIPTextModel, save_name: str):
44 | """ Save the mapper and config to be used at inference. """
45 | cfg_ = RunConfig(**self.cfg.__dict__.copy())
46 | state_dict = {
47 | "state_dict": text_encoder.text_model.embeddings.mapper.state_dict(),
48 | "cfg": pyrallis.encode(cfg_),
49 | "encoder": text_encoder.text_model.embeddings.mapper.encoder
50 | }
51 | torch.save(state_dict, self.save_root / save_name)
52 |
53 |
54 | @staticmethod
55 | def load_my_mapper(mapper_path: Path) -> Tuple[RunConfig, Mapper]:
56 | mapper_ckpt = torch.load(mapper_path, map_location="cpu")
57 | cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg'])
58 | neti_mapper = Mapper(output_dim=768,
59 | norm_scale=cfg.model.target_norm,
60 | use_positional_encoding=cfg.model.use_positional_encoding,
61 | num_pe_time_anchors=cfg.model.num_pe_time_anchors,
62 | sigma_t=cfg.model.sigma_t,
63 | output_bypass=cfg.model.output_bypass,
64 | token_num=4)
65 | neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True)
66 | encoder = mapper_ckpt['encoder']
67 | if isinstance(encoder, TimePositionalEncoding):
68 | encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda())
69 | elif isinstance(encoder, BasicEncoder):
70 | encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda()
71 | encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda()
72 | neti_mapper.encoder = encoder.cuda()
73 | neti_mapper.cuda()
74 | neti_mapper.eval()
75 | return cfg, neti_mapper
76 |
77 | @staticmethod
78 | def load_learned_embed_in_clip(learned_embeds_path: Path,
79 | text_encoder: PersonaCLIPTextModel,
80 | tokenizer: CLIPTokenizer) -> Tuple[str, int]:
81 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
82 |
83 | # separate token and the embeds
84 | trained_tokens = list(loaded_learned_embeds.keys())
85 | embeds = list(loaded_learned_embeds.values())
86 |
87 | # cast to dtype of text_encoder
88 | dtype = text_encoder.get_input_embeddings().weight.dtype
89 | embeds = [e.to(dtype) for e in embeds]
90 |
91 | # add the tokens in tokenizer
92 | num_added_tokens = tokenizer.add_tokens(trained_tokens)
93 | if num_added_tokens == 0:
94 | raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. "
95 | f"Please pass a different `token` that is not already in the tokenizer.")
96 |
97 | # resize the token embeddings
98 | text_encoder.resize_token_embeddings(len(tokenizer))
99 |
100 | # get the id for the token and assign the embeds
101 | placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens]
102 |
103 | for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)):
104 | text_encoder.get_input_embeddings().weight.data[token_id] = embed
105 |
106 | # assert len(trained_tokens) == 1, "Only one placeholder token is supported"
107 | # placeholder_token = trained_tokens[0]
108 | # placeholder_token_id = placeholder_token_ids[0]
109 | placeholder_token = trained_tokens
110 | placeholder_token_id = placeholder_token_ids
111 | return placeholder_token, placeholder_token_id
112 |
--------------------------------------------------------------------------------
/clipseg/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | This license does not apply to the model weights.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/clipseg/Quickstart.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import requests\n",
11 | "\n",
12 | "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
13 | "! unzip -d weights -j weights.zip\n",
14 | "from models.clipseg import CLIPDensePredT\n",
15 | "from PIL import Image\n",
16 | "from torchvision import transforms\n",
17 | "from matplotlib import pyplot as plt\n",
18 | "\n",
19 | "# load model\n",
20 | "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
21 | "model.eval();\n",
22 | "\n",
23 | "# non-strict, because we only stored decoder weights (not CLIP weights)\n",
24 | "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "Load and normalize `example_image.jpg`. You can also load through an URL."
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "# load and normalize image\n",
41 | "input_image = Image.open('example_image.jpg')\n",
42 | "\n",
43 | "# or load from URL...\n",
44 | "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
45 | "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
46 | "\n",
47 | "transform = transforms.Compose([\n",
48 | " transforms.ToTensor(),\n",
49 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
50 | " transforms.Resize((352, 352)),\n",
51 | "])\n",
52 | "img = transform(input_image).unsqueeze(0)"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {},
58 | "source": [
59 | "Predict and visualize (this might take a few seconds if running without GPU support)"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
69 | "\n",
70 | "# predict\n",
71 | "with torch.no_grad():\n",
72 | " preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
73 | "\n",
74 | "# visualize prediction\n",
75 | "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
76 | "[a.axis('off') for a in ax.flatten()]\n",
77 | "ax[0].imshow(input_image)\n",
78 | "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
79 | "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
80 | ]
81 | }
82 | ],
83 | "metadata": {
84 | "interpreter": {
85 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
86 | },
87 | "kernelspec": {
88 | "display_name": "Python 3",
89 | "language": "python",
90 | "name": "python3"
91 | },
92 | "language_info": {
93 | "codemirror_mode": {
94 | "name": "ipython",
95 | "version": 3
96 | },
97 | "file_extension": ".py",
98 | "mimetype": "text/x-python",
99 | "name": "python",
100 | "nbconvert_exporter": "python",
101 | "pygments_lexer": "ipython3",
102 | "version": "3.8.10"
103 | }
104 | },
105 | "nbformat": 4,
106 | "nbformat_minor": 4
107 | }
108 |
--------------------------------------------------------------------------------
/clipseg/README.md:
--------------------------------------------------------------------------------
1 | # Image Segmentation Using Text and Image Prompts
2 | This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
3 |
4 | **November 2022:** CLIPSeg has been integrated into the [HuggingFace Transformers library](https://huggingface.co/docs/transformers/main/en/model_doc/clipseg). Thank you, [NielsRogge](https://github.com/NielsRogge)!
5 | **September 2022:** We released new weights for fine-grained predictions (see below for details).
6 | **March 2022:** The Paper has been accepted to CVPR 2022!
7 |
8 |
9 |
10 |
11 | The systems allows to create segmentation models without training based on:
12 | - An arbitrary text query
13 | - Or an image with a mask highlighting stuff or an object.
14 |
15 | ### Quick Start
16 |
17 | In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
18 | It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
19 | (please note that the VM does not use a GPU, thus inference takes a few seconds).
20 |
21 |
22 | ### Dependencies
23 | This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
24 | Additional dependencies are hidden for double blind review.
25 |
26 |
27 | ### Datasets
28 |
29 | * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
30 | * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
31 | * `PascalZeroShot`: Wrapper class for PascalZeroShot
32 | * `COCOWrapper`: Wrapper class for COCO.
33 |
34 | ### Models
35 |
36 | * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
37 | * `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
38 |
39 | ### Third Party Dependencies
40 | For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
41 | ```bash
42 | git clone https://github.com/cvlab-yonsei/JoEm
43 | git clone https://github.com/Jia-Research-Lab/PFENet.git
44 | git clone https://github.com/ChenyunWu/PhraseCutDataset.git
45 | git clone https://github.com/juhongm999/hsnet.git
46 | ```
47 |
48 | ### Weights
49 |
50 | The MIT license does not apply to these weights.
51 |
52 | We provide three model weights, for D=64 (2x, ~4MB each) and D=16 (~1MB).
53 | ```
54 | wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
55 | unzip -d weights -j weights.zip
56 | ```
57 |
58 | #### New Fine-grained Weights
59 | We introduced a more complex module for transforming tokens into predictions that allow for more refined predictions (in contrast to the square-like predictions of other weights). Corresponding weights are available in the weight download above called `rd64-uni-refined.pth`.
60 | They can be loaded by:
61 | ```python
62 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
63 | model.load_state_dict(torch.load('weights/rd64-uni-refined.pth'), strict=False)
64 | ```
65 |
66 | See below for a direct comparison of the new fine-grained weights (top) and the old weights (below).
67 |
68 |
69 |
70 |
71 |
72 | ### Training and Evaluation
73 |
74 | To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
75 |
76 | For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
77 |
78 |
79 | ### Usage of PFENet Wrappers
80 |
81 | In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
82 | `git clone https://github.com/Jia-Research-Lab/PFENet.git `
83 |
84 |
85 | ### License
86 |
87 | The source code files in this repository (excluding model weights) are released under MIT license.
88 |
89 | ### Citation
90 | ```
91 | @InProceedings{lueddecke22_cvpr,
92 | author = {L\"uddecke, Timo and Ecker, Alexander},
93 | title = {Image Segmentation Using Text and Image Prompts},
94 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
95 | month = {June},
96 | year = {2022},
97 | pages = {7086-7096}
98 | }
99 |
100 | ```
101 |
--------------------------------------------------------------------------------
/clipseg/Tables.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "import clip\n",
13 | "from evaluation_utils import norm, denorm\n",
14 | "from general_utils import *\n"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "# PhraseCut"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
49 | "tab1 = pc[['name'] + cols]\n",
50 | "for k in cols:\n",
51 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
52 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
53 | "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
54 | "print(tab1.to_latex(header=False, index=False))"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "For 0.1 threshold"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
71 | "tab1 = pc[['name'] + cols]\n",
72 | "for k in cols:\n",
73 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
74 | "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
75 | "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
76 | "print(tab1.to_latex(header=False, index=False))"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "# One-shot"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "### Pascal"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
118 | "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
119 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
120 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
121 | "\n",
122 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
123 | "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
124 | "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
125 | "\n",
126 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
127 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
128 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "#### Pascal Zero-shot (in one-shot setting)\n",
136 | "\n",
137 | "Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
147 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
148 | "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
149 | "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
150 | "\n",
151 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
152 | "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
153 | "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
154 | "\n",
155 | "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
156 | "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
157 | "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": null,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "# without fixed thresholds...\n",
167 | "\n",
168 | "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
169 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
170 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
171 | "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
172 | "\n",
173 | "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
174 | "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
175 | "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "### COCO"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
201 | "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
202 | "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
203 | "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
204 | "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
205 | "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
206 | "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
207 | ]
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "metadata": {},
212 | "source": [
213 | "# Zero-shot"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": null,
219 | "metadata": {},
220 | "outputs": [],
221 | "source": [
222 | "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {},
229 | "outputs": [],
230 | "source": [
231 | "\n",
232 | "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
233 | "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
234 | "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
235 | "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "metadata": {},
241 | "source": [
242 | "# Ablation"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": null,
248 | "metadata": {},
249 | "outputs": [],
250 | "source": [
251 | "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "metadata": {},
258 | "outputs": [],
259 | "source": [
260 | "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
261 | "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
262 | " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
263 | "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "metadata": {},
279 | "outputs": [],
280 | "source": [
281 | "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {},
287 | "source": [
288 | "# Generalization"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "generalization = experiment('experiments/generalize.yaml').dataframe()"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": null,
312 | "metadata": {},
313 | "outputs": [],
314 | "source": [
315 | "print(\n",
316 | " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
317 | " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
318 | " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
319 | " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
320 | ")"
321 | ]
322 | }
323 | ],
324 | "metadata": {
325 | "interpreter": {
326 | "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
327 | },
328 | "kernelspec": {
329 | "display_name": "env2",
330 | "language": "python",
331 | "name": "env2"
332 | },
333 | "language_info": {
334 | "codemirror_mode": {
335 | "name": "ipython",
336 | "version": 3
337 | },
338 | "file_extension": ".py",
339 | "mimetype": "text/x-python",
340 | "name": "python",
341 | "nbconvert_exporter": "python",
342 | "pygments_lexer": "ipython3",
343 | "version": "3.8.8"
344 | }
345 | },
346 | "nbformat": 4,
347 | "nbformat_minor": 4
348 | }
349 |
--------------------------------------------------------------------------------
/clipseg/Visual_Feature_Engineering.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Systematic"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%load_ext autoreload\n",
17 | "%autoreload 2\n",
18 | "\n",
19 | "import clip\n",
20 | "from evaluation_utils import norm, denorm\n",
21 | "from general_utils import *\n",
22 | "from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
23 | "\n",
24 | "clip_device = 'cuda'\n",
25 | "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
26 | "clip_model.eval();\n",
27 | "\n",
28 | "from models.clipseg import CLIPDensePredTMasked\n",
29 | "\n",
30 | "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
31 | "clip_mask_model.eval();"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
41 | " text_class_labels=True, image_size=352, min_area=0.1,\n",
42 | " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "plot_data(lvis)"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "from collections import defaultdict\n",
61 | "import json\n",
62 | "\n",
63 | "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
64 | "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
65 | "\n",
66 | "objects_per_image = defaultdict(lambda : set())\n",
67 | "for ann in lvis_raw['annotations']:\n",
68 | " objects_per_image[ann['image_id']].add(ann['category_id'])\n",
69 | " \n",
70 | "for ann in lvis_val_raw['annotations']:\n",
71 | " objects_per_image[ann['image_id']].add(ann['category_id']) \n",
72 | " \n",
73 | "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
74 | "\n",
75 | "del lvis_raw, lvis_val_raw"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "#bs = 32\n",
85 | "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "from general_utils import get_batch\n",
95 | "from functools import partial\n",
96 | "from evaluation_utils import img_preprocess\n",
97 | "import torch\n",
98 | "\n",
99 | "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
100 | "\n",
101 | " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
102 | "\n",
103 | " all_prompts = []\n",
104 | " \n",
105 | " with torch.no_grad():\n",
106 | " valid_sims = []\n",
107 | " torch.manual_seed(571)\n",
108 | " \n",
109 | " if type(batches_or_dataset) == list:\n",
110 | " loader = batches_or_dataset # already loaded\n",
111 | " max_iter = float('inf')\n",
112 | " else:\n",
113 | " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
114 | " max_iter = 50\n",
115 | " \n",
116 | " global batch\n",
117 | " for i_batch, (batch, batch_y) in enumerate(loader):\n",
118 | " \n",
119 | " if i_batch >= max_iter: break\n",
120 | " \n",
121 | " processed_batch = process(batch)\n",
122 | " if type(processed_batch) == dict:\n",
123 | " \n",
124 | " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
125 | " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
126 | " else:\n",
127 | " processed_batch = process(batch).to(clip_device)\n",
128 | " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
129 | " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
130 | " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
131 | " \n",
132 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
133 | " bs = len(batch[0])\n",
134 | " for j in range(bs):\n",
135 | " \n",
136 | " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
137 | " support_image = basename(lvis.samples[c][sid])\n",
138 | " \n",
139 | " img_objs = [o for o in objects_per_image[int(support_image)]]\n",
140 | " img_objs = [o.replace('_', ' ') for o in img_objs]\n",
141 | " \n",
142 | " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
143 | " if o != batch_y[2][j]]\n",
144 | " \n",
145 | " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
146 | " all_prompts += [prompts]\n",
147 | " \n",
148 | " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
149 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
150 | "\n",
151 | " global logits\n",
152 | " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
153 | "\n",
154 | " global sim\n",
155 | " sim = torch.softmax(logits, dim=-1)\n",
156 | " \n",
157 | " valid_sims += [sim]\n",
158 | " \n",
159 | " #valid_sims = torch.stack(valid_sims)\n",
160 | " return valid_sims, all_prompts\n",
161 | " \n",
162 | "\n",
163 | "def new_img_preprocess(x):\n",
164 | " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
165 | " \n",
166 | "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
167 | "get_similarities(lvis, lambda x: x[1]);"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "preprocessing_functions = [\n",
177 | "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
178 | "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
179 | "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
180 | "# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
181 | "# ['add red outline', partial(img_preprocess, outline=True)],\n",
182 | " \n",
183 | "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
184 | "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
185 | "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
186 | "# ['BG blur', partial(img_preprocess, blur=3)],\n",
187 | "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
188 | " \n",
189 | "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
190 | "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
191 | " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
192 | " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
193 | "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
194 | "]\n",
195 | "\n",
196 | "preprocessing_functions = preprocessing_functions\n",
197 | "\n",
198 | "base, base_p = get_similarities(lvis, lambda x: x[1])\n",
199 | "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "metadata": {},
215 | "outputs": [],
216 | "source": [
217 | "for j in range(1):\n",
218 | " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "from pandas import DataFrame\n",
228 | "tab = dict()\n",
229 | "for j, (name, _) in enumerate(preprocessing_functions):\n",
230 | " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
231 | " \n",
232 | " \n",
233 | "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {},
239 | "source": [
240 | "# Visual"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": null,
246 | "metadata": {},
247 | "outputs": [],
248 | "source": [
249 | "from evaluation_utils import denorm, norm"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": null,
255 | "metadata": {},
256 | "outputs": [],
257 | "source": [
258 | "def load_sample(filename, filename2):\n",
259 | " from os.path import join\n",
260 | " bp = expanduser('~/cloud/resources/sample_images')\n",
261 | " tf = transforms.Compose([\n",
262 | " transforms.ToTensor(),\n",
263 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
264 | " transforms.Resize(224),\n",
265 | " transforms.CenterCrop(224)\n",
266 | " ])\n",
267 | " tf2 = transforms.Compose([\n",
268 | " transforms.ToTensor(),\n",
269 | " transforms.Resize(224),\n",
270 | " transforms.CenterCrop(224)\n",
271 | " ])\n",
272 | " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
273 | " inp1[1] = inp1[1].unsqueeze(0)\n",
274 | " inp1[2] = inp1[2][:1] \n",
275 | " return inp1\n",
276 | "\n",
277 | "def all_preprocessing(inp1):\n",
278 | " return [\n",
279 | " img_preprocess(inp1),\n",
280 | " img_preprocess(inp1, colorize=True),\n",
281 | " img_preprocess(inp1, outline=True), \n",
282 | " img_preprocess(inp1, blur=3),\n",
283 | " img_preprocess(inp1, bg_fac=0.1),\n",
284 | " #img_preprocess(inp1, bg_fac=0.5),\n",
285 | " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
286 | " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
287 | " ]\n",
288 | "\n"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "from torchvision import transforms\n",
298 | "from PIL import Image\n",
299 | "from matplotlib import pyplot as plt\n",
300 | "from evaluation_utils import img_preprocess\n",
301 | "import clip\n",
302 | "\n",
303 | "images_queries = [\n",
304 | " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
305 | " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
306 | "]\n",
307 | "\n",
308 | "\n",
309 | "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
310 | "\n",
311 | "for j, (images, objects) in enumerate(images_queries):\n",
312 | " \n",
313 | " joint_image = all_preprocessing(images)\n",
314 | " \n",
315 | " joint_image = torch.stack(joint_image)[:,0]\n",
316 | " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
317 | " image_features = clip_model.encode_image(joint_image)\n",
318 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
319 | " \n",
320 | " prompts = [f'a photo of a {obj}'for obj in objects]\n",
321 | " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
322 | " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
323 | " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
324 | " sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
325 | "\n",
326 | " for i, img in enumerate(joint_image):\n",
327 | " ax[2*j, i].axis('off')\n",
328 | " \n",
329 | " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
330 | " ax[2*j+ 1, i].grid(True)\n",
331 | " \n",
332 | " ax[2*j + 1, i].set_ylim(0,1)\n",
333 | " ax[2*j + 1, i].set_yticklabels([])\n",
334 | " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
335 | "# ax[1, i].set_xticklabels(objects, rotation=90)\n",
336 | " for k in range(len(sim[i])):\n",
337 | " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
338 | " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
339 | "\n",
340 | "plt.tight_layout()\n",
341 | "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
342 | ]
343 | }
344 | ],
345 | "metadata": {
346 | "kernelspec": {
347 | "display_name": "env2",
348 | "language": "python",
349 | "name": "env2"
350 | },
351 | "language_info": {
352 | "codemirror_mode": {
353 | "name": "ipython",
354 | "version": 3
355 | },
356 | "file_extension": ".py",
357 | "mimetype": "text/x-python",
358 | "name": "python",
359 | "nbconvert_exporter": "python",
360 | "pygments_lexer": "ipython3",
361 | "version": "3.8.8"
362 | }
363 | },
364 | "nbformat": 4,
365 | "nbformat_minor": 4
366 | }
367 |
--------------------------------------------------------------------------------
/clipseg/datasets/coco_wrapper.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from types import new_class
3 | import torch
4 | import numpy as np
5 | import os
6 | import json
7 |
8 | from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
9 | from random import shuffle, seed as set_seed
10 | from PIL import Image
11 |
12 | from itertools import combinations
13 | from torchvision import transforms
14 | from torchvision.transforms.transforms import Resize
15 |
16 | from datasets.utils import blend_image_segmentation
17 | from general_utils import get_from_repository
18 |
19 | COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
20 |
21 | class COCOWrapper(object):
22 |
23 | def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
24 | with_class_label=False):
25 | super().__init__()
26 |
27 | self.mask = mask
28 | self.with_class_label = with_class_label
29 | self.negative_prob = negative_prob
30 |
31 | from third_party.hsnet.data.coco import DatasetCOCO
32 |
33 | get_from_repository('COCO-20i', ['COCO-20i.tar'])
34 |
35 | foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
36 |
37 | def build_img_metadata_classwise(self):
38 | with open(foldpath % (self.split, self.fold), 'rb') as f:
39 | img_metadata_classwise = pickle.load(f)
40 | return img_metadata_classwise
41 |
42 |
43 | DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
44 | # DatasetCOCO.read_mask = read_mask
45 |
46 | mean = [0.485, 0.456, 0.406]
47 | std = [0.229, 0.224, 0.225]
48 | transform = transforms.Compose([
49 | transforms.Resize((image_size, image_size)),
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean, std)
52 | ])
53 |
54 | self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
55 |
56 | self.all_classes = [self.coco.class_ids]
57 | self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
58 |
59 | def __len__(self):
60 | return len(self.coco)
61 |
62 | def __getitem__(self, i):
63 | sample = self.coco[i]
64 |
65 | label_name = COCO_CLASSES[int(sample['class_id'])]
66 |
67 | img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
68 |
69 | if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
70 | new_class_id = sample['class_id']
71 | while new_class_id == sample['class_id']:
72 | sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
73 | new_class_id = sample2['class_id']
74 | img_s = sample2['support_imgs'][0]
75 | seg_s = torch.zeros_like(seg_s)
76 |
77 | mask = self.mask
78 | if mask == 'separate':
79 | supp = (img_s, seg_s)
80 | elif mask == 'text_label':
81 | # DEPRECATED
82 | supp = [int(sample['class_id'])]
83 | elif mask == 'text':
84 | supp = [label_name]
85 | else:
86 | if mask.startswith('text_and_'):
87 | mask = mask[9:]
88 | label_add = [label_name]
89 | else:
90 | label_add = []
91 |
92 | supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
93 |
94 | if self.with_class_label:
95 | label = (torch.zeros(0), sample['class_id'],)
96 | else:
97 | label = (torch.zeros(0), )
98 |
99 | return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
--------------------------------------------------------------------------------
/clipseg/datasets/pascal_classes.json:
--------------------------------------------------------------------------------
1 | [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
--------------------------------------------------------------------------------
/clipseg/datasets/pascal_zeroshot.py:
--------------------------------------------------------------------------------
1 | from os.path import expanduser
2 | import torch
3 | import json
4 | import torchvision
5 | from general_utils import get_from_repository
6 | from general_utils import log
7 | from torchvision import transforms
8 |
9 | PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
10 | ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
11 | ['chair.n.01', 'pot_plant.n.01']]
12 |
13 |
14 | class PascalZeroShot(object):
15 |
16 | def __init__(self, split, n_unseen, image_size=224) -> None:
17 | super().__init__()
18 |
19 | import sys
20 | sys.path.append('third_party/JoEm')
21 | from third_party.JoEm.data_loader.dataset import VOCSegmentation
22 | from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
23 |
24 | self.pascal_classes = VOC
25 | self.image_size = image_size
26 |
27 | self.transform = transforms.Compose([
28 | transforms.Resize((image_size, image_size)),
29 | ])
30 |
31 | if split == 'train':
32 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
33 | split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
34 | ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
35 | elif split == 'val':
36 | self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
37 | split=split, transform=False,
38 | ignore_bg=False, ignore_unseen=False)
39 |
40 | self.unseen_idx = get_unseen_idx(n_unseen)
41 |
42 | def __len__(self):
43 | return len(self.voc)
44 |
45 | def __getitem__(self, i):
46 |
47 | sample = self.voc[i]
48 | label = sample['label'].long()
49 | all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
50 | class_indices = [l for l in all_labels]
51 | class_names = [self.pascal_classes[l] for l in all_labels]
52 |
53 | image = self.transform(sample['image'])
54 |
55 | label = transforms.Resize((self.image_size, self.image_size),
56 | interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
57 |
58 | return (image,), (label, )
59 |
60 |
61 |
--------------------------------------------------------------------------------
/clipseg/datasets/pfe_dataset.py:
--------------------------------------------------------------------------------
1 | from os.path import expanduser
2 | import torch
3 | import json
4 | from general_utils import get_from_repository
5 | from datasets.utils import blend_image_segmentation
6 | from general_utils import log
7 |
8 | PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
9 |
10 |
11 | class PFEPascalWrapper(object):
12 |
13 | def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
14 | import sys
15 | # sys.path.append(expanduser('~/projects/new_one_shot'))
16 | from third_party.PFENet.util.dataset import SemData
17 |
18 | get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
19 |
20 | self.p_negative = p_negative
21 | self.size = size
22 | self.mode = mode
23 | self.image_size = image_size
24 |
25 | if label_support in {True, False}:
26 | log.warning('label_support argument is deprecated. Use mask instead.')
27 | #raise ValueError()
28 |
29 | self.mask = mask
30 |
31 | value_scale = 255
32 | mean = [0.485, 0.456, 0.406]
33 | mean = [item * value_scale for item in mean]
34 | std = [0.229, 0.224, 0.225]
35 | std = [item * value_scale for item in std]
36 |
37 | import third_party.PFENet.util.transform as transform
38 |
39 | if mode == 'val':
40 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
41 |
42 | data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
43 | data_transform += [
44 | transform.ToTensor(),
45 | transform.Normalize(mean=mean, std=std)
46 | ]
47 |
48 |
49 | elif mode == 'train':
50 | data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
51 |
52 | assert image_size != 'original'
53 |
54 | data_transform = [
55 | transform.RandScale([0.9, 1.1]),
56 | transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
57 | transform.RandomGaussianBlur(),
58 | transform.RandomHorizontalFlip(),
59 | transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
60 | transform.ToTensor(),
61 | transform.Normalize(mean=mean, std=std)
62 | ]
63 |
64 | data_transform = transform.Compose(data_transform)
65 |
66 | self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
67 | data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
68 |
69 | self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
70 |
71 | # verify that subcls_list always has length 1
72 | # assert len(set([len(d[4]) for d in self.dataset])) == 1
73 |
74 | print('actual length', len(self.dataset.data_list))
75 |
76 | def __len__(self):
77 | if self.mode == 'val':
78 | return len(self.dataset.data_list)
79 | else:
80 | return len(self.dataset.data_list)
81 |
82 | def __getitem__(self, index):
83 | if self.dataset.mode == 'train':
84 | image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
85 | elif self.dataset.mode == 'val':
86 | image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
87 | ori_label = torch.from_numpy(ori_label).unsqueeze(0)
88 |
89 | if self.image_size != 'original':
90 | longerside = max(ori_label.size(1), ori_label.size(2))
91 | backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
92 | backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
93 | label = backmask.clone().long()
94 | else:
95 | label = label.unsqueeze(0)
96 |
97 | # assert label.shape == (473, 473)
98 |
99 | if self.p_negative > 0:
100 | if torch.rand(1).item() < self.p_negative:
101 | while True:
102 | idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
103 | _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
104 | if subcls_list[0] != subcls_list_tmp[0]:
105 | break
106 |
107 | s_x = s_x[0]
108 | s_y = (s_y == 1)[0]
109 | label_fg = (label == 1).float()
110 | val_mask = (label != 255).float()
111 |
112 | class_id = self.class_list[subcls_list[0]]
113 |
114 | label_name = PASCAL_CLASSES[class_id][0]
115 | label_add = ()
116 | mask = self.mask
117 |
118 | if mask == 'text':
119 | support = ('a photo of a ' + label_name + '.',)
120 | elif mask == 'separate':
121 | support = (s_x, s_y)
122 | else:
123 | if mask.startswith('text_and_'):
124 | label_add = (label_name,)
125 | mask = mask[9:]
126 |
127 | support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
128 |
129 | return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
130 |
--------------------------------------------------------------------------------
/clipseg/datasets/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def blend_image_segmentation(img, seg, mode, image_size=224):
7 |
8 |
9 | if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}:
10 | if isinstance(img, np.ndarray):
11 | img = torch.from_numpy(img)
12 |
13 | if isinstance(seg, np.ndarray):
14 | seg = torch.from_numpy(seg)
15 |
16 | if mode == 'overlay':
17 | out = img * seg
18 | out = [out.astype('float32')]
19 | elif mode == 'highlight':
20 | out = img * seg[None, :, :] * 0.85 + 0.15 * img
21 | out = [out.astype('float32')]
22 | elif mode == 'highlight2':
23 | img = img / 2
24 | out = (img+0.1) * seg[None, :, :] + 0.3 * img
25 | out = [out.astype('float32')]
26 | elif mode == 'blur_highlight':
27 | from evaluation_utils import img_preprocess
28 | out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01]
29 | elif mode == 'blur3_highlight':
30 | from evaluation_utils import img_preprocess
31 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01]
32 | elif mode == 'blur3_highlight01':
33 | from evaluation_utils import img_preprocess
34 | out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01]
35 | elif mode == 'blur_highlight_random':
36 | from evaluation_utils import img_preprocess
37 | out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01]
38 | elif mode == 'crop':
39 | from evaluation_utils import img_preprocess
40 | out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()]
41 | elif mode == 'crop_blur_highlight':
42 | from evaluation_utils import img_preprocess
43 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()]
44 | elif mode == 'crop_blur_highlight352':
45 | from evaluation_utils import img_preprocess
46 | out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()]
47 | elif mode == 'shape':
48 | out = [np.stack([seg[:, :]]*3).astype('float32')]
49 | elif mode == 'concat':
50 | out = [np.concatenate([img, seg[None, :, :]]).astype('float32')]
51 | elif mode == 'image_only':
52 | out = [img.astype('float32')]
53 | elif mode == 'image_black':
54 | out = [img.astype('float32')*0]
55 | elif mode is None:
56 | out = [img.astype('float32')]
57 | elif mode == 'separate':
58 | out = [img.astype('float32'), seg.astype('int64')]
59 | elif mode == 'separate_img_black':
60 | out = [img.astype('float32')*0, seg.astype('int64')]
61 | elif mode == 'separate_seg_ones':
62 | out = [img.astype('float32'), np.ones_like(seg).astype('int64')]
63 | elif mode == 'separate_both_black':
64 | out = [img.astype('float32')*0, seg.astype('int64')*0]
65 | else:
66 | raise ValueError(f'invalid mode: {mode}')
67 |
68 | return out
--------------------------------------------------------------------------------
/clipseg/environment.yml:
--------------------------------------------------------------------------------
1 | name: clipseg-environment
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - numpy
7 | - scipy
8 | - matplotlib-base
9 | - pip
10 | - pip:
11 | - --find-links https://download.pytorch.org/whl/torch_stable.html
12 | - torch==1.12.1+cpu
13 | - torchvision==0.13.1+cpu
14 | - opencv-python
15 | - git+https://github.com/openai/CLIP.git
16 |
--------------------------------------------------------------------------------
/clipseg/evaluation_utils.py:
--------------------------------------------------------------------------------
1 | from torch.functional import Tensor
2 | from general_utils import load_model
3 | from torch.utils.data import DataLoader
4 | import torch
5 | import numpy as np
6 |
7 | def denorm(img):
8 |
9 | np_input = False
10 | if isinstance(img, np.ndarray):
11 | img = torch.from_numpy(img)
12 | np_input = True
13 |
14 | mean = torch.Tensor([0.485, 0.456, 0.406])
15 | std = torch.Tensor([0.229, 0.224, 0.225])
16 |
17 | img_denorm = (img*std[:,None,None]) + mean[:,None,None]
18 |
19 | if np_input:
20 | img_denorm = np.clip(img_denorm.numpy(), 0, 1)
21 | else:
22 | img_denorm = torch.clamp(img_denorm, 0, 1)
23 |
24 | return img_denorm
25 |
26 |
27 | def norm(img):
28 | mean = torch.Tensor([0.485, 0.456, 0.406])
29 | std = torch.Tensor([0.229, 0.224, 0.225])
30 | return (img - mean[:,None,None]) / std[:,None,None]
31 |
32 |
33 | def fast_iou_curve(p, g):
34 |
35 | g = g[p.sort().indices]
36 | p = torch.sigmoid(p.sort().values)
37 |
38 | scores = []
39 | vals = np.linspace(0, 1, 50)
40 |
41 | for q in vals:
42 |
43 | n = int(len(g) * q)
44 |
45 | valid = torch.where(p > q)[0]
46 | if len(valid) > 0:
47 | n = int(valid[0])
48 | else:
49 | n = len(g)
50 |
51 | fn = g[:n].sum()
52 | tn = n - fn
53 | tp = g[n:].sum()
54 | fp = len(g) - n - tp
55 |
56 | iou = tp / (tp + fn + fp)
57 |
58 | precision = tp / (tp + fp)
59 | recall = tp / (tp + fn)
60 |
61 | scores += [iou]
62 |
63 | return vals, scores
64 |
65 |
66 | def fast_rp_curve(p, g):
67 |
68 | g = g[p.sort().indices]
69 | p = torch.sigmoid(p.sort().values)
70 |
71 | precisions, recalls = [], []
72 | vals = np.linspace(p.min(), p.max(), 250)
73 |
74 | for q in p[::100000]:
75 |
76 | n = int(len(g) * q)
77 |
78 | valid = torch.where(p > q)[0]
79 | if len(valid) > 0:
80 | n = int(valid[0])
81 | else:
82 | n = len(g)
83 |
84 | fn = g[:n].sum()
85 | tn = n - fn
86 | tp = g[n:].sum()
87 | fp = len(g) - n - tp
88 |
89 | iou = tp / (tp + fn + fp)
90 |
91 | precision = tp / (tp + fp)
92 | recall = tp / (tp + fn)
93 |
94 | precisions += [precision]
95 | recalls += [recall]
96 |
97 | return recalls, precisions
98 |
99 |
100 | # Image processing
101 |
102 | def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2,
103 | brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224):
104 | import cv2
105 |
106 | rw = rect_width
107 |
108 | out = []
109 | for img, mask in zip(batch[1], batch[2]):
110 |
111 | img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img)
112 | mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
113 |
114 | img *= brightness
115 | img_bl = img
116 | if blur > 0: # best 5
117 | img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1)
118 |
119 | if grayscale:
120 | img_bl = img_bl[1][None]
121 |
122 | #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl
123 | # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask)
124 | img_inp = img*mask + (bg_fac) * img_bl * (1-mask)
125 |
126 | if rect:
127 | _, bbox = crop_mask(img, mask, context=0.1)
128 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None]
129 | img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None]
130 | img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
131 | img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
132 |
133 |
134 | if center_context is not None:
135 | img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size)
136 |
137 | if colorize:
138 | img_gray = denorm(img)
139 | img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY)
140 | img_gray = torch.stack([torch.from_numpy(img_gray)]*3)
141 | img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask)
142 | img_inp = norm(img_inp)
143 |
144 | if outline:
145 | cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
146 | outline_img = np.zeros(mask.shape, dtype=np.uint8)
147 | cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255))
148 | outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255.
149 | img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img)
150 | img_inp = norm(img_inp)
151 |
152 | out += [img_inp]
153 |
154 | return torch.stack(out)
155 |
156 |
157 | def object_crop(img, mask, context=0.0, square=False, image_size=224):
158 | img_crop, bbox = crop_mask(img, mask, context=context, square=square)
159 | img_crop = pad_to_square(img_crop, channel_dim=0)
160 | img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0)
161 | return img_crop
162 |
163 |
164 | def crop_mask(img, mask, context=0.0, square=False):
165 |
166 | assert img.shape[1:] == mask.shape
167 |
168 | bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()]
169 | bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()]
170 | bbox = [int(x) for x in bbox]
171 |
172 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
173 |
174 | # square mask
175 | if square:
176 | bbox[0] = int(max(0, bbox[0] - context * height))
177 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
178 | bbox[2] = int(max(0, bbox[2] - context * width))
179 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
180 |
181 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
182 | if height > width:
183 | bbox[2] = int(max(0, (bbox[2] - 0.5*height)))
184 | bbox[3] = bbox[2] + height
185 | else:
186 | bbox[0] = int(max(0, (bbox[0] - 0.5*width)))
187 | bbox[1] = bbox[0] + width
188 | else:
189 | bbox[0] = int(max(0, bbox[0] - context * height))
190 | bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
191 | bbox[2] = int(max(0, bbox[2] - context * width))
192 | bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
193 |
194 | width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
195 | img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]]
196 | return img_crop, bbox
197 |
198 |
199 | def pad_to_square(img, channel_dim=2, fill=0):
200 | """
201 |
202 |
203 | add padding such that a squared image is returned """
204 |
205 | from torchvision.transforms.functional import pad
206 |
207 | if channel_dim == 2:
208 | img = img.permute(2, 0, 1)
209 | elif channel_dim == 0:
210 | pass
211 | else:
212 | raise ValueError('invalid channel_dim')
213 |
214 | h, w = img.shape[1:]
215 | pady1 = pady2 = padx1 = padx2 = 0
216 |
217 | if h > w:
218 | padx1 = (h - w) // 2
219 | padx2 = h - w - padx1
220 | elif w > h:
221 | pady1 = (w - h) // 2
222 | pady2 = w - h - pady1
223 |
224 | img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant')
225 |
226 | if channel_dim == 2:
227 | img_padded = img_padded.permute(1, 2, 0)
228 |
229 | return img_padded
230 |
231 |
232 | # qualitative
233 |
234 | def split_sentence(inp, limit=9):
235 | t_new, current_len = [], 0
236 | for k, t in enumerate(inp.split(' ')):
237 | current_len += len(t) + 1
238 | t_new += [t+' ']
239 | # not last
240 | if current_len > limit and k != len(inp.split(' ')) - 1:
241 | current_len = 0
242 | t_new += ['\n']
243 |
244 | t_new = ''.join(t_new)
245 | return t_new
246 |
247 |
248 | from matplotlib import pyplot as plt
249 |
250 |
251 | def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None):
252 |
253 | row_off = 0 if labels is None else 1
254 | _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2)))
255 | [a.axis('off') for a in ax.flatten()]
256 |
257 | if labels is not None:
258 | for j in range(len(labels)):
259 | t_new = split_sentence(labels[j], limit=6)
260 | ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale)
261 |
262 |
263 | for i in range(len(imgs)):
264 | ax[i + row_off,0].imshow(imgs[i])
265 | for j in range(len(preds)):
266 | img = preds[j][i][0].detach().cpu().numpy()
267 |
268 | if gt_labels is not None and labels[j] == gt_labels[i]:
269 | print(j, labels[j], gt_labels[i])
270 | edgecolor = 'red'
271 | if aps is not None:
272 | ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8)
273 | else:
274 | edgecolor = 'k'
275 |
276 | rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none",
277 | edgecolor=edgecolor, linewidth=3)
278 | ax[i + row_off,1 + j].add_patch(rect)
279 |
280 | if vmax is None:
281 | this_vmax = 1
282 | elif vmax == 'per_prompt':
283 | this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))])
284 | elif vmax == 'per_image':
285 | this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))])
286 |
287 | ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap)
288 |
289 |
290 | # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max())
291 | plt.tight_layout()
292 | plt.subplots_adjust(wspace=0.05, hspace=0.05)
--------------------------------------------------------------------------------
/clipseg/example_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/example_image.jpg
--------------------------------------------------------------------------------
/clipseg/experiments/ablation.yaml:
--------------------------------------------------------------------------------
1 | configuration:
2 | batch_size: 64
3 | optimizer: torch.optim.AdamW
4 |
5 | lr: 0.001
6 |
7 | trainer: experiment_setup.train_loop
8 | scorer: experiment_setup.score
9 | model: models.clipseg.CLIPDensePredT
10 |
11 | lr_scheduler: cosine
12 | T_max: 20000
13 | eta_min: 0.0001
14 |
15 | max_iterations: 20000 # <-##########################################
16 | val_interval: null
17 |
18 | # dataset
19 | dataset: datasets.phrasecut.PhraseCut # <-----------------
20 | split_mode: pascal_test
21 | split: train
22 | mask: text_and_crop_blur_highlight352
23 | image_size: 352
24 | negative_prob: 0.2
25 | mix_text_max: 0.5
26 |
27 | # general
28 | mix: True # <-----------------
29 | prompt: shuffle+
30 | norm_cond: True
31 | mix_text_min: 0.0
32 | with_visual: True
33 |
34 | # model
35 | version: 'ViT-B/16'
36 | extract_layers: [3, 7, 9]
37 | reduce_dim: 64
38 | depth: 3
39 | fix_shift: False # <-##########################################
40 |
41 | loss: torch.nn.functional.binary_cross_entropy_with_logits
42 | amp: True
43 |
44 | test_configuration_common:
45 | normalize: True
46 | image_size: 352
47 | batch_size: 32
48 | sigmoid: True
49 | split: test
50 | label_support: True
51 |
52 | test_configuration:
53 |
54 | -
55 | name: pc
56 | metric: metrics.FixedIntervalMetrics
57 | test_dataset: phrasecut
58 | mask: text
59 |
60 | -
61 | name: pc-vis
62 | metric: metrics.FixedIntervalMetrics
63 | test_dataset: phrasecut
64 | mask: crop_blur_highlight352
65 | with_visual: True
66 | visual_only: True
67 |
68 |
69 | columns: [name,
70 | pc_fgiou_best, pc_miou_best, pc_fgiou_0.5,
71 | pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5,
72 | duration]
73 |
74 |
75 | individual_configurations:
76 |
77 | - {name: rd64-uni}
78 | - {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003}
79 | - {name: rd64-no-negatives, negative_prob: 0.0}
80 | - {name: rd64-neg0.5, negative_prob: 0.5}
81 | - {name: rd64-no-visual, with_visual: False, mix: False}
82 | - {name: rd16-uni, reduce_dim: 16}
83 | - {name: rd64-layer3, extract_layers: [3], depth: 1}
84 | - {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}}
--------------------------------------------------------------------------------
/clipseg/experiments/coco.yaml:
--------------------------------------------------------------------------------
1 | configuration:
2 | batch_size: 64
3 | optimizer: torch.optim.AdamW
4 |
5 | lr: 0.001
6 |
7 | trainer: experiment_setup.train_loop
8 | scorer: experiment_setup.score
9 | model: models.clipseg.CLIPDensePredT
10 |
11 | lr_scheduler: cosine
12 | T_max: 20000
13 | eta_min: 0.0001
14 |
15 | max_iterations: 20000
16 | val_interval: null
17 |
18 | # dataset
19 | dataset: datasets.coco_wrapper.COCOWrapper
20 | # split_mode: pascal_test
21 | split: train
22 | mask: text_and_blur3_highlight01
23 | image_size: 352
24 | normalize: True
25 | pre_crop_image_size: [sample, 1, 1.5]
26 | aug: 1new
27 |
28 | # general
29 | mix: True
30 | prompt: shuffle+
31 | norm_cond: True
32 | mix_text_min: 0.0
33 |
34 | # model
35 | out: 1
36 | extract_layers: [3, 7, 9]
37 | reduce_dim: 64
38 | depth: 3
39 | fix_shift: False
40 |
41 | loss: torch.nn.functional.binary_cross_entropy_with_logits
42 | amp: True
43 |
44 | test_configuration_common:
45 | normalize: True
46 | image_size: 352
47 | # max_iterations: 10
48 | batch_size: 8
49 | sigmoid: True
50 | test_dataset: coco
51 | metric: metrics.FixedIntervalMetrics
52 |
53 | test_configuration:
54 |
55 | -
56 | name: coco_t
57 | mask: text
58 |
59 | -
60 | name: coco_h
61 | mask: blur3_highlight01
62 |
63 | -
64 | name: coco_h2
65 | mask: crop_blur_highlight352
66 |
67 |
68 | columns: [i, name,
69 | coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5,
70 | coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5,
71 | coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t,
72 | train_loss, duration, date
73 | ]
74 |
75 | individual_configurations:
76 |
77 |
78 | - {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
79 | - {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
80 | - {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
81 | - {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
82 |
83 |
84 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
85 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
86 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
87 | - {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
88 |
89 |
90 | # ViT
91 | - {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
92 | - {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
93 | - {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
94 | - {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
95 |
96 |
97 | # BASELINE
98 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
99 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
100 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
101 | - {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
--------------------------------------------------------------------------------
/clipseg/experiments/pascal_0shot.yaml:
--------------------------------------------------------------------------------
1 | configuration:
2 | batch_size: 64
3 | optimizer: torch.optim.AdamW
4 |
5 | lr: 0.001
6 |
7 | # val_metric_class: metrics.BinaryIoU
8 | # use_val_metric: BIoU_fg
9 |
10 | trainer: experiment_setup.train_loop
11 | scorer: experiment_setup.score
12 | model: models.clipseg.CLIPDensePredT
13 |
14 | lr_scheduler: cosine
15 | T_max: 20000
16 | eta_min: 0.0001
17 |
18 | max_iterations: 20000 # <-##########################################
19 | val_interval: null
20 |
21 | # dataset
22 | dataset: datasets.phrasecut.PhraseCut # <-----------------
23 | split_mode: pascal_test
24 | split: train
25 | image_size: 352
26 | normalize: True
27 | pre_crop_image_size: [sample, 1, 1.5]
28 | aug: 1new
29 |
30 | # new, not
31 | with_visual: True
32 |
33 | # general
34 | mix: False # <-----------------
35 | prompt: shuffle+
36 | norm_cond: True
37 | mix_text_min: 0.0
38 |
39 | # model
40 | out: 1
41 | extract_layers: [3, 7, 9]
42 | reduce_dim: 64
43 | depth: 3
44 | fix_shift: False # <-##########################################
45 |
46 | loss: torch.nn.functional.binary_cross_entropy_with_logits
47 | amp: True
48 |
49 | test_configuration_common:
50 | normalize: True
51 | image_size: 352
52 | batch_size: 32
53 | # max_iterations: 150
54 |
55 | test_configuration:
56 | test_dataset: pascal_zs
57 |
58 | columns: [name, pas_zs_seen, pas_zs_unseen, duration, date]
59 |
60 | - {name: rd64-uni-zs5, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, remove_classes: [zs, 5], negative_prob: 0.2, mix: True, mix_text_max: 0.5, mask: text_and_crop_blur_highlight352}
61 | - {name: rd64-uni-zs2, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, remove_classes: [zs, 2], negative_prob: 0.2, mix: True, mix_text_max: 0.5, mask: text_and_crop_blur_highlight352}
62 |
--------------------------------------------------------------------------------
/clipseg/experiments/pascal_1shot.yaml:
--------------------------------------------------------------------------------
1 | configuration:
2 | batch_size: 64
3 | optimizer: torch.optim.AdamW
4 |
5 | lr: 0.001
6 |
7 | trainer: experiment_setup.train_loop
8 | scorer: experiment_setup.score
9 | model: models.clipseg.CLIPDensePredT
10 |
11 | lr_scheduler: cosine
12 | T_max: 20000
13 | eta_min: 0.0001
14 |
15 | max_iterations: 20000 # <-##########################################
16 | val_interval: null
17 |
18 | # dataset
19 | dataset: datasets.phrasecut.PhraseCut
20 | split_mode: pascal_test
21 | mode: train
22 | mask: text_and_crop_blur_highlight352
23 | image_size: 352
24 | normalize: True
25 | pre_crop_image_size: [sample, 1, 1.5]
26 | aug: 1new
27 | with_visual: True
28 | split: train
29 |
30 | # general
31 | mix: True
32 | prompt: shuffle+
33 | norm_cond: True
34 | mix_text_min: 0.0
35 |
36 | # model
37 | out: 1
38 | version: 'ViT-B/16'
39 | extract_layers: [3, 7, 9]
40 | reduce_dim: 64
41 | depth: 3
42 |
43 | loss: torch.nn.functional.binary_cross_entropy_with_logits
44 | amp: True
45 |
46 | test_configuration_common:
47 | normalize: True
48 | image_size: 352
49 | metric: metrics.FixedIntervalMetrics
50 | batch_size: 1
51 | test_dataset: pascal
52 | sigmoid: True
53 | # max_iterations: 250
54 |
55 | test_configuration:
56 |
57 | -
58 | name: pas_t
59 | mask: text
60 |
61 | -
62 | name: pas_h
63 | mask: blur3_highlight01
64 |
65 | -
66 | name: pas_h2
67 | mask: crop_blur_highlight352
68 |
69 |
70 | columns: [name,
71 | pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct,
72 | pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct,
73 | pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t,
74 | train_loss, duration, date
75 | ]
76 |
77 | individual_configurations:
78 |
79 | - {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}}
80 | - {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}}
81 | - {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}}
82 | - {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}}
83 |
84 |
85 | - {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}}
86 | - {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}}
87 | - {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}}
88 | - {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}}
89 |
90 |
91 | # baseline
92 | - {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}}
93 | - {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}}
94 | - {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}}
95 | - {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}}
96 |
97 | # ViT
98 | - {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}}
99 | - {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}}
100 | - {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}}
101 | - {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}}
102 |
--------------------------------------------------------------------------------
/clipseg/experiments/phrasecut.yaml:
--------------------------------------------------------------------------------
1 | configuration:
2 | batch_size: 64
3 | optimizer: torch.optim.AdamW
4 |
5 | lr: 0.001
6 |
7 | trainer: experiment_setup.train_loop
8 | scorer: experiment_setup.score
9 | model: models.clipseg.CLIPDensePredT
10 |
11 | lr_scheduler: cosine
12 | T_max: 20000
13 | eta_min: 0.0001
14 |
15 | max_iterations: 20000
16 | val_interval: null
17 |
18 | # dataset
19 | dataset: datasets.phrasecut.PhraseCut # <-----------------
20 | split_mode: pascal_test
21 | split: train
22 | mask: text_and_crop_blur_highlight352
23 | image_size: 352
24 | normalize: True
25 | pre_crop_image_size: [sample, 1, 1.5]
26 | aug: 1new
27 |
28 | # general
29 | mix: False # <-----------------
30 | prompt: shuffle+
31 | norm_cond: True
32 | mix_text_min: 0.0
33 |
34 | # model
35 | out: 1
36 | extract_layers: [3, 7, 9]
37 | reduce_dim: 64
38 | depth: 3
39 | fix_shift: False
40 |
41 | loss: torch.nn.functional.binary_cross_entropy_with_logits
42 | amp: True
43 |
44 | test_configuration_common:
45 | normalize: True
46 | image_size: 352
47 | batch_size: 32
48 | # max_iterations: 5
49 | # max_iterations: 150
50 |
51 | test_configuration:
52 |
53 | -
54 | name: pc # old: phrasecut
55 | metric: metrics.FixedIntervalMetrics
56 | test_dataset: phrasecut
57 | split: test
58 | mask: text
59 | label_support: True
60 | sigmoid: True
61 |
62 |
63 | columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date]
64 |
65 |
66 | individual_configurations:
67 |
68 | # important ones
69 |
70 |
71 | - {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5}
72 |
73 | # this is almost the same training setting as the refined model except for transformer dropout of 0.1 (currently not implemented in the model)
74 | - {name: rd64-uni-refined, version: 'ViT-B/16', reduce_dim: 64, negative_prob: 0.2, complex_trans_conv: True, with_visual: True, mix: True, mix_text_max: 0.5, T_max: 50000, max_iterations: 50000}
75 |
76 |
77 | # this was accedentally trained using old mask
78 | - {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01}
79 | - {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False}
80 | # this was accedentally trained using old mask
81 | - {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01}
82 |
83 | - {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003}
84 | - {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001}
85 |
--------------------------------------------------------------------------------
/clipseg/general_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import inspect
3 | import torch
4 | import os
5 | import sys
6 | import yaml
7 | from shutil import copy, copytree
8 | from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename
9 |
10 |
11 | class Logger(object):
12 |
13 | def __getattr__(self, k):
14 | return print
15 |
16 | log = Logger()
17 |
18 | def training_config_from_cli_args():
19 | experiment_name = sys.argv[1]
20 | experiment_id = int(sys.argv[2])
21 |
22 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
23 |
24 | config = yaml_config['configuration']
25 | config = {**config, **yaml_config['individual_configurations'][experiment_id]}
26 | config = AttributeDict(config)
27 | return config
28 |
29 |
30 | def score_config_from_cli_args():
31 | experiment_name = sys.argv[1]
32 | experiment_id = int(sys.argv[2])
33 |
34 |
35 | yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
36 |
37 | config = yaml_config['test_configuration_common']
38 |
39 | if type(yaml_config['test_configuration']) == list:
40 | test_id = int(sys.argv[3])
41 | config = {**config, **yaml_config['test_configuration'][test_id]}
42 | else:
43 | config = {**config, **yaml_config['test_configuration']}
44 |
45 | if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]:
46 | config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']}
47 |
48 | train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name']
49 |
50 | config = AttributeDict(config)
51 | return config, train_checkpoint_id
52 |
53 |
54 | def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository',
55 | local_dir='~/datasets'):
56 | """ copies files from repository to local folder.
57 |
58 | repo_files: list of filenames or list of tuples [filename, target path]
59 |
60 | e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar'])
61 | will create a folder 'MyDataset' in local_dir, and extract the content of
62 | '/data/dataset1.tar' to /MyDataset/other/path.
63 | """
64 |
65 | local_dir = realpath(join(expanduser(local_dir), local_name))
66 |
67 | dataset_exists = True
68 |
69 | # check if folder is available
70 | if not isdir(local_dir):
71 | dataset_exists = False
72 |
73 | if integrity_check is not None:
74 | try:
75 | integrity_ok = integrity_check(local_dir)
76 | except BaseException:
77 | integrity_ok = False
78 |
79 | if integrity_ok:
80 | log.hint('Passed custom integrity check')
81 | else:
82 | log.hint('Custom integrity check failed')
83 |
84 | dataset_exists = dataset_exists and integrity_ok
85 |
86 | if not dataset_exists:
87 |
88 | repo_dir = realpath(expanduser(repo_dir))
89 |
90 | for i, filename in enumerate(repo_files):
91 |
92 | if type(filename) == str:
93 | origin, target = filename, filename
94 | archive_target = join(local_dir, basename(origin))
95 | extract_target = join(local_dir)
96 | else:
97 | origin, target = filename
98 | archive_target = join(local_dir, dirname(target), basename(origin))
99 | extract_target = join(local_dir, dirname(target))
100 |
101 | archive_origin = join(repo_dir, origin)
102 |
103 | log.hint(f'copy: {archive_origin} to {archive_target}')
104 |
105 | # make sure the path exists
106 | os.makedirs(dirname(archive_target), exist_ok=True)
107 |
108 | if os.path.isfile(archive_target):
109 | # only copy if size differs
110 | if os.path.getsize(archive_target) != os.path.getsize(archive_origin):
111 | log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}')
112 | copy(archive_origin, archive_target)
113 | else:
114 | copy(archive_origin, archive_target)
115 |
116 | extract_archive(archive_target, extract_target, noarchive_ok=True)
117 |
118 | # concurrent processes might have deleted the file
119 | if os.path.isfile(archive_target):
120 | os.remove(archive_target)
121 |
122 |
123 | def extract_archive(filename, target_folder=None, noarchive_ok=False):
124 | from subprocess import run, PIPE
125 |
126 | if filename.endswith('.tgz') or filename.endswith('.tar'):
127 | command = f'tar -xf {filename}'
128 | command += f' -C {target_folder}' if target_folder is not None else ''
129 | elif filename.endswith('.tar.gz'):
130 | command = f'tar -xzf {filename}'
131 | command += f' -C {target_folder}' if target_folder is not None else ''
132 | elif filename.endswith('zip'):
133 | command = f'unzip {filename}'
134 | command += f' -d {target_folder}' if target_folder is not None else ''
135 | else:
136 | if noarchive_ok:
137 | return
138 | else:
139 | raise ValueError(f'unsuppored file ending of {filename}')
140 |
141 | log.hint(command)
142 | result = run(command.split(), stdout=PIPE, stderr=PIPE)
143 | if result.returncode != 0:
144 | print(result.stdout, result.stderr)
145 |
146 |
147 | class AttributeDict(dict):
148 | """
149 | An extended dictionary that allows access to elements as atttributes and counts
150 | these accesses. This way, we know if some attributes were never used.
151 | """
152 |
153 | def __init__(self, *args, **kwargs):
154 | from collections import Counter
155 | super().__init__(*args, **kwargs)
156 | self.__dict__['counter'] = Counter()
157 |
158 | def __getitem__(self, k):
159 | self.__dict__['counter'][k] += 1
160 | return super().__getitem__(k)
161 |
162 | def __getattr__(self, k):
163 | self.__dict__['counter'][k] += 1
164 | return super().get(k)
165 |
166 | def __setattr__(self, k, v):
167 | return super().__setitem__(k, v)
168 |
169 | def __delattr__(self, k, v):
170 | return super().__delitem__(k, v)
171 |
172 | def unused_keys(self, exceptions=()):
173 | return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions]
174 |
175 | def assume_no_unused_keys(self, exceptions=()):
176 | if len(self.unused_keys(exceptions=exceptions)) > 0:
177 | log.warning('Unused keys:', self.unused_keys(exceptions=exceptions))
178 |
179 |
180 | def get_attribute(name):
181 | import importlib
182 |
183 | if name is None:
184 | raise ValueError('The provided attribute is None')
185 |
186 | name_split = name.split('.')
187 | mod = importlib.import_module('.'.join(name_split[:-1]))
188 | return getattr(mod, name_split[-1])
189 |
190 |
191 |
192 | def filter_args(input_args, default_args):
193 |
194 | updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()}
195 | used_args = {k: v for k, v in input_args.items() if k in default_args}
196 | unused_args = {k: v for k, v in input_args.items() if k not in default_args}
197 |
198 | return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args)
199 |
200 |
201 | def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False):
202 |
203 | config = json.load(open(join('logs', checkpoint_id, 'config.json')))
204 |
205 | if model_args != 'from_config' and type(model_args) != dict:
206 | raise ValueError('model_args must either be "from_config" or a dictionary of values')
207 |
208 | model_cls = get_attribute(config['model'])
209 |
210 | # load model
211 | if model_args == 'from_config':
212 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
213 |
214 | model = model_cls(**model_args)
215 |
216 | if weights_file is None:
217 | weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
218 | else:
219 | weights_file = realpath(join('logs', checkpoint_id, weights_file))
220 |
221 | if isfile(weights_file):
222 | weights = torch.load(weights_file)
223 | for _, w in weights.items():
224 | assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
225 | model.load_state_dict(weights, strict=strict)
226 | else:
227 | raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
228 |
229 | if with_config:
230 | return model, config
231 |
232 | return model
233 |
234 |
235 | class TrainingLogger(object):
236 |
237 | def __init__(self, model, log_dir, config=None, *args):
238 | super().__init__()
239 | self.model = model
240 | self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None
241 |
242 | os.makedirs('logs/', exist_ok=True)
243 | os.makedirs(self.base_path, exist_ok=True)
244 |
245 | if config is not None:
246 | json.dump(config, open(join(self.base_path, 'config.json'), 'w'))
247 |
248 | def iter(self, i, **kwargs):
249 | if i % 100 == 0 and 'loss' in kwargs:
250 | loss = kwargs['loss']
251 | print(f'iteration {i}: loss {loss:.4f}')
252 |
253 | def save_weights(self, only_trainable=False, weight_file='weights.pth'):
254 | if self.model is None:
255 | raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.')
256 |
257 | weights_path = join(self.base_path, weight_file)
258 |
259 | weight_dict = self.model.state_dict()
260 |
261 | if only_trainable:
262 | weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad}
263 |
264 | torch.save(weight_dict, weights_path)
265 | log.info(f'Saved weights to {weights_path}')
266 |
267 | def __enter__(self):
268 | return self
269 |
270 | def __exit__(self, type, value, traceback):
271 | """ automatically stop processes if used in a context manager """
272 | pass
--------------------------------------------------------------------------------
/clipseg/metrics.py:
--------------------------------------------------------------------------------
1 | from torch.functional import Tensor
2 | from general_utils import log
3 | from collections import defaultdict
4 | import numpy as np
5 |
6 | import torch
7 | from torch.nn import functional as nnf
8 |
9 |
10 | class BaseMetric(object):
11 |
12 | def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True,
13 | eval_validation=True):
14 | self._names = tuple(metric_names)
15 | self._eval_intermediate = eval_intermediate
16 | self._eval_validation = eval_validation
17 |
18 | self._pred_range = pred_range
19 | self._pred_index = pred_index
20 | self._gt_index = gt_index
21 |
22 | self.predictions = []
23 | self.ground_truths = []
24 |
25 | def eval_intermediate(self):
26 | return self._eval_intermediate
27 |
28 | def eval_validation(self):
29 | return self._eval_validation
30 |
31 | def names(self):
32 | return self._names
33 |
34 | def add(self, predictions, ground_truth):
35 | raise NotImplementedError
36 |
37 | def value(self):
38 | raise NotImplementedError
39 |
40 | def scores(self):
41 | # similar to value but returns dict
42 | value = self.value()
43 | if type(value) == dict:
44 | return value
45 | else:
46 | assert type(value) in {list, tuple}
47 | return list(zip(self.names(), self.value()))
48 |
49 | def _get_pred_gt(self, predictions, ground_truth):
50 | pred = predictions[self._pred_index]
51 | gt = ground_truth[self._gt_index]
52 |
53 | if self._pred_range is not None:
54 | pred = pred[:, self._pred_range[0]: self._pred_range[1]]
55 |
56 | return pred, gt
57 |
58 |
59 | class FixedIntervalMetrics(BaseMetric):
60 |
61 | def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None,
62 | resize_pred=None, n_values=51, custom_threshold=None):
63 |
64 |
65 | super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'))
66 | self.intersections = []
67 | self.unions = []
68 | # self.threshold = threshold
69 | self.sigmoid = sigmoid
70 | self.resize_to = resize_to
71 | self.resize_pred = resize_pred # resize prediction to match ground truth
72 | self.class_count = defaultdict(lambda: 0)
73 | self.per_class = defaultdict(lambda : [0,0])
74 | self.ignore_mask = ignore_mask
75 | self.custom_threshold = custom_threshold
76 |
77 | self.scores_ap = []
78 | self.scores_iou = []
79 | self.gts, self.preds = [], []
80 | self.classes = []
81 |
82 | # [1:-1] ignores 0 and 1
83 | self.threshold_values = np.linspace(0, 1, n_values)[1:-1]
84 |
85 | self.metrics = dict(tp=[], fp=[], fn=[], tn=[])
86 |
87 | def add(self, pred, gt):
88 |
89 | pred_batch = pred[0].cpu()
90 |
91 | if self.sigmoid:
92 | pred_batch = torch.sigmoid(pred_batch)
93 |
94 | gt_batch = gt[0].cpu()
95 | mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch))
96 | cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch)
97 |
98 | if self.resize_to is not None:
99 | gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest')
100 | pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False)
101 |
102 | if isinstance(cls_batch, torch.Tensor):
103 | cls_batch = cls_batch.cpu().numpy().tolist()
104 |
105 | assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}'
106 |
107 | for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch):
108 |
109 | if self.resize_pred:
110 | predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True)
111 |
112 | p = predictions.flatten()
113 | g = ground_truth.flatten()
114 |
115 | assert len(p) == len(g)
116 |
117 | if mask is not None:
118 | m = mask.flatten().bool()
119 | p = p[m]
120 | g = g[m]
121 |
122 | p_sorted = p.sort()
123 | p = p_sorted.values
124 | g = g[p_sorted.indices]
125 |
126 | tps, fps, fns, tns = [], [], [], []
127 | for thresh in self.threshold_values:
128 |
129 | valid = torch.where(p > thresh)[0]
130 | if len(valid) > 0:
131 | n = int(valid[0])
132 | else:
133 | n = len(g)
134 |
135 | fn = int(g[:n].sum())
136 | tp = int(g[n:].sum())
137 | fns += [fn]
138 | tns += [n - fn]
139 | tps += [tp]
140 | fps += [len(g) - n - tp]
141 |
142 | self.metrics['tp'] += [tps]
143 | self.metrics['fp'] += [fps]
144 | self.metrics['fn'] += [fns]
145 | self.metrics['tn'] += [tns]
146 |
147 | self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls]
148 |
149 | def value(self):
150 |
151 | import time
152 | t_start = time.time()
153 |
154 | if set(self.classes) == set([None]):
155 | all_classes = None
156 | log.warning('classes were not provided, cannot compute mIoU')
157 | else:
158 | all_classes = set(int(c) for c in self.classes)
159 | # log.info(f'compute metrics for {len(all_classes)} classes')
160 |
161 | summed = {k: [sum([self.metrics[k][i][j]
162 | for i in range(len(self.metrics[k]))])
163 | for j in range(len(self.threshold_values))]
164 | for k in self.metrics.keys()}
165 |
166 | if all_classes is not None:
167 |
168 | assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn'])
169 | # group by class
170 | metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes}
171 | for i in range(len(self.metrics['tp'])):
172 | for k in self.metrics.keys():
173 | metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]]
174 |
175 | # sum over all instances within the classes
176 | summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()}
177 |
178 |
179 | # Compute average precision
180 |
181 | assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made'
182 |
183 | # only consider values where a prediction is made
184 | precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values))
185 | if summed['tp'][j] + summed['fp'][j] > 0]
186 | recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))
187 | if summed['tp'][j] + summed['fp'][j] > 0]
188 |
189 | # remove duplicate recall-precision-pairs (and sort by recall value)
190 | recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0]))
191 |
192 | from scipy.integrate import simps
193 | ap = simps(precisions, recalls)
194 |
195 | # Compute best IoU
196 | fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))]
197 |
198 | biniou_scores = [
199 | 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) +
200 | 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j]))
201 | for j in range(len(self.threshold_values))
202 | ]
203 |
204 | index_0p5 = self.threshold_values.tolist().index(0.5)
205 | index_0p1 = self.threshold_values.tolist().index(0.1)
206 | index_0p2 = self.threshold_values.tolist().index(0.2)
207 | index_0p3 = self.threshold_values.tolist().index(0.3)
208 |
209 | if self.custom_threshold is not None:
210 | index_ct = self.threshold_values.tolist().index(self.custom_threshold)
211 |
212 | if all_classes is not None:
213 | # mean IoU
214 | mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j])
215 | for c in all_classes])
216 | for j in range(len(self.threshold_values))]
217 |
218 | mean_iou_dict = {
219 | 'miou_best': max(mean_ious) if all_classes is not None else None,
220 | 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None,
221 | 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None,
222 | 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None,
223 | 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None,
224 | 'miou_best_t': self.threshold_values[np.argmax(mean_ious)],
225 | 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None,
226 | 'mean_iou_scores': mean_ious,
227 | }
228 |
229 | print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s')
230 |
231 | return {
232 | 'ap': ap,
233 |
234 | # fgiou
235 | 'fgiou_best': max(fgiou_scores),
236 | 'fgiou_0.5': fgiou_scores[index_0p5],
237 | 'fgiou_0.1': fgiou_scores[index_0p1],
238 | 'fgiou_0.2': fgiou_scores[index_0p2],
239 | 'fgiou_0.3': fgiou_scores[index_0p3],
240 | 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)],
241 |
242 | # mean iou
243 |
244 |
245 | # biniou
246 | 'biniou_best': max(biniou_scores),
247 | 'biniou_0.5': biniou_scores[index_0p5],
248 | 'biniou_0.1': biniou_scores[index_0p1],
249 | 'biniou_0.2': biniou_scores[index_0p2],
250 | 'biniou_0.3': biniou_scores[index_0p3],
251 | 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)],
252 |
253 | # custom threshold
254 | 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None,
255 | 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None,
256 | 'ct': self.custom_threshold,
257 |
258 | # statistics
259 | 'fgiou_scores': fgiou_scores,
260 | 'biniou_scores': biniou_scores,
261 | 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))),
262 | 'summed_statistics': summed,
263 | 'summed_by_cls_statistics': summed_by_cls,
264 |
265 | **mean_iou_dict
266 | }
267 |
268 | # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'
269 |
270 | # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls}
271 |
272 |
--------------------------------------------------------------------------------
/clipseg/models/vitseg.py:
--------------------------------------------------------------------------------
1 | import math
2 | from posixpath import basename, dirname, join
3 | # import clip
4 | from clip.model import convert_weights
5 | import torch
6 | import json
7 | from torch import nn
8 | from torch.nn import functional as nnf
9 | from torch.nn.modules import activation
10 | from torch.nn.modules.activation import ReLU
11 | from torchvision import transforms
12 |
13 | normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
14 |
15 | from torchvision.models import ResNet
16 |
17 |
18 | def process_prompts(conditional, prompt_list, conditional_map):
19 | # DEPRECATED
20 |
21 | # randomly sample a synonym
22 | words = [conditional_map[int(i)] for i in conditional]
23 | words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
24 | words = [w.replace('_', ' ') for w in words]
25 |
26 | if prompt_list is not None:
27 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
28 | prompts = [prompt_list[i] for i in prompt_indices]
29 | else:
30 | prompts = ['a photo of {}'] * (len(words))
31 |
32 | return [promt.format(w) for promt, w in zip(prompts, words)]
33 |
34 |
35 | class VITDenseBase(nn.Module):
36 |
37 | def rescaled_pos_emb(self, new_size):
38 | assert len(new_size) == 2
39 |
40 | a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
41 | b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
42 | return torch.cat([self.model.positional_embedding[:1], b])
43 |
44 | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
45 |
46 | with torch.no_grad():
47 |
48 | x_inp = nnf.interpolate(x_inp, (384, 384))
49 |
50 | x = self.model.patch_embed(x_inp)
51 | cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
52 | if self.model.dist_token is None:
53 | x = torch.cat((cls_token, x), dim=1)
54 | else:
55 | x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
56 | x = self.model.pos_drop(x + self.model.pos_embed)
57 |
58 | activations = []
59 | for i, block in enumerate(self.model.blocks):
60 | x = block(x)
61 |
62 | if i in extract_layers:
63 | # permute to be compatible with CLIP
64 | activations += [x.permute(1,0,2)]
65 |
66 | x = self.model.norm(x)
67 | x = self.model.head(self.model.pre_logits(x[:, 0]))
68 |
69 | # again for CLIP compatibility
70 | # x = x.permute(1, 0, 2)
71 |
72 | return x, activations, None
73 |
74 | def sample_prompts(self, words, prompt_list=None):
75 |
76 | prompt_list = prompt_list if prompt_list is not None else self.prompt_list
77 |
78 | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
79 | prompts = [prompt_list[i] for i in prompt_indices]
80 | return [promt.format(w) for promt, w in zip(prompts, words)]
81 |
82 | def get_cond_vec(self, conditional, batch_size):
83 | # compute conditional from a single string
84 | if conditional is not None and type(conditional) == str:
85 | cond = self.compute_conditional(conditional)
86 | cond = cond.repeat(batch_size, 1)
87 |
88 | # compute conditional from string list/tuple
89 | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
90 | assert len(conditional) == batch_size
91 | cond = self.compute_conditional(conditional)
92 |
93 | # use conditional directly
94 | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
95 | cond = conditional
96 |
97 | # compute conditional from image
98 | elif conditional is not None and type(conditional) == torch.Tensor:
99 | with torch.no_grad():
100 | cond, _, _ = self.visual_forward(conditional)
101 | else:
102 | raise ValueError('invalid conditional')
103 | return cond
104 |
105 | def compute_conditional(self, conditional):
106 | import clip
107 |
108 | dev = next(self.parameters()).device
109 |
110 | if type(conditional) in {list, tuple}:
111 | text_tokens = clip.tokenize(conditional).to(dev)
112 | cond = self.clip_model.encode_text(text_tokens)
113 | else:
114 | if conditional in self.precomputed_prompts:
115 | cond = self.precomputed_prompts[conditional].float().to(dev)
116 | else:
117 | text_tokens = clip.tokenize([conditional]).to(dev)
118 | cond = self.clip_model.encode_text(text_tokens)[0]
119 |
120 | return cond
121 |
122 |
123 | class VITDensePredT(VITDenseBase):
124 |
125 | def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
126 | depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
127 | learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
128 | add_calibration=False, process_cond=None, not_pretrained=False):
129 | super().__init__()
130 | # device = 'cpu'
131 |
132 | self.extract_layers = extract_layers
133 | self.cond_layer = cond_layer
134 | self.limit_to_clip_only = limit_to_clip_only
135 | self.process_cond = None
136 |
137 | if add_calibration:
138 | self.calibration_conds = 1
139 |
140 | self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
141 |
142 | self.add_activation1 = True
143 |
144 | import timm
145 | self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
146 | self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
147 |
148 | for p in self.model.parameters():
149 | p.requires_grad_(False)
150 |
151 | import clip
152 | self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
153 | # del self.clip_model.visual
154 |
155 |
156 | self.token_shape = (14, 14)
157 |
158 | # conditional
159 | if reduce_cond is not None:
160 | self.reduce_cond = nn.Linear(512, reduce_cond)
161 | for p in self.reduce_cond.parameters():
162 | p.requires_grad_(False)
163 | else:
164 | self.reduce_cond = None
165 |
166 | # self.film = AVAILABLE_BLOCKS['film'](512, 128)
167 | self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
168 | self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
169 |
170 | # DEPRECATED
171 | # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
172 |
173 | assert len(self.extract_layers) == depth
174 |
175 | self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
176 | self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
177 | self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
178 |
179 | trans_conv_ks = (16, 16)
180 | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
181 |
182 | # refinement and trans conv
183 |
184 | if learn_trans_conv_only:
185 | for p in self.parameters():
186 | p.requires_grad_(False)
187 |
188 | for p in self.trans_conv.parameters():
189 | p.requires_grad_(True)
190 |
191 | if prompt == 'fixed':
192 | self.prompt_list = ['a photo of a {}.']
193 | elif prompt == 'shuffle':
194 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
195 | elif prompt == 'shuffle+':
196 | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
197 | 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
198 | 'a bad photo of a {}.', 'a photo of the {}.']
199 | elif prompt == 'shuffle_clip':
200 | from models.clip_prompts import imagenet_templates
201 | self.prompt_list = imagenet_templates
202 |
203 | if process_cond is not None:
204 | if process_cond == 'clamp' or process_cond[0] == 'clamp':
205 |
206 | val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
207 |
208 | def clamp_vec(x):
209 | return torch.clamp(x, -val, val)
210 |
211 | self.process_cond = clamp_vec
212 |
213 | elif process_cond.endswith('.pth'):
214 |
215 | shift = torch.load(process_cond)
216 | def add_shift(x):
217 | return x + shift.to(x.device)
218 |
219 | self.process_cond = add_shift
220 |
221 | import pickle
222 | precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
223 | self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
224 |
225 |
226 | def forward(self, inp_image, conditional=None, return_features=False, mask=None):
227 |
228 | assert type(return_features) == bool
229 |
230 | # inp_image = inp_image.to(self.model.positional_embedding.device)
231 |
232 | if mask is not None:
233 | raise ValueError('mask not supported')
234 |
235 | # x_inp = normalize(inp_image)
236 | x_inp = inp_image
237 |
238 | bs, dev = inp_image.shape[0], x_inp.device
239 |
240 | inp_image_size = inp_image.shape[2:]
241 |
242 | cond = self.get_cond_vec(conditional, bs)
243 |
244 | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
245 |
246 | activation1 = activations[0]
247 | activations = activations[1:]
248 |
249 | a = None
250 | for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
251 |
252 | if a is not None:
253 | a = reduce(activation) + a
254 | else:
255 | a = reduce(activation)
256 |
257 | if i == self.cond_layer:
258 | if self.reduce_cond is not None:
259 | cond = self.reduce_cond(cond)
260 |
261 | a = self.film_mul(cond) * a + self.film_add(cond)
262 |
263 | a = block(a)
264 |
265 | for block in self.extra_blocks:
266 | a = a + block(a)
267 |
268 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
269 |
270 | size = int(math.sqrt(a.shape[2]))
271 |
272 | a = a.view(bs, a.shape[1], size, size)
273 |
274 | if self.trans_conv is not None:
275 | a = self.trans_conv(a)
276 |
277 | if self.upsample_proj is not None:
278 | a = self.upsample_proj(a)
279 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
280 |
281 | a = nnf.interpolate(a, inp_image_size)
282 |
283 | if return_features:
284 | return a, visual_q, cond, [activation1] + activations
285 | else:
286 | return a,
287 |
--------------------------------------------------------------------------------
/clipseg/mycode.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import inspect
3 | import json
4 | import yaml
5 | import math
6 | import os
7 | import sys
8 |
9 | from general_utils import log
10 |
11 | import numpy as np
12 | from functools import partial
13 | from os.path import expanduser, join, isfile, basename
14 |
15 | from torch.cuda.amp import autocast, GradScaler
16 | from torch.optim.lr_scheduler import LambdaLR
17 | from contextlib import nullcontext
18 | from torch.utils.data import DataLoader
19 |
20 | from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
21 |
22 |
23 | def cosine_warmup_lr(i, warmup=10, max_iter=90):
24 | """ Cosine LR with Warmup """
25 | if i < warmup:
26 | return (i+1)/(warmup+1)
27 | else:
28 | return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
29 |
30 |
31 | def validate(model, dataset, config):
32 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
33 |
34 | metric_class, use_metric = config.val_metric_class, config.use_val_metric
35 | loss_fn = get_attribute(config.loss)
36 |
37 | model.eval()
38 | model.cuda()
39 |
40 | if metric_class is not None:
41 | metric = get_attribute(metric_class)()
42 |
43 | with torch.no_grad():
44 |
45 | i, losses = 0, []
46 | for data_x, data_y in data_loader:
47 |
48 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
49 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
50 |
51 | prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
52 | pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
53 |
54 | if metric_class is not None:
55 | metric.add([pred], data_y)
56 |
57 | # pred = model(data_x[0], prompts)
58 | # loss = loss_fn(pred[0], data_y[0])
59 | loss = loss_fn(pred, data_y[0])
60 | losses += [float(loss)]
61 |
62 | i += 1
63 |
64 | if config.val_max_iterations is not None and i > config.val_max_iterations:
65 | break
66 |
67 | if use_metric is None:
68 | return np.mean(losses), {}, False
69 | else:
70 | metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
71 | return np.mean(losses), metric_scores, True
72 |
73 |
74 | def main():
75 |
76 | config = training_config_from_cli_args()
77 |
78 | val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
79 |
80 | model_cls = get_attribute(config.model)
81 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
82 | model = model_cls(**model_args).cuda()
83 |
84 | dataset_cls = get_attribute(config.dataset)
85 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
86 |
87 | dataset = dataset_cls(**dataset_args)
88 |
89 | log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
90 |
91 | if val_interval is not None:
92 | dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
93 | _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
94 | print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
95 |
96 | dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
97 |
98 | # optimizer
99 | opt_cls = get_attribute(config.optimizer)
100 | if config.optimize == 'torch.optim.SGD':
101 | opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
102 | else:
103 | opt_args = {}
104 | opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
105 |
106 | if config.lr_scheduler == 'cosine':
107 | assert config.T_max is not None and config.eta_min is not None
108 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
109 | elif config.lr_scheduler == 'warmup_cosine':
110 | lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
111 | else:
112 | lr_scheduler = None
113 |
114 | batch_size, max_iterations = config.batch_size, config.max_iterations
115 |
116 | loss_fn = get_attribute(config.loss)
117 |
118 | if config.amp:
119 | log.info('Using AMP')
120 | autocast_fn = autocast
121 | scaler = GradScaler()
122 | else:
123 | autocast_fn, scaler = nullcontext, None
124 |
125 |
126 | save_only_trainable = True
127 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
128 |
129 | # disable config when hyperparam. opt. to avoid writing logs.
130 | tracker_config = config if not config.hyperparameter_optimization else None
131 |
132 | with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
133 |
134 | i = 0
135 | while True:
136 | for data_x, data_y in data_loader:
137 |
138 | # between caption and output feature.
139 | # 1. Sample random captions
140 | # 2. Check alignment with CLIP
141 |
142 | # randomly mix text and visual support conditionals
143 | if config.mix:
144 |
145 | assert config.mask.startswith('text_and')
146 |
147 | with autocast_fn():
148 | # data_x[1] = text label
149 | prompts = model.sample_prompts(data_x[1])
150 |
151 | # model.clip_model()
152 |
153 | text_cond = model.compute_conditional(prompts)
154 | if model.__class__.__name__ == 'CLIPDensePredTMasked':
155 | # when mask=='separate'
156 | visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
157 | else:
158 | # data_x[2] = visual prompt
159 | visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
160 |
161 | max_txt = config.mix_text_max if config.mix_text_max is not None else 1
162 | batch_size = text_cond.shape[0]
163 |
164 | # sample weights for each element in batch
165 | text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
166 | text_weights = text_weights.cuda()
167 |
168 | if dataset.__class__.__name__ == 'PhraseCut':
169 | # give full weight to text where support_image is invalid
170 | visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
171 | text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
172 |
173 | cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
174 |
175 | else:
176 | # no mix
177 |
178 | if model.__class__.__name__ == 'CLIPDensePredTMasked':
179 | # compute conditional vector using CLIP masking
180 | with autocast_fn():
181 | assert config.mask == 'separate'
182 | cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
183 | else:
184 | cond = data_x[1]
185 | if isinstance(cond, torch.Tensor):
186 | cond = cond.cuda()
187 |
188 | with autocast_fn():
189 | visual_q = None
190 |
191 | pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
192 |
193 | loss = loss_fn(pred, data_y[0].cuda())
194 |
195 | if torch.isnan(loss) or torch.isinf(loss):
196 | # skip if loss is nan
197 | log.warning('Training stopped due to inf/nan loss.')
198 | sys.exit(-1)
199 |
200 | extra_loss = 0
201 | loss += extra_loss
202 |
203 | opt.zero_grad()
204 |
205 | if scaler is None:
206 | loss.backward()
207 | opt.step()
208 | else:
209 | scaler.scale(loss).backward()
210 | scaler.step(opt)
211 | scaler.update()
212 |
213 | if lr_scheduler is not None:
214 | lr_scheduler.step()
215 | if i % 2000 == 0:
216 | current_lr = [g['lr'] for g in opt.param_groups][0]
217 | log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
218 |
219 | logger.iter(i=i, loss=loss)
220 | i += 1
221 |
222 | if i >= max_iterations:
223 |
224 | if not isfile(join(logger.base_path, 'weights.pth')):
225 | # only write if no weights were already written
226 | logger.save_weights(only_trainable=save_only_trainable)
227 |
228 | sys.exit(0)
229 |
230 |
231 | if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
232 | logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
233 |
234 |
235 | if val_interval is not None and i % val_interval == val_interval - 1:
236 |
237 | val_loss, val_scores, maximize = validate(model, dataset_val, config)
238 |
239 | if len(val_scores) > 0:
240 |
241 | score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
242 |
243 | if maximize and val_scores[config.use_val_metric] > best_val_score:
244 | logger.save_weights(only_trainable=save_only_trainable)
245 | best_val_score = val_scores[config.use_val_metric]
246 |
247 | elif not maximize and val_scores[config.use_val_metric] < best_val_score:
248 | logger.save_weights(only_trainable=save_only_trainable)
249 | best_val_score = val_scores[config.use_val_metric]
250 |
251 | else:
252 | score_str = ''
253 | # if no score is used, fall back to loss
254 | if val_loss < best_val_loss:
255 | logger.save_weights(only_trainable=save_only_trainable)
256 | best_val_loss = val_loss
257 |
258 | log.info(f'Validation loss: {val_loss}' + score_str)
259 | logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
260 | model.train()
261 |
262 | print('epoch complete')
263 |
264 |
265 | if __name__ == '__main__':
266 | main()
--------------------------------------------------------------------------------
/clipseg/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/overview.png
--------------------------------------------------------------------------------
/clipseg/sample_rd64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/sample_rd64.png
--------------------------------------------------------------------------------
/clipseg/sample_rd64_refined.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/sample_rd64_refined.png
--------------------------------------------------------------------------------
/clipseg/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | with open("README.md", "r", encoding="utf-8") as readme_file:
4 | readme = readme_file.read()
5 |
6 | requirements = [
7 | "numpy",
8 | "scipy",
9 | "matplotlib",
10 | "torch",
11 | "torchvision",
12 | "opencv-python",
13 | "CLIP @ git+https://github.com/openai/CLIP.git"
14 | ]
15 |
16 | setup(
17 | name='clipseg',
18 | packages=['clipseg'],
19 | package_dir={'clipseg': 'models'},
20 | package_data={'clipseg': [
21 | "../weights/*.pth",
22 | ]},
23 | version='0.0.1',
24 | url='https://github.com/timojl/clipseg',
25 | python_requires='>=3.9',
26 | install_requires=requirements,
27 | description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".',
28 | long_description=readme,
29 | long_description_content_type="text/markdown",
30 | )
31 |
--------------------------------------------------------------------------------
/clipseg/training.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import inspect
3 | import json
4 | import yaml
5 | import math
6 | import os
7 | import sys
8 |
9 | from general_utils import log
10 |
11 | import numpy as np
12 | from functools import partial
13 | from os.path import expanduser, join, isfile, basename
14 |
15 | from torch.cuda.amp import autocast, GradScaler
16 | from torch.optim.lr_scheduler import LambdaLR
17 | from contextlib import nullcontext
18 | from torch.utils.data import DataLoader
19 |
20 | from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
21 |
22 |
23 | def cosine_warmup_lr(i, warmup=10, max_iter=90):
24 | """ Cosine LR with Warmup """
25 | if i < warmup:
26 | return (i+1)/(warmup+1)
27 | else:
28 | return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
29 |
30 |
31 | def validate(model, dataset, config):
32 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
33 |
34 | metric_class, use_metric = config.val_metric_class, config.use_val_metric
35 | loss_fn = get_attribute(config.loss)
36 |
37 | model.eval()
38 | model.cuda()
39 |
40 | if metric_class is not None:
41 | metric = get_attribute(metric_class)()
42 |
43 | with torch.no_grad():
44 |
45 | i, losses = 0, []
46 | for data_x, data_y in data_loader:
47 |
48 | data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
49 | data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
50 |
51 | prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
52 | pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
53 |
54 | if metric_class is not None:
55 | metric.add([pred], data_y)
56 |
57 | # pred = model(data_x[0], prompts)
58 | # loss = loss_fn(pred[0], data_y[0])
59 | loss = loss_fn(pred, data_y[0])
60 | losses += [float(loss)]
61 |
62 | i += 1
63 |
64 | if config.val_max_iterations is not None and i > config.val_max_iterations:
65 | break
66 |
67 | if use_metric is None:
68 | return np.mean(losses), {}, False
69 | else:
70 | metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
71 | return np.mean(losses), metric_scores, True
72 |
73 |
74 | def main():
75 |
76 | config = training_config_from_cli_args()
77 |
78 | val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
79 |
80 | model_cls = get_attribute(config.model)
81 | _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
82 | model = model_cls(**model_args).cuda()
83 |
84 | dataset_cls = get_attribute(config.dataset)
85 | _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
86 |
87 | dataset = dataset_cls(**dataset_args)
88 |
89 | log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
90 |
91 | if val_interval is not None:
92 | dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
93 | _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
94 | print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
95 |
96 | dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
97 |
98 | # optimizer
99 | opt_cls = get_attribute(config.optimizer)
100 | if config.optimize == 'torch.optim.SGD':
101 | opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
102 | else:
103 | opt_args = {}
104 | opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
105 |
106 | if config.lr_scheduler == 'cosine':
107 | assert config.T_max is not None and config.eta_min is not None
108 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
109 | elif config.lr_scheduler == 'warmup_cosine':
110 | lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
111 | else:
112 | lr_scheduler = None
113 |
114 | batch_size, max_iterations = config.batch_size, config.max_iterations
115 |
116 | loss_fn = get_attribute(config.loss)
117 |
118 | if config.amp:
119 | log.info('Using AMP')
120 | autocast_fn = autocast
121 | scaler = GradScaler()
122 | else:
123 | autocast_fn, scaler = nullcontext, None
124 |
125 |
126 | save_only_trainable = True
127 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
128 |
129 | # disable config when hyperparam. opt. to avoid writing logs.
130 | tracker_config = config if not config.hyperparameter_optimization else None
131 |
132 | with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
133 |
134 | i = 0
135 | while True:
136 | for data_x, data_y in data_loader:
137 |
138 | # between caption and output feature.
139 | # 1. Sample random captions
140 | # 2. Check alignment with CLIP
141 |
142 | # randomly mix text and visual support conditionals
143 | if config.mix:
144 |
145 | assert config.mask.startswith('text_and')
146 |
147 | with autocast_fn():
148 | # data_x[1] = text label
149 | prompts = model.sample_prompts(data_x[1])
150 |
151 | # model.clip_model()
152 |
153 | text_cond = model.compute_conditional(prompts)
154 | if model.__class__.__name__ == 'CLIPDensePredTMasked':
155 | # when mask=='separate'
156 | visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
157 | else:
158 | # data_x[2] = visual prompt
159 | visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
160 |
161 | max_txt = config.mix_text_max if config.mix_text_max is not None else 1
162 | batch_size = text_cond.shape[0]
163 |
164 | # sample weights for each element in batch
165 | text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
166 | text_weights = text_weights.cuda()
167 |
168 | if dataset.__class__.__name__ == 'PhraseCut':
169 | # give full weight to text where support_image is invalid
170 | visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
171 | text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
172 |
173 | cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
174 |
175 | else:
176 | # no mix
177 |
178 | if model.__class__.__name__ == 'CLIPDensePredTMasked':
179 | # compute conditional vector using CLIP masking
180 | with autocast_fn():
181 | assert config.mask == 'separate'
182 | cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
183 | else:
184 | cond = data_x[1]
185 | if isinstance(cond, torch.Tensor):
186 | cond = cond.cuda()
187 |
188 | with autocast_fn():
189 | visual_q = None
190 |
191 | pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
192 |
193 | loss = loss_fn(pred, data_y[0].cuda())
194 |
195 | if torch.isnan(loss) or torch.isinf(loss):
196 | # skip if loss is nan
197 | log.warning('Training stopped due to inf/nan loss.')
198 | sys.exit(-1)
199 |
200 | extra_loss = 0
201 | loss += extra_loss
202 |
203 | opt.zero_grad()
204 |
205 | if scaler is None:
206 | loss.backward()
207 | opt.step()
208 | else:
209 | scaler.scale(loss).backward()
210 | scaler.step(opt)
211 | scaler.update()
212 |
213 | if lr_scheduler is not None:
214 | lr_scheduler.step()
215 | if i % 2000 == 0:
216 | current_lr = [g['lr'] for g in opt.param_groups][0]
217 | log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
218 |
219 | logger.iter(i=i, loss=loss)
220 | i += 1
221 |
222 | if i >= max_iterations:
223 |
224 | if not isfile(join(logger.base_path, 'weights.pth')):
225 | # only write if no weights were already written
226 | logger.save_weights(only_trainable=save_only_trainable)
227 |
228 | sys.exit(0)
229 |
230 |
231 | if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
232 | logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
233 |
234 |
235 | if val_interval is not None and i % val_interval == val_interval - 1:
236 |
237 | val_loss, val_scores, maximize = validate(model, dataset_val, config)
238 |
239 | if len(val_scores) > 0:
240 |
241 | score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
242 |
243 | if maximize and val_scores[config.use_val_metric] > best_val_score:
244 | logger.save_weights(only_trainable=save_only_trainable)
245 | best_val_score = val_scores[config.use_val_metric]
246 |
247 | elif not maximize and val_scores[config.use_val_metric] < best_val_score:
248 | logger.save_weights(only_trainable=save_only_trainable)
249 | best_val_score = val_scores[config.use_val_metric]
250 |
251 | else:
252 | score_str = ''
253 | # if no score is used, fall back to loss
254 | if val_loss < best_val_loss:
255 | logger.save_weights(only_trainable=save_only_trainable)
256 | best_val_loss = val_loss
257 |
258 | log.info(f'Validation loss: {val_loss}' + score_str)
259 | logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
260 | model.train()
261 |
262 | print('epoch complete')
263 |
264 |
265 | if __name__ == '__main__':
266 | main()
--------------------------------------------------------------------------------
/clipseg/weights/rd16-uni.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/weights/rd16-uni.pth
--------------------------------------------------------------------------------
/clipseg/weights/rd64-uni-refined.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/weights/rd64-uni-refined.pth
--------------------------------------------------------------------------------
/clipseg/weights/rd64-uni.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/clipseg/weights/rd64-uni.pth
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | UNET_LAYERS = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID',
2 | 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
3 |
4 | SD_INFERENCE_TIMESTEPS = [999, 979, 959, 939, 919, 899, 879, 859, 839, 819, 799, 779, 759, 739, 719, 699, 679, 659,
5 | 639, 619, 599, 579, 559, 539, 519, 500, 480, 460, 440, 420, 400, 380, 360, 340, 320, 300,
6 | 280, 260, 240, 220, 200, 180, 160, 140, 120, 100, 80, 60, 40, 20]
7 |
8 | # PROMPTS = [
9 | # "A photo of {} in the jungle",
10 | # "A photo of {} on a beach",
11 | # "A photo of {} in Times Square",
12 | # "A photo of {} in the moon",
13 | # "A painting of {} in the style of Monet",
14 | # "Oil painting of {}",
15 | # "A Marc Chagall painting of {}",
16 | # "A manga drawing of {}",
17 | # 'A watercolor painting of {}',
18 | # "A statue of {}",
19 | # "App icon of {}",
20 | # "A sand sculpture of {}",
21 | # "Colorful graffiti of {}",
22 | # "{} wearing a Santa hat",
23 | # "{} in a construction outfit",
24 | # "{} on the cover of Time magazine",
25 | # "A movie poster of {}",
26 | # "{} is wearing a doctoral gown and a doctoral cap",
27 | # "{} in an astronaut suit",
28 | # "A photo of {} wearing a life jacket",
29 | # "{} as a jedi master, inside hogwarts with a mystical scepter",
30 | # "{} wearing a sombrero",
31 | # "{} dressed like a wizard, reading a grimoire",
32 | # "{} as a character in the Sherlock Holmes' movie",
33 | # "{} buckled in a seat on a plane",
34 | # "A selfie of {} in Times Square",
35 | # "An oil painting of {}",
36 | # "A photo of {} as a cowboy",
37 | # "{} in a chef's outfit, cooking in a kitchen",
38 | # "{} selfie standing under the pink blossoms of a cherry tree",
39 | # ]
40 |
41 | PROMPTS = [
42 | "A {} is giving a speech on the podium",
43 | "A {} is walking on the red carpet at the evening gala",
44 | "A {} is holding a Halloween pumpkin lantern",
45 | "A {} is reading a book",
46 | "A {} wearing winter camo military gear in the snow",
47 | "A {} sitting in a hammock",
48 | "A photo of {} as an knight in armor",
49 | "A {} as a police officer",
50 | "A photo of {} in the lab",
51 | "A {} sitting a leather sofa",
52 | "A {} is reading newspaper, sitting in a busy train",
53 | "A {} in a jungle",
54 | "A {} is playing the piano",
55 | "A {} is doing a skateboard",
56 | "A {} is riding a motorcycle, road background",
57 | "{} is working on a beautiful design project, creating design projects, in a beautiful workspace",
58 | "{} driving a racing car",
59 | "{} floats on a raft along the street of a flooded city",
60 | "{} baking pizza in front of a wood-fired oven",
61 | "{} plays bass guitar on stage",
62 | ]
63 |
64 | VALIDATION_PROMPTS = [
65 | "A photo of a {}",
66 | "A photo of a {} on a beach",
67 | "App icon of {}",
68 | "A painting of {} in the style of Monet",
69 | ]
70 |
71 | IMAGENET_TEMPLATES_SMALL = [
72 | "a photo of a {}",
73 | "a rendering of a {}",
74 | "a cropped photo of the {}",
75 | "the photo of a {}",
76 | "a photo of a clean {}",
77 | "a photo of a dirty {}",
78 | "a dark photo of the {}",
79 | "a photo of my {}",
80 | "a photo of the cool {}",
81 | "a close-up photo of a {}",
82 | "a bright photo of the {}",
83 | "a cropped photo of a {}",
84 | "a photo of the {}",
85 | "a good photo of the {}",
86 | "a photo of one {}",
87 | "a close-up photo of the {}",
88 | "a rendition of the {}",
89 | "a photo of the clean {}",
90 | "a rendition of a {}",
91 | "a photo of a nice {}",
92 | "a good photo of a {}",
93 | "a photo of the nice {}",
94 | "a photo of the small {}",
95 | "a photo of the weird {}",
96 | "a photo of the large {}",
97 | "a photo of a cool {}",
98 | "a photo of a small {}",
99 | ]
100 |
101 | IMAGENET_STYLE_TEMPLATES_SMALL = [
102 | "a painting in the style of {}",
103 | "a rendering in the style of {}",
104 | "a cropped painting in the style of {}",
105 | "the painting in the style of {}",
106 | "a clean painting in the style of {}",
107 | "a dirty painting in the style of {}",
108 | "a dark painting in the style of {}",
109 | "a picture in the style of {}",
110 | "a cool painting in the style of {}",
111 | "a close-up painting in the style of {}",
112 | "a bright painting in the style of {}",
113 | "a cropped painting in the style of {}",
114 | "a good painting in the style of {}",
115 | "a close-up painting in the style of {}",
116 | "a rendition in the style of {}",
117 | "a nice painting in the style of {}",
118 | "a small painting in the style of {}",
119 | "a weird painting in the style of {}",
120 | "a large painting in the style of {}",
121 | ]
122 |
--------------------------------------------------------------------------------
/data/person_1/person_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/data/person_1/person_1.jpg
--------------------------------------------------------------------------------
/environment/environment.yaml:
--------------------------------------------------------------------------------
1 | name: persona
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.11.0
10 | - torchvision=0.12.0
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - diffusers
15 | - opencv-python==4.1.2.30
16 | - pudb==2019.2
17 | - invisible-watermark
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.4.2
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit>=0.73.1
24 | - einops==0.3.0
25 | - torch-fidelity==0.3.0
26 | - torchmetrics==0.6.0
27 | - kornia==0.6
28 | - -r requirements.txt
29 |
--------------------------------------------------------------------------------
/environment/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==0.4.3
2 | diffusers
3 | opencv-python==4.1.2.30
4 | pudb==2019.2
5 | invisible-watermark
6 | imageio==2.9.0
7 | imageio-ffmpeg==0.4.2
8 | pytorch-lightning==1.4.2
9 | omegaconf==2.1.1
10 | test-tube>=0.7.5
11 | streamlit>=0.73.1
12 | einops==0.3.0
13 | torch-fidelity==0.3.0
14 | torchmetrics==0.6.0
15 | kornia==0.6
16 | ftfy
17 | opencv-python
18 | ipywidgets
19 | matplotlib
20 | jupyter
21 | pyrallis==0.3.1
22 | loguru==0.7.0
23 | diffusers==0.14.0
24 | transformers==4.27.4
25 | accelerate==0.18.0
26 |
--------------------------------------------------------------------------------
/inference_text.txt:
--------------------------------------------------------------------------------
1 | {} wearing a Santa hat
2 | A photo of {} as a cowboy
3 | {} in an astronaut suit
--------------------------------------------------------------------------------
/input_configs/inference.yaml:
--------------------------------------------------------------------------------
1 | input_dir: ./logs/person_1
2 | prompts_file_path: ./inference_text.txt
3 | clip_ckpt_path: "/path/to/clip/ckpt/"
4 | iteration: 1000
5 | seeds: [ 233, 234, 235, 236 ]
6 | torch_dtype: fp16
7 | inference_dir: ./results/
8 | super_category_token: face
9 | image_path: ./data/person_1/person_1.jpg
--------------------------------------------------------------------------------
/input_configs/train.yaml:
--------------------------------------------------------------------------------
1 | log:
2 | exp_name: person_1
3 | exp_dir: ./log/person_1
4 | save_steps: 1000
5 | data:
6 | train_data_dir: ./data/person_1
7 | placeholder_token:
8 | super_category_token: face
9 | dataloader_num_workers: 2
10 | model:
11 | pretrained_model_name_or_path: /path/to/diffusion/dreamlike-artdreamlike-photoreal-2.0/
12 | clip_ckpt_path: "/path/to/clip/ckpt/"
13 | normalize_mapper_output: True
14 | use_positional_encoding: True
15 | num_pe_time_anchors: 200
16 | output_bypass: True
17 | eval:
18 | validation_steps: 2000
19 | optim:
20 | max_train_steps: 1000
21 | learning_rate: 5e-5
22 | train_batch_size: 1
23 | gradient_accumulation_steps: 2
24 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/models/__init__.py
--------------------------------------------------------------------------------
/models/clip_prior.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import clip
5 | import numpy as np
6 | import kornia
7 | import os
8 |
9 | class MultiCLIP(torch.nn.Module):
10 | def __init__(self, clip_ckpt_path, device="cpu"):
11 | super().__init__()
12 | model_32, _ = clip.load(os.path.join(clip_ckpt_path,"ViT-B-32.pt"), device=device)
13 | # model_16, _ = clip.load(os.path.join(clip_ckpt_path,"ViT-B-16.pt"), device=device)
14 | # model_101, _ = clip.load(os.path.join(clip_ckpt_path,"RN101.pt"), device=device)
15 | self.model_32 = model_32
16 | # self.model_16 = model_16
17 | # self.model_101 = model_101
18 |
19 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
20 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
21 |
22 | def preprocess(self, x):
23 | # normalize to [0,1]
24 | x = kornia.geometry.resize(x, (224, 224),
25 | interpolation='bicubic',align_corners=True,
26 | antialias=False)
27 | x = (x + 1.) / 2.
28 | x = kornia.enhance.normalize(x, self.mean, self.std)
29 | return x
30 |
31 | def encode_image(self, image, dtype):
32 | with torch.no_grad():
33 | image = self.preprocess(image)
34 | vectors = [self.model_32.encode_image(image.to(dtype))]
35 | return torch.cat(vectors, dim=-1).to(dtype)
36 |
--------------------------------------------------------------------------------
/models/clip_text_embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | from torch import nn
5 | from transformers import CLIPTextConfig
6 | from typing import Union
7 | from models.mapper import Mapper
8 | from utils.types import Batch
9 |
10 |
11 | class PersonaCLIPTextEmbeddings(nn.Module):
12 | """ Modification of CLIPTextEmbedding to allow for the use of a Mapper to overwrite the concept token. """
13 |
14 | def __init__(self, config: CLIPTextConfig):
15 | super().__init__()
16 | embed_dim = config.hidden_size
17 | self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
18 | self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
19 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
20 |
21 | def set_mapper(self, mapper: Mapper):
22 | self.mapper = mapper
23 |
24 | def forward(self, input_ids: Optional[torch.LongTensor] = None,
25 | position_ids: Optional[torch.LongTensor] = None,
26 | inputs_embeds: Optional[torch.FloatTensor] = None,
27 | batch: Optional[Batch] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
28 |
29 | if batch is not None:
30 | input_ids = batch.input_ids
31 |
32 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
33 |
34 | if position_ids is None:
35 | position_ids = self.position_ids[:, :seq_length]
36 |
37 | if inputs_embeds is None:
38 | inputs_embeds = self.token_embedding(input_ids)
39 |
40 | bypass_outputs = None
41 | if batch is not None:
42 | mapper_outputs = self.mapper(input=batch.mapper_input)
43 |
44 | mapper_outputs = mapper_outputs.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
45 | if self.mapper.output_bypass:
46 | bypass_outputs = mapper_outputs[:, mapper_outputs.shape[1] // 2:]
47 | mapper_outputs = mapper_outputs[:, :mapper_outputs.shape[1] // 2]
48 |
49 | # Overwrite the index of the placeholder token with the mapper output for each entry in the batch
50 |
51 | for i in range(len(batch.placeholder_token_id)):
52 | learnable_idxs = (input_ids == batch.placeholder_token_id[i]).nonzero(as_tuple=True)[1]
53 | inputs_embeds[torch.arange(input_ids.shape[0]), learnable_idxs] = mapper_outputs[i]
54 |
55 | position_embeddings = self.position_embedding(position_ids)
56 | embeddings = inputs_embeds + position_embeddings
57 |
58 | return embeddings, bypass_outputs
59 |
--------------------------------------------------------------------------------
/models/clip_text_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.utils.checkpoint
5 | from torch import nn
6 | from transformers.modeling_outputs import BaseModelOutputWithPooling
7 | from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPEncoder
8 | from transformers.models.clip.modeling_clip import CLIPTextTransformer, _expand_mask
9 | import torch.nn.functional as F
10 | from models.clip_text_embedding import PersonaCLIPTextEmbeddings
11 | from utils.types import Batch
12 | import numpy as np
13 | from PIL import Image
14 |
15 | class PersonaCLIPTextModel(CLIPTextModel):
16 | def __init__(self, config: CLIPTextConfig):
17 | super().__init__(config)
18 | self.text_model = PersonaCLIPTextTransformer(config)
19 | self.post_init()
20 |
21 | def forward(self, input_ids: Optional[torch.Tensor] = None,
22 | attention_mask: Optional[torch.Tensor] = None,
23 | position_ids: Optional[torch.Tensor] = None,
24 | output_attentions: Optional[bool] = None,
25 | output_hidden_states: Optional[bool] = None,
26 | return_dict: Optional[bool] = None,
27 | batch: Optional[Batch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
28 | return self.text_model.forward(
29 | batch=batch,
30 | input_ids=input_ids,
31 | attention_mask=attention_mask,
32 | position_ids=position_ids,
33 | output_attentions=output_attentions,
34 | output_hidden_states=output_hidden_states,
35 | return_dict=return_dict,
36 | )
37 |
38 |
39 | class PersonaCLIPTextTransformer(CLIPTextTransformer):
40 | def __init__(self, config: CLIPTextConfig):
41 | super().__init__(config=config)
42 | self.config = config
43 | embed_dim = config.hidden_size
44 | self.embeddings = PersonaCLIPTextEmbeddings(config)
45 | self.encoder = CLIPEncoder(config)
46 | self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
47 |
48 | def forward(self, input_ids: Optional[torch.Tensor] = None,
49 | attention_mask: Optional[torch.Tensor] = None,
50 | position_ids: Optional[torch.Tensor] = None,
51 | output_attentions: Optional[bool] = None,
52 | output_hidden_states: Optional[bool] = None,
53 | return_dict: Optional[bool] = None,
54 | batch: Optional[Batch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
55 |
56 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
57 | output_hidden_states = (
58 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
59 | )
60 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
61 |
62 | bypass_output = None
63 |
64 | if input_ids is not None:
65 | input_shape = input_ids.size()
66 | input_ids = input_ids.view(-1, input_shape[-1])
67 | hidden_states, _ = self.embeddings(input_ids=input_ids, position_ids=position_ids)
68 |
69 | elif batch is not None:
70 | input_shape = batch.input_ids.size()
71 | batch.input_ids = batch.input_ids.view(-1, input_shape[-1])
72 | hidden_states, bypass_output = self.embeddings(batch=batch, position_ids=position_ids)
73 |
74 | else:
75 | raise ValueError("You have to specify either batch or input_ids!")
76 |
77 | bsz, seq_len = input_shape
78 | # CLIP's text model uses causal mask, prepare it here.
79 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
80 | causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
81 | hidden_states.device
82 | )
83 |
84 | # expand attention_mask
85 | if attention_mask is not None:
86 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
87 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
88 |
89 | output_attentions = True
90 |
91 | encoder_outputs = self.encoder(
92 | inputs_embeds=hidden_states,
93 | attention_mask=attention_mask,
94 | causal_attention_mask=causal_attention_mask,
95 | output_attentions=output_attentions,
96 | output_hidden_states=output_hidden_states,
97 | return_dict=return_dict,
98 | )
99 |
100 | last_hidden_state = encoder_outputs[0]
101 | last_hidden_state_with_bypass = last_hidden_state.clone()
102 |
103 | if bypass_output is not None:
104 | for i in range(len(batch.placeholder_token_id)):
105 | learnable_idxs = (batch.input_ids == batch.placeholder_token_id[i]).nonzero(as_tuple=True)[1]
106 | existing_state = last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs]
107 |
108 | new_state = F.normalize(bypass_output[i] + existing_state, dim=-1) * existing_state.norm(dim=1, keepdim=True)
109 | new_state = new_state.to(dtype=hidden_states.dtype)
110 | last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs] = new_state
111 |
112 | last_hidden_state = self.final_layer_norm(last_hidden_state)
113 | last_hidden_state_with_bypass = self.final_layer_norm(last_hidden_state_with_bypass)
114 |
115 | if input_ids is not None:
116 | pooled_output = last_hidden_state[
117 | torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
118 | ]
119 | pooled_output_with_bypass = last_hidden_state_with_bypass[
120 | torch.arange(last_hidden_state_with_bypass.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
121 | ]
122 | elif batch is not None:
123 | pooled_output = last_hidden_state[
124 | torch.arange(last_hidden_state.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
125 | ]
126 | pooled_output_with_bypass = last_hidden_state_with_bypass[
127 | torch.arange(last_hidden_state_with_bypass.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
128 | ]
129 | else:
130 | raise ValueError("You have to specify either batch or input_ids!")
131 |
132 | if bypass_output is not None:
133 | return BaseModelOutputWithPooling(
134 | last_hidden_state=last_hidden_state,
135 | pooler_output=pooled_output,
136 | hidden_states=encoder_outputs.hidden_states,
137 | attentions=encoder_outputs.attentions,
138 | ), BaseModelOutputWithPooling(
139 | last_hidden_state=last_hidden_state_with_bypass,
140 | pooler_output=pooled_output_with_bypass,
141 | hidden_states=encoder_outputs.hidden_states,
142 | attentions=encoder_outputs.attentions,
143 | )
144 | else:
145 | return BaseModelOutputWithPooling(
146 | last_hidden_state=last_hidden_state,
147 | pooler_output=pooled_output,
148 | hidden_states=encoder_outputs.hidden_states,
149 | attentions=encoder_outputs.attentions,
150 | ), None
151 |
--------------------------------------------------------------------------------
/models/mapper.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Optional, List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from models.positional_encoding import BasicEncoder, TimePositionalEncoding
9 | from utils.types import Mapper_input
10 |
11 | class Mapper(nn.Module):
12 | def __init__(self, output_dim: int = 768,
13 | norm_scale: Optional[torch.Tensor] = None,
14 | use_positional_encoding: bool = True,
15 | num_pe_time_anchors: int = 200,
16 | sigma_t: float = 1.0,
17 | output_bypass: bool = True,
18 | token_num = 1):
19 | super().__init__()
20 | self.norm_scale = norm_scale
21 | self.output_bypass = output_bypass
22 | self.token_num = token_num
23 |
24 |
25 | self.use_positional_encoding = use_positional_encoding
26 | if self.use_positional_encoding:
27 | self.encoder = TimePositionalEncoding(sigma_t=sigma_t).cuda()
28 | self.input_dim = num_pe_time_anchors
29 | else:
30 | self.encoder = BasicEncoder().cuda()
31 | self.input_dim = 2
32 |
33 | self.input_layer = self.set_input_layer(num_time_anchors=num_pe_time_anchors)
34 |
35 | self.timestep_proj = nn.Sequential(
36 | nn.Linear(self.input_dim, 128),
37 | nn.LayerNorm(128),
38 | nn.LeakyReLU(),
39 | nn.Linear(128, 128),
40 | nn.LayerNorm(128),
41 | nn.LeakyReLU())
42 | if self.output_bypass:
43 | self.image_proj = nn.Sequential(
44 | nn.Linear(512 + 128, 128),
45 | nn.LayerNorm(128),
46 | nn.LeakyReLU(),
47 | nn.Linear(128, 128),
48 | nn.LayerNorm(128),
49 | nn.LeakyReLU())
50 | self.image_output_layer = nn.Sequential(nn.Linear(128, token_num * output_dim))
51 |
52 | self.output_layer = nn.Sequential(nn.Linear(128, token_num * output_dim))
53 |
54 |
55 | def set_input_layer(self, num_time_anchors: int) -> nn.Module:
56 | if self.use_positional_encoding:
57 | input_layer = nn.Linear(self.encoder.num_w, self.input_dim)
58 | input_layer.weight.data = self.encoder.init_layer(num_time_anchors)
59 | else:
60 | input_layer = nn.Identity()
61 | return input_layer
62 |
63 | def get_time_embedding(self, timestep: torch.Tensor) -> torch.Tensor:
64 | time_embedding = self.encoder.encode(timestep)
65 | time_embedding = self.input_layer(time_embedding)
66 | return time_embedding
67 |
68 | def forward(self, input: Mapper_input) -> torch.Tensor:
69 | timestep = input.timesteps.float()
70 | word_embedding = input.word_embedding
71 | image_embedding = input.image_embedding
72 | time_embedding = self.get_time_embedding(timestep)
73 | embedding = self.timestep_proj(time_embedding)
74 |
75 | if self.output_bypass:
76 | bypass = torch.cat([embedding, image_embedding],dim=-1)
77 | bypass = self.image_proj(bypass)
78 | bypass = self.image_output_layer(bypass)
79 | if self.training and random.random() < 0.5:
80 | for idx in torch.arange(bypass.shape[0]):
81 | bypass[idx][0:] = 0
82 | bypass = bypass.view(-1,768)
83 |
84 | embedding = self.output_layer(embedding)
85 | embedding = embedding.view(-1,768)
86 | embedding = F.normalize(embedding + word_embedding, dim=-1) * self.norm_scale
87 | if self.output_bypass:
88 | embedding = torch.cat([embedding, bypass],dim=-1)
89 | return embedding
90 |
--------------------------------------------------------------------------------
/models/positional_encoding.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import math
3 | import torch
4 | from torch import nn
5 |
6 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
7 | """
8 | Create sinusoidal timestep embeddings.
9 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
10 | These may be fractional.
11 | :param dim: the dimension of the output.
12 | :param max_period: controls the minimum frequency of the embeddings.
13 | :return: an [N x dim] Tensor of positional embeddings.
14 | """
15 | if not repeat_only:
16 | half = dim // 2
17 | freqs = torch.exp(
18 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
19 | ).to(device=timesteps.device)
20 | args = timesteps[:, None].float() * freqs[None]
21 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
22 | if dim % 2:
23 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
24 | else:
25 | embedding = repeat(timesteps, 'b -> b d', d=dim)
26 | return embedding
27 |
28 | class BasicEncoder(nn.Module):
29 | """ Simply normalizes the given timestep and unet layer to be between -1 and 1. """
30 |
31 | def __init__(self, num_denoising_timesteps: int = 1000, num_unet_layers: int = 16):
32 | super().__init__()
33 | self.normalized_timesteps = (torch.arange(num_denoising_timesteps) / (num_denoising_timesteps - 1)) * 2 - 1
34 | self.normalized_unet_layers = (torch.arange(num_unet_layers) / (num_unet_layers - 1)) * 2 - 1
35 | self.normalized_timesteps = nn.Parameter(self.normalized_timesteps).cuda()
36 | self.normalized_unet_layers = nn.Parameter(self.normalized_unet_layers).cuda()
37 |
38 | def encode(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
39 | normalized_input = torch.stack([self.normalized_timesteps[timestep.long()],
40 | self.normalized_unet_layers[unet_layer.long()]]).T
41 | return normalized_input
42 |
43 |
44 | class TimePositionalEncoding(nn.Module):
45 |
46 | def __init__(self, sigma_t: float, num_w: int = 1024):
47 | super().__init__()
48 | self.sigma_t = sigma_t
49 | self.num_w = num_w
50 | self.w = torch.randn((num_w, 1))
51 | self.w[:, 0] *= sigma_t
52 | self.w = nn.Parameter(self.w).cuda()
53 |
54 | def encode(self, t: torch.Tensor):
55 | """ Maps the given time and layer input into a 2048-dimensional vector. """
56 | if type(t) == int or t.ndim == 0:
57 | x = torch.tensor([t]).float()
58 | else:
59 | x = t.unsqueeze(0)
60 | x = x.cuda()
61 | v_norm = timestep_embedding(x, 2048).squeeze(0)
62 | return v_norm
63 |
64 | def init_layer(self, num_time_anchors: int) -> torch.Tensor:
65 | """ Computes the weights for the positional encoding layer of size 200x2048."""
66 | anchor_vectors = []
67 | for t_anchor in range(400, 800, 400 // num_time_anchors):
68 | anchor_vectors.append(self.encode(t_anchor).float())
69 | A = torch.stack(anchor_vectors)
70 | return A
--------------------------------------------------------------------------------
/models/xti_attention_processor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | import torch
3 | from diffusers.models.cross_attention import CrossAttention
4 | import numpy as np
5 | from PIL import Image
6 |
7 | class XTIAttenProc:
8 |
9 | def __call__(self, attn: CrossAttention,
10 | hidden_states: torch.Tensor,
11 | encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None,
12 | attention_mask: Optional[torch.Tensor] = None):
13 | _ehs_bypass = None
14 | if encoder_hidden_states is not None:
15 | if isinstance(encoder_hidden_states, dict):
16 | this_idx = encoder_hidden_states["this_idx"]
17 | _ehs = encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]
18 | if f"CONTEXT_TENSOR_BYPASS_{this_idx}" in encoder_hidden_states:
19 | _ehs_bypass = encoder_hidden_states[f"CONTEXT_TENSOR_BYPASS_{this_idx}"]
20 | encoder_hidden_states["this_idx"] += 1
21 | encoder_hidden_states["this_idx"] %= 16
22 | else:
23 | _ehs = encoder_hidden_states
24 | else:
25 | _ehs = None
26 |
27 | batch_size, sequence_length, _ = (hidden_states.shape if _ehs is None else _ehs.shape)
28 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
29 | query = attn.to_q(hidden_states)
30 |
31 | if _ehs is None:
32 | _ehs = hidden_states
33 | elif attn.cross_attention_norm:
34 | _ehs = attn.norm_cross(_ehs)
35 | _ehs_bypass = attn.norm_cross(_ehs_bypass)
36 |
37 | key = attn.to_k(_ehs)
38 | if _ehs_bypass is not None:
39 | value = attn.to_v(_ehs_bypass)
40 | else:
41 | value = attn.to_v(_ehs)
42 |
43 | query = attn.head_to_batch_dim(query)
44 | key = attn.head_to_batch_dim(key)
45 | value = attn.head_to_batch_dim(value)
46 |
47 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
48 | hidden_states = torch.bmm(attention_probs, value)
49 | hidden_states = attn.batch_to_head_dim(hidden_states)
50 |
51 | # linear proj
52 | hidden_states = attn.to_out[0](hidden_states)
53 | # dropout
54 | hidden_states = attn.to_out[1](hidden_states)
55 |
56 | return hidden_states
57 |
58 |
59 | class MyAttenProc:
60 | def __call__(self, attn: CrossAttention,
61 | hidden_states: torch.Tensor,
62 | encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None,
63 | attention_mask: Optional[torch.Tensor] = None):
64 |
65 | _ehs_bypass = None
66 | if encoder_hidden_states is not None:
67 | if isinstance(encoder_hidden_states, dict):
68 | _ehs = encoder_hidden_states[f"CONTEXT_TENSOR"]
69 | if f"CONTEXT_TENSOR_BYPASS" in encoder_hidden_states:
70 | _ehs_bypass = encoder_hidden_states[f"CONTEXT_TENSOR_BYPASS"]
71 | else:
72 | _ehs = encoder_hidden_states
73 | else:
74 | _ehs = None
75 |
76 | batch_size, sequence_length, _ = (hidden_states.shape if _ehs is None else _ehs.shape)
77 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
78 | query = attn.to_q(hidden_states)
79 |
80 | if _ehs is None:
81 | _ehs = hidden_states
82 | elif attn.cross_attention_norm:
83 | _ehs = attn.norm_cross(_ehs)
84 | _ehs_bypass = attn.norm_cross(_ehs_bypass)
85 |
86 | key = attn.to_k(_ehs)
87 | if _ehs_bypass is not None:
88 | value = attn.to_v(_ehs_bypass)
89 | else:
90 | value = attn.to_v(_ehs)
91 |
92 | query = attn.head_to_batch_dim(query)
93 | key = attn.head_to_batch_dim(key)
94 | value = attn.head_to_batch_dim(value)
95 |
96 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
97 | hidden_states = torch.bmm(attention_probs, value)
98 | hidden_states = attn.batch_to_head_dim(hidden_states)
99 |
100 | # linear proj
101 | hidden_states = attn.to_out[0](hidden_states)
102 | # dropout
103 | hidden_states = attn.to_out[1](hidden_states)
104 |
105 | return hidden_states
--------------------------------------------------------------------------------
/prompt_manager.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Dict, Any
2 |
3 | import torch
4 | from tqdm import tqdm
5 | from transformers import CLIPTokenizer
6 |
7 | import constants
8 | from models.clip_text_encoder import PersonaCLIPTextModel
9 | from utils.types import Mapper_input, Batch
10 |
11 |
12 | class PromptManager:
13 | """ Class for computing all time and space embeddings for a given prompt. """
14 | def __init__(self, tokenizer: CLIPTokenizer,
15 | text_encoder: PersonaCLIPTextModel,
16 | timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS,
17 | unet_layers: List[str] = constants.UNET_LAYERS,
18 | placeholder_token_id: Optional[List] = None,
19 | placeholder_token: Optional[List] = None,
20 | torch_dtype: torch.dtype = torch.float32):
21 | self.tokenizer = tokenizer
22 | self.text_encoder = text_encoder
23 | self.timesteps = timesteps
24 | self.unet_layers = unet_layers
25 | self.placeholder_token = placeholder_token
26 | self.placeholder_token_id = placeholder_token_id
27 | self.dtype = torch_dtype
28 |
29 | def my_embed_prompt(self, text: str,
30 | word_embedding: torch.Tensor,
31 | image_embedding: torch.Tensor,
32 | num_images_per_prompt: int = 1,
33 | super_category_token: str = 'face') -> List[Dict[str, Any]]:
34 |
35 | constant_text = text.format(super_category_token)
36 | constant_ids = self.tokenizer(
37 | constant_text,
38 | padding="max_length",
39 | max_length=self.tokenizer.model_max_length,
40 | return_tensors="pt",
41 | ).input_ids
42 |
43 | dynamic_text = text.format(' '.join(self.placeholder_token))
44 | dynamic_ids = self.tokenizer(
45 | dynamic_text,
46 | padding="max_length",
47 | max_length=self.tokenizer.model_max_length,
48 | return_tensors="pt",
49 | ).input_ids
50 |
51 | # print(dynamic_text, dynamic_ids)
52 | # Compute embeddings for each timestep and each U-Net layer
53 | print(f"Computing embeddings over {len(self.timesteps)} timesteps.")
54 |
55 | hidden_states_per_timestep = []
56 | for timestep in tqdm(self.timesteps):
57 | _hs = {}.copy()
58 | if timestep > 800:
59 | ids = constant_ids
60 | elif 800 >= timestep >=400:
61 | ids = dynamic_ids
62 | else:
63 | _hs = hidden_states_per_timestep[-1]
64 | hidden_states_per_timestep.append(_hs)
65 | continue
66 |
67 | mapper_input = Mapper_input(timesteps=timestep.unsqueeze(0),
68 | word_embedding=word_embedding.unsqueeze(0),
69 | image_embedding=image_embedding.unsqueeze(0))
70 | batch = Batch(
71 | input_ids=ids.to(device=self.text_encoder.device),
72 | placeholder_token_id=self.placeholder_token_id,
73 | mapper_input=mapper_input)
74 | layer_hidden_state, layer_hidden_state_bypass = self.text_encoder(batch=batch)
75 | layer_hidden_state = layer_hidden_state[0].to(dtype=self.dtype)
76 | _hs[f"CONTEXT_TENSOR"] = layer_hidden_state.repeat(num_images_per_prompt, 1, 1)
77 | if layer_hidden_state_bypass is not None:
78 | layer_hidden_state_bypass = layer_hidden_state_bypass[0].to(dtype=self.dtype)
79 | _hs[f"CONTEXT_TENSOR_BYPASS"] = layer_hidden_state_bypass.repeat(num_images_per_prompt, 1, 1)
80 | # _hs['timestep'] = timestep
81 | hidden_states_per_timestep.append(_hs)
82 | return hidden_states_per_timestep
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/inference.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from dataclasses import dataclass, field
3 | from pathlib import Path
4 | from typing import Optional, List, Tuple, Union
5 |
6 | import numpy as np
7 | import pyrallis
8 | import torch
9 | import PIL
10 | from PIL import Image
11 | from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
12 | from transformers import CLIPTokenizer
13 |
14 | sys.path.append(".")
15 | sys.path.append("..")
16 |
17 | import constants
18 | from models.clip_text_encoder import PersonaCLIPTextModel
19 | from models.mapper import Mapper
20 | from prompt_manager import PromptManager
21 | from sd_pipeline_call import sd_pipeline_call
22 | from models.xti_attention_processor import XTIAttenProc, MyAttenProc
23 | from checkpoint_handler import CheckpointHandler
24 | from utils import vis_utils
25 | from models.clip_prior import MultiCLIP
26 |
27 | import time
28 | @dataclass
29 | class InferenceConfig:
30 | # Specifies which checkpoint iteration we want to load
31 | iteration: Optional[int] = None
32 | # The input directory containing the saved models and embeddings
33 | input_dir: Optional[Path] = None
34 | # Where the save the inference results to
35 | inference_dir: Optional[Path] = None
36 | # Specific path to the mapper you want to load, overrides `input_dir`
37 | mapper_checkpoint_path: Optional[Path] = None
38 | # Specific path to the embeddings you want to load, overrides `input_dir`
39 | learned_embeds_path: Optional[Path] = None
40 | # List of prompts to run inference on
41 | prompts: Optional[List[str]] = None
42 | # Text file containing a prompts to run inference on (one prompt per line), overrides `prompts`
43 | prompts_file_path: Optional[Path] = None
44 | # List of random seeds to run on
45 | seeds: List[int] = field(default_factory=lambda: [42])
46 | # If you want to run with dropout at inference time, this specifies the truncation indices for applying dropout.
47 | # None indicates that no dropout will be performed. If a list of indices is provided, will run all indices.
48 | truncation_idxs: Optional[Union[int, List[int]]] = None
49 | # Whether to run with torch.float16 or torch.float32
50 | torch_dtype: str = "fp16"
51 | clip_ckpt_path: Optional[Path] = "/path/to/clip/ckpt/"
52 | super_category_token: str = "face"
53 | image_path: Optional[str] = None
54 |
55 | def __post_init__(self):
56 | assert bool(self.prompts) != bool(self.prompts_file_path), \
57 | "You must provide either prompts or prompts_file_path, but not both!"
58 | self._set_prompts()
59 | self._set_input_paths()
60 | self.inference_dir.mkdir(exist_ok=True, parents=True)
61 | if type(self.truncation_idxs) == int:
62 | self.truncation_idxs = [self.truncation_idxs]
63 | self.torch_dtype = torch.float16 if self.torch_dtype == "fp16" else torch.float32
64 |
65 | def _set_input_paths(self):
66 | if self.inference_dir is None:
67 | assert self.input_dir is not None, "You must pass an input_dir if you do not specify inference_dir"
68 | self.inference_dir = self.input_dir / f"inference_{self.iteration}"
69 | if self.mapper_checkpoint_path is None:
70 | assert self.input_dir is not None, "You must pass an input_dir if you do not specify mapper_checkpoint_path"
71 | self.mapper_checkpoint_path = self.input_dir / f"mapper-steps-{self.iteration}.pt"
72 | if self.learned_embeds_path is None:
73 | assert self.input_dir is not None, "You must pass an input_dir if you do not specify learned_embeds_path"
74 | self.learned_embeds_path = self.input_dir / f"learned_embeds-steps-{self.iteration}.bin"
75 |
76 | def _set_prompts(self):
77 | if self.prompts_file_path is not None:
78 | assert self.prompts_file_path.exists(), f"Prompts file {self.prompts_file_path} does not exist!"
79 | self.prompts = self.prompts_file_path.read_text().splitlines()
80 |
81 |
82 | @pyrallis.wrap()
83 | def main(infer_cfg: InferenceConfig):
84 | train_cfg, mapper = CheckpointHandler.load_my_mapper(infer_cfg.mapper_checkpoint_path)
85 | pipeline, placeholder_token, placeholder_token_id = load_stable_diffusion_model(
86 | pretrained_model_name_or_path=train_cfg.model.pretrained_model_name_or_path,
87 | mapper=mapper,
88 | learned_embeds_path=infer_cfg.learned_embeds_path,
89 | torch_dtype=infer_cfg.torch_dtype
90 | )
91 | clip = MultiCLIP(clip_ckpt_path=infer_cfg.clip_ckpt_path).to(pipeline.device, dtype=infer_cfg.torch_dtype)
92 | prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
93 | text_encoder=pipeline.text_encoder,
94 | timesteps=pipeline.scheduler.timesteps,
95 | unet_layers=constants.UNET_LAYERS,
96 | placeholder_token=placeholder_token,
97 | placeholder_token_id=placeholder_token_id,
98 | torch_dtype=infer_cfg.torch_dtype)
99 |
100 | with torch.autocast("cuda"):
101 | with torch.no_grad():
102 | token_embeds = pipeline.text_encoder.get_input_embeddings().weight.data
103 | super_category_token_id = pipeline.tokenizer.encode(infer_cfg.super_category_token, add_special_tokens=False)[0]
104 | word_embedding = token_embeds[super_category_token_id].clone().detach()
105 | image = read_image(infer_cfg.image_path).unsqueeze(0).to(pipeline.device)
106 | image_embedding = clip.encode_image(image=image, dtype=infer_cfg.torch_dtype).detach()[0]
107 |
108 |
109 | for prompt in infer_cfg.prompts:
110 | output_path = infer_cfg.inference_dir
111 | output_path.mkdir(exist_ok=True, parents=True)
112 | prompt_image = run_inference(prompt=prompt,
113 | pipeline=pipeline,
114 | prompt_manager=prompt_manager,
115 | word_embedding=word_embedding,
116 | image_embedding=image_embedding,
117 | seeds=infer_cfg.seeds,
118 | output_path=output_path,
119 | num_images_per_prompt=1,
120 | super_category_token=infer_cfg.super_category_token)
121 |
122 |
123 | def run_inference(prompt: str,
124 | pipeline: StableDiffusionPipeline,
125 | prompt_manager: PromptManager,
126 | word_embedding: torch.Tensor,
127 | image_embedding: torch.Tensor,
128 | seeds: List[int],
129 | output_path: Optional[Path] = None,
130 | num_images_per_prompt: int = 1,
131 | super_category_token: str = 'face') -> Image.Image:
132 | with torch.autocast("cuda"):
133 | with torch.no_grad():
134 | prompt_embeds = prompt_manager.my_embed_prompt(prompt,
135 | word_embedding=word_embedding,
136 | image_embedding=image_embedding,
137 | num_images_per_prompt=num_images_per_prompt,
138 | super_category_token=super_category_token)
139 | joined_images = []
140 |
141 | for seed in seeds:
142 | generator = torch.Generator(device='cuda').manual_seed(seed)
143 | images = sd_pipeline_call(pipeline,
144 | prompt_embeds=prompt_embeds,
145 | generator=generator,
146 | num_images_per_prompt=num_images_per_prompt).images
147 | seed_image = Image.fromarray(np.concatenate(images, axis=1)).convert("RGB")
148 | if output_path is not None:
149 | tmp_prompt = prompt.format(super_category_token).replace(' ', '_')
150 | image_name = output_path.name
151 | save_name = f'{tmp_prompt}_{image_name}_{seed}.jpg'
152 | seed_image.save(output_path / save_name)
153 | joined_images.append(seed_image)
154 | joined_image = vis_utils.get_image_grid(joined_images)
155 |
156 |
157 | return joined_image
158 |
159 |
160 | def load_stable_diffusion_model(pretrained_model_name_or_path: str,
161 | learned_embeds_path: Path,
162 | mapper: Optional[Mapper] = None,
163 | num_denoising_steps: int = 50,
164 | torch_dtype: torch.dtype = torch.float16) -> Tuple[StableDiffusionPipeline, str, int]:
165 | tokenizer = CLIPTokenizer.from_pretrained(
166 | pretrained_model_name_or_path, subfolder="tokenizer")
167 | text_encoder = PersonaCLIPTextModel.from_pretrained(
168 | pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
169 | )
170 | if mapper is not None:
171 | text_encoder.text_model.embeddings.set_mapper(mapper)
172 | placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
173 | learned_embeds_path=learned_embeds_path,
174 | text_encoder=text_encoder,
175 | tokenizer=tokenizer
176 | )
177 | pipeline = StableDiffusionPipeline.from_pretrained(
178 | pretrained_model_name_or_path,
179 | torch_dtype=torch_dtype,
180 | text_encoder=text_encoder,
181 | tokenizer=tokenizer
182 | ).to("cuda")
183 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
184 | pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
185 | # pipeline.unet.set_attn_processor(XTIAttenProc())
186 | pipeline.unet.set_attn_processor(MyAttenProc())
187 | return pipeline, placeholder_token, placeholder_token_id
188 |
189 | def read_image(image_path):
190 | image = Image.open(image_path)
191 | if not image.mode == "RGB":
192 | image = image.convert("RGB")
193 | img = np.array(image).astype(np.uint8)
194 |
195 | image = Image.fromarray(img)
196 | image = image.resize((512, 512), resample=PIL.Image.BICUBIC)
197 |
198 | image = np.array(image).astype(np.uint8)
199 | image = (image / 127.5 - 1.0).astype(np.float32)
200 |
201 | image = torch.from_numpy(image).permute(2, 0, 1)
202 | return image
203 |
204 | if __name__ == '__main__':
205 | main()
206 |
--------------------------------------------------------------------------------
/scripts/seg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 | import argparse
5 |
6 | import numpy as np
7 | import torch
8 | from torchvision import transforms
9 | from tqdm import tqdm
10 | from PIL import Image
11 |
12 | from clipseg.models.clipseg import CLIPDensePredT
13 |
14 | PROMPT_TEMPLATE = {
15 | 'a photo of a {}': 4,
16 | 'a good photo of a {}': 5,
17 | 'the photo of a {}': 4,
18 | 'a good photo of the {}': 5,
19 | 'image of a {}': 3,
20 | 'image of the {}': 3,
21 | 'A photograph of {}': 3,
22 | 'A {} shown in a photo': 1,
23 | 'A photo of {}': 3,
24 | }
25 |
26 | if __name__ == '__main__':
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--image_dir', type=str, required=True)
29 | parser.add_argument('--super_class', type=str, required=True)
30 | parser.add_argument('--model_weight', type=str, default='./clipseg/weights/rd64-uni-refined.pth')
31 | parser.add_argument('--device', type=str, default='cuda')
32 | args = parser.parse_args()
33 |
34 | image_dir = args.image_dir
35 | super_class = args.super_class
36 | model_weight_path = args.model_weight
37 | device = args.device
38 |
39 | files = [f for f in os.listdir(image_dir) if f.split('.')[-1].lower() in ['png', 'jpg', 'jpeg']]
40 | image_paths = [os.path.join(image_dir, f) for f in files]
41 | mask_save_paths = [os.path.join(image_dir, 'mask', f) for f in files]
42 |
43 | os.makedirs(os.path.join(image_dir, 'mask'))
44 |
45 | prompts = [prompt.format(super_class) for prompt in PROMPT_TEMPLATE]
46 |
47 | model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
48 | model.load_state_dict(torch.load(model_weight_path, map_location='cpu'), strict=False)
49 | model = model.eval().to(device)
50 | transform = transforms.Compose([
51 | transforms.ToTensor(),
52 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
53 | transforms.Resize((352, 352)),
54 | ])
55 |
56 | with torch.no_grad():
57 | for img_p, save_p in tqdm(zip(image_paths, mask_save_paths)):
58 | img = Image.open(img_p).convert('RGB')
59 | img = transform(img).unsqueeze(0).to(device)
60 | preds = model(img.repeat(len(prompts), 1, 1, 1), prompts)[0]
61 | preds = torch.sigmoid(preds)
62 | pred = preds.mean(dim=0)[0]
63 | pred = Image.fromarray((pred.cpu().numpy() * 255).astype(np.uint8))
64 | pred.save(save_p)
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import pyrallis
4 | from diffusers.utils import check_min_version
5 |
6 | sys.path.append(".")
7 | sys.path.append("..")
8 |
9 | from training.coach_2 import Coach
10 | from training.config import RunConfig
11 |
12 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
13 | check_min_version("0.14.0")
14 |
15 |
16 | @pyrallis.wrap()
17 | def main(cfg: RunConfig):
18 | prepare_directories(cfg=cfg)
19 | coach = Coach(cfg)
20 | coach.train()
21 |
22 |
23 | def prepare_directories(cfg: RunConfig):
24 | cfg.log.exp_dir.mkdir(parents=True, exist_ok=True)
25 | cfg.log.logging_dir = cfg.log.exp_dir / cfg.log.logging_dir
26 | cfg.log.logging_dir.mkdir(parents=True, exist_ok=True)
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/sd_pipeline_call.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional, Union
2 | import tqdm
3 | import torch
4 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline
5 |
6 | @torch.no_grad()
7 | def sd_pipeline_call(
8 | pipeline: StableDiffusionPipeline,
9 | prompt_embeds: torch.FloatTensor,
10 | height: Optional[int] = None,
11 | width: Optional[int] = None,
12 | num_inference_steps: int = 50,
13 | guidance_scale: float = 8.0,
14 | negative_prompt: Optional[Union[str, List[str]]] = None,
15 | num_images_per_prompt: Optional[int] = 1,
16 | eta: float = 0.0,
17 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
18 | latents: Optional[torch.FloatTensor] = None,
19 | output_type: Optional[str] = "pil",
20 | return_dict: bool = True,
21 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
22 | callback_steps: int = 1,
23 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
24 | noise_latents: Optional[torch.FloatTensor] = None):
25 | """ Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument."""
26 |
27 | # 0. Default height and width to unet
28 | height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
29 | width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
30 |
31 | # 2. Define call parameters
32 | batch_size = 1
33 | device = pipeline._execution_device
34 |
35 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
36 | neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
37 | negative_prompt_embeds, _ = pipeline.text_encoder(
38 | input_ids=neg_prompt.input_ids.to(device),
39 | attention_mask=None,
40 | )
41 | negative_prompt_embeds = negative_prompt_embeds[0]
42 |
43 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
44 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
45 | # corresponds to doing no classifier free guidance.
46 | do_classifier_free_guidance = guidance_scale > 1.0
47 |
48 | # 4. Prepare timesteps
49 | pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
50 | timesteps = pipeline.scheduler.timesteps
51 |
52 | # 5. Prepare latent variables
53 | if noise_latents is not None:
54 | latents = noise_latents
55 | else:
56 | num_channels_latents = pipeline.unet.in_channels
57 | latents = pipeline.prepare_latents(
58 | batch_size * num_images_per_prompt,
59 | num_channels_latents,
60 | height,
61 | width,
62 | pipeline.text_encoder.dtype,
63 | device,
64 | generator,
65 | latents,
66 | )
67 |
68 | # 6. Prepare extra step kwargs.
69 | extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
70 |
71 | # 7. Denoising loop
72 | num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
73 | with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
74 | for i, t in enumerate(timesteps):
75 |
76 | if do_classifier_free_guidance:
77 | latent_model_input = latents
78 | latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
79 |
80 | # predict the noise residual
81 | noise_pred_uncond = pipeline.unet(
82 | latent_model_input,
83 | t,
84 | encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
85 | cross_attention_kwargs=cross_attention_kwargs,
86 | ).sample
87 |
88 | embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds
89 |
90 | noise_pred_text = pipeline.unet(
91 | latent_model_input,
92 | t,
93 | encoder_hidden_states=embed,
94 | cross_attention_kwargs=cross_attention_kwargs,
95 | ).sample
96 |
97 | # perform guidance
98 | if do_classifier_free_guidance:
99 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
100 |
101 | # compute the previous noisy sample x_t -> x_t-1
102 | latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
103 |
104 | # call the callback, if provided
105 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
106 | progress_bar.update()
107 | if callback is not None and i % callback_steps == 0:
108 | callback(i, t, latents)
109 |
110 | if output_type == "latent":
111 | image = latents
112 | has_nsfw_concept = None
113 | elif output_type == "pil":
114 | # 8. Post-processing
115 | image = pipeline.decode_latents(latents)
116 | # 9. Run safety checker
117 | # image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
118 | has_nsfw_concept = None
119 | # 10. Convert to PIL
120 | image = pipeline.numpy_to_pil(image)
121 | else:
122 | # 8. Post-processing
123 | image = pipeline.decode_latents(latents)
124 | # 9. Run safety checker
125 | image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
126 |
127 |
128 | # Offload last model to CPU
129 | if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
130 | pipeline.final_offload_hook.offload()
131 |
132 | if not return_dict:
133 | return image, has_nsfw_concept
134 |
135 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
136 |
137 |
138 | def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline,
139 | negative_prompt: Optional[Union[str, List[str]]] = None):
140 | if negative_prompt is None:
141 | negative_prompt = ""
142 | uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
143 | uncond_input = pipeline.tokenizer(
144 | uncond_tokens,
145 | padding="max_length",
146 | max_length=pipeline.tokenizer.model_max_length,
147 | truncation=True,
148 | return_tensors="pt",
149 | )
150 | return uncond_input
151 |
152 | @torch.no_grad()
153 | def invert(
154 | pipeline: StableDiffusionPipeline,
155 | prompt_embeds: torch.FloatTensor,
156 | start_latents: torch.FloatTensor,
157 | num_inference_steps: int = 50,
158 | guidance_scale: float = 8.0,
159 | negative_prompt: Optional[Union[str, List[str]]] = None,
160 | num_images_per_prompt: Optional[int] = 1):
161 |
162 | batch_size = 1
163 | device = pipeline._execution_device
164 |
165 | neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
166 | negative_prompt_embeds, _ = pipeline.text_encoder(
167 | input_ids=neg_prompt.input_ids.to(device),
168 | attention_mask=None,
169 | )
170 | negative_prompt_embeds = negative_prompt_embeds[0]
171 |
172 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
173 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
174 | # corresponds to doing no classifier free guidance.
175 | do_classifier_free_guidance = guidance_scale > 1.0
176 |
177 | pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
178 | timesteps = reversed(pipeline.scheduler.timesteps)
179 |
180 | latents = start_latents.clone()
181 | intermediate_latents = []
182 |
183 | for i, t in enumerate(timesteps):
184 | if i >= num_inference_steps - 1: continue
185 |
186 | t = timesteps[i]
187 |
188 | # predict the noise residual
189 | latent_model_input = latents
190 | latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
191 |
192 | embed = prompt_embeds[0] if type(prompt_embeds) == list else prompt_embeds
193 |
194 | noise_pred_text = pipeline.unet(
195 | latent_model_input,
196 | t,
197 | encoder_hidden_states=embed,
198 | ).sample
199 |
200 | if do_classifier_free_guidance:
201 | noise_pred_uncond = pipeline.unet(
202 | latent_model_input,
203 | t,
204 | encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
205 | ).sample
206 |
207 | noise_pred_text = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
208 |
209 | current_t = max(0, t.item() - (1000//num_inference_steps))#t
210 | next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1
211 | alpha_t = pipeline.scheduler.alphas_cumprod[current_t]
212 | alpha_t_next = pipeline.scheduler.alphas_cumprod[next_t]
213 |
214 | latents = (latents - (1-alpha_t).sqrt()*noise_pred_text)*(alpha_t_next.sqrt()/alpha_t.sqrt()) + (1-alpha_t_next).sqrt()*noise_pred_text
215 | intermediate_latents.append(latents)
216 |
217 | return torch.cat(intermediate_latents)
218 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/training/__init__.py
--------------------------------------------------------------------------------
/training/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from pathlib import Path
3 | from typing import List, Optional, Dict
4 |
5 | from constants import VALIDATION_PROMPTS
6 |
7 |
8 | @dataclass
9 | class LogConfig:
10 | """ Parameters for logging and saving """
11 | # Name of experiment. This will be the name of the output folder
12 | exp_name: str
13 | # The output directory where the model predictions and checkpoints will be written
14 | exp_dir: Path = Path("./logs")
15 | # Save interval
16 | save_steps: int = 250
17 | # [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
18 | # `output_dir/runs/**CURRENT_DATETIME_HOSTNAME`
19 | logging_dir: Path = Path("logs")
20 | # The integration to report the results to. Supported platforms are "tensorboard" '
21 | # (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
22 | report_to: str = "tensorboard"
23 | # Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator`
24 | checkpoints_total_limit: Optional[int] = None
25 |
26 |
27 | @dataclass
28 | class DataConfig:
29 | """ Parameters for data """
30 | # A folder containing the training data
31 | train_data_dir: Path
32 | # A token to use as a placeholder for the concept
33 | placeholder_token: str
34 | # Super category token to use for normalizing the mapper output
35 | super_category_token: Optional[str] = "object"
36 | # Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process
37 | dataloader_num_workers: int = 8
38 | # Choose between 'object' and 'style' - used for selecting the prompts for training
39 | learnable_property: str = "object"
40 | # How many times to repeat the training data
41 | repeats: int = 100
42 | # The resolution for input images, all the images in the train/validation dataset will be resized to this resolution
43 | resolution: int = 512
44 | # Whether to center crop images before resizing to resolution
45 | center_crop: bool = False
46 |
47 |
48 | @dataclass
49 | class ModelConfig:
50 | """ Parameters for defining all models """
51 | # Path to pretrained model or model identifier from huggingface.co/models
52 | pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4"
53 | # Path to pretrained clip model
54 | clip_ckpt_path: str = "/path/to/clip/ckpt"
55 | # Whether to use our Nested Dropout technique
56 | use_nested_dropout: bool = True
57 | # Probability to apply nested dropout during training
58 | nested_dropout_prob: float = 0.5
59 | # Whether to normalize the norm of the mapper's output vector
60 | normalize_mapper_output: bool = True
61 | # Target norm for the mapper's output vector
62 | target_norm: Optional[float] = None
63 | # Whether to use positional encoding over the input to the mapper
64 | use_positional_encoding: bool = True
65 | # Sigmas used for computing positional encoding
66 | sigma_t: float = 1.0
67 | # Number of time anchors for computing our positional encodings
68 | num_pe_time_anchors: int = 10
69 | # Whether to output the textual bypass vector
70 | output_bypass: bool = True
71 | # Revision of pretrained model identifier from huggingface.co/models
72 | revision: Optional[str] = None
73 | # Whether training should be resumed from a previous checkpoint.
74 | mapper_checkpoint_path: Optional[Path] = None
75 |
76 |
77 | @dataclass
78 | class EvalConfig:
79 | """ Parameters for validation """
80 | # A list of prompts that will be used during validation to verify that the model is learning
81 | validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS)
82 | # Number of images that should be generated during validation with `validation_prompt`
83 | num_validation_images: int = 4
84 | # Seeds to use for generating the validation images
85 | validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456])
86 | # Run validation every X steps.
87 | validation_steps: int = 100
88 | # Number of denoising steps
89 | num_denoising_steps: int = 50
90 |
91 | def __post_init__(self):
92 | if self.validation_seeds is None:
93 | self.validation_seeds = list(range(self.num_validation_images))
94 | assert len(self.validation_seeds) == self.num_validation_images, \
95 | "Length of validation_seeds should equal num_validation_images"
96 |
97 | @dataclass
98 | class OptimConfig:
99 | """ Parameters for the optimization process """
100 | # Total number of training steps to perform.
101 | max_train_steps: Optional[int] = 1_000
102 | # Learning rate
103 | learning_rate: float = 1e-3
104 | # Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size
105 | scale_lr: bool = True
106 | # Batch size (per device) for the training dataloader
107 | train_batch_size: int = 2
108 | # Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass
109 | gradient_checkpointing: bool = False
110 | # Number of updates steps to accumulate before performing a backward/update pass
111 | gradient_accumulation_steps: int = 4
112 | # A seed for reproducible training
113 | seed: Optional[int] = None
114 | # The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",
115 | # "constant", "constant_with_warmup"]
116 | lr_scheduler: str = "constant"
117 | # Number of steps for the warmup in the lr scheduler
118 | lr_warmup_steps: int = 0
119 | # The beta1 parameter for the Adam optimizer
120 | adam_beta1: float = 0.9
121 | # The beta2 parameter for the Adam optimizer
122 | adam_beta2: float = 0.999
123 | # Weight decay to use
124 | adam_weight_decay: float = 1e-2
125 | # Epsilon value for the Adam optimizer
126 | adam_epsilon: float = 1e-08
127 | # Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.
128 | # and an Nvidia Ampere GPU.
129 | mixed_precision: str = "no"
130 | # Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see
131 | # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
132 | allow_tf32: bool = False
133 |
134 |
135 | @dataclass
136 | class RunConfig:
137 | """ The main configuration for the coach trainer """
138 | log: LogConfig = field(default_factory=LogConfig)
139 | data: DataConfig = field(default_factory=DataConfig)
140 | model: ModelConfig = field(default_factory=ModelConfig)
141 | eval: EvalConfig = field(default_factory=EvalConfig)
142 | optim: OptimConfig = field(default_factory=OptimConfig)
143 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from pathlib import Path
3 | from typing import Dict, Any
4 | import PIL
5 | import numpy as np
6 | import torch
7 | import torch.utils.checkpoint
8 | from PIL import Image
9 | from packaging import version
10 | from torch.utils.data import Dataset
11 | from torchvision import transforms
12 | from transformers import CLIPTokenizer
13 | from constants import IMAGENET_STYLE_TEMPLATES_SMALL, IMAGENET_TEMPLATES_SMALL, PROMPTS
14 |
15 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
16 | PIL_INTERPOLATION = {
17 | "linear": PIL.Image.Resampling.BILINEAR,
18 | "bilinear": PIL.Image.Resampling.BILINEAR,
19 | "bicubic": PIL.Image.Resampling.BICUBIC,
20 | "lanczos": PIL.Image.Resampling.LANCZOS,
21 | "nearest": PIL.Image.Resampling.NEAREST,
22 | }
23 | else:
24 | PIL_INTERPOLATION = {
25 | "linear": PIL.Image.LINEAR,
26 | "bilinear": PIL.Image.BILINEAR,
27 | "bicubic": PIL.Image.BICUBIC,
28 | "lanczos": PIL.Image.LANCZOS,
29 | "nearest": PIL.Image.NEAREST,
30 | }
31 |
32 |
33 | class TextualInversionDataset(Dataset):
34 |
35 | def __init__(self, data_root: Path,
36 | tokenizer: CLIPTokenizer,
37 | learnable_property: str = "object", # [object, style]
38 | size: int = 512,
39 | repeats: int = 100,
40 | interpolation: str = "bicubic",
41 | flip_p: float = 0.5,
42 | set: str = "train",
43 | super_category_token: str = "face",
44 | placeholder_token: list = [],
45 | center_crop: bool = False):
46 | self.data_root = data_root
47 | self.tokenizer = tokenizer
48 | self.learnable_property = learnable_property
49 | self.size = size
50 | self.super_category_token = super_category_token
51 | self.placeholder_token = placeholder_token
52 | self.center_crop = center_crop
53 | self.flip_p = flip_p
54 |
55 | if "." in str(self.data_root) and str(self.data_root).split(".")[-1].lower() in ["jpg", "jpeg", "png"]:
56 | self.image_paths = [self.data_root]
57 | else:
58 | self.image_paths = list(self.data_root.glob("*.*"))
59 |
60 | self.num_images = len(self.image_paths)
61 | self._length = self.num_images
62 |
63 | print(f"Running on {self.num_images} images")
64 |
65 | if set == "train":
66 | self._length = self.num_images * repeats
67 |
68 | self.interpolation = {
69 | "linear": PIL_INTERPOLATION["linear"],
70 | "bilinear": PIL_INTERPOLATION["bilinear"],
71 | "bicubic": PIL_INTERPOLATION["bicubic"],
72 | "lanczos": PIL_INTERPOLATION["lanczos"],
73 | }[interpolation]
74 |
75 | self.templates = IMAGENET_STYLE_TEMPLATES_SMALL if learnable_property == "style" else IMAGENET_TEMPLATES_SMALL
76 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
77 |
78 | self.templates_2 = PROMPTS
79 |
80 | def __len__(self) -> int:
81 | return self._length
82 |
83 | def __getitem__(self, i: int) -> Dict[str, Any]:
84 | image_path = self.image_paths[i % self.num_images]
85 | image = Image.open(image_path)
86 |
87 | if not image.mode == "RGB":
88 | image = image.convert("RGB")
89 |
90 | example = dict()
91 |
92 | image_name = image_path.name
93 | pre_path = image_path.parent.joinpath('mask')
94 | if pre_path.exists():
95 | mask_path = pre_path.joinpath(image_name)
96 | org_mask = Image.open(mask_path).convert("L")
97 | mask = org_mask.resize((self.size//8, self.size//8), Image.NEAREST)
98 | mask = np.array(mask) / 255
99 | else:
100 | mask = np.ones((self.size//8, self.size//8))
101 | mask = mask[np.newaxis, ...]
102 | mask = torch.from_numpy(mask)
103 | example['mask'] = mask
104 |
105 | text = random.choice(self.templates)
106 | example['text'] = text.format(' '.join(self.placeholder_token))
107 | example["input_ids"] = self.tokenizer(
108 | example['text'],
109 | padding="max_length",
110 | truncation=True,
111 | max_length=self.tokenizer.model_max_length,
112 | return_tensors="pt",
113 | ).input_ids[0]
114 |
115 | example["input_ids_length"] = len(self.tokenizer(example['text'])['input_ids'])-2
116 |
117 | open_text = random.choice(self.templates_2).format(' '.join(self.placeholder_token))
118 | example["open_input_ids"] = self.tokenizer(
119 | open_text,
120 | padding="max_length",
121 | truncation=True,
122 | max_length=self.tokenizer.model_max_length,
123 | return_tensors="pt",
124 | ).input_ids[0]
125 | example["open_input_ids_length"] = len(self.tokenizer(open_text)['input_ids'])-2
126 |
127 | # default to score-sde preprocessing
128 | img = np.array(image).astype(np.uint8)
129 |
130 | if self.center_crop:
131 | crop = min(img.shape[0], img.shape[1])
132 | h, w = img.shape[0], img.shape[1]
133 | img = img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2]
134 |
135 | image = Image.fromarray(img)
136 | image = image.resize((self.size, self.size), resample=self.interpolation)
137 |
138 | # image = self.flip_transform(image)
139 | image = np.array(image).astype(np.uint8)
140 | image = (image / 127.5 - 1.0).astype(np.float32)
141 |
142 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
143 | return example
144 |
--------------------------------------------------------------------------------
/training/logger.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pyrallis
4 | from diffusers.utils import is_wandb_available
5 | from loguru import logger
6 |
7 | from training.config import RunConfig
8 |
9 |
10 | class CoachLogger:
11 |
12 | def __init__(self, cfg: RunConfig):
13 | self.cfg = cfg
14 | self.step = 0
15 | self.configure_loguru()
16 | self.log_config()
17 | self.validate_wandb()
18 |
19 | def configure_loguru(self):
20 | logger.remove()
21 | format = '{time:YYYY-MM-DD HH:mm:ss} {message}'
22 | logger.add(sys.stdout, colorize=True, format=format)
23 | logger.add(self.cfg.log.logging_dir / 'log.txt', colorize=False, format=format)
24 |
25 | def log_config(self):
26 | with (self.cfg.log.exp_dir / 'config.yaml').open('w') as f:
27 | pyrallis.dump(self.cfg, f)
28 | self.log_message('\n' + pyrallis.dump(self.cfg))
29 |
30 | def validate_wandb(self):
31 | if self.cfg.log.report_to == "wandb":
32 | if not is_wandb_available():
33 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
34 |
35 | @staticmethod
36 | def log_message(msg: str):
37 | logger.info(msg)
38 |
39 | def log_start_of_training(self, total_batch_size: int, num_samples: int):
40 | self.log_message("***** Running training *****")
41 | self.log_message(f" Num examples = {num_samples}")
42 | self.log_message(f" Instantaneous batch size per device = {self.cfg.optim.train_batch_size}")
43 | self.log_message(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
44 | self.log_message(f" Gradient Accumulation steps = {self.cfg.optim.gradient_accumulation_steps}")
45 | self.log_message(f" Total optimization steps = {self.cfg.optim.max_train_steps}")
46 |
47 | def update_step(self, step: int):
48 | self.step = step
49 |
--------------------------------------------------------------------------------
/training/validate.py:
--------------------------------------------------------------------------------
1 | from email.mime import image
2 | from typing import List
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from accelerate import Accelerator
8 | from accelerate.utils import set_seed
9 | from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, AutoencoderKL
10 | from diffusers.utils import is_wandb_available
11 | from tqdm import tqdm
12 | from transformers import CLIPTokenizer
13 |
14 | from training.config import RunConfig
15 | from models.clip_text_encoder import PersonaCLIPTextModel
16 | from models.xti_attention_processor import XTIAttenProc, MyAttenProc
17 | from prompt_manager import PromptManager
18 | from sd_pipeline_call import sd_pipeline_call
19 |
20 | if is_wandb_available():
21 | import wandb
22 |
23 |
24 | class ValidationHandler:
25 |
26 | def __init__(self, cfg: RunConfig, placeholder_token_id: int, weights_dtype: torch.dtype):
27 | self.cfg = cfg
28 | self.placeholder_token_id = placeholder_token_id
29 | self.weight_dtype = weights_dtype
30 |
31 | def infer(self, accelerator: Accelerator,
32 | tokenizer: CLIPTokenizer,
33 | text_encoder: PersonaCLIPTextModel,
34 | unet: UNet2DConditionModel, vae: AutoencoderKL,
35 | prompts: List[str],
36 | num_images_per_prompt: int,
37 | seeds: List[int],
38 | step: int,
39 | word_embedding: torch.Tensor,
40 | image_embedding: torch.Tensor):
41 | """ Runs inference during our training scheme. """
42 | pipeline = self.load_stable_diffusion_model(accelerator, tokenizer, text_encoder, unet, vae)
43 | prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
44 | text_encoder=pipeline.text_encoder,
45 | timesteps=pipeline.scheduler.timesteps,
46 | placeholder_token=self.cfg.data.placeholder_token,
47 | placeholder_token_id=self.placeholder_token_id)
48 | joined_images = []
49 | for prompt in prompts:
50 | images = self.infer_on_prompt(pipeline=pipeline,
51 | prompt_manager=prompt_manager,
52 | prompt=prompt,
53 | num_images_per_prompt=num_images_per_prompt,
54 | seeds=seeds,
55 | word_embedding=word_embedding,
56 | image_embedding=image_embedding)
57 | prompt_image = Image.fromarray(np.concatenate(images, axis=1))
58 | joined_images.append(prompt_image)
59 | final_image = Image.fromarray(np.concatenate(joined_images, axis=0))
60 | final_image.save(self.cfg.log.exp_dir / f"val-image-{step}.png")
61 | self.log_with_accelerator(accelerator, joined_images, step=step)
62 | del pipeline
63 | torch.cuda.empty_cache()
64 | text_encoder.text_model.embeddings.mapper.train()
65 | if self.cfg.optim.seed is not None:
66 | set_seed(self.cfg.optim.seed)
67 | return final_image
68 |
69 | def infer_on_prompt(self, pipeline: StableDiffusionPipeline,
70 | prompt_manager: PromptManager,
71 | prompt: str,
72 | seeds: List[int],
73 | word_embedding: torch.Tensor,
74 | image_embedding: torch.Tensor,
75 | num_images_per_prompt: int = 1) -> List[Image.Image]:
76 | prompt_embeds = self.compute_embeddings(prompt_manager=prompt_manager, prompt=prompt, word_embedding=word_embedding, image_embedding=image_embedding)
77 | all_images = []
78 | for idx in tqdm(range(num_images_per_prompt)):
79 | generator = torch.Generator(device='cuda').manual_seed(seeds[idx])
80 | images = sd_pipeline_call(pipeline,
81 | prompt_embeds=prompt_embeds,
82 | generator=generator,
83 | num_images_per_prompt=1).images
84 | all_images.extend(images)
85 | return all_images
86 |
87 | @staticmethod
88 | def compute_embeddings(prompt_manager: PromptManager, prompt: str, word_embedding: torch.Tensor, image_embedding: torch.Tensor) -> torch.Tensor:
89 | with torch.autocast("cuda"):
90 | with torch.no_grad():
91 | prompt_embeds = prompt_manager.my_embed_prompt(prompt, word_embedding, image_embedding)
92 | return prompt_embeds
93 |
94 | def load_stable_diffusion_model(self, accelerator: Accelerator,
95 | tokenizer: CLIPTokenizer,
96 | text_encoder: PersonaCLIPTextModel,
97 | unet: UNet2DConditionModel,
98 | vae: AutoencoderKL) -> StableDiffusionPipeline:
99 | """ Loads SD model given the current text encoder and our mapper. """
100 | pipeline = StableDiffusionPipeline.from_pretrained(self.cfg.model.pretrained_model_name_or_path,
101 | text_encoder=accelerator.unwrap_model(text_encoder),
102 | tokenizer=tokenizer,
103 | unet=unet,
104 | vae=vae,
105 | torch_dtype=self.weight_dtype)
106 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
107 | pipeline = pipeline.to(accelerator.device)
108 | pipeline.set_progress_bar_config(disable=True)
109 | pipeline.scheduler.set_timesteps(self.cfg.eval.num_denoising_steps, device=pipeline.device)
110 | # pipeline.unet.set_attn_processor(XTIAttenProc())
111 | pipeline.unet.set_attn_processor(MyAttenProc())
112 | text_encoder.text_model.embeddings.mapper.eval()
113 | return pipeline
114 |
115 | def log_with_accelerator(self, accelerator: Accelerator, images: List[Image.Image], step: int):
116 | for tracker in accelerator.trackers:
117 | if tracker.name == "tensorboard":
118 | np_images = np.stack([np.asarray(img) for img in images])
119 | tracker.writer.add_images("validation", np_images, step, dataformats="NHWC")
120 | if tracker.name == "wandb":
121 | tracker.log({"validation": [wandb.Image(image, caption=f"{i}: {self.cfg.eval.validation_prompts[i]}")
122 | for i, image in enumerate(images)]})
123 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xzhe-Vision/PersonaMagic/64e5369ca97d5adca3c381e44a8eb769c89cb823/utils/__init__.py
--------------------------------------------------------------------------------
/utils/types.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import torch
6 |
7 |
8 | @dataclass
9 | class Mapper_input:
10 | timesteps: torch.Tensor
11 | word_embedding: torch.Tensor
12 | image_embedding: torch.Tensor
13 |
14 | @dataclass
15 | class Batch:
16 | input_ids: torch.Tensor
17 | placeholder_token_id: int
18 | mapper_input: Mapper_input
19 |
20 |
--------------------------------------------------------------------------------
/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List
3 |
4 | from PIL import Image
5 |
6 |
7 | def get_image_grid(images: List[Image.Image]) -> Image:
8 | num_images = len(images)
9 | cols = int(math.ceil(math.sqrt(num_images)))
10 | rows = int(math.ceil(num_images / cols))
11 | width, height = images[0].size
12 | grid_image = Image.new('RGB', (cols * width, rows * height))
13 | for i, img in enumerate(images):
14 | x = i % cols
15 | y = i // cols
16 | grid_image.paste(img, (x * width, y * height))
17 | return grid_image
18 |
--------------------------------------------------------------------------------