├── .gitignore ├── CLIP.png ├── LICENSE ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── model.py └── tokenizer.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # work 132 | .test.py 133 | test.py 134 | .pretrained_models/ 135 | pretrained_models/ 136 | paddleclip.egg-info/ 137 | .paddleclip.egg-info/ 138 | dist/ 139 | .dist/ 140 | build/ 141 | .build/ -------------------------------------------------------------------------------- /CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AgentMaker/Paddle-CLIP/45b10f1170ac648609980cf982319b33605ce552/CLIP.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paddle-CLIP 2 | ![GitHub forks](https://img.shields.io/github/forks/AgentMaker/Paddle-CLIP) 3 | ![GitHub Repo stars](https://img.shields.io/github/stars/AgentMaker/Paddle-CLIP) 4 | ![GitHub release (latest by date including pre-releases)](https://img.shields.io/github/v/release/AgentMaker/Paddle-CLIP?include_prereleases) 5 | ![GitHub](https://img.shields.io/github/license/AgentMaker/Paddle-CLIP) 6 | A PaddlePaddle version implementation of CLIP of OpenAI. [【origin repo】](https://github.com/openai/CLIP/) 7 | 8 | ## Install Package 9 | * Install by pip: 10 | ```shell 11 | $ pip install paddleclip 12 | ``` 13 | * Install by wheel package:[【Releases Packages】](https://github.com/AgentMaker/Paddle-CLIP/releases) 14 | 15 | ## Requirements 16 | * wget 17 | * ftfy 18 | * regex 19 | * paddlepaddle(cpu/gpu)>=2.0.1 20 | 21 | ## Quick Start 22 | ```python 23 | import paddle 24 | from PIL import Image 25 | from clip import tokenize, load_model 26 | 27 | # Load the model 28 | model, transforms = load_model('ViT_B_32', pretrained=True) 29 | 30 | # Prepare the inputs 31 | image = transforms(Image.open("CLIP.png")).unsqueeze(0) 32 | text = tokenize(["a diagram", "a dog", "a cat"]) 33 | 34 | # Calculate features and probability 35 | with paddle.no_grad(): 36 | logits_per_image, logits_per_text = model(image, text) 37 | probs = paddle.nn.functional.softmax(logits_per_image, axis=-1) 38 | 39 | # Print the result 40 | print(probs.numpy()) 41 | ``` 42 | [[0.9927937 0.00421065 0.00299568]] 43 | 44 | ## Zero-Shot Prediction 45 | ```python 46 | import paddle 47 | from clip import tokenize, load_model 48 | from paddle.vision.datasets import Cifar100 49 | 50 | # Load the model 51 | model, transforms = load_model('ViT_B_32', pretrained=True) 52 | 53 | # Load the dataset 54 | cifar100 = Cifar100(mode='test', backend='pil') 55 | classes = [ 56 | 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 57 | 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 58 | 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 59 | 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 60 | 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 61 | 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 62 | 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 63 | 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 64 | 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 65 | 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' 66 | ] 67 | 68 | # Prepare the inputs 69 | image, class_id = cifar100[3637] 70 | image_input = transforms(image).unsqueeze(0) 71 | text_inputs = tokenize(["a photo of a %s" % c for c in classes]) 72 | 73 | # Calculate features 74 | with paddle.no_grad(): 75 | image_features = model.encode_image(image_input) 76 | text_features = model.encode_text(text_inputs) 77 | 78 | # Pick the top 5 most similar labels for the image 79 | image_features /= image_features.norm(axis=-1, keepdim=True) 80 | text_features /= text_features.norm(axis=-1, keepdim=True) 81 | similarity = (100.0 * image_features @ text_features.t()) 82 | similarity = paddle.nn.functional.softmax(similarity, axis=-1) 83 | values, indices = similarity[0].topk(5) 84 | 85 | # Print the result 86 | for value, index in zip(values, indices): 87 | print('%s: %.02f%%' % (classes[index], value*100.)) 88 | ``` 89 | snake: 65.31% 90 | turtle: 12.29% 91 | sweet_pepper: 3.83% 92 | lizard: 1.88% 93 | crocodile: 1.75% 94 | 95 | ## Linear-probe evaluation 96 | ```python 97 | import os 98 | import paddle 99 | import numpy as np 100 | from tqdm import tqdm 101 | from paddle.io import DataLoader 102 | from clip import tokenize, load_model 103 | from paddle.vision.datasets import Cifar100 104 | from sklearn.linear_model import LogisticRegression 105 | 106 | # Load the model 107 | model, transforms = load_model('ViT_B_32', pretrained=True) 108 | 109 | # Load the dataset 110 | train = Cifar100(mode='train', transform=transforms, backend='pil') 111 | test = Cifar100(mode='test', transform=transforms, backend='pil') 112 | 113 | # Get features 114 | def get_features(dataset): 115 | all_features = [] 116 | all_labels = [] 117 | 118 | with paddle.no_grad(): 119 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)): 120 | features = model.encode_image(images) 121 | all_features.append(features) 122 | all_labels.append(labels) 123 | 124 | return paddle.concat(all_features).numpy(), paddle.concat(all_labels).numpy() 125 | 126 | # Calculate the image features 127 | train_features, train_labels = get_features(train) 128 | test_features, test_labels = get_features(test) 129 | 130 | # Perform logistic regression 131 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=0) 132 | classifier.fit(train_features, train_labels) 133 | 134 | # Evaluate using the logistic regression classifier 135 | predictions = classifier.predict(test_features) 136 | accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100. 137 | 138 | # Print the result 139 | print(f"Accuracy = {accuracy:.3f}") 140 | ``` 141 | Accuracy = 79.900 142 | 143 | ## Pretrained Models Download 144 | * [RN50](https://bj.bcebos.com/v1/ai-studio-online/6ffc89246e974a809e6e4b40fdb58063a112a0153e674dae8ed5b6dfe5d46d86?responseContentDisposition=attachment%3B%20filename%3DRN50.pdparams) 145 | * [RN50x4](https://bj.bcebos.com/v1/ai-studio-online/9f874e0174da48ffbd7c17e77b1fb278632620a9995e476ba873e334caec9037?responseContentDisposition=attachment%3B%20filename%3DRN50x4.pdparams) 146 | * [RN101](https://bj.bcebos.com/v1/ai-studio-online/484592d98c584785bc8f6f9f7badbf4a9fb7a96f6102470697ed974e8eeee2a9?responseContentDisposition=attachment%3B%20filename%3DRN101.pdparams) 147 | * [ViT_B_32](https://bj.bcebos.com/v1/ai-studio-online/eb5e4dbf1ec142caa003a27cefd510ef46a8a6c3932a4d60bfecb3f3ab746c02?responseContentDisposition=attachment%3B%20filename%3DViT-B-32.pdparams) 148 | 149 | ## Contact us 150 | Email : [agentmaker@163.com]()
151 | QQ Group : 1005109853 152 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import paddle 4 | from .tokenizer import Tokenizer 5 | from .model import CLIP 6 | from paddle.vision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 7 | 8 | 9 | tokenizer = Tokenizer() 10 | 11 | 12 | def get_transforms(image_resolution): 13 | transforms = Compose([ 14 | Resize(image_resolution, interpolation='bicubic'), 15 | CenterCrop(image_resolution), 16 | lambda image: image.convert("RGB"), 17 | ToTensor(), 18 | Normalize((0.48145466, 0.4578275, 0.40821073), 19 | (0.26862954, 0.26130258, 0.27577711)), 20 | ]) 21 | return transforms 22 | 23 | 24 | def clip_rn50(): 25 | model = CLIP( 26 | embed_dim=1024, 27 | image_resolution=224, 28 | vision_layers=(3, 4, 6, 3), 29 | vision_width=64, 30 | vision_patch_size=None, 31 | context_length=77, 32 | vocab_size=49408, 33 | transformer_width=512, 34 | transformer_heads=8, 35 | transformer_layers=12 36 | ) 37 | return model, get_transforms(224) 38 | 39 | 40 | def clip_rn101(): 41 | model = CLIP( 42 | embed_dim=512, 43 | image_resolution=224, 44 | vision_layers=(3, 4, 23, 3), 45 | vision_width=64, 46 | vision_patch_size=None, 47 | context_length=77, 48 | vocab_size=49408, 49 | transformer_width=512, 50 | transformer_heads=8, 51 | transformer_layers=12 52 | ) 53 | return model, get_transforms(224) 54 | 55 | 56 | def clip_rn50x4(): 57 | model = CLIP( 58 | embed_dim=640, 59 | image_resolution=288, 60 | vision_layers=(4, 6, 10, 6), 61 | vision_width=80, 62 | vision_patch_size=None, 63 | context_length=77, 64 | vocab_size=49408, 65 | transformer_width=640, 66 | transformer_heads=10, 67 | transformer_layers=12 68 | ) 69 | return model, get_transforms(288) 70 | 71 | 72 | def clip_vit_b_32(): 73 | model = CLIP( 74 | embed_dim=512, 75 | image_resolution=224, 76 | vision_layers=12, 77 | vision_width=768, 78 | vision_patch_size=32, 79 | context_length=77, 80 | vocab_size=49408, 81 | transformer_width=512, 82 | transformer_heads=8, 83 | transformer_layers=12 84 | ) 85 | return model, get_transforms(224) 86 | 87 | 88 | def tokenize(texts, context_length=77): 89 | """ 90 | Returns the tokenized representation of given input string(s) 91 | Parameters 92 | ---------- 93 | texts : Union[str, List[str]] 94 | An input string or a list of input strings to tokenize 95 | context_length : int 96 | The context length to use; all CLIP models use 77 as the context length 97 | Returns 98 | ------- 99 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 100 | """ 101 | if isinstance(texts, str): 102 | texts = [texts] 103 | 104 | sot_token = tokenizer.encoder["<|startoftext|>"] 105 | eot_token = tokenizer.encoder["<|endoftext|>"] 106 | all_tokens = [[sot_token] + 107 | tokenizer.encode(text) + [eot_token] for text in texts] 108 | result = paddle.zeros((len(all_tokens), context_length), dtype='int64') 109 | 110 | for i, tokens in enumerate(all_tokens): 111 | if len(tokens) > context_length: 112 | raise RuntimeError( 113 | f"Input {texts[i]} is too long for context length {context_length}") 114 | result[i, :len(tokens)] = paddle.to_tensor(tokens) 115 | 116 | return result 117 | 118 | 119 | model_dict = { 120 | 'RN50': [clip_rn50, r'https://bj.bcebos.com/v1/ai-studio-online/6ffc89246e974a809e6e4b40fdb58063a112a0153e674dae8ed5b6dfe5d46d86?responseContentDisposition=attachment%3B%20filename%3DRN50.pdparams', 'RN50.pdparams'], 121 | 'RN50x4': [clip_rn50x4, r'https://bj.bcebos.com/v1/ai-studio-online/9f874e0174da48ffbd7c17e77b1fb278632620a9995e476ba873e334caec9037?responseContentDisposition=attachment%3B%20filename%3DRN50x4.pdparams', 'RN50x4.pdparams'], 122 | 'RN101': [clip_rn101, r'https://bj.bcebos.com/v1/ai-studio-online/484592d98c584785bc8f6f9f7badbf4a9fb7a96f6102470697ed974e8eeee2a9?responseContentDisposition=attachment%3B%20filename%3DRN101.pdparams', 'RN101.pdparams'], 123 | 'ViT_B_32': [clip_vit_b_32, r'https://bj.bcebos.com/v1/ai-studio-online/eb5e4dbf1ec142caa003a27cefd510ef46a8a6c3932a4d60bfecb3f3ab746c02?responseContentDisposition=attachment%3B%20filename%3DViT-B-32.pdparams', 'ViT-B-32.pdparams'] 124 | } 125 | 126 | 127 | def load_model(model_name, pretrained=False): 128 | model_fn, url, file_name = model_dict[model_name] 129 | model, transforms = model_fn() 130 | 131 | if pretrained: 132 | model_path = os.path.join('pretrained_models', file_name) 133 | if not os.path.isfile(model_path): 134 | if not os.path.exists('pretrained_models'): 135 | os.mkdir('pretrained_models') 136 | wget.download(url, out=model_path) 137 | params = paddle.load(model_path) 138 | model.set_dict(params) 139 | 140 | model.eval() 141 | return model, transforms 142 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AgentMaker/Paddle-CLIP/45b10f1170ac648609980cf982319b33605ce552/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn as nn 3 | from paddle.nn.initializer import Assign, Normal, Constant 4 | 5 | 6 | class Identity(nn.Layer): 7 | def __init__(self): 8 | super(Identity, self).__init__() 9 | 10 | def forward(self, inputs): 11 | return inputs 12 | 13 | 14 | class QuickGELU(nn.Layer): 15 | def forward(self, x): 16 | return x * nn.functional.sigmoid(1.702 * x) 17 | 18 | 19 | class MultiHeadAttention(nn.MultiHeadAttention): 20 | def __init__(self, 21 | embed_dim, 22 | num_heads, 23 | output_dim=None): 24 | super(MultiHeadAttention, self).__init__(embed_dim, num_heads) 25 | self.out_proj = nn.Linear(embed_dim, output_dim or embed_dim) 26 | 27 | 28 | class Bottleneck(nn.Layer): 29 | expansion = 4 30 | 31 | def __init__(self, inplanes, planes, stride=1): 32 | super().__init__() 33 | self.conv1 = nn.Conv2D(inplanes, planes, 1, bias_attr=False) 34 | self.bn1 = nn.BatchNorm2D(planes) 35 | 36 | self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) 37 | self.bn2 = nn.BatchNorm2D(planes) 38 | 39 | self.avgpool = nn.AvgPool2D(stride) if stride > 1 else Identity() 40 | 41 | self.conv3 = nn.Conv2D( 42 | planes, planes * self.expansion, 1, bias_attr=False) 43 | self.bn3 = nn.BatchNorm2D(planes * self.expansion) 44 | 45 | self.relu = nn.ReLU() 46 | self.downsample = None 47 | self.stride = stride 48 | 49 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 50 | self.downsample = nn.Sequential( 51 | ("-1", nn.AvgPool2D(stride)), 52 | ("0", nn.Conv2D(inplanes, planes * 53 | self.expansion, 1, stride=1, bias_attr=False)), 54 | ("1", nn.BatchNorm2D(planes * self.expansion)) 55 | ) 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.relu(self.bn1(self.conv1(x))) 61 | out = self.relu(self.bn2(self.conv2(out))) 62 | out = self.avgpool(out) 63 | out = self.bn3(self.conv3(out)) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | return out 71 | 72 | 73 | class AttentionPool2D(nn.Layer): 74 | def __init__(self, spacial_dim, embed_dim, num_heads, output_dim=None): 75 | super().__init__() 76 | positional_embedding = self.create_parameter( 77 | shape=(spacial_dim ** 2 + 1, embed_dim), 78 | default_initializer=Assign( 79 | paddle.randn((spacial_dim ** 2 + 1, embed_dim)) / 80 | embed_dim ** 0.5 81 | ) 82 | ) 83 | self.add_parameter("positional_embedding", positional_embedding) 84 | 85 | self.attn = MultiHeadAttention(embed_dim, num_heads, output_dim) 86 | 87 | def forward(self, x): 88 | x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * 89 | x.shape[3])).transpose((2, 0, 1)) 90 | x = paddle.concat([x.mean(axis=0, keepdim=True), x], axis=0) 91 | x = x + self.positional_embedding.unsqueeze(1) 92 | x = x.transpose((1, 0, 2)) 93 | x = self.attn(query=x, key=x, value=x) 94 | x = x.transpose((1, 0, 2)) 95 | return x[0] 96 | 97 | 98 | class ModifiedResNet(nn.Layer): 99 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 100 | super().__init__() 101 | self.output_dim = output_dim 102 | self.input_resolution = input_resolution 103 | 104 | self.conv1 = nn.Conv2D(3, width // 2, kernel_size=3, 105 | stride=2, padding=1, bias_attr=False) 106 | self.bn1 = nn.BatchNorm2D(width // 2) 107 | 108 | self.conv2 = nn.Conv2D(width // 2, width // 2, 109 | kernel_size=3, padding=1, bias_attr=False) 110 | self.bn2 = nn.BatchNorm2D(width // 2) 111 | 112 | self.conv3 = nn.Conv2D( 113 | width // 2, width, kernel_size=3, padding=1, bias_attr=False) 114 | self.bn3 = nn.BatchNorm2D(width) 115 | 116 | self.avgpool = nn.AvgPool2D(2) 117 | self.relu = nn.ReLU() 118 | 119 | # residual layers 120 | self._inplanes = width 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 127 | self.attnpool = AttentionPool2D( 128 | input_resolution // 32, embed_dim, heads, output_dim) 129 | 130 | def _make_layer(self, planes, blocks, stride=1): 131 | layers = [Bottleneck(self._inplanes, planes, stride)] 132 | 133 | self._inplanes = planes * Bottleneck.expansion 134 | for _ in range(1, blocks): 135 | layers.append(Bottleneck(self._inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def stem(self, x): 140 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 141 | x = self.relu(bn(conv(x))) 142 | 143 | x = self.avgpool(x) 144 | return x 145 | 146 | def forward(self, x): 147 | x = self.stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | return x 154 | 155 | 156 | class ResidualAttentionBlock(nn.Layer): 157 | def __init__(self, d_model, n_head, attn_mask=None): 158 | super().__init__() 159 | self.attn = MultiHeadAttention(d_model, n_head) 160 | self.ln_1 = nn.LayerNorm(d_model) 161 | self.mlp = nn.Sequential( 162 | ("c_fc", nn.Linear(d_model, d_model * 4)), 163 | ("gelu", QuickGELU()), 164 | ("c_proj", nn.Linear(d_model * 4, d_model)) 165 | ) 166 | self.ln_2 = nn.LayerNorm(d_model) 167 | self.attn_mask = attn_mask 168 | 169 | def attention(self, x): 170 | self.attn_mask = self.attn_mask if self.attn_mask is not None else None 171 | return self.attn(x, x, x, attn_mask=self.attn_mask) 172 | 173 | def forward(self, x): 174 | x = x + self.attention(self.ln_1(x)) 175 | x = x + self.mlp(self.ln_2(x)) 176 | return x 177 | 178 | 179 | class Transformer(nn.Layer): 180 | def __init__(self, width, layers, heads, attn_mask=None): 181 | super().__init__() 182 | self.width = width 183 | self.layers = layers 184 | self.resblocks = nn.Sequential( 185 | *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 186 | 187 | def forward(self, x): 188 | return self.resblocks(x) 189 | 190 | 191 | class VisualTransformer(nn.Layer): 192 | def __init__(self, input_resolution, patch_size, width, layers, heads, output_dim): 193 | super().__init__() 194 | self.input_resolution = input_resolution 195 | self.output_dim = output_dim 196 | self.conv1 = nn.Conv2D(in_channels=3, out_channels=width, 197 | kernel_size=patch_size, stride=patch_size, bias_attr=False) 198 | 199 | scale = width ** -0.5 200 | 201 | class_embedding = self.create_parameter( 202 | shape=(width,), 203 | default_initializer=Assign( 204 | scale * paddle.randn((width,)) 205 | ) 206 | ) 207 | self.add_parameter("class_embedding", class_embedding) 208 | 209 | positional_embedding = self.create_parameter( 210 | shape=(width,), 211 | default_initializer=Assign( 212 | scale * 213 | paddle.randn( 214 | ((input_resolution // patch_size) ** 2 + 1, width)) 215 | ) 216 | ) 217 | self.add_parameter("positional_embedding", positional_embedding) 218 | 219 | self.ln_pre = nn.LayerNorm(width) 220 | 221 | self.transformer = Transformer(width, layers, heads) 222 | 223 | self.ln_post = nn.LayerNorm(width) 224 | 225 | proj = self.create_parameter( 226 | shape=(width,), 227 | default_initializer=Assign( 228 | scale * paddle.randn(((width, output_dim))) 229 | ) 230 | ) 231 | self.add_parameter("proj", proj) 232 | 233 | def forward(self, x): 234 | x = self.conv1(x) 235 | x = x.reshape((x.shape[0], x.shape[1], -1)) 236 | x = x.transpose((0, 2, 1)) 237 | zeros = paddle.zeros((x.shape[0], 1, x.shape[-1]), dtype='float32') 238 | x = paddle.concat([self.class_embedding + zeros, x], axis=1) 239 | x = x + self.positional_embedding 240 | x = self.ln_pre(x) 241 | x = self.transformer(x) 242 | x = self.ln_post(x[:, 0, :]) 243 | 244 | if self.proj is not None: 245 | x = x @ self.proj 246 | 247 | return x 248 | 249 | 250 | class CLIP(nn.Layer): 251 | def __init__(self, 252 | embed_dim, 253 | # vision 254 | image_resolution, 255 | vision_layers, 256 | vision_width, 257 | vision_patch_size, 258 | # text 259 | context_length, 260 | vocab_size, 261 | transformer_width, 262 | transformer_heads, 263 | transformer_layers 264 | ): 265 | super().__init__() 266 | self.context_length = context_length 267 | self.embed_dim = embed_dim 268 | 269 | if isinstance(vision_layers, (tuple, list)): 270 | vision_heads = vision_width * 32 // 64 271 | self.visual = ModifiedResNet( 272 | layers=vision_layers, 273 | output_dim=embed_dim, 274 | heads=vision_heads, 275 | input_resolution=image_resolution, 276 | width=vision_width 277 | ) 278 | else: 279 | vision_heads = vision_width // 64 280 | self.visual = VisualTransformer( 281 | input_resolution=image_resolution, 282 | patch_size=vision_patch_size, 283 | width=vision_width, 284 | layers=vision_layers, 285 | heads=vision_heads, 286 | output_dim=embed_dim 287 | ) 288 | 289 | self.transformer = Transformer( 290 | width=transformer_width, 291 | layers=transformer_layers, 292 | heads=transformer_heads, 293 | attn_mask=self.build_attention_mask() 294 | ) 295 | 296 | self.vocab_size = vocab_size 297 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 298 | 299 | positional_embedding = self.create_parameter( 300 | shape=(self.context_length, transformer_width), 301 | default_initializer=Assign( 302 | paddle.empty((self.context_length, transformer_width)) 303 | ) 304 | ) 305 | self.add_parameter("positional_embedding", positional_embedding) 306 | 307 | self.ln_final = nn.LayerNorm(transformer_width) 308 | 309 | text_projection = self.create_parameter( 310 | shape=(transformer_width, embed_dim), 311 | default_initializer=Assign( 312 | paddle.empty((transformer_width, embed_dim)) 313 | ) 314 | ) 315 | self.add_parameter("text_projection", text_projection) 316 | 317 | logit_scale = self.create_parameter( 318 | shape=(1,), 319 | default_initializer=Assign(paddle.ones([1])) 320 | ) 321 | self.add_parameter("logit_scale", logit_scale) 322 | 323 | self.initialize_parameters() 324 | 325 | def initialize_parameters(self): 326 | Normal(std=0.02)(self.token_embedding.weight) 327 | Normal(std=0.01)(self.positional_embedding) 328 | 329 | if isinstance(self.visual, ModifiedResNet): 330 | if self.visual.attnpool is not None: 331 | std = self.embed_dim ** -0.5 332 | normal_ = Normal(std=std) 333 | normal_(self.visual.attnpool.attn.q_proj.weight) 334 | normal_(self.visual.attnpool.attn.k_proj.weight) 335 | normal_(self.visual.attnpool.attn.v_proj.weight) 336 | normal_(self.visual.attnpool.attn.out_proj.weight) 337 | 338 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 339 | for name, param in resnet_block.named_parameters(): 340 | if name.endswith("bn3.weight"): 341 | Constant(value=0.0)(param) 342 | 343 | proj_std = (self.transformer.width ** -0.5) * \ 344 | ((2 * self.transformer.layers) ** -0.5) 345 | attn_std = self.transformer.width ** -0.5 346 | fc_std = (2 * self.transformer.width) ** -0.5 347 | 348 | for resblock in self.transformer.resblocks: 349 | normal_ = Normal(std=attn_std) 350 | normal_(resblock.attn.q_proj.weight) 351 | normal_(resblock.attn.k_proj.weight) 352 | normal_(resblock.attn.v_proj.weight) 353 | Normal(std=proj_std)(resblock.attn.out_proj.weight) 354 | Normal(std=fc_std)(resblock.mlp.c_fc.weight) 355 | Normal(std=proj_std)(resblock.mlp.c_proj.weight) 356 | 357 | if self.text_projection is not None: 358 | Normal(std=self.transformer.width ** -0.5)(self.text_projection) 359 | 360 | def build_attention_mask(self): 361 | mask = paddle.full( 362 | (self.context_length, self.context_length), float("-inf") 363 | ) 364 | mask = paddle.triu(mask, diagonal=1) 365 | return mask 366 | 367 | def encode_image(self, image): 368 | return self.visual(image) 369 | 370 | def encode_text(self, text): 371 | x = self.token_embedding(text) 372 | x = x + self.positional_embedding 373 | x = self.transformer(x) 374 | x = self.ln_final(x) 375 | 376 | select = [] 377 | index = zip( 378 | paddle.arange(x.shape[0]).numpy(), 379 | text.argmax(axis=-1).numpy() 380 | ) 381 | for i, j in index: 382 | select.append(x[int(i), int(j)]) 383 | 384 | x = paddle.stack(select) @ self.text_projection 385 | 386 | return x 387 | 388 | def forward(self, image, text): 389 | image_features = self.encode_image(image) 390 | text_features = self.encode_text(text) 391 | 392 | # normalized features 393 | image_features = image_features / \ 394 | image_features.norm(axis=-1, keepdim=True) 395 | text_features = text_features / \ 396 | text_features.norm(axis=-1, keepdim=True) 397 | 398 | # cosine similarity as logits 399 | logit_scale = self.logit_scale.exp() 400 | logits_per_image = logit_scale * image_features @ text_features.t() 401 | logits_per_text = logit_scale * text_features @ image_features.t() 402 | 403 | # shape = [global_batch_size, global_batch_size] 404 | return logits_per_image, logits_per_text 405 | -------------------------------------------------------------------------------- /clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1)) + \ 27 | list(range(ord("¡"), ord("¬")+1)) + \ 28 | list(range(ord("®"), ord("ÿ")+1)) 29 | cs = bs[:] 30 | n = 0 31 | for b in range(2**8): 32 | if b not in bs: 33 | bs.append(b) 34 | cs.append(2**8+n) 35 | n += 1 36 | cs = [chr(n) for n in cs] 37 | return dict(zip(bs, cs)) 38 | 39 | 40 | def get_pairs(word): 41 | """Return set of symbol pairs in a word. 42 | Word is represented as tuple of symbols (symbols being variable-length strings). 43 | """ 44 | pairs = set() 45 | prev_char = word[0] 46 | for char in word[1:]: 47 | pairs.add((prev_char, char)) 48 | prev_char = char 49 | return pairs 50 | 51 | 52 | def basic_clean(text): 53 | text = ftfy.fix_text(text) 54 | text = html.unescape(html.unescape(text)) 55 | return text.strip() 56 | 57 | 58 | def whitespace_clean(text): 59 | text = re.sub(r'\s+', ' ', text) 60 | text = text.strip() 61 | return text 62 | 63 | 64 | class Tokenizer(object): 65 | def __init__(self, bpe_path: str = default_bpe()): 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 69 | merges = merges[1:49152-256-2+1] 70 | merges = [tuple(merge.split()) for merge in merges] 71 | vocab = list(bytes_to_unicode().values()) 72 | vocab = vocab + [v+'' for v in vocab] 73 | for merge in merges: 74 | vocab.append(''.join(merge)) 75 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 76 | self.encoder = dict(zip(vocab, range(len(vocab)))) 77 | self.decoder = {v: k for k, v in self.encoder.items()} 78 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 79 | self.cache = {'<|startoftext|>': '<|startoftext|>', 80 | '<|endoftext|>': '<|endoftext|>'} 81 | self.pat = re.compile( 82 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + (token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token+'' 92 | 93 | while True: 94 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get( 95 | pair, float('inf'))) 96 | if bigram not in self.bpe_ranks: 97 | break 98 | first, second = bigram 99 | new_word = [] 100 | i = 0 101 | while i < len(word): 102 | try: 103 | j = word.index(first, i) 104 | new_word.extend(word[i:j]) 105 | i = j 106 | except: 107 | new_word.extend(word[i:]) 108 | break 109 | 110 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 111 | new_word.append(first+second) 112 | i += 2 113 | else: 114 | new_word.append(word[i]) 115 | i += 1 116 | new_word = tuple(new_word) 117 | word = new_word 118 | if len(word) == 1: 119 | break 120 | else: 121 | pairs = get_pairs(word) 122 | word = ' '.join(word) 123 | self.cache[token] = word 124 | return word 125 | 126 | def encode(self, text): 127 | bpe_tokens = [] 128 | text = whitespace_clean(basic_clean(text)).lower() 129 | for token in re.findall(self.pat, text): 130 | token = ''.join(self.byte_encoder[b] 131 | for b in token.encode('utf-8')) 132 | bpe_tokens.extend(self.encoder[bpe_token] 133 | for bpe_token in self.bpe(token).split(' ')) 134 | return bpe_tokens 135 | 136 | def decode(self, tokens): 137 | text = ''.join([self.decoder[token] for token in tokens]) 138 | text = bytearray([self.byte_decoder[c] for c in text]).decode( 139 | 'utf-8', errors="replace").replace('', ' ') 140 | return text 141 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wget 2 | ftfy 3 | regex 4 | paddlepaddle>=2.0.1 5 | # paddlepaddle-gpu>=2.0.1 6 | 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | setup( 3 | name='paddleclip', 4 | version='1.0.0', 5 | author='jm12138', 6 | author_email='2286040843@qq.com', 7 | packages=['clip'], 8 | license='Apache-2.0 License', 9 | description='Paddle CLIP', 10 | install_requires=['wget', 'ftfy', 'regex'], 11 | package_data={'': ['bpe_simple_vocab_16e6.txt.gz']} 12 | ) 13 | --------------------------------------------------------------------------------