├── .gitignore
├── LICENSE
├── README.md
├── examples
├── image_captioning
│ ├── README.md
│ ├── requirements.txt
│ └── run.py
├── object_detection
│ ├── README.md
│ ├── clip_rcnn.py
│ ├── requirements.txt
│ ├── resources
│ │ └── preds.jpg
│ └── run.py
└── vqgan_clip
│ ├── README.md
│ ├── __init__.py
│ ├── configs
│ └── vqvae.yaml
│ ├── requirements.txt
│ ├── run.py
│ └── vqvae
│ ├── __init__.py
│ ├── image_processor.py
│ ├── utils.py
│ └── vqvae.py
├── pytorch_clip_guided_loss
├── __init__.py
├── clip_guided_loss.py
└── utils.py
├── requirements.txt
├── resources
└── preview.png
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch_clip_guided_loss: Pytorch implementation of the CLIP guided loss for Text-To-Image, Image-To-Image, or Image-To-Text generation.
2 |
3 | A simple library that implements CLIP guided loss in PyTorch.
4 |
5 |
6 |
7 |
8 |
9 | [](https://pepy.tech/project/pytorch_clip_guided_loss)
10 | [](https://pepy.tech/project/pytorch_clip_guided_loss)
11 | [](https://pepy.tech/project/pytorch_clip_guided_loss)
12 |
13 |
14 | ## Install package
15 |
16 | ```bash
17 | pip install pytorch_clip_guided_loss
18 | ```
19 |
20 | ## Install the latest version
21 |
22 | ```bash
23 | pip install --upgrade git+https://github.com/bes-dev/pytorch_clip_guided_loss.git
24 | ```
25 |
26 | ## Features
27 | - The library supports multiple prompts (images or texts) as targets for optimization.
28 | - The library automatically detects the language of the input text, and multilingual translate it via google translate.
29 | - The library supports the original CLIP model by OpenAI and ruCLIP model by SberAI.
30 |
31 | ## Usage
32 |
33 | ### Simple code
34 |
35 | ```python
36 | import torch
37 | from pytorch_clip_guided_loss import get_clip_guided_loss
38 |
39 | loss_fn = get_clip_guided_loss(clip_type="ruclip", input_range = (-1, 1)).eval().requires_grad_(False)
40 | # text prompt
41 | loss_fn.add_prompt(text="text description of the what we would like to generate")
42 | # image prompt
43 | loss_fn.add_prompt(image=torch.randn(1, 3, 224, 224))
44 |
45 | # variable
46 | var = torch.randn(1, 3, 224, 224).requires_grad_(True)
47 | loss = loss_fn.image_loss(image=var)["loss"]
48 | loss.backward()
49 | print(var.grad)
50 | ```
51 |
52 | ### VQGAN-CLIP
53 |
54 | We provide our tiny implementation of the VQGAN-CLIP pipeline for image generation as an example of the usage of our library.
55 | To start using our implementation of the VQGAN-CLIP please follow by [documentation](examples/vqgan_clip).
56 |
57 | ### Zero-shot Object Detection
58 |
59 | We provide our tiny implementation of the object detector based on Selective Search region proposals and CLIP guided loss.
60 | To start using our implementation of the ClipRCNN please follow by [documentation](examples/object_detection).
--------------------------------------------------------------------------------
/examples/image_captioning/README.md:
--------------------------------------------------------------------------------
1 | # Simple CLIP guided image captioning
2 |
3 | Simple gradient based CLIP guided image captioning.
4 |
5 | ## Usage
6 |
7 | ### Install requirements
8 |
9 | ```bash
10 | $ pip install -r requirements.txt
11 | ```
12 |
13 | ### Generate image caption
14 |
15 | ```bash
16 | $ python run.py --help
17 |
18 | usage: run.py [-h] [--device DEVICE] [--clip-type CLIP_TYPE] [--lr LR] [--n-steps N_STEPS] [--length-min LENGTH_MIN] [--length-max LENGTH_MAX] [--img-path IMG_PATH]
19 |
20 | optional arguments:
21 | -h, --help show this help message and exit
22 | --device DEVICE inference device.
23 | --clip-type CLIP_TYPE
24 | Type of CLIP model [clip, ruclip].
25 | --lr LR Learning rate.
26 | --n-steps N_STEPS Number steps of optimization.
27 | --length-min LENGTH_MIN
28 | Minimum sequence length
29 | --length-max LENGTH_MAX
30 | Maximum sequence length
31 | --img-path IMG_PATH Path to input image
32 | ```
33 |
34 | ```bash
35 | $ python run.py --img-path --length-min 1 --length-max 32 --clip-type ruclip
36 | ```
--------------------------------------------------------------------------------
/examples/image_captioning/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | cv2
3 | pytorch_clip
4 | pytorch_clip_guided_loss
5 | numpy
6 | tqdm
7 |
--------------------------------------------------------------------------------
/examples/image_captioning/run.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import argparse
14 | import typing
15 | import cv2
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from pytorch_clip.processor.text_processor import TextProcessor
20 | from pytorch_clip_guided_loss import get_clip_guided_loss
21 | from tqdm.autonotebook import tqdm
22 | # image transforms
23 | import albumentations as A
24 | from albumentations.pytorch import ToTensorV2
25 |
26 |
27 | class STEQuantize(torch.autograd.Function):
28 | """ Quantize embeddings to codebook with
29 | gradients in style of Straight-Through Estimators.
30 | """
31 | @staticmethod
32 | def forward(ctx, embs: torch.Tensor, codebook: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]:
33 | """ Forward path.
34 | Arguments:
35 | embs (torch.Tensor): input embeddings.
36 | codebook (torch.Tensor): codebook.
37 | Returns:
38 | embs_q (torch.Tensor): quantized embeddings
39 | """
40 | d = embs.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * embs @ codebook.T
41 | indices = d.argmin(-1)
42 | embs_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
43 | return embs_q, indices
44 |
45 | @staticmethod
46 | def backward(ctx, grad_in: torch.Tensor, grad_ids: torch.Tensor) -> typing.Tuple[torch.Tensor, None]:
47 | """ Backward path like Straight-Through Estimators.
48 | Arguments:
49 | grad_in (torch.Tensor): input gradients.
50 | Returns:
51 | grad_out (torch.Tensor): STE gradients.
52 | """
53 | return grad_in, None
54 |
55 | def ste_quantize(x: torch.Tensor, codebook: torch.tensor) -> torch.Tensor:
56 | """ Quantize embeddings to codebook with
57 | gradients in style of Straight-Through Estimators.
58 | Arguments:
59 | embs (torch.Tensor): input embeddings.
60 | codebook (torch.Tensor): codebook.
61 | Returns:
62 | embs_q (torch.Tensor): quantized embeddings
63 | """
64 | return STEQuantize.apply(x, codebook)
65 |
66 |
67 | class MaskedGrad(torch.autograd.Function):
68 | """ Apply masked gradients
69 | """
70 | @staticmethod
71 | def forward(ctx, var: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
72 | """ Forward path.
73 | Arguments:
74 | var (torch.Tensor): input variable.
75 | mask (torch.Tensor): mask of the gradient.
76 | Returns:
77 | var (torch.Tensor): input variable.
78 | """
79 | ctx.save_for_backward(mask)
80 | return var
81 |
82 | @staticmethod
83 | def backward(ctx, grad_in: torch.Tensor) -> typing.Tuple[torch.Tensor, None]:
84 | """ Backward path returns masked gradient for variable.
85 | Arguments:
86 | grad_in (torch.Tensor): input gradients.
87 | Returns:
88 | grad_out (torch.Tensor): masked gradients.
89 | """
90 | mask, = ctx.saved_tensors
91 | grad_out = grad_in * mask
92 | return grad_out, None
93 |
94 |
95 | def masked_grad(var: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
96 | """ Apply masked gradients.
97 | Arguments:
98 | var (torch.Tensor): input variable.
99 | mask (torch.Tensor): mask of the gradient.
100 | Returns:
101 | var (torch.Tensor): input variable.
102 | """
103 | return MaskedGrad.apply(var, mask)
104 |
105 |
106 | def init_params(
107 | tokenizer: TextProcessor,
108 | length_min: int,
109 | length_max: int,
110 | dictionary: nn.Module,
111 | device: str
112 | ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
113 | """ Initialize random embeddings.
114 | Arguments:
115 | tokenizer (TextProcessor): text tokenizer.
116 | length_min (int): minimum length of the target text.
117 | length_max (int): maximum length of the target text.
118 | dictionary (nn.Module): dictionary of the embeddings.
119 | device (str): target device.
120 | Returns:
121 | embeds (torch.Tensor): random embeddings.
122 | attention_mask (torch.Tensor): attention mask.
123 | gradient_mask (torch.Tensor): gradient mask.
124 | """
125 | length_min = min(length_min, tokenizer.get_max_length() - 2)
126 | length_max = min(length_max, tokenizer.get_max_length() - 2)
127 | vocab_size = tokenizer.vocab_size()
128 | bos_emb = dictionary(torch.LongTensor([[tokenizer.bos_id]]).to(device))
129 | eos_emb = dictionary(torch.LongTensor([[tokenizer.eos_id]]).to(device))
130 | pad_emb = dictionary(torch.LongTensor([[tokenizer.pad_id]]).to(device))
131 |
132 | embeds = []
133 | attention_masks = []
134 | grad_masks = []
135 | for l in range(length_min, length_max + 1):
136 | ids = torch.randint(0, vocab_size, (1, l), device=device)
137 | embed = dictionary(ids)
138 | embed = torch.cat([bos_emb, embed, eos_emb, pad_emb.repeat_interleave(tokenizer.get_max_length() - 2 - l, dim=1)], dim=1)
139 | embeds.append(embed)
140 | attention_masks.append(
141 | torch.LongTensor([1] * (l + 2) + [0] * (tokenizer.get_max_length() - 2 - l)).unsqueeze(0)
142 | )
143 | grad_masks.append(
144 | torch.LongTensor([0] + [1] * l + [0] * (tokenizer.get_max_length() - 1 - l)).unsqueeze(0)
145 | )
146 | embeds = torch.cat(embeds, dim=0).detach()
147 | embeds.requires_grad = True
148 | attention_masks = torch.cat(attention_masks, dim=0).to(device)
149 | grad_masks = torch.cat(grad_masks, dim=0).unsqueeze(-1).to(device)
150 | return embeds, attention_masks, grad_masks
151 |
152 |
153 | def remove_repeats(strings: typing.List[str]) -> typing.List[str]:
154 | """ remove repeated words and sentences
155 | Arguments:
156 | strings (typing.List[str]): input strings.
157 | Returns:
158 | strings (typing.List[str]): output strings.
159 | """
160 | out = []
161 | for s in strings:
162 | words = s.split()
163 | if len(words):
164 | output_words = [words[0]]
165 | for i in range(1, len(words)):
166 | if output_words[-1] != words[i]:
167 | output_words.append(words[i])
168 | s_out = " ".join(output_words)
169 | if not s_out in out:
170 | out.append(s_out)
171 | return out
172 |
173 |
174 | def main(args):
175 | # load model
176 | clip_guided_loss = get_clip_guided_loss(args.clip_type, input_range = (0, 1))
177 | img_transforms = A.Compose([
178 | ToTensorV2()
179 | ])
180 | tokenizer = clip_guided_loss.text_processor
181 | dictionary = clip_guided_loss.model.get_text_dictionary()
182 | # model to inference device
183 | clip_guided_loss.to(args.device)
184 | # initialize prompt
185 | image = (img_transforms(image=cv2.imread(args.img_path))["image"].unsqueeze(0) / 255.0).to(args.device)
186 | clip_guided_loss.add_prompt(image=image)
187 | # initialize text
188 | embeds, attention_mask, grad_mask = init_params(
189 | tokenizer,
190 | args.length_min,
191 | args.length_max,
192 | dictionary,
193 | args.device
194 | )
195 | # initilize valid range for embeddings
196 | range_min = dictionary.weight.min(dim=0).values[None, None, :]
197 | range_max = dictionary.weight.max(dim=0).values[None, None, :]
198 | # initialize optimizer
199 | opt = torch.optim.Adam([embeds], lr=args.lr)
200 | # start optimization
201 | iterator = tqdm(range(args.n_steps))
202 | for i in iterator:
203 | opt.zero_grad()
204 | x = masked_grad(embeds, grad_mask)
205 | x, ids = ste_quantize(x, dictionary.weight)
206 | loss = clip_guided_loss.text_loss(
207 | input_ids=ids,
208 | attention_mask=attention_mask,
209 | embed=x
210 | )["loss"]
211 | loss.backward()
212 | opt.step()
213 | with torch.inference_mode():
214 | embeds.copy_(embeds.maximum(range_min).minimum(range_max))
215 | iterator.set_description(f"loss: {loss.item()}")
216 | # print outputs
217 | x, ids = ste_quantize(embeds, dictionary.weight)
218 | strings = remove_repeats(tokenizer.decode(ids))
219 | input_ids, attention_mask = [], []
220 | for s in strings:
221 | out = tokenizer.encode(s, return_mask=True)
222 | input_ids.append(out["input_ids"])
223 | attention_mask.append(out["attention_mask"])
224 | input_ids = torch.cat(input_ids, dim=0).to(args.device)
225 | attention_mask = torch.cat(attention_mask, dim=0).to(args.device)
226 | embeds = dictionary(input_ids)
227 | loss = clip_guided_loss.text_loss(
228 | input_ids=input_ids,
229 | attention_mask=attention_mask,
230 | embed=embeds,
231 | reduce=None
232 | )["loss"]
233 | print(f"best caption: {strings[loss.argmin()]}")
234 |
235 |
236 | if __name__ == "__main__":
237 | parser = argparse.ArgumentParser()
238 | parser.add_argument("--device", type=str, default="cuda:0", help="inference device.")
239 | parser.add_argument("--clip-type", type=str, default="ruclip", help="Type of CLIP model [clip, ruclip].")
240 | parser.add_argument("--lr", type=float, default=0.1, help="Learning rate.")
241 | parser.add_argument("--n-steps", type=int, default=100, help="Number steps of optimization.")
242 | parser.add_argument("--length-min", type=int, default=10, help="Minimum sequence length")
243 | parser.add_argument("--length-max", type=int, default=32, help="Maximum sequence length")
244 | parser.add_argument("--img-path", type=str, default=None, help="Path to input image")
245 | args = parser.parse_args()
246 | main(args)
247 |
--------------------------------------------------------------------------------
/examples/object_detection/README.md:
--------------------------------------------------------------------------------
1 | # ClipRCNN: CLIP guided zero-shot object detector.
2 |
3 |
4 |
5 |
6 |
7 | Tiny implementation of the object detector based on Selective Search region proposals and CLIP guided loss.
8 | Our detector supports both text and image user's prompts.
9 |
10 | ## Notice
11 |
12 | This is just toy implementation of the text-driven object detection pipeline.
13 | If you want to use text-driven bounding box filtering in your own object-detection pipeline,
14 | please refer to our library [pytorch_clip_bbox](https://github.com/bes-dev/pytorch_clip_bbox).
15 | pytorch clip bbox is powerful tool to integrate text driven approach to any object-detection pipelines.
16 |
17 | ## Usage
18 |
19 | ### Install requirements
20 |
21 | ```bash
22 | $ pip install -r requirements.txt
23 | ```
24 |
25 | ### Detect objects
26 |
27 | ```bash
28 | $ python run.py --help
29 |
30 | usage: run.py [-h] [-i IMAGE] [--device DEVICE] [--text-prompt TEXT_PROMPT] [--image-prompt IMAGE_PROMPT] [--clip-type CLIP_TYPE] [--batch-size BATCH_SIZE] [--scale SCALE] [--sigma SIGMA] [--min-size MIN_SIZE]
31 | [--aspect-ratio ASPECT_RATIO] [--top-k TOP_K] [--output-image OUTPUT_IMAGE]
32 |
33 | optional arguments:
34 | -h, --help show this help message and exit
35 | -i IMAGE, --image IMAGE
36 | Input image.
37 | --device DEVICE inference device.
38 | --text-prompt TEXT_PROMPT
39 | Text prompt.
40 | --image-prompt IMAGE_PROMPT
41 | Image prompt.
42 | --clip-type CLIP_TYPE
43 | Type of CLIP model [clip, ruclip].
44 | --batch-size BATCH_SIZE
45 | Batch size.
46 | --scale SCALE Scale (selective search).
47 | --sigma SIGMA Sigma (selective search).
48 | --min-size MIN_SIZE Minimum area of the region proposal (selective search).
49 | --aspect-ratio ASPECT_RATIO
50 | Aspect ratio (selective search).
51 | --top-k TOP_K top k predictions will be return.
52 | --output-image OUTPUT_IMAGE
53 | Output image name.
54 | ```
55 |
56 | ```bash
57 | $ python run.py -i --text-prompt "red cup" --output-name output.png
58 | ```
59 |
--------------------------------------------------------------------------------
/examples/object_detection/clip_rcnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | from typing import List, Dict, Tuple, Optional
14 | import cv2
15 | import selectivesearch
16 | import torch
17 | import torch.nn as nn
18 | import numpy as np
19 | from pytorch_clip_guided_loss import get_clip_guided_loss
20 |
21 |
22 | class ClipRCNN(nn.Module):
23 | """ Implementation of the CLIP guided object detection model.
24 | Model is based on Selective Search region proposals and CLIP
25 | guided loss to make text/image driven object detection.
26 | Arguments:
27 | scale (int): Free parameter. Higher means larger clusters in felzenszwalb segmentation.
28 | sigma (float): Width of Gaussian kernel for felzenszwalb segmentation.
29 | min_size (int): Minimum component size for felzenszwalb segmentation.
30 | aspect_ratio (Tuple[float, float]): valid range of aspect ratios for region proposals.
31 | clip_type (str): type of the CLIP model.
32 | batch_size (int): batch size.
33 | top_k (int): top k predictions will be return.
34 | """
35 | def __init__(
36 | self,
37 | scale: int = 500,
38 | sigma: float = 0.9,
39 | min_size: float = 0.1,
40 | aspect_ratio: Tuple[float, float] = (0.5, 1.5),
41 | clip_type: str = "ruclip",
42 | batch_size: int = 128,
43 | top_k: int = 1
44 | ):
45 | super().__init__()
46 | # selective search parameters
47 | self.scale = scale
48 | self.sigma = sigma
49 | self.min_size = min_size
50 | self.aspect_ratio = aspect_ratio
51 | # inference params
52 | self.batch_size = batch_size
53 | # output params
54 | self.top_k = top_k
55 | # CLIP guided loss
56 | self.clip_loss = get_clip_guided_loss(clip_type, input_range=(0.0, 1.0))
57 | self.input_size = self.clip_loss.image_processor[0].size
58 | # utils
59 | self.register_buffer("device_info", torch.tensor(0))
60 |
61 | def add_prompt(
62 | self,
63 | image: Optional[torch.Tensor] = None,
64 | text: Optional[str] = None,
65 | weight: float = 1.0,
66 | label: Optional[str] = None,
67 | store_src: bool = True
68 | ) -> str:
69 | """Add prompt to loss function.
70 | Arguments:
71 | image (torch.Tensor): input image [Optional].
72 | text (str): input text [Optional].
73 | weight (float): importance of the prompt.
74 | label (str): label of the prompt [Optional].
75 | store_src (bool): store source data of the prompt.
76 | Returns:
77 | label (src): label of the prompt.
78 | """
79 | return self.clip_loss.add_prompt(image, text, weight, label, store_src)
80 |
81 | def clear_prompts(self) -> None:
82 | """Delete all available prompts."""
83 | return self.clip_loss.clear_prompts()
84 |
85 | @torch.no_grad()
86 | def detect(self, img: np.array) -> List[Dict]:
87 | """ Detect objects on the input image using CLIP guided prompts.
88 | Argument:
89 | img (np.array): input image.
90 | Returns:
91 | outputs (List[Dict]): predicts in format:
92 | [{"rect": [x, y, w, h], "loss": loss_val}]
93 | """
94 | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95 | # generate proposals by selective search
96 | proposals = self._generate_proposals(img_rgb)
97 | if not len(proposals):
98 | return []
99 | batch = self._prepare_batch(img_rgb, proposals).to(self.device_info.device)
100 | # predict CLIP loss
101 | loss = self._predict_clip_loss(batch)
102 | outputs = self._generate_output(proposals, loss)
103 | return outputs
104 |
105 | def _generate_proposals(self, img: np.array) -> List[Tuple[int, int, int, int]]:
106 | """ Generate region proposals using selective search algorithm.
107 | Argument:
108 | img (np.array): input image.
109 | Returns:
110 | proposals (List[Tuple[int, int, int, int]]): output proposals in format [(x, y, w, h)]
111 | """
112 | min_size = int(img.shape[0] * img.shape[1] * self.min_size)
113 | # generate proposals
114 | img_lbl, regions = selectivesearch.selective_search(
115 | img, scale=self.scale, sigma=self.sigma, min_size=min_size
116 | )
117 | # filter by aspect ratio
118 | proposals = []
119 | for region in regions:
120 | x, y, w, h = region["rect"]
121 | aspect_ratio = float(w) / float(h)
122 | if aspect_ratio > self.aspect_ratio[0] and aspect_ratio < self.aspect_ratio[1]:
123 | proposals.append([x, y, w, h])
124 | return proposals
125 |
126 | def _prepare_batch(self, img: np.array, proposals: List[Tuple[int, int, int, int]]) -> torch.Tensor:
127 | """ Crop region proposals and generate batch
128 | Argument:
129 | img (np.array): input image.
130 | proposals (List[Tuple[int, int, int, int]]): output proposals in format [(x, y, w, h)]
131 | Returns:
132 | batch (torch.Tensor): output batch (B, C, H, W).
133 | """
134 | batch = []
135 | for x, y, w, h in proposals:
136 | crop = cv2.resize(img[y:y+h, x:x+w], self.input_size)
137 | batch.append(torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0))
138 | batch = torch.cat(batch, dim=0)
139 | # normalize batch
140 | batch = batch / 255.0
141 | return batch
142 |
143 | def _predict_clip_loss(self, batch_full: torch.Tensor) -> torch.Tensor:
144 | """ Predict CLIP loss for region proposals using user's prompts.
145 | Argument:
146 | batch_full (torch.Tensor): input batch (B, C, H, W).
147 | Returns:
148 | loss (torch.Tensor): output batch (B, ).
149 | """
150 | loss = []
151 | id_start = 0
152 | while id_start < batch_full.size(0):
153 | id_stop = min(id_start + self.batch_size, batch_full.size(0))
154 | batch = batch_full[id_start:id_stop]
155 | loss.append(self.clip_loss.image_loss(image=batch, reduce=None)["loss"].cpu())
156 | id_start = id_stop
157 | loss = torch.cat(loss, dim=0)
158 | return loss
159 |
160 | def _generate_output(self, proposals: List[Tuple[int, int, int, int]], loss: torch.Tensor) -> List[Dict]:
161 | """ Generate top_k predictions as an output of the model.
162 | Argument:
163 | proposals (List[Tuple[int, int, int, int]]): output proposals in format [(x, y, w, h)]
164 | loss (torch.Tensor): output batch (B, ).
165 | Returns:
166 | outputs (List[Dict]): predicts in format:
167 | [{"rect": [x, y, w, h], "loss": loss_val}]
168 | """
169 | output = []
170 | vals, ids = loss.sort()
171 | top_k = min(self.top_k, len(proposals))
172 | for i in range(top_k):
173 | output.append({
174 | "rect": proposals[ids[i]],
175 | "loss": vals[i]
176 | })
177 | return output
178 |
--------------------------------------------------------------------------------
/examples/object_detection/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | cv2
3 | pytorch_clip_guided_loss
4 | numpy
5 | selectivesearch
6 |
--------------------------------------------------------------------------------
/examples/object_detection/resources/preds.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bes-dev/pytorch_clip_guided_loss/5dc906ccf981b6647346dd46a53886b313a4de77/examples/object_detection/resources/preds.jpg
--------------------------------------------------------------------------------
/examples/object_detection/run.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import argparse
14 | import cv2
15 | import torch
16 | from clip_rcnn import ClipRCNN
17 |
18 |
19 | def main(args):
20 | # build detector
21 | detector = ClipRCNN(
22 | scale=args.scale,
23 | sigma=args.sigma,
24 | min_size=args.min_size,
25 | aspect_ratio=[float(r) for r in args.aspect_ratio.split(",")],
26 | clip_type=args.clip_type,
27 | batch_size=args.batch_size,
28 | top_k=args.top_k
29 | )
30 | # add prompts
31 | if args.text_prompt is not None:
32 | detector.add_prompt(text=args.text_prompt)
33 | if args.image_prompt is not None:
34 | image = cv2.cvtColor(cv2.imread(args.image_prompt), cv2.COLOR_BGR2RGB)
35 | image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
36 | image = img / 255.0
37 | detector.add_prompt(image=image)
38 | image = cv2.imread(args.image)
39 | boxes = detector.detect(image)
40 | for box in boxes:
41 | x, y, w, h = box["rect"]
42 | cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 4)
43 | if args.output_image is None:
44 | cv2.imshow("image", images)
45 | cv2.waitKey()
46 | else:
47 | cv2.imwrite(args.output_image, image)
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument("-i", "--image", type=str, help="Input image.")
53 | parser.add_argument("--device", type=str, default="cuda:0", help="inference device.")
54 | parser.add_argument("--text-prompt", type=str, default=None, help="Text prompt.")
55 | parser.add_argument("--image-prompt", type=str, default=None, help="Image prompt.")
56 | parser.add_argument("--clip-type", type=str, default="ruclip", help="Type of CLIP model [clip, ruclip].")
57 | parser.add_argument("--batch-size", type=int, default=128, help="Batch size.")
58 | parser.add_argument("--scale", type=int, default=500, help="Scale (selective search).")
59 | parser.add_argument("--sigma", type=float, default=0.9, help="Sigma (selective search).")
60 | parser.add_argument("--min-size", type=float, default=0.05, help="Minimum area of the region proposal (selective search).")
61 | parser.add_argument("--aspect-ratio", type=str, default="0.5,1.5", help="Aspect ratio (selective search).")
62 | parser.add_argument("--top-k", type=int, default=1, help="top k predictions will be return.")
63 | parser.add_argument("--output-image", type=str, default=None, help="Output image name.")
64 | args = parser.parse_args()
65 | main(args)
66 |
--------------------------------------------------------------------------------
/examples/vqgan_clip/README.md:
--------------------------------------------------------------------------------
1 | # VQGAN-CLIP Text-To-Image pipeline
2 |
3 | Tiny implementation of the Text-To-Image pipeline using VQVAE by SberAI and pytorch_clip_guided_loss library.
4 |
5 | ## Usage
6 |
7 | ### Install requirements
8 |
9 | ```bash
10 | $ pip install -r requirements.txt
11 | ```
12 |
13 | ### Generate image from text
14 |
15 | ```bash
16 | $ python run.py --help
17 |
18 | usage: run.py [-h] [--device DEVICE] [-t TEXT] [--cfg CFG] [--clip-type CLIP_TYPE] [--output-size OUTPUT_SIZE] [--output-name OUTPUT_NAME] [--lr LR] [--n-steps N_STEPS] [--batch-size BATCH_SIZE]
19 |
20 | optional arguments:
21 | -h, --help show this help message and exit
22 | --device DEVICE inference device.
23 | -t TEXT, --text TEXT Text prompt.
24 | --cfg CFG Path to VQVAE config.
25 | --clip-type CLIP_TYPE
26 | Type of CLIP model [clip, ruclip].
27 | --output-size OUTPUT_SIZE
28 | Size of the output image.
29 | --output-name OUTPUT_NAME
30 | Name of the output image.
31 | --lr LR Learning rate.
32 | --n-steps N_STEPS Number steps of optimization.
33 | --batch-size BATCH_SIZE
34 | Batch size.
35 | ```
36 |
37 | ```bash
38 | $ python run.py --text "A painting in the style of Picasso" --output-name output.png
39 | ```
--------------------------------------------------------------------------------
/examples/vqgan_clip/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bes-dev/pytorch_clip_guided_loss/5dc906ccf981b6647346dd46a53886b313a4de77/examples/vqgan_clip/__init__.py
--------------------------------------------------------------------------------
/examples/vqgan_clip/configs/vqvae.yaml:
--------------------------------------------------------------------------------
1 | name: VQVAE
2 | description: vqgan.gumbelf8-sber.config.yml
3 | ckpt:
4 | repo_id: shonenkov/rudalle-utils
5 | filename: vqgan.gumbelf8-sber.model.ckpt
6 | params:
7 | model:
8 | base_learning_rate: 4.5e-06
9 | target: taming.models.vqgan.GumbelVQ
10 | params:
11 | kl_weight: 1.0e-08
12 | embed_dim: 256
13 | n_embed: 8192
14 | monitor: val/rec_loss
15 | temperature_scheduler_config:
16 | target: taming.lr_scheduler.LambdaWarmUpCosineScheduler
17 | params:
18 | warm_up_steps: 0
19 | max_decay_steps: 1000001
20 | lr_start: 0.9
21 | lr_max: 0.9
22 | lr_min: 1.0e-06
23 | ddconfig:
24 | double_z: false
25 | z_channels: 256
26 | resolution: 256
27 | in_channels: 3
28 | out_ch: 3
29 | ch: 128
30 | ch_mult:
31 | - 1
32 | - 1
33 | - 2
34 | - 4
35 | num_res_blocks: 2
36 | attn_resolutions:
37 | - 32
38 | dropout: 0.0
39 | lossconfig:
40 | target: taming.modules.losses.vqperceptual.DummyLoss
41 |
--------------------------------------------------------------------------------
/examples/vqgan_clip/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | cv2
3 | pytorch_clip_guided_loss
4 | numpy
5 | taming-transformers
6 | huggingface_hub
7 | kornia
8 | tqdm
9 |
--------------------------------------------------------------------------------
/examples/vqgan_clip/run.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import argparse
14 | import cv2
15 | import torch
16 | import torch.nn as nn
17 | from omegaconf import OmegaConf
18 | import kornia.augmentation as K
19 | from vqvae import GumbelVQ, ste_quantize
20 | from pytorch_clip_guided_loss import get_clip_guided_loss
21 | from tqdm.autonotebook import tqdm
22 |
23 |
24 | def main(args):
25 | # load model
26 | vqvae, vqvae_processor = GumbelVQ.from_pretrained(OmegaConf.load(args.cfg))
27 | clip_guided_loss = get_clip_guided_loss(args.clip_type, input_range = (0, 1))
28 | # model to inference device
29 | vqvae.to(args.device)
30 | clip_guided_loss.to(args.device)
31 | # initialize prompt
32 | clip_guided_loss.add_prompt(text=args.text)
33 | # initialize image
34 | n_toks = args.output_size // 2 ** (vqvae.decoder.num_resolutions - 1)
35 | z = vqvae.ids_to_embs(
36 | torch.randint(vqvae.quantize.n_embed, (1, n_toks * n_toks)).to(args.device)
37 | ).detach().requires_grad_(True)
38 | # initialize augmentations
39 | augs = nn.Sequential(
40 | K.RandomAffine(degrees=15, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True),
41 | K.RandomPerspective(distortion_scale=0.7, p=0.7),
42 | K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7),
43 | K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7)
44 | )
45 | # initialize optimizer
46 | opt = torch.optim.AdamW([z], lr=args.lr)
47 | # initilize valid range for embeddings
48 | z_min = vqvae.quantize.embed.weight.min(dim=0).values[None, :, None, None]
49 | z_max = vqvae.quantize.embed.weight.max(dim=0).values[None, :, None, None]
50 | # start optimization
51 | iterator = tqdm(range(args.n_steps))
52 | for i in iterator:
53 | opt.zero_grad()
54 | x = ste_quantize(z.movedim(1, 3), vqvae.quantize.embed.weight).movedim(3, 1)
55 | x = vqvae.embs_to_img(x).add(1).div(2).clamp(0, 1)
56 | x = x.repeat_interleave(args.batch_size, dim=0)
57 | x = augs(x)
58 | loss = clip_guided_loss.image_loss(image = x)["loss"]
59 | loss.backward()
60 | opt.step()
61 | with torch.inference_mode():
62 | z.copy_(z.maximum(z_min).minimum(z_max))
63 | iterator.set_description(f"loss: {loss.item()}")
64 | # save image
65 | z = ste_quantize(z.movedim(1, 3), vqvae.quantize.embed.weight).movedim(3, 1)
66 | cv2.imwrite(args.output_name, vqvae_processor.decode(vqvae.embs_to_img(z), rgb2bgr=True)[0])
67 |
68 |
69 | if __name__ == "__main__":
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument("--device", type=str, default="cuda:0", help="inference device.")
72 | parser.add_argument("-t", "--text", type=str, default=None, help="Text prompt.")
73 | parser.add_argument("--cfg", type=str, default="configs/vqvae.yaml", help="Path to VQVAE config.")
74 | parser.add_argument("--clip-type", type=str, default="ruclip", help="Type of CLIP model [clip, ruclip].")
75 | parser.add_argument("--output-size", type=int, default=256, help="Size of the output image.")
76 | parser.add_argument("--output-name", type=str, default="output.png", help="Name of the output image.")
77 | parser.add_argument("--lr", type=float, default=0.1, help="Learning rate.")
78 | parser.add_argument("--n-steps", type=int, default=100, help="Number steps of optimization.")
79 | parser.add_argument("--batch-size", type=int, default=32, help="Batch size.")
80 | args = parser.parse_args()
81 | main(args)
82 |
--------------------------------------------------------------------------------
/examples/vqgan_clip/vqvae/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | from .vqvae import GumbelVQ
14 | from .utils import ste_quantize
15 |
--------------------------------------------------------------------------------
/examples/vqgan_clip/vqvae/image_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import typing
14 | import cv2
15 | import torch
16 | import numpy as np
17 | from .utils import img_to_tensor, tensor_to_img
18 |
19 |
20 | class OpenCVImageProcessor:
21 | """ Implementation of image encode/decode for VQVAE. """
22 | @staticmethod
23 | def encode(
24 | img: np.array,
25 | size: typing.Optional[int] = None,
26 | normalize: bool = True,
27 | input_range: typing.Tuple[float, float] = (0.0, 255.0),
28 | bgr2rgb: bool = True
29 | ) -> torch.Tensor:
30 | """ Encode input image.
31 | Arguments:
32 | img (np.array): input image.
33 | size (typing.Optional[int]): target size of the image.
34 | normalize (bool): normalize input image.
35 | input_range (typing.Tuple[float, float]): input range.
36 | bgr2rgb (bool): convert input image from BGR to RGB.
37 | Returns:
38 | tensor (torch.Tensor): encoded image.
39 | """
40 | if size is not None:
41 | img = cv2.resize(img, size)
42 | tensor = img_to_tensor(img, normalize, input_range, bgr2rgb)
43 | return tensor
44 |
45 | @staticmethod
46 | def decode(
47 | tensor: torch.Tensor,
48 | normalize: bool = True,
49 | input_range: typing.Tuple[float, float] = (-1, 1),
50 | to_numpy: bool = True,
51 | rgb2bgr: bool = True
52 | ) -> typing.List[np.array or torch.Tensor]:
53 | """ Encode input tensor (output of the VQVAE decoder).
54 | Arguments:
55 | tensor (torch.Tensor): input tensor.
56 | normalize (bool): normalize input tensor.
57 | input_range (typing.Tuple[float, float]): input range.
58 | to_numpy (bool): convert outputs to np.array.
59 | rgb2bgr (bool): convert output from RGB to BGR.
60 | Returns:
61 | imgs (typing.List[np.array or torch.Tensor]): decoded images.
62 | """
63 | imgs = []
64 | for i in range(tensor.size(0)):
65 | img = tensor_to_img(tensor[i], normalize, input_range, to_numpy, rgb2bgr)
66 | imgs.append(img)
67 | return imgs
68 |
--------------------------------------------------------------------------------
/examples/vqgan_clip/vqvae/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import typing
14 | import cv2
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import numpy as np
19 |
20 |
21 | def tensor_to_img(
22 | tensor: torch.Tensor,
23 | normalize: bool = True,
24 | input_range: typing.Tuple[float, float] = (-1, 1),
25 | to_numpy: bool = True,
26 | rgb2bgr: bool = True
27 | ) -> np.array or torch.Tensor:
28 | """ Decode torch.Tensor to np.array.
29 | Arguments:
30 | tensor (torch.Tensor): input tensor.
31 | normalize (bool): normalize input tensor.
32 | input_range (typing.Tuple[float, float]): input range.
33 | to_numpy (bool): convert outputs to np.array.
34 | rgb2bgr (bool): convert output from RGB to BGR.
35 | Returns:
36 | img (np.array or torch.Tensor): decoded image.
37 | """
38 | if normalize:
39 | tensor = torch.clamp(tensor, min=input_range[0], max=input_range[1])
40 | tensor = (tensor - input_range[0]) / (input_range[1] - input_range[0] + 1e-5)
41 | img = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
42 | if to_numpy:
43 | img = img.to('cpu', torch.uint8).numpy()
44 | if rgb2bgr:
45 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
46 | return img
47 |
48 |
49 | def img_to_tensor(
50 | img: np.array,
51 | normalize: bool = True,
52 | input_range: typing.Tuple[float, float] = (0.0, 255.0),
53 | bgr2rgb: bool = True
54 | ) -> torch.Tensor:
55 | """ Encode np.array to torch.Tensor.
56 | Arguments:
57 | img (np.array): input image.
58 | size (typing.Optional[int]): target size of the image.
59 | normalize (bool): normalize input image.
60 | input_range (typing.Tuple[float, float]): input range.
61 | bgr2rgb (bool): convert input image from BGR to RGB.
62 | Returns:
63 | tensor (torch.Tensor): encoded image.
64 | """
65 | if bgr2rgb:
66 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
67 | tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(dtype=torch.float32)
68 | if normalize:
69 | tensor = torch.clamp(tensor, min=input_range[0], max=input_range[1])
70 | tensor = (tensor - input_range[0]) / (input_range[1] - input_range[0] + 1e-5)
71 | tensor = 2.0 * tensor - 1.0
72 | return tensor
73 |
74 |
75 | class STEQuantize(torch.autograd.Function):
76 | """ Quantize VQVAE embeddings to VQVAE codebook with
77 | gradients in style of Straight-Through Estimators.
78 | """
79 | @staticmethod
80 | def forward(ctx, embs: torch.Tensor, codebook: torch.Tensor) -> torch.Tensor:
81 | """ Forward path.
82 | Arguments:
83 | embs (torch.Tensor): input embeddings.
84 | codebook (torch.Tensor): VQVAE codebook.
85 | Returns:
86 | embs_q (torch.Tensor): quantized embeddings
87 | """
88 | d = embs.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * embs @ codebook.T
89 | indices = d.argmin(-1)
90 | embs_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
91 | return embs_q
92 |
93 | @staticmethod
94 | def backward(ctx, grad_in: torch.Tensor) -> typing.Tuple[torch.Tensor, None]:
95 | """ Backward path like Straight-Through Estimators.
96 | Arguments:
97 | grad_in (torch.Tensor): input gradients.
98 | Returns:
99 | grad_out (torch.Tensor): STE gradients.
100 | """
101 | return grad_in, None
102 |
103 |
104 | def ste_quantize(x: torch.Tensor, codebook: torch.tensor) -> torch.Tensor:
105 | """ Quantize VQVAE embeddings to VQVAE codebook with
106 | gradients in style of Straight-Through Estimators.
107 | Arguments:
108 | embs (torch.Tensor): input embeddings.
109 | codebook (torch.Tensor): VQVAE codebook.
110 | Returns:
111 | embs_q (torch.Tensor): quantized embeddings
112 | """
113 | return STEQuantize.apply(x, codebook)
114 |
--------------------------------------------------------------------------------
/examples/vqgan_clip/vqvae/vqvae.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from math import sqrt, log
3 | import sys
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import einsum
8 | from einops import rearrange
9 | from taming.modules.diffusionmodules.model import Encoder, Decoder
10 | from omegaconf import OmegaConf
11 | from huggingface_hub import hf_hub_download
12 | from .image_processor import OpenCVImageProcessor
13 |
14 |
15 | class GumbelQuantize(nn.Module):
16 | """
17 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
18 | Gumbel Softmax trick quantizer
19 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
20 | https://arxiv.org/abs/1611.01144
21 | """
22 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
23 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True):
24 | super().__init__()
25 | self.embedding_dim = embedding_dim
26 | self.n_embed = n_embed
27 | self.straight_through = straight_through
28 | self.temperature = temp_init
29 | self.kl_weight = kl_weight
30 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
31 | self.embed = nn.Embedding(self.n_embed, self.embedding_dim)
32 | self.use_vqinterface = use_vqinterface
33 |
34 | def forward(self, z, temp=None, return_logits=False):
35 | hard = self.straight_through if self.training else True
36 | temp = self.temperature if temp is None else temp
37 | logits = self.proj(z)
38 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
39 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
40 | # + kl divergence to the prior loss
41 | qy = F.softmax(logits, dim=1)
42 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
43 | ind = soft_one_hot.argmax(dim=1)
44 | if self.use_vqinterface:
45 | if return_logits:
46 | return z_q, diff, (None, None, ind), logits
47 | return z_q, diff, (None, None, ind)
48 | return z_q, diff, ind
49 |
50 |
51 | class GumbelVQ(nn.Module):
52 | def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8, **kwargs):
53 | super().__init__()
54 | z_channels = ddconfig['z_channels']
55 | self.encoder = Encoder(**ddconfig)
56 | self.decoder = Decoder(**ddconfig)
57 | self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0)
58 | self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
59 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
60 |
61 | def encode(self, x):
62 | h = self.encoder(x)
63 | h = self.quant_conv(h)
64 | quant, emb_loss, info = self.quantize(h)
65 | return quant, emb_loss, info
66 |
67 | def decode(self, quant):
68 | quant = self.post_quant_conv(quant)
69 | dec = self.decoder(quant)
70 | return dec
71 |
72 | def img_to_ids(self, img):
73 | _, _, [_, _, indices] = self.encode(img)
74 | return rearrange(indices, 'b h w -> b (h w)')
75 |
76 | def ids_to_embs(self, ids):
77 | b, n = ids.shape
78 | one_hot = F.one_hot(ids, num_classes=self.quantize.n_embed).float()
79 | embs = (one_hot @ self.quantize.embed.weight)
80 | embs = rearrange(embs, 'b (h w) c -> b c h w', h = int(sqrt(n)))
81 | return embs
82 |
83 | def embs_to_img(self, embs):
84 | img = self.decode(embs)
85 | return img
86 |
87 | @classmethod
88 | def from_pretrained(cls, cfg):
89 | print(f"[{cls.__name__}]: create model")
90 | model = cls(**cfg.params.model.params)
91 | print(f"[{cls.__name__}]: load checkpoint {cfg.ckpt}")
92 | model.load_state_dict(
93 | torch.load(hf_hub_download(**cfg.ckpt), map_location="cpu")["state_dict"],
94 | strict=False
95 | )
96 | print(f"[{cls.__name__}]: create processor")
97 | model_preprocess = OpenCVImageProcessor()
98 | return model, model_preprocess
99 |
--------------------------------------------------------------------------------
/pytorch_clip_guided_loss/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import typing
14 | import torch.nn as nn
15 | from omegaconf import OmegaConf
16 | from .utils import res_path
17 | from .clip_guided_loss import CLIPGuidedLoss
18 | from pytorch_clip import get_clip_model, get_models_list
19 |
20 |
21 | def get_clip_guided_loss(
22 | clip_type: str = "clip",
23 | input_range: typing.Tuple[float, float] = (-1.0, 1.0),
24 | cache_dir: str = "/tmp/"
25 | ) -> nn.Module:
26 | """get CLIPGuidedloss model.
27 | Arguments:
28 | clip_typle (str): type of the CLIP model ("clip" or "ruclip").
29 | input_range (tuple[float, float]): input range.
30 | cache_dir (str): path to cache dir.
31 | Returns:
32 | model (nn.Module): CLIPGuidedloss model.
33 | """
34 | clip, tokenizer, transforms, cfg = get_clip_model(clip_type, input_range, cache_dir)
35 | return CLIPGuidedLoss(clip, tokenizer, transforms)
36 |
--------------------------------------------------------------------------------
/pytorch_clip_guided_loss/clip_guided_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import typing
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | # clip model
18 | from pytorch_clip import get_clip_model
19 | from pytorch_clip.processor.text_processor import TextProcessor
20 | # utils
21 | from omegaconf import OmegaConf
22 |
23 |
24 | class CLIPPrompt(nn.Module):
25 | """ Implementation of CLIP prompt
26 | Arguments:
27 | embed (torch.Tensor): input embedding.
28 | weight (float): importance of the prompt.
29 | src (torch.Tensor or str): source data of the prompt.
30 | """
31 | def __init__(
32 | self,
33 | embed: torch.Tensor,
34 | weight: float,
35 | src: typing.Optional[torch.Tensor or str] = None
36 | ):
37 | super().__init__()
38 | self.register_buffer("embed", embed)
39 | self.register_buffer("weight", torch.as_tensor(weight))
40 | if isinstance(src, torch.Tensor):
41 | self.register_buffer("src", src)
42 | else:
43 | self.src = src
44 |
45 | def forward(self, x: torch.Tensor) -> torch.Tensor:
46 | """Compute spherical distance loss between prompt and input embedding.
47 | Arguments:
48 | x (torch.Tensor): input embedding.
49 | Returns:
50 | loss (torch.Tensor): output spherical loss.
51 | """
52 | return self.weight * x.sub(self.embed).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
53 |
54 |
55 | class CLIPGuidedLoss(nn.Module):
56 | """ Implementation of CLIP guided loss function.
57 | Arguments:
58 | model (nn.Module): CLIP model.
59 | text_processor (TextProcessor): text processor.
60 | image_processor (nn.Module): image processor.
61 | """
62 | def __init__(self, model: nn.Module, text_processor: TextProcessor, image_processor: nn.Module):
63 | super().__init__()
64 | # clip model
65 | self.model = model
66 | self.text_processor = text_processor
67 | self.image_processor = image_processor
68 | # prompts
69 | self.prompts = nn.ModuleDict()
70 | # device info
71 | self.register_buffer("device_info", torch.tensor(1))
72 |
73 | def add_prompt(
74 | self,
75 | image: typing.Optional[torch.Tensor] = None,
76 | text: typing.Optional[str] = None,
77 | weight: float = 1.0,
78 | label: typing.Optional[str] = None,
79 | store_src: bool = True
80 | ) -> str:
81 | """Add prompt to loss function.
82 | Arguments:
83 | image (torch.Tensor): input image [Optional].
84 | text (str): input text [Optional].
85 | weight (float): importance of the prompt.
86 | label (str): label of the prompt [Optional].
87 | store_src (bool): store source data of the prompt.
88 | Returns:
89 | label (src): label of the prompt.
90 | """
91 | if text is None and image is None:
92 | return
93 | embed, src = self._get_embed(image, text)
94 | if label is None:
95 | label = str(len(self.prompts))
96 | self.prompts[label] = CLIPPrompt(
97 | embed = embed.detach(),
98 | weight = weight,
99 | src = src if store_src else None
100 | ).to(self.device_info.device)
101 | return label
102 |
103 | def delete_prompt(self, label: typing.Optional[str] = None) -> None:
104 | """Add prompt to loss function.
105 | Arguments:
106 | label (str): label of the prompt to delete [Optional].
107 | """
108 | if label in self.prompts:
109 | self.prompts.pop(label)
110 |
111 | def clear_prompts(self) -> None:
112 | """Delete all available prompts."""
113 | self.prompts.clear()
114 |
115 | def get_prompts_list(self) -> typing.List[str]:
116 | """Get list of all available prompts.
117 | Returns:
118 | prompts (list): list of prompts labels.
119 | """
120 | return list(self.prompts.keys())
121 |
122 | def get_prompt(self, label: str) -> typing.Optional[CLIPPrompt]:
123 | """Get prompt if available.
124 | Arguments:
125 | label (str): label of the prompt [Optional].
126 | Returns:
127 | prompt (CLIPPrompt or None): prompt [Optional].
128 | """
129 | if label in self.prompts:
130 | return self.prompts[label]
131 |
132 | def forward(self, embed: typing.Optional[torch.Tensor], reduce: str = "mean") -> torch.Tensor:
133 | """Compute CLIP guided loss between input image/text and all available prompts.
134 | Arguments:
135 | embed (torch.Tensor): input embedding.
136 | reduce (str): reduce mode ("mean", "sum" or None)
137 | Returns:
138 | loss (torch.Tensor): CLIP guided loss.
139 | """
140 | assert reduce in ["mean", "sum", None]
141 | loss = {}
142 | for key, prompt in self.prompts.items():
143 | loss[key] = prompt(embed)
144 | if reduce is not None:
145 | loss[key] = loss[key].mean() if reduce == "mean" else loss[key].sum()
146 | loss["loss"] = sum(loss.values()) if len(loss) else 0
147 | return loss
148 |
149 | def image_loss(self, image: typing.Optional[torch.Tensor], reduce: str="mean") -> torch.Tensor:
150 | """Compute CLIP guided loss between input image and all available prompts.
151 | Arguments:
152 | image (torch.Tensor): input image.
153 | reduce (str): reduce mode ("mean", "sum" or None)
154 | Returns:
155 | loss (torch.Tensor): CLIP guided loss.
156 | """
157 | embed, _ = self._get_embed(image=image)
158 | return self(embed, reduce=reduce)
159 |
160 | def text_loss(self,
161 | input_ids: typing.Optional[torch.Tensor],
162 | embed: typing.Optional[torch.Tensor],
163 | attention_mask: typing.Optional[torch.Tensor] = None,
164 | reduce: str = "mean"
165 | ) -> torch.Tensor:
166 | """Compute CLIP guided loss between input text and all available prompts.
167 | Arguments:
168 | input_ids (torch.Tensor): input text tokens indices.
169 | embed (torch.Tensor): input text embeddings.
170 | attention_mask (torch.Tensor): attention mask [Optional].
171 | reduce (str): reduce mode ("mean", "sum" or None)
172 | Returns:
173 | loss (torch.Tensor): CLIP guided loss.
174 | """
175 | embed = self.model.get_text_features(
176 | input_ids = input_ids,
177 | attention_mask = attention_mask,
178 | inputs_embeds = embed
179 | )
180 | return self(embed, reduce=reduce)
181 |
182 | def _get_embed(
183 | self,
184 | image: typing.Optional[torch.Tensor] = None,
185 | text: typing.Optional[typing.Dict[str, torch.Tensor] or str] = None
186 | ) -> typing.Tuple[torch.Tensor, torch.Tensor or typing.Dict[str, torch.Tensor] or str]:
187 | """Compute CLIP embedding for input data.
188 | Arguments:
189 | image (torch.Tensor): input image [Optional].
190 | text (str): input text [Optional].
191 | Returns:
192 | embed (torch.Tensor): output embedding.
193 | """
194 | if image is None and text is None:
195 | raise ValueError("You have to specify either text or images. Both cannot be none.")
196 | if image is not None:
197 | embed = self.model.get_image_features(self.image_processor(image))
198 | src = image
199 | else:
200 | if isinstance(text, str):
201 | batch = self.text_processor.encode(text, return_mask=True)
202 | batch["input_ids"] = batch["input_ids"].view(1, -1).to(self.device_info.device)
203 | if "attention_mask" in batch:
204 | batch["attention_mask"] = batch["attention_mask"].view(1, -1).to(self.device_info.device)
205 | else:
206 | batch = text
207 | embed = self.model.get_text_features(**batch)
208 | src = text
209 | embed = F.normalize(embed, dim=-1)
210 | return embed, src
211 |
--------------------------------------------------------------------------------
/pytorch_clip_guided_loss/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 | import os
14 |
15 |
16 | def get_module_path():
17 | """ Get module path
18 | Returns:
19 | path (str): path to current module.
20 | """
21 | file_path = os.path.abspath(__file__)
22 | module_path = os.path.dirname(file_path)
23 | return module_path
24 |
25 |
26 | def res_path(path):
27 | """ Resource path
28 | Arguments:
29 | path (str): related path from module dir to some resources.
30 | Returns:
31 | path (str): absolute path to module dir.
32 | """
33 | return os.path.join(get_module_path(), path)
34 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | wheel
2 | omegaconf
3 | torch
4 | transformers
5 | huggingface_hub
6 | youtokentome
7 | googletrans
8 | numpy
9 | pytorch_clip
10 |
--------------------------------------------------------------------------------
/resources/preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bes-dev/pytorch_clip_guided_loss/5dc906ccf981b6647346dd46a53886b313a4de77/resources/preview.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 by Sergei Belousov
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 | from setuptools import setup, find_packages
17 |
18 | readme = open('README.md').read()
19 |
20 | VERSION = '2021.12.25.0'
21 |
22 | requirements = [
23 | 'wheel',
24 | 'Cython',
25 | 'cython',
26 | 'omegaconf',
27 | 'torch',
28 | 'transformers',
29 | 'huggingface_hub',
30 | 'youtokentome',
31 | 'googletrans',
32 | 'numpy',
33 | 'pytorch_clip'
34 | ]
35 |
36 | setup(
37 | # Metadata
38 | name='pytorch_clip_guided_loss',
39 | version=VERSION,
40 | author='Sergei Belousov aka BeS',
41 | author_email='sergei.o.belousov@gmail.com',
42 | description='Pytorch implementation of the CLIP guided loss.',
43 | long_description=readme,
44 | long_description_content_type='text/markdown',
45 |
46 | # Package info
47 | packages=find_packages(exclude=('*test*',)),
48 |
49 | #
50 | zip_safe=True,
51 | install_requires=requirements,
52 |
53 | # Classifiers
54 | classifiers=[
55 | 'Programming Language :: Python :: 3',
56 | ],
57 |
58 |
59 | # install .json configs
60 | package_data={
61 | # "pytorch_clip_guided_loss": ["configs/*.yml"]
62 | },
63 | include_package_data=True
64 | )
65 |
--------------------------------------------------------------------------------