├── .gitignore ├── LICENSE ├── README.md ├── examples ├── image_captioning │ ├── README.md │ ├── requirements.txt │ └── run.py ├── object_detection │ ├── README.md │ ├── clip_rcnn.py │ ├── requirements.txt │ ├── resources │ │ └── preds.jpg │ └── run.py └── vqgan_clip │ ├── README.md │ ├── __init__.py │ ├── configs │ └── vqvae.yaml │ ├── requirements.txt │ ├── run.py │ └── vqvae │ ├── __init__.py │ ├── image_processor.py │ ├── utils.py │ └── vqvae.py ├── pytorch_clip_guided_loss ├── __init__.py ├── clip_guided_loss.py └── utils.py ├── requirements.txt ├── resources └── preview.png ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_clip_guided_loss: Pytorch implementation of the CLIP guided loss for Text-To-Image, Image-To-Image, or Image-To-Text generation. 2 | 3 | A simple library that implements CLIP guided loss in PyTorch. 4 | 5 |

6 | 7 |

8 | 9 | [![Downloads](https://pepy.tech/badge/pytorch_clip_guided_loss)](https://pepy.tech/project/pytorch_clip_guided_loss) 10 | [![Downloads](https://pepy.tech/badge/pytorch_clip_guided_loss/month)](https://pepy.tech/project/pytorch_clip_guided_loss) 11 | [![Downloads](https://pepy.tech/badge/pytorch_clip_guided_loss/week)](https://pepy.tech/project/pytorch_clip_guided_loss) 12 | 13 | 14 | ## Install package 15 | 16 | ```bash 17 | pip install pytorch_clip_guided_loss 18 | ``` 19 | 20 | ## Install the latest version 21 | 22 | ```bash 23 | pip install --upgrade git+https://github.com/bes-dev/pytorch_clip_guided_loss.git 24 | ``` 25 | 26 | ## Features 27 | - The library supports multiple prompts (images or texts) as targets for optimization. 28 | - The library automatically detects the language of the input text, and multilingual translate it via google translate. 29 | - The library supports the original CLIP model by OpenAI and ruCLIP model by SberAI. 30 | 31 | ## Usage 32 | 33 | ### Simple code 34 | 35 | ```python 36 | import torch 37 | from pytorch_clip_guided_loss import get_clip_guided_loss 38 | 39 | loss_fn = get_clip_guided_loss(clip_type="ruclip", input_range = (-1, 1)).eval().requires_grad_(False) 40 | # text prompt 41 | loss_fn.add_prompt(text="text description of the what we would like to generate") 42 | # image prompt 43 | loss_fn.add_prompt(image=torch.randn(1, 3, 224, 224)) 44 | 45 | # variable 46 | var = torch.randn(1, 3, 224, 224).requires_grad_(True) 47 | loss = loss_fn.image_loss(image=var)["loss"] 48 | loss.backward() 49 | print(var.grad) 50 | ``` 51 | 52 | ### VQGAN-CLIP 53 | 54 | We provide our tiny implementation of the VQGAN-CLIP pipeline for image generation as an example of the usage of our library. 55 | To start using our implementation of the VQGAN-CLIP please follow by [documentation](examples/vqgan_clip). 56 | 57 | ### Zero-shot Object Detection 58 | 59 | We provide our tiny implementation of the object detector based on Selective Search region proposals and CLIP guided loss. 60 | To start using our implementation of the ClipRCNN please follow by [documentation](examples/object_detection). -------------------------------------------------------------------------------- /examples/image_captioning/README.md: -------------------------------------------------------------------------------- 1 | # Simple CLIP guided image captioning 2 | 3 | Simple gradient based CLIP guided image captioning. 4 | 5 | ## Usage 6 | 7 | ### Install requirements 8 | 9 | ```bash 10 | $ pip install -r requirements.txt 11 | ``` 12 | 13 | ### Generate image caption 14 | 15 | ```bash 16 | $ python run.py --help 17 | 18 | usage: run.py [-h] [--device DEVICE] [--clip-type CLIP_TYPE] [--lr LR] [--n-steps N_STEPS] [--length-min LENGTH_MIN] [--length-max LENGTH_MAX] [--img-path IMG_PATH] 19 | 20 | optional arguments: 21 | -h, --help show this help message and exit 22 | --device DEVICE inference device. 23 | --clip-type CLIP_TYPE 24 | Type of CLIP model [clip, ruclip]. 25 | --lr LR Learning rate. 26 | --n-steps N_STEPS Number steps of optimization. 27 | --length-min LENGTH_MIN 28 | Minimum sequence length 29 | --length-max LENGTH_MAX 30 | Maximum sequence length 31 | --img-path IMG_PATH Path to input image 32 | ``` 33 | 34 | ```bash 35 | $ python run.py --img-path --length-min 1 --length-max 32 --clip-type ruclip 36 | ``` -------------------------------------------------------------------------------- /examples/image_captioning/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | cv2 3 | pytorch_clip 4 | pytorch_clip_guided_loss 5 | numpy 6 | tqdm 7 | -------------------------------------------------------------------------------- /examples/image_captioning/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import argparse 14 | import typing 15 | import cv2 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from pytorch_clip.processor.text_processor import TextProcessor 20 | from pytorch_clip_guided_loss import get_clip_guided_loss 21 | from tqdm.autonotebook import tqdm 22 | # image transforms 23 | import albumentations as A 24 | from albumentations.pytorch import ToTensorV2 25 | 26 | 27 | class STEQuantize(torch.autograd.Function): 28 | """ Quantize embeddings to codebook with 29 | gradients in style of Straight-Through Estimators. 30 | """ 31 | @staticmethod 32 | def forward(ctx, embs: torch.Tensor, codebook: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: 33 | """ Forward path. 34 | Arguments: 35 | embs (torch.Tensor): input embeddings. 36 | codebook (torch.Tensor): codebook. 37 | Returns: 38 | embs_q (torch.Tensor): quantized embeddings 39 | """ 40 | d = embs.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * embs @ codebook.T 41 | indices = d.argmin(-1) 42 | embs_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 43 | return embs_q, indices 44 | 45 | @staticmethod 46 | def backward(ctx, grad_in: torch.Tensor, grad_ids: torch.Tensor) -> typing.Tuple[torch.Tensor, None]: 47 | """ Backward path like Straight-Through Estimators. 48 | Arguments: 49 | grad_in (torch.Tensor): input gradients. 50 | Returns: 51 | grad_out (torch.Tensor): STE gradients. 52 | """ 53 | return grad_in, None 54 | 55 | def ste_quantize(x: torch.Tensor, codebook: torch.tensor) -> torch.Tensor: 56 | """ Quantize embeddings to codebook with 57 | gradients in style of Straight-Through Estimators. 58 | Arguments: 59 | embs (torch.Tensor): input embeddings. 60 | codebook (torch.Tensor): codebook. 61 | Returns: 62 | embs_q (torch.Tensor): quantized embeddings 63 | """ 64 | return STEQuantize.apply(x, codebook) 65 | 66 | 67 | class MaskedGrad(torch.autograd.Function): 68 | """ Apply masked gradients 69 | """ 70 | @staticmethod 71 | def forward(ctx, var: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 72 | """ Forward path. 73 | Arguments: 74 | var (torch.Tensor): input variable. 75 | mask (torch.Tensor): mask of the gradient. 76 | Returns: 77 | var (torch.Tensor): input variable. 78 | """ 79 | ctx.save_for_backward(mask) 80 | return var 81 | 82 | @staticmethod 83 | def backward(ctx, grad_in: torch.Tensor) -> typing.Tuple[torch.Tensor, None]: 84 | """ Backward path returns masked gradient for variable. 85 | Arguments: 86 | grad_in (torch.Tensor): input gradients. 87 | Returns: 88 | grad_out (torch.Tensor): masked gradients. 89 | """ 90 | mask, = ctx.saved_tensors 91 | grad_out = grad_in * mask 92 | return grad_out, None 93 | 94 | 95 | def masked_grad(var: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 96 | """ Apply masked gradients. 97 | Arguments: 98 | var (torch.Tensor): input variable. 99 | mask (torch.Tensor): mask of the gradient. 100 | Returns: 101 | var (torch.Tensor): input variable. 102 | """ 103 | return MaskedGrad.apply(var, mask) 104 | 105 | 106 | def init_params( 107 | tokenizer: TextProcessor, 108 | length_min: int, 109 | length_max: int, 110 | dictionary: nn.Module, 111 | device: str 112 | ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 113 | """ Initialize random embeddings. 114 | Arguments: 115 | tokenizer (TextProcessor): text tokenizer. 116 | length_min (int): minimum length of the target text. 117 | length_max (int): maximum length of the target text. 118 | dictionary (nn.Module): dictionary of the embeddings. 119 | device (str): target device. 120 | Returns: 121 | embeds (torch.Tensor): random embeddings. 122 | attention_mask (torch.Tensor): attention mask. 123 | gradient_mask (torch.Tensor): gradient mask. 124 | """ 125 | length_min = min(length_min, tokenizer.get_max_length() - 2) 126 | length_max = min(length_max, tokenizer.get_max_length() - 2) 127 | vocab_size = tokenizer.vocab_size() 128 | bos_emb = dictionary(torch.LongTensor([[tokenizer.bos_id]]).to(device)) 129 | eos_emb = dictionary(torch.LongTensor([[tokenizer.eos_id]]).to(device)) 130 | pad_emb = dictionary(torch.LongTensor([[tokenizer.pad_id]]).to(device)) 131 | 132 | embeds = [] 133 | attention_masks = [] 134 | grad_masks = [] 135 | for l in range(length_min, length_max + 1): 136 | ids = torch.randint(0, vocab_size, (1, l), device=device) 137 | embed = dictionary(ids) 138 | embed = torch.cat([bos_emb, embed, eos_emb, pad_emb.repeat_interleave(tokenizer.get_max_length() - 2 - l, dim=1)], dim=1) 139 | embeds.append(embed) 140 | attention_masks.append( 141 | torch.LongTensor([1] * (l + 2) + [0] * (tokenizer.get_max_length() - 2 - l)).unsqueeze(0) 142 | ) 143 | grad_masks.append( 144 | torch.LongTensor([0] + [1] * l + [0] * (tokenizer.get_max_length() - 1 - l)).unsqueeze(0) 145 | ) 146 | embeds = torch.cat(embeds, dim=0).detach() 147 | embeds.requires_grad = True 148 | attention_masks = torch.cat(attention_masks, dim=0).to(device) 149 | grad_masks = torch.cat(grad_masks, dim=0).unsqueeze(-1).to(device) 150 | return embeds, attention_masks, grad_masks 151 | 152 | 153 | def remove_repeats(strings: typing.List[str]) -> typing.List[str]: 154 | """ remove repeated words and sentences 155 | Arguments: 156 | strings (typing.List[str]): input strings. 157 | Returns: 158 | strings (typing.List[str]): output strings. 159 | """ 160 | out = [] 161 | for s in strings: 162 | words = s.split() 163 | if len(words): 164 | output_words = [words[0]] 165 | for i in range(1, len(words)): 166 | if output_words[-1] != words[i]: 167 | output_words.append(words[i]) 168 | s_out = " ".join(output_words) 169 | if not s_out in out: 170 | out.append(s_out) 171 | return out 172 | 173 | 174 | def main(args): 175 | # load model 176 | clip_guided_loss = get_clip_guided_loss(args.clip_type, input_range = (0, 1)) 177 | img_transforms = A.Compose([ 178 | ToTensorV2() 179 | ]) 180 | tokenizer = clip_guided_loss.text_processor 181 | dictionary = clip_guided_loss.model.get_text_dictionary() 182 | # model to inference device 183 | clip_guided_loss.to(args.device) 184 | # initialize prompt 185 | image = (img_transforms(image=cv2.imread(args.img_path))["image"].unsqueeze(0) / 255.0).to(args.device) 186 | clip_guided_loss.add_prompt(image=image) 187 | # initialize text 188 | embeds, attention_mask, grad_mask = init_params( 189 | tokenizer, 190 | args.length_min, 191 | args.length_max, 192 | dictionary, 193 | args.device 194 | ) 195 | # initilize valid range for embeddings 196 | range_min = dictionary.weight.min(dim=0).values[None, None, :] 197 | range_max = dictionary.weight.max(dim=0).values[None, None, :] 198 | # initialize optimizer 199 | opt = torch.optim.Adam([embeds], lr=args.lr) 200 | # start optimization 201 | iterator = tqdm(range(args.n_steps)) 202 | for i in iterator: 203 | opt.zero_grad() 204 | x = masked_grad(embeds, grad_mask) 205 | x, ids = ste_quantize(x, dictionary.weight) 206 | loss = clip_guided_loss.text_loss( 207 | input_ids=ids, 208 | attention_mask=attention_mask, 209 | embed=x 210 | )["loss"] 211 | loss.backward() 212 | opt.step() 213 | with torch.inference_mode(): 214 | embeds.copy_(embeds.maximum(range_min).minimum(range_max)) 215 | iterator.set_description(f"loss: {loss.item()}") 216 | # print outputs 217 | x, ids = ste_quantize(embeds, dictionary.weight) 218 | strings = remove_repeats(tokenizer.decode(ids)) 219 | input_ids, attention_mask = [], [] 220 | for s in strings: 221 | out = tokenizer.encode(s, return_mask=True) 222 | input_ids.append(out["input_ids"]) 223 | attention_mask.append(out["attention_mask"]) 224 | input_ids = torch.cat(input_ids, dim=0).to(args.device) 225 | attention_mask = torch.cat(attention_mask, dim=0).to(args.device) 226 | embeds = dictionary(input_ids) 227 | loss = clip_guided_loss.text_loss( 228 | input_ids=input_ids, 229 | attention_mask=attention_mask, 230 | embed=embeds, 231 | reduce=None 232 | )["loss"] 233 | print(f"best caption: {strings[loss.argmin()]}") 234 | 235 | 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument("--device", type=str, default="cuda:0", help="inference device.") 239 | parser.add_argument("--clip-type", type=str, default="ruclip", help="Type of CLIP model [clip, ruclip].") 240 | parser.add_argument("--lr", type=float, default=0.1, help="Learning rate.") 241 | parser.add_argument("--n-steps", type=int, default=100, help="Number steps of optimization.") 242 | parser.add_argument("--length-min", type=int, default=10, help="Minimum sequence length") 243 | parser.add_argument("--length-max", type=int, default=32, help="Maximum sequence length") 244 | parser.add_argument("--img-path", type=str, default=None, help="Path to input image") 245 | args = parser.parse_args() 246 | main(args) 247 | -------------------------------------------------------------------------------- /examples/object_detection/README.md: -------------------------------------------------------------------------------- 1 | # ClipRCNN: CLIP guided zero-shot object detector. 2 | 3 |

4 | 5 |

6 | 7 | Tiny implementation of the object detector based on Selective Search region proposals and CLIP guided loss. 8 | Our detector supports both text and image user's prompts. 9 | 10 | ## Notice 11 | 12 | This is just toy implementation of the text-driven object detection pipeline. 13 | If you want to use text-driven bounding box filtering in your own object-detection pipeline, 14 | please refer to our library [pytorch_clip_bbox](https://github.com/bes-dev/pytorch_clip_bbox). 15 | pytorch clip bbox is powerful tool to integrate text driven approach to any object-detection pipelines. 16 | 17 | ## Usage 18 | 19 | ### Install requirements 20 | 21 | ```bash 22 | $ pip install -r requirements.txt 23 | ``` 24 | 25 | ### Detect objects 26 | 27 | ```bash 28 | $ python run.py --help 29 | 30 | usage: run.py [-h] [-i IMAGE] [--device DEVICE] [--text-prompt TEXT_PROMPT] [--image-prompt IMAGE_PROMPT] [--clip-type CLIP_TYPE] [--batch-size BATCH_SIZE] [--scale SCALE] [--sigma SIGMA] [--min-size MIN_SIZE] 31 | [--aspect-ratio ASPECT_RATIO] [--top-k TOP_K] [--output-image OUTPUT_IMAGE] 32 | 33 | optional arguments: 34 | -h, --help show this help message and exit 35 | -i IMAGE, --image IMAGE 36 | Input image. 37 | --device DEVICE inference device. 38 | --text-prompt TEXT_PROMPT 39 | Text prompt. 40 | --image-prompt IMAGE_PROMPT 41 | Image prompt. 42 | --clip-type CLIP_TYPE 43 | Type of CLIP model [clip, ruclip]. 44 | --batch-size BATCH_SIZE 45 | Batch size. 46 | --scale SCALE Scale (selective search). 47 | --sigma SIGMA Sigma (selective search). 48 | --min-size MIN_SIZE Minimum area of the region proposal (selective search). 49 | --aspect-ratio ASPECT_RATIO 50 | Aspect ratio (selective search). 51 | --top-k TOP_K top k predictions will be return. 52 | --output-image OUTPUT_IMAGE 53 | Output image name. 54 | ``` 55 | 56 | ```bash 57 | $ python run.py -i --text-prompt "red cup" --output-name output.png 58 | ``` 59 | -------------------------------------------------------------------------------- /examples/object_detection/clip_rcnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | from typing import List, Dict, Tuple, Optional 14 | import cv2 15 | import selectivesearch 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | from pytorch_clip_guided_loss import get_clip_guided_loss 20 | 21 | 22 | class ClipRCNN(nn.Module): 23 | """ Implementation of the CLIP guided object detection model. 24 | Model is based on Selective Search region proposals and CLIP 25 | guided loss to make text/image driven object detection. 26 | Arguments: 27 | scale (int): Free parameter. Higher means larger clusters in felzenszwalb segmentation. 28 | sigma (float): Width of Gaussian kernel for felzenszwalb segmentation. 29 | min_size (int): Minimum component size for felzenszwalb segmentation. 30 | aspect_ratio (Tuple[float, float]): valid range of aspect ratios for region proposals. 31 | clip_type (str): type of the CLIP model. 32 | batch_size (int): batch size. 33 | top_k (int): top k predictions will be return. 34 | """ 35 | def __init__( 36 | self, 37 | scale: int = 500, 38 | sigma: float = 0.9, 39 | min_size: float = 0.1, 40 | aspect_ratio: Tuple[float, float] = (0.5, 1.5), 41 | clip_type: str = "ruclip", 42 | batch_size: int = 128, 43 | top_k: int = 1 44 | ): 45 | super().__init__() 46 | # selective search parameters 47 | self.scale = scale 48 | self.sigma = sigma 49 | self.min_size = min_size 50 | self.aspect_ratio = aspect_ratio 51 | # inference params 52 | self.batch_size = batch_size 53 | # output params 54 | self.top_k = top_k 55 | # CLIP guided loss 56 | self.clip_loss = get_clip_guided_loss(clip_type, input_range=(0.0, 1.0)) 57 | self.input_size = self.clip_loss.image_processor[0].size 58 | # utils 59 | self.register_buffer("device_info", torch.tensor(0)) 60 | 61 | def add_prompt( 62 | self, 63 | image: Optional[torch.Tensor] = None, 64 | text: Optional[str] = None, 65 | weight: float = 1.0, 66 | label: Optional[str] = None, 67 | store_src: bool = True 68 | ) -> str: 69 | """Add prompt to loss function. 70 | Arguments: 71 | image (torch.Tensor): input image [Optional]. 72 | text (str): input text [Optional]. 73 | weight (float): importance of the prompt. 74 | label (str): label of the prompt [Optional]. 75 | store_src (bool): store source data of the prompt. 76 | Returns: 77 | label (src): label of the prompt. 78 | """ 79 | return self.clip_loss.add_prompt(image, text, weight, label, store_src) 80 | 81 | def clear_prompts(self) -> None: 82 | """Delete all available prompts.""" 83 | return self.clip_loss.clear_prompts() 84 | 85 | @torch.no_grad() 86 | def detect(self, img: np.array) -> List[Dict]: 87 | """ Detect objects on the input image using CLIP guided prompts. 88 | Argument: 89 | img (np.array): input image. 90 | Returns: 91 | outputs (List[Dict]): predicts in format: 92 | [{"rect": [x, y, w, h], "loss": loss_val}] 93 | """ 94 | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 95 | # generate proposals by selective search 96 | proposals = self._generate_proposals(img_rgb) 97 | if not len(proposals): 98 | return [] 99 | batch = self._prepare_batch(img_rgb, proposals).to(self.device_info.device) 100 | # predict CLIP loss 101 | loss = self._predict_clip_loss(batch) 102 | outputs = self._generate_output(proposals, loss) 103 | return outputs 104 | 105 | def _generate_proposals(self, img: np.array) -> List[Tuple[int, int, int, int]]: 106 | """ Generate region proposals using selective search algorithm. 107 | Argument: 108 | img (np.array): input image. 109 | Returns: 110 | proposals (List[Tuple[int, int, int, int]]): output proposals in format [(x, y, w, h)] 111 | """ 112 | min_size = int(img.shape[0] * img.shape[1] * self.min_size) 113 | # generate proposals 114 | img_lbl, regions = selectivesearch.selective_search( 115 | img, scale=self.scale, sigma=self.sigma, min_size=min_size 116 | ) 117 | # filter by aspect ratio 118 | proposals = [] 119 | for region in regions: 120 | x, y, w, h = region["rect"] 121 | aspect_ratio = float(w) / float(h) 122 | if aspect_ratio > self.aspect_ratio[0] and aspect_ratio < self.aspect_ratio[1]: 123 | proposals.append([x, y, w, h]) 124 | return proposals 125 | 126 | def _prepare_batch(self, img: np.array, proposals: List[Tuple[int, int, int, int]]) -> torch.Tensor: 127 | """ Crop region proposals and generate batch 128 | Argument: 129 | img (np.array): input image. 130 | proposals (List[Tuple[int, int, int, int]]): output proposals in format [(x, y, w, h)] 131 | Returns: 132 | batch (torch.Tensor): output batch (B, C, H, W). 133 | """ 134 | batch = [] 135 | for x, y, w, h in proposals: 136 | crop = cv2.resize(img[y:y+h, x:x+w], self.input_size) 137 | batch.append(torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0)) 138 | batch = torch.cat(batch, dim=0) 139 | # normalize batch 140 | batch = batch / 255.0 141 | return batch 142 | 143 | def _predict_clip_loss(self, batch_full: torch.Tensor) -> torch.Tensor: 144 | """ Predict CLIP loss for region proposals using user's prompts. 145 | Argument: 146 | batch_full (torch.Tensor): input batch (B, C, H, W). 147 | Returns: 148 | loss (torch.Tensor): output batch (B, ). 149 | """ 150 | loss = [] 151 | id_start = 0 152 | while id_start < batch_full.size(0): 153 | id_stop = min(id_start + self.batch_size, batch_full.size(0)) 154 | batch = batch_full[id_start:id_stop] 155 | loss.append(self.clip_loss.image_loss(image=batch, reduce=None)["loss"].cpu()) 156 | id_start = id_stop 157 | loss = torch.cat(loss, dim=0) 158 | return loss 159 | 160 | def _generate_output(self, proposals: List[Tuple[int, int, int, int]], loss: torch.Tensor) -> List[Dict]: 161 | """ Generate top_k predictions as an output of the model. 162 | Argument: 163 | proposals (List[Tuple[int, int, int, int]]): output proposals in format [(x, y, w, h)] 164 | loss (torch.Tensor): output batch (B, ). 165 | Returns: 166 | outputs (List[Dict]): predicts in format: 167 | [{"rect": [x, y, w, h], "loss": loss_val}] 168 | """ 169 | output = [] 170 | vals, ids = loss.sort() 171 | top_k = min(self.top_k, len(proposals)) 172 | for i in range(top_k): 173 | output.append({ 174 | "rect": proposals[ids[i]], 175 | "loss": vals[i] 176 | }) 177 | return output 178 | -------------------------------------------------------------------------------- /examples/object_detection/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | cv2 3 | pytorch_clip_guided_loss 4 | numpy 5 | selectivesearch 6 | -------------------------------------------------------------------------------- /examples/object_detection/resources/preds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bes-dev/pytorch_clip_guided_loss/5dc906ccf981b6647346dd46a53886b313a4de77/examples/object_detection/resources/preds.jpg -------------------------------------------------------------------------------- /examples/object_detection/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import argparse 14 | import cv2 15 | import torch 16 | from clip_rcnn import ClipRCNN 17 | 18 | 19 | def main(args): 20 | # build detector 21 | detector = ClipRCNN( 22 | scale=args.scale, 23 | sigma=args.sigma, 24 | min_size=args.min_size, 25 | aspect_ratio=[float(r) for r in args.aspect_ratio.split(",")], 26 | clip_type=args.clip_type, 27 | batch_size=args.batch_size, 28 | top_k=args.top_k 29 | ) 30 | # add prompts 31 | if args.text_prompt is not None: 32 | detector.add_prompt(text=args.text_prompt) 33 | if args.image_prompt is not None: 34 | image = cv2.cvtColor(cv2.imread(args.image_prompt), cv2.COLOR_BGR2RGB) 35 | image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) 36 | image = img / 255.0 37 | detector.add_prompt(image=image) 38 | image = cv2.imread(args.image) 39 | boxes = detector.detect(image) 40 | for box in boxes: 41 | x, y, w, h = box["rect"] 42 | cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 4) 43 | if args.output_image is None: 44 | cv2.imshow("image", images) 45 | cv2.waitKey() 46 | else: 47 | cv2.imwrite(args.output_image, image) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("-i", "--image", type=str, help="Input image.") 53 | parser.add_argument("--device", type=str, default="cuda:0", help="inference device.") 54 | parser.add_argument("--text-prompt", type=str, default=None, help="Text prompt.") 55 | parser.add_argument("--image-prompt", type=str, default=None, help="Image prompt.") 56 | parser.add_argument("--clip-type", type=str, default="ruclip", help="Type of CLIP model [clip, ruclip].") 57 | parser.add_argument("--batch-size", type=int, default=128, help="Batch size.") 58 | parser.add_argument("--scale", type=int, default=500, help="Scale (selective search).") 59 | parser.add_argument("--sigma", type=float, default=0.9, help="Sigma (selective search).") 60 | parser.add_argument("--min-size", type=float, default=0.05, help="Minimum area of the region proposal (selective search).") 61 | parser.add_argument("--aspect-ratio", type=str, default="0.5,1.5", help="Aspect ratio (selective search).") 62 | parser.add_argument("--top-k", type=int, default=1, help="top k predictions will be return.") 63 | parser.add_argument("--output-image", type=str, default=None, help="Output image name.") 64 | args = parser.parse_args() 65 | main(args) 66 | -------------------------------------------------------------------------------- /examples/vqgan_clip/README.md: -------------------------------------------------------------------------------- 1 | # VQGAN-CLIP Text-To-Image pipeline 2 | 3 | Tiny implementation of the Text-To-Image pipeline using VQVAE by SberAI and pytorch_clip_guided_loss library. 4 | 5 | ## Usage 6 | 7 | ### Install requirements 8 | 9 | ```bash 10 | $ pip install -r requirements.txt 11 | ``` 12 | 13 | ### Generate image from text 14 | 15 | ```bash 16 | $ python run.py --help 17 | 18 | usage: run.py [-h] [--device DEVICE] [-t TEXT] [--cfg CFG] [--clip-type CLIP_TYPE] [--output-size OUTPUT_SIZE] [--output-name OUTPUT_NAME] [--lr LR] [--n-steps N_STEPS] [--batch-size BATCH_SIZE] 19 | 20 | optional arguments: 21 | -h, --help show this help message and exit 22 | --device DEVICE inference device. 23 | -t TEXT, --text TEXT Text prompt. 24 | --cfg CFG Path to VQVAE config. 25 | --clip-type CLIP_TYPE 26 | Type of CLIP model [clip, ruclip]. 27 | --output-size OUTPUT_SIZE 28 | Size of the output image. 29 | --output-name OUTPUT_NAME 30 | Name of the output image. 31 | --lr LR Learning rate. 32 | --n-steps N_STEPS Number steps of optimization. 33 | --batch-size BATCH_SIZE 34 | Batch size. 35 | ``` 36 | 37 | ```bash 38 | $ python run.py --text "A painting in the style of Picasso" --output-name output.png 39 | ``` -------------------------------------------------------------------------------- /examples/vqgan_clip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bes-dev/pytorch_clip_guided_loss/5dc906ccf981b6647346dd46a53886b313a4de77/examples/vqgan_clip/__init__.py -------------------------------------------------------------------------------- /examples/vqgan_clip/configs/vqvae.yaml: -------------------------------------------------------------------------------- 1 | name: VQVAE 2 | description: vqgan.gumbelf8-sber.config.yml 3 | ckpt: 4 | repo_id: shonenkov/rudalle-utils 5 | filename: vqgan.gumbelf8-sber.model.ckpt 6 | params: 7 | model: 8 | base_learning_rate: 4.5e-06 9 | target: taming.models.vqgan.GumbelVQ 10 | params: 11 | kl_weight: 1.0e-08 12 | embed_dim: 256 13 | n_embed: 8192 14 | monitor: val/rec_loss 15 | temperature_scheduler_config: 16 | target: taming.lr_scheduler.LambdaWarmUpCosineScheduler 17 | params: 18 | warm_up_steps: 0 19 | max_decay_steps: 1000001 20 | lr_start: 0.9 21 | lr_max: 0.9 22 | lr_min: 1.0e-06 23 | ddconfig: 24 | double_z: false 25 | z_channels: 256 26 | resolution: 256 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: 31 | - 1 32 | - 1 33 | - 2 34 | - 4 35 | num_res_blocks: 2 36 | attn_resolutions: 37 | - 32 38 | dropout: 0.0 39 | lossconfig: 40 | target: taming.modules.losses.vqperceptual.DummyLoss 41 | -------------------------------------------------------------------------------- /examples/vqgan_clip/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | cv2 3 | pytorch_clip_guided_loss 4 | numpy 5 | taming-transformers 6 | huggingface_hub 7 | kornia 8 | tqdm 9 | -------------------------------------------------------------------------------- /examples/vqgan_clip/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import argparse 14 | import cv2 15 | import torch 16 | import torch.nn as nn 17 | from omegaconf import OmegaConf 18 | import kornia.augmentation as K 19 | from vqvae import GumbelVQ, ste_quantize 20 | from pytorch_clip_guided_loss import get_clip_guided_loss 21 | from tqdm.autonotebook import tqdm 22 | 23 | 24 | def main(args): 25 | # load model 26 | vqvae, vqvae_processor = GumbelVQ.from_pretrained(OmegaConf.load(args.cfg)) 27 | clip_guided_loss = get_clip_guided_loss(args.clip_type, input_range = (0, 1)) 28 | # model to inference device 29 | vqvae.to(args.device) 30 | clip_guided_loss.to(args.device) 31 | # initialize prompt 32 | clip_guided_loss.add_prompt(text=args.text) 33 | # initialize image 34 | n_toks = args.output_size // 2 ** (vqvae.decoder.num_resolutions - 1) 35 | z = vqvae.ids_to_embs( 36 | torch.randint(vqvae.quantize.n_embed, (1, n_toks * n_toks)).to(args.device) 37 | ).detach().requires_grad_(True) 38 | # initialize augmentations 39 | augs = nn.Sequential( 40 | K.RandomAffine(degrees=15, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True), 41 | K.RandomPerspective(distortion_scale=0.7, p=0.7), 42 | K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7), 43 | K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7) 44 | ) 45 | # initialize optimizer 46 | opt = torch.optim.AdamW([z], lr=args.lr) 47 | # initilize valid range for embeddings 48 | z_min = vqvae.quantize.embed.weight.min(dim=0).values[None, :, None, None] 49 | z_max = vqvae.quantize.embed.weight.max(dim=0).values[None, :, None, None] 50 | # start optimization 51 | iterator = tqdm(range(args.n_steps)) 52 | for i in iterator: 53 | opt.zero_grad() 54 | x = ste_quantize(z.movedim(1, 3), vqvae.quantize.embed.weight).movedim(3, 1) 55 | x = vqvae.embs_to_img(x).add(1).div(2).clamp(0, 1) 56 | x = x.repeat_interleave(args.batch_size, dim=0) 57 | x = augs(x) 58 | loss = clip_guided_loss.image_loss(image = x)["loss"] 59 | loss.backward() 60 | opt.step() 61 | with torch.inference_mode(): 62 | z.copy_(z.maximum(z_min).minimum(z_max)) 63 | iterator.set_description(f"loss: {loss.item()}") 64 | # save image 65 | z = ste_quantize(z.movedim(1, 3), vqvae.quantize.embed.weight).movedim(3, 1) 66 | cv2.imwrite(args.output_name, vqvae_processor.decode(vqvae.embs_to_img(z), rgb2bgr=True)[0]) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--device", type=str, default="cuda:0", help="inference device.") 72 | parser.add_argument("-t", "--text", type=str, default=None, help="Text prompt.") 73 | parser.add_argument("--cfg", type=str, default="configs/vqvae.yaml", help="Path to VQVAE config.") 74 | parser.add_argument("--clip-type", type=str, default="ruclip", help="Type of CLIP model [clip, ruclip].") 75 | parser.add_argument("--output-size", type=int, default=256, help="Size of the output image.") 76 | parser.add_argument("--output-name", type=str, default="output.png", help="Name of the output image.") 77 | parser.add_argument("--lr", type=float, default=0.1, help="Learning rate.") 78 | parser.add_argument("--n-steps", type=int, default=100, help="Number steps of optimization.") 79 | parser.add_argument("--batch-size", type=int, default=32, help="Batch size.") 80 | args = parser.parse_args() 81 | main(args) 82 | -------------------------------------------------------------------------------- /examples/vqgan_clip/vqvae/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | from .vqvae import GumbelVQ 14 | from .utils import ste_quantize 15 | -------------------------------------------------------------------------------- /examples/vqgan_clip/vqvae/image_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import typing 14 | import cv2 15 | import torch 16 | import numpy as np 17 | from .utils import img_to_tensor, tensor_to_img 18 | 19 | 20 | class OpenCVImageProcessor: 21 | """ Implementation of image encode/decode for VQVAE. """ 22 | @staticmethod 23 | def encode( 24 | img: np.array, 25 | size: typing.Optional[int] = None, 26 | normalize: bool = True, 27 | input_range: typing.Tuple[float, float] = (0.0, 255.0), 28 | bgr2rgb: bool = True 29 | ) -> torch.Tensor: 30 | """ Encode input image. 31 | Arguments: 32 | img (np.array): input image. 33 | size (typing.Optional[int]): target size of the image. 34 | normalize (bool): normalize input image. 35 | input_range (typing.Tuple[float, float]): input range. 36 | bgr2rgb (bool): convert input image from BGR to RGB. 37 | Returns: 38 | tensor (torch.Tensor): encoded image. 39 | """ 40 | if size is not None: 41 | img = cv2.resize(img, size) 42 | tensor = img_to_tensor(img, normalize, input_range, bgr2rgb) 43 | return tensor 44 | 45 | @staticmethod 46 | def decode( 47 | tensor: torch.Tensor, 48 | normalize: bool = True, 49 | input_range: typing.Tuple[float, float] = (-1, 1), 50 | to_numpy: bool = True, 51 | rgb2bgr: bool = True 52 | ) -> typing.List[np.array or torch.Tensor]: 53 | """ Encode input tensor (output of the VQVAE decoder). 54 | Arguments: 55 | tensor (torch.Tensor): input tensor. 56 | normalize (bool): normalize input tensor. 57 | input_range (typing.Tuple[float, float]): input range. 58 | to_numpy (bool): convert outputs to np.array. 59 | rgb2bgr (bool): convert output from RGB to BGR. 60 | Returns: 61 | imgs (typing.List[np.array or torch.Tensor]): decoded images. 62 | """ 63 | imgs = [] 64 | for i in range(tensor.size(0)): 65 | img = tensor_to_img(tensor[i], normalize, input_range, to_numpy, rgb2bgr) 66 | imgs.append(img) 67 | return imgs 68 | -------------------------------------------------------------------------------- /examples/vqgan_clip/vqvae/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import typing 14 | import cv2 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import numpy as np 19 | 20 | 21 | def tensor_to_img( 22 | tensor: torch.Tensor, 23 | normalize: bool = True, 24 | input_range: typing.Tuple[float, float] = (-1, 1), 25 | to_numpy: bool = True, 26 | rgb2bgr: bool = True 27 | ) -> np.array or torch.Tensor: 28 | """ Decode torch.Tensor to np.array. 29 | Arguments: 30 | tensor (torch.Tensor): input tensor. 31 | normalize (bool): normalize input tensor. 32 | input_range (typing.Tuple[float, float]): input range. 33 | to_numpy (bool): convert outputs to np.array. 34 | rgb2bgr (bool): convert output from RGB to BGR. 35 | Returns: 36 | img (np.array or torch.Tensor): decoded image. 37 | """ 38 | if normalize: 39 | tensor = torch.clamp(tensor, min=input_range[0], max=input_range[1]) 40 | tensor = (tensor - input_range[0]) / (input_range[1] - input_range[0] + 1e-5) 41 | img = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0) 42 | if to_numpy: 43 | img = img.to('cpu', torch.uint8).numpy() 44 | if rgb2bgr: 45 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 46 | return img 47 | 48 | 49 | def img_to_tensor( 50 | img: np.array, 51 | normalize: bool = True, 52 | input_range: typing.Tuple[float, float] = (0.0, 255.0), 53 | bgr2rgb: bool = True 54 | ) -> torch.Tensor: 55 | """ Encode np.array to torch.Tensor. 56 | Arguments: 57 | img (np.array): input image. 58 | size (typing.Optional[int]): target size of the image. 59 | normalize (bool): normalize input image. 60 | input_range (typing.Tuple[float, float]): input range. 61 | bgr2rgb (bool): convert input image from BGR to RGB. 62 | Returns: 63 | tensor (torch.Tensor): encoded image. 64 | """ 65 | if bgr2rgb: 66 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 67 | tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(dtype=torch.float32) 68 | if normalize: 69 | tensor = torch.clamp(tensor, min=input_range[0], max=input_range[1]) 70 | tensor = (tensor - input_range[0]) / (input_range[1] - input_range[0] + 1e-5) 71 | tensor = 2.0 * tensor - 1.0 72 | return tensor 73 | 74 | 75 | class STEQuantize(torch.autograd.Function): 76 | """ Quantize VQVAE embeddings to VQVAE codebook with 77 | gradients in style of Straight-Through Estimators. 78 | """ 79 | @staticmethod 80 | def forward(ctx, embs: torch.Tensor, codebook: torch.Tensor) -> torch.Tensor: 81 | """ Forward path. 82 | Arguments: 83 | embs (torch.Tensor): input embeddings. 84 | codebook (torch.Tensor): VQVAE codebook. 85 | Returns: 86 | embs_q (torch.Tensor): quantized embeddings 87 | """ 88 | d = embs.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * embs @ codebook.T 89 | indices = d.argmin(-1) 90 | embs_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 91 | return embs_q 92 | 93 | @staticmethod 94 | def backward(ctx, grad_in: torch.Tensor) -> typing.Tuple[torch.Tensor, None]: 95 | """ Backward path like Straight-Through Estimators. 96 | Arguments: 97 | grad_in (torch.Tensor): input gradients. 98 | Returns: 99 | grad_out (torch.Tensor): STE gradients. 100 | """ 101 | return grad_in, None 102 | 103 | 104 | def ste_quantize(x: torch.Tensor, codebook: torch.tensor) -> torch.Tensor: 105 | """ Quantize VQVAE embeddings to VQVAE codebook with 106 | gradients in style of Straight-Through Estimators. 107 | Arguments: 108 | embs (torch.Tensor): input embeddings. 109 | codebook (torch.Tensor): VQVAE codebook. 110 | Returns: 111 | embs_q (torch.Tensor): quantized embeddings 112 | """ 113 | return STEQuantize.apply(x, codebook) 114 | -------------------------------------------------------------------------------- /examples/vqgan_clip/vqvae/vqvae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from math import sqrt, log 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import einsum 8 | from einops import rearrange 9 | from taming.modules.diffusionmodules.model import Encoder, Decoder 10 | from omegaconf import OmegaConf 11 | from huggingface_hub import hf_hub_download 12 | from .image_processor import OpenCVImageProcessor 13 | 14 | 15 | class GumbelQuantize(nn.Module): 16 | """ 17 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 18 | Gumbel Softmax trick quantizer 19 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 20 | https://arxiv.org/abs/1611.01144 21 | """ 22 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, 23 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True): 24 | super().__init__() 25 | self.embedding_dim = embedding_dim 26 | self.n_embed = n_embed 27 | self.straight_through = straight_through 28 | self.temperature = temp_init 29 | self.kl_weight = kl_weight 30 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 31 | self.embed = nn.Embedding(self.n_embed, self.embedding_dim) 32 | self.use_vqinterface = use_vqinterface 33 | 34 | def forward(self, z, temp=None, return_logits=False): 35 | hard = self.straight_through if self.training else True 36 | temp = self.temperature if temp is None else temp 37 | logits = self.proj(z) 38 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 39 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) 40 | # + kl divergence to the prior loss 41 | qy = F.softmax(logits, dim=1) 42 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 43 | ind = soft_one_hot.argmax(dim=1) 44 | if self.use_vqinterface: 45 | if return_logits: 46 | return z_q, diff, (None, None, ind), logits 47 | return z_q, diff, (None, None, ind) 48 | return z_q, diff, ind 49 | 50 | 51 | class GumbelVQ(nn.Module): 52 | def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8, **kwargs): 53 | super().__init__() 54 | z_channels = ddconfig['z_channels'] 55 | self.encoder = Encoder(**ddconfig) 56 | self.decoder = Decoder(**ddconfig) 57 | self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0) 58 | self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) 59 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) 60 | 61 | def encode(self, x): 62 | h = self.encoder(x) 63 | h = self.quant_conv(h) 64 | quant, emb_loss, info = self.quantize(h) 65 | return quant, emb_loss, info 66 | 67 | def decode(self, quant): 68 | quant = self.post_quant_conv(quant) 69 | dec = self.decoder(quant) 70 | return dec 71 | 72 | def img_to_ids(self, img): 73 | _, _, [_, _, indices] = self.encode(img) 74 | return rearrange(indices, 'b h w -> b (h w)') 75 | 76 | def ids_to_embs(self, ids): 77 | b, n = ids.shape 78 | one_hot = F.one_hot(ids, num_classes=self.quantize.n_embed).float() 79 | embs = (one_hot @ self.quantize.embed.weight) 80 | embs = rearrange(embs, 'b (h w) c -> b c h w', h = int(sqrt(n))) 81 | return embs 82 | 83 | def embs_to_img(self, embs): 84 | img = self.decode(embs) 85 | return img 86 | 87 | @classmethod 88 | def from_pretrained(cls, cfg): 89 | print(f"[{cls.__name__}]: create model") 90 | model = cls(**cfg.params.model.params) 91 | print(f"[{cls.__name__}]: load checkpoint {cfg.ckpt}") 92 | model.load_state_dict( 93 | torch.load(hf_hub_download(**cfg.ckpt), map_location="cpu")["state_dict"], 94 | strict=False 95 | ) 96 | print(f"[{cls.__name__}]: create processor") 97 | model_preprocess = OpenCVImageProcessor() 98 | return model, model_preprocess 99 | -------------------------------------------------------------------------------- /pytorch_clip_guided_loss/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import typing 14 | import torch.nn as nn 15 | from omegaconf import OmegaConf 16 | from .utils import res_path 17 | from .clip_guided_loss import CLIPGuidedLoss 18 | from pytorch_clip import get_clip_model, get_models_list 19 | 20 | 21 | def get_clip_guided_loss( 22 | clip_type: str = "clip", 23 | input_range: typing.Tuple[float, float] = (-1.0, 1.0), 24 | cache_dir: str = "/tmp/" 25 | ) -> nn.Module: 26 | """get CLIPGuidedloss model. 27 | Arguments: 28 | clip_typle (str): type of the CLIP model ("clip" or "ruclip"). 29 | input_range (tuple[float, float]): input range. 30 | cache_dir (str): path to cache dir. 31 | Returns: 32 | model (nn.Module): CLIPGuidedloss model. 33 | """ 34 | clip, tokenizer, transforms, cfg = get_clip_model(clip_type, input_range, cache_dir) 35 | return CLIPGuidedLoss(clip, tokenizer, transforms) 36 | -------------------------------------------------------------------------------- /pytorch_clip_guided_loss/clip_guided_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import typing 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | # clip model 18 | from pytorch_clip import get_clip_model 19 | from pytorch_clip.processor.text_processor import TextProcessor 20 | # utils 21 | from omegaconf import OmegaConf 22 | 23 | 24 | class CLIPPrompt(nn.Module): 25 | """ Implementation of CLIP prompt 26 | Arguments: 27 | embed (torch.Tensor): input embedding. 28 | weight (float): importance of the prompt. 29 | src (torch.Tensor or str): source data of the prompt. 30 | """ 31 | def __init__( 32 | self, 33 | embed: torch.Tensor, 34 | weight: float, 35 | src: typing.Optional[torch.Tensor or str] = None 36 | ): 37 | super().__init__() 38 | self.register_buffer("embed", embed) 39 | self.register_buffer("weight", torch.as_tensor(weight)) 40 | if isinstance(src, torch.Tensor): 41 | self.register_buffer("src", src) 42 | else: 43 | self.src = src 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | """Compute spherical distance loss between prompt and input embedding. 47 | Arguments: 48 | x (torch.Tensor): input embedding. 49 | Returns: 50 | loss (torch.Tensor): output spherical loss. 51 | """ 52 | return self.weight * x.sub(self.embed).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 53 | 54 | 55 | class CLIPGuidedLoss(nn.Module): 56 | """ Implementation of CLIP guided loss function. 57 | Arguments: 58 | model (nn.Module): CLIP model. 59 | text_processor (TextProcessor): text processor. 60 | image_processor (nn.Module): image processor. 61 | """ 62 | def __init__(self, model: nn.Module, text_processor: TextProcessor, image_processor: nn.Module): 63 | super().__init__() 64 | # clip model 65 | self.model = model 66 | self.text_processor = text_processor 67 | self.image_processor = image_processor 68 | # prompts 69 | self.prompts = nn.ModuleDict() 70 | # device info 71 | self.register_buffer("device_info", torch.tensor(1)) 72 | 73 | def add_prompt( 74 | self, 75 | image: typing.Optional[torch.Tensor] = None, 76 | text: typing.Optional[str] = None, 77 | weight: float = 1.0, 78 | label: typing.Optional[str] = None, 79 | store_src: bool = True 80 | ) -> str: 81 | """Add prompt to loss function. 82 | Arguments: 83 | image (torch.Tensor): input image [Optional]. 84 | text (str): input text [Optional]. 85 | weight (float): importance of the prompt. 86 | label (str): label of the prompt [Optional]. 87 | store_src (bool): store source data of the prompt. 88 | Returns: 89 | label (src): label of the prompt. 90 | """ 91 | if text is None and image is None: 92 | return 93 | embed, src = self._get_embed(image, text) 94 | if label is None: 95 | label = str(len(self.prompts)) 96 | self.prompts[label] = CLIPPrompt( 97 | embed = embed.detach(), 98 | weight = weight, 99 | src = src if store_src else None 100 | ).to(self.device_info.device) 101 | return label 102 | 103 | def delete_prompt(self, label: typing.Optional[str] = None) -> None: 104 | """Add prompt to loss function. 105 | Arguments: 106 | label (str): label of the prompt to delete [Optional]. 107 | """ 108 | if label in self.prompts: 109 | self.prompts.pop(label) 110 | 111 | def clear_prompts(self) -> None: 112 | """Delete all available prompts.""" 113 | self.prompts.clear() 114 | 115 | def get_prompts_list(self) -> typing.List[str]: 116 | """Get list of all available prompts. 117 | Returns: 118 | prompts (list): list of prompts labels. 119 | """ 120 | return list(self.prompts.keys()) 121 | 122 | def get_prompt(self, label: str) -> typing.Optional[CLIPPrompt]: 123 | """Get prompt if available. 124 | Arguments: 125 | label (str): label of the prompt [Optional]. 126 | Returns: 127 | prompt (CLIPPrompt or None): prompt [Optional]. 128 | """ 129 | if label in self.prompts: 130 | return self.prompts[label] 131 | 132 | def forward(self, embed: typing.Optional[torch.Tensor], reduce: str = "mean") -> torch.Tensor: 133 | """Compute CLIP guided loss between input image/text and all available prompts. 134 | Arguments: 135 | embed (torch.Tensor): input embedding. 136 | reduce (str): reduce mode ("mean", "sum" or None) 137 | Returns: 138 | loss (torch.Tensor): CLIP guided loss. 139 | """ 140 | assert reduce in ["mean", "sum", None] 141 | loss = {} 142 | for key, prompt in self.prompts.items(): 143 | loss[key] = prompt(embed) 144 | if reduce is not None: 145 | loss[key] = loss[key].mean() if reduce == "mean" else loss[key].sum() 146 | loss["loss"] = sum(loss.values()) if len(loss) else 0 147 | return loss 148 | 149 | def image_loss(self, image: typing.Optional[torch.Tensor], reduce: str="mean") -> torch.Tensor: 150 | """Compute CLIP guided loss between input image and all available prompts. 151 | Arguments: 152 | image (torch.Tensor): input image. 153 | reduce (str): reduce mode ("mean", "sum" or None) 154 | Returns: 155 | loss (torch.Tensor): CLIP guided loss. 156 | """ 157 | embed, _ = self._get_embed(image=image) 158 | return self(embed, reduce=reduce) 159 | 160 | def text_loss(self, 161 | input_ids: typing.Optional[torch.Tensor], 162 | embed: typing.Optional[torch.Tensor], 163 | attention_mask: typing.Optional[torch.Tensor] = None, 164 | reduce: str = "mean" 165 | ) -> torch.Tensor: 166 | """Compute CLIP guided loss between input text and all available prompts. 167 | Arguments: 168 | input_ids (torch.Tensor): input text tokens indices. 169 | embed (torch.Tensor): input text embeddings. 170 | attention_mask (torch.Tensor): attention mask [Optional]. 171 | reduce (str): reduce mode ("mean", "sum" or None) 172 | Returns: 173 | loss (torch.Tensor): CLIP guided loss. 174 | """ 175 | embed = self.model.get_text_features( 176 | input_ids = input_ids, 177 | attention_mask = attention_mask, 178 | inputs_embeds = embed 179 | ) 180 | return self(embed, reduce=reduce) 181 | 182 | def _get_embed( 183 | self, 184 | image: typing.Optional[torch.Tensor] = None, 185 | text: typing.Optional[typing.Dict[str, torch.Tensor] or str] = None 186 | ) -> typing.Tuple[torch.Tensor, torch.Tensor or typing.Dict[str, torch.Tensor] or str]: 187 | """Compute CLIP embedding for input data. 188 | Arguments: 189 | image (torch.Tensor): input image [Optional]. 190 | text (str): input text [Optional]. 191 | Returns: 192 | embed (torch.Tensor): output embedding. 193 | """ 194 | if image is None and text is None: 195 | raise ValueError("You have to specify either text or images. Both cannot be none.") 196 | if image is not None: 197 | embed = self.model.get_image_features(self.image_processor(image)) 198 | src = image 199 | else: 200 | if isinstance(text, str): 201 | batch = self.text_processor.encode(text, return_mask=True) 202 | batch["input_ids"] = batch["input_ids"].view(1, -1).to(self.device_info.device) 203 | if "attention_mask" in batch: 204 | batch["attention_mask"] = batch["attention_mask"].view(1, -1).to(self.device_info.device) 205 | else: 206 | batch = text 207 | embed = self.model.get_text_features(**batch) 208 | src = text 209 | embed = F.normalize(embed, dim=-1) 210 | return embed, src 211 | -------------------------------------------------------------------------------- /pytorch_clip_guided_loss/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import os 14 | 15 | 16 | def get_module_path(): 17 | """ Get module path 18 | Returns: 19 | path (str): path to current module. 20 | """ 21 | file_path = os.path.abspath(__file__) 22 | module_path = os.path.dirname(file_path) 23 | return module_path 24 | 25 | 26 | def res_path(path): 27 | """ Resource path 28 | Arguments: 29 | path (str): related path from module dir to some resources. 30 | Returns: 31 | path (str): absolute path to module dir. 32 | """ 33 | return os.path.join(get_module_path(), path) 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wheel 2 | omegaconf 3 | torch 4 | transformers 5 | huggingface_hub 6 | youtokentome 7 | googletrans 8 | numpy 9 | pytorch_clip 10 | -------------------------------------------------------------------------------- /resources/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bes-dev/pytorch_clip_guided_loss/5dc906ccf981b6647346dd46a53886b313a4de77/resources/preview.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 by Sergei Belousov 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | from setuptools import setup, find_packages 17 | 18 | readme = open('README.md').read() 19 | 20 | VERSION = '2021.12.25.0' 21 | 22 | requirements = [ 23 | 'wheel', 24 | 'Cython', 25 | 'cython', 26 | 'omegaconf', 27 | 'torch', 28 | 'transformers', 29 | 'huggingface_hub', 30 | 'youtokentome', 31 | 'googletrans', 32 | 'numpy', 33 | 'pytorch_clip' 34 | ] 35 | 36 | setup( 37 | # Metadata 38 | name='pytorch_clip_guided_loss', 39 | version=VERSION, 40 | author='Sergei Belousov aka BeS', 41 | author_email='sergei.o.belousov@gmail.com', 42 | description='Pytorch implementation of the CLIP guided loss.', 43 | long_description=readme, 44 | long_description_content_type='text/markdown', 45 | 46 | # Package info 47 | packages=find_packages(exclude=('*test*',)), 48 | 49 | # 50 | zip_safe=True, 51 | install_requires=requirements, 52 | 53 | # Classifiers 54 | classifiers=[ 55 | 'Programming Language :: Python :: 3', 56 | ], 57 | 58 | 59 | # install .json configs 60 | package_data={ 61 | # "pytorch_clip_guided_loss": ["configs/*.yml"] 62 | }, 63 | include_package_data=True 64 | ) 65 | --------------------------------------------------------------------------------