├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── asset ├── Fig_app.png ├── Fig_detection_results.png ├── Table_industrial.png ├── Table_medical.png ├── framework.png ├── img.png ├── img2.png └── img3.png ├── config.py ├── data_preprocess ├── br35h.py ├── brain_mri.py ├── btad.py ├── clinicdb.py ├── colondb.py ├── dagm-pre.py ├── dagm.py ├── dtd.py ├── endo.py ├── headct-pre.py ├── headct.py ├── isic.py ├── mpdd.py ├── mvtec.py ├── sdd-pre.py ├── sdd.py ├── tn3k.py └── visa.py ├── dataset ├── __init__.py ├── base_dataset.py ├── br35h.py ├── brain_mri.py ├── btad.py ├── clinicdb.py ├── colondb.py ├── dagm.py ├── dtd.py ├── headct.py ├── isic.py ├── mpdd.py ├── mvtec.py ├── sdd.py ├── tn3k.py └── visa.py ├── install.sh ├── loss.py ├── method ├── __init__.py ├── adaclip.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip_model.py ├── custom_clip.py ├── simple_tokenizer.py ├── tokenizer.py ├── trainer.py ├── transformer.py └── utils.py ├── model_configs ├── ViT-B-16.json ├── ViT-B-32.json ├── ViT-H-14.json ├── ViT-L-14-336.json ├── ViT-L-14.json ├── ViT-bigG-14.json └── ViT-g-14.json ├── requirements.txt ├── test.py ├── test.sh ├── test_single_image.sh ├── tools ├── __init__.py ├── csv_tools.py ├── logger.py ├── metrics.py ├── training_tools.py └── visualization.py ├── train.py └── train.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | *.png filter=lfs diff=lfs merge=lfs -text 2 | *.txt.gz filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /result/ 2 | /.idea/ 3 | /__pycache__/ 4 | /weights/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yunkang Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaCLIP (Detecting Anomalies for Novel Categories) 2 | [![HuggingFace Space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/Caoyunkang/AdaCLIP) 3 | 4 | > [**ECCV 24**] [**AdaCLIP: Adapting CLIP with Hybrid Learnable Prompts for Zero-Shot Anomaly Detection**](https://arxiv.org/abs/2407.15795). 5 | > 6 | > by [Yunkang Cao](https://caoyunkang.github.io/), [Jiangning Zhang](https://zhangzjn.github.io/), [Luca Frittoli](https://scholar.google.com/citations?user=cdML_XUAAAAJ), 7 | > [Yuqi Cheng](https://scholar.google.com/citations?user=02BC-WgAAAAJ&hl=en), [Weiming Shen](https://scholar.google.com/citations?user=FuSHsx4AAAAJ&hl=en), [Giacomo Boracchi](https://boracchi.faculty.polimi.it/) 8 | > 9 | 10 | ## Introduction 11 | Zero-shot anomaly detection (ZSAD) targets the identification of anomalies within images from arbitrary novel categories. 12 | This study introduces AdaCLIP for the ZSAD task, leveraging a pre-trained vision-language model (VLM), CLIP. 13 | AdaCLIP incorporates learnable prompts into CLIP and optimizes them through training on auxiliary annotated anomaly detection data. 14 | Two types of learnable prompts are proposed: \textit{static} and \textit{dynamic}. Static prompts are shared across all images, serving to preliminarily adapt CLIP for ZSAD. 15 | In contrast, dynamic prompts are generated for each test image, providing CLIP with dynamic adaptation capabilities. 16 | The combination of static and dynamic prompts is referred to as hybrid prompts, and yields enhanced ZSAD performance. 17 | Extensive experiments conducted across 14 real-world anomaly detection datasets from industrial and medical domains indicate that AdaCLIP outperforms other ZSAD methods and can generalize better to different categories and even domains. 18 | Finally, our analysis highlights the importance of diverse auxiliary data and optimized prompts for enhanced generalization capacity. 19 | 20 | ## Corrections 21 | - The description to the utilized training set in our paper is not accurate. By default, we utilize MVTec AD & ColonDB for training, 22 | and VisA & ClinicDB are utilized for evaluations on MVTec AD & ColonDB. 23 | 24 | ## Overview of AdaCLIP 25 | ![overview](asset/framework.png) 26 | 27 | ## 🛠️ Getting Started 28 | 29 | ### Installation 30 | To set up the AdaCLIP environment, follow one of the methods below: 31 | 32 | - Clone this repo: 33 | ```shell 34 | git clone https://github.com/caoyunkang/AdaCLIP.git && cd AdaCLIP 35 | ``` 36 | - You can use our provided installation script for an automated setup:: 37 | ```shell 38 | sh install.sh 39 | ``` 40 | - If you prefer to construct the experimental environment manually, follow these steps: 41 | ```shell 42 | conda create -n AdaCLIP python=3.9.5 -y 43 | conda activate AdaCLIP 44 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html 45 | pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4 46 | pip install gradio # Optional, for app 47 | ``` 48 | - Remember to update the dataset root in config.py according to your preference: 49 | ```python 50 | DATA_ROOT = '../datasets' # Original setting 51 | ``` 52 | 53 | ### Dataset Preparation 54 | Please download our processed visual anomaly detection datasets to your `DATA_ROOT` as needed. 55 | 56 | #### Industrial Visual Anomaly Detection Datasets 57 | Note: some links are still in processing... 58 | 59 | | Dataset | Google Drive | Baidu Drive | Task 60 | |------------|------------------|------------------| ------------------| 61 | | MVTec AD | [Google Drive](https://drive.google.com/file/d/12IukAqxOj497J4F0Mel-FvaONM030qwP/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1k36IMP4w32hY9BXOUM5ZmA?pwd=kxud) | Anomaly Detection & Localization | 62 | | VisA | [Google Drive](https://drive.google.com/file/d/1U0MZVro5yGgaHNQ8kWb3U1a0Qlz4HiHI/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/15CIsP-ulZ1AN0_3quA068w?pwd=lmgc) | Anomaly Detection & Localization | 63 | | MPDD | [Google Drive](https://drive.google.com/file/d/1cLkZs8pN8onQzfyNskeU_836JLjrtJz1/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/11T3mkloDCl7Hze5znkXOQA?pwd=4p7m) | Anomaly Detection & Localization | 64 | | BTAD | [Google Drive](https://drive.google.com/file/d/19Kd8jJLxZExwiTc9__6_r_jPqkmTXt4h/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1f4Tq-EXRz6iAswygH2WbFg?pwd=a60n) | Anomaly Detection & Localization | 65 | | KSDD | [Google Drive](https://drive.google.com/file/d/13UidsM1taqEAVV_JJTBiCV1D3KUBpmpj/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/12EaOdkSbdK85WX5ajrfjQw?pwd=6n3z) | Anomaly Detection & Localization | 66 | | DAGM | [Google Drive](https://drive.google.com/file/d/1f4sm8hpWQRzZMpvM-j7Q3xPG2vtdwvTy/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1JpDUJIksD99t003dNF1y9g?pwd=u3aq) | Anomaly Detection & Localization | 67 | | DTD-Synthetic | [Google Drive](https://drive.google.com/file/d/1em51XXz5_aBNRJlJxxv3-Ed1dO9H3QgS/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/16FlvIBWtjaDzWxlZfWjNeg?pwd=aq5c) | Anomaly Detection & Localization | 68 | 69 | 70 | 71 | 72 | #### Medical Visual Anomaly Detection Datasets 73 | | Dataset | Google Drive | Baidu Drive | Task 74 | |------------|------------------|------------------| ------------------| 75 | | HeadCT | [Google Drive](https://drive.google.com/file/d/1ore0yCV31oLwwC--YUuTQfij-f2V32O2/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/16PfXWJlh6Y9vkecY9IownA?pwd=svsl) | Anomaly Detection | 76 | | BrainMRI | [Google Drive](https://drive.google.com/file/d/1JLYyzcPG3ULY2J_aw1SY9esNujYm9GKd/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1UgGlTR-ABWAEiVUX-QSPhA?pwd=vh9e) | Anomaly Detection | 77 | | Br35H | [Google Drive](https://drive.google.com/file/d/1qaZ6VJDRk3Ix3oVp3NpFyTsqXLJ_JjQy/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1yCS6t3ht6qwJgM06YsU3mg?pwd=ps1e) | Anomaly Detection | 78 | | ISIC | [Google Drive](https://drive.google.com/file/d/1atZwmnFsz7mCsHWBZ8pkL_-Eul9bKFEx/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1Mf0w8RFY9ECZBEoNTyV3ZA?pwd=p954) | Anomaly Localization | 79 | | ColonDB | [Google Drive](https://drive.google.com/file/d/1tjZ0o5dgzka3wf_p4ErSRJ9fcC-RJK8R/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1nJ4L65vfNFGpkK_OJjLoVg?pwd=v8q7) | Anomaly Localization | 80 | | ClinicDB | [Google Drive](https://drive.google.com/file/d/1ciqZwMs1smSGDlwQ6tsr6YzylrqQBn9n/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1TPysfqhA_sXRPLGNwWBX6Q?pwd=3da6) | Anomaly Localization | 81 | | TN3K | [Google Drive](https://drive.google.com/file/d/1LuKEMhrUGwFBlGCaej46WoooH89V3O8_/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1i5jMofCcRFcUdteq8VMEOQ?pwd=aoez) | Anomaly Localization | 82 | 83 | #### Custom Datasets 84 | To use your custom dataset, follow these steps: 85 | 86 | 1. Refer to the instructions in `./data_preprocess` to generate the JSON file for your dataset. 87 | 2. Use `./dataset/base_dataset.py` to construct your own dataset. 88 | 89 | 90 | ### Weight Preparation 91 | 92 | We offer various pre-trained weights on different auxiliary datasets. 93 | Please download the pre-trained weights in `./weights`. 94 | 95 | | Pre-trained Datasets | Google Drive | Baidu Drive 96 | |------------|------------------|------------------| 97 | | MVTec AD & ClinicDB | [Google Drive](https://drive.google.com/file/d/1xVXANHGuJBRx59rqPRir7iqbkYzq45W0/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1K9JhNAmmDt4n5Sqlq4-5hQ?pwd=fks1) | 98 | | VisA & ColonDB | [Google Drive](https://drive.google.com/file/d/1QGmPB0ByPZQ7FucvGODMSz7r5Ke5wx9W/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1GmRCylpboPseT9lguCO9nw?pwd=fvvf) | 99 | | All Datasets Mentioned Above | [Google Drive](https://drive.google.com/file/d/1Cgkfx3GAaSYnXPLolx-P7pFqYV0IVzZF/view?usp=drive_link) | [Baidu Drive](https://pan.baidu.com/s/1J4aFAOhUbeYOBfZFbkOixA?pwd=0ts3) | 100 | 101 | 102 | ### Train 103 | 104 | By default, we use MVTec AD & Colondb for training and VisA for validation: 105 | ```shell 106 | CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data mvtec colondb --testing_data visa 107 | ``` 108 | 109 | 110 | Alternatively, for evaluation on MVTec AD & Colondb, we use VisA & ClinicDB for training and MVTec AD for validation. 111 | ```shell 112 | CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec 113 | ``` 114 | Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable. 115 | It is recommended to run the training process multiple times and choose the best model based on performance 116 | on the validation set as the final model. 117 | 118 | 119 | To construct a robust ZSAD model for demonstration, we also train our AdaCLIP on all AD datasets mentioned above: 120 | ```shell 121 | CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True \ 122 | --training_data \ 123 | br35h brain_mri btad clinicdb colondb \ 124 | dagm dtd headct isic mpdd mvtec sdd tn3k visa \ 125 | --testing_data mvtec 126 | ``` 127 | 128 | ### Test 129 | 130 | Manually select the best models from the validation set and place them in the `weights/` directory. Then, run the following testing script: 131 | ```shell 132 | sh test.sh 133 | ``` 134 | 135 | If you want to test on a single image, you can refer to `test_single_image.sh`: 136 | ```shell 137 | CUDA_VISIBLE_DEVICES=0 python test.py --testing_model image --ckt_path weights/pretrained_all.pth --save_fig True \ 138 | --image_path asset/img.png --class_name candle --save_name test.png 139 | ``` 140 | 141 | ## Main Results 142 | 143 | Due to differences in versions utilized, the reported performance may vary slightly compared to the detection performance 144 | with the provided pre-trained weights. Some categories may show higher performance while others may show lower. 145 | 146 | ![Table_industrial](./asset/Table_industrial.png) 147 | ![Table_medical](./asset/Table_medical.png) 148 | ![Fig_detection_results](./asset/Fig_detection_results.png) 149 | 150 | ### :page_facing_up: Demo App 151 | 152 | To run the demo application, use the following command: 153 | 154 | ```bash 155 | python app.py 156 | ``` 157 | 158 | Or visit our [Online Demo](https://huggingface.co/spaces/Caoyunkang/AdaCLIP) for a quick start. The three pre-trained weights mentioned are available there. Feel free to test them with your own data! 159 | 160 | Please note that we currently do not have a GPU environment for our Hugging Face Space, so inference for a single image may take approximately 50 seconds. 161 | 162 | ![Demo](./asset/Fig_app.png) 163 | 164 | ## 💘 Acknowledgements 165 | Our work is largely inspired by the following projects. Thanks for their admiring contribution. 166 | 167 | - [VAND-APRIL-GAN](https://github.com/ByChelsea/VAND-APRIL-GAN) 168 | - [AnomalyCLIP](https://github.com/zqhang/AnomalyCLIP) 169 | - [SAA](https://github.com/caoyunkang/Segment-Any-Anomaly) 170 | 171 | 172 | ## Stargazers over time 173 | [![Stargazers over time](https://starchart.cc/caoyunkang/AdaCLIP.svg?variant=adaptive)](https://starchart.cc/caoyunkang/AdaCLIP) 174 | 175 | 176 | ## Citation 177 | 178 | If you find this project helpful for your research, please consider citing the following BibTeX entry. 179 | 180 | ```BibTex 181 | 182 | @inproceedings{AdaCLIP, 183 | title={AdaCLIP: Adapting CLIP with Hybrid Learnable Prompts for Zero-Shot Anomaly Detection}, 184 | author={Cao, Yunkang and Zhang, Jiangning and Frittoli, Luca and Cheng, Yuqi and Shen, Weiming and Boracchi, Giacomo}, 185 | booktitle={European Conference on Computer Vision}, 186 | year={2024} 187 | } 188 | 189 | ``` 190 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from PIL import Image, ImageDraw, ImageFont 3 | import warnings 4 | import os 5 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 6 | import json 7 | import os 8 | import torch 9 | from scipy.ndimage import gaussian_filter 10 | import cv2 11 | from method import AdaCLIP_Trainer 12 | import numpy as np 13 | 14 | ############ Init Model 15 | ckt_path1 = 'weights/pretrained_mvtec_colondb.pth' 16 | ckt_path2 = "weights/pretrained_visa_clinicdb.pth" 17 | ckt_path3 = 'weights/pretrained_all.pth' 18 | 19 | # Configurations 20 | image_size = 518 21 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 22 | # device = 'cpu' 23 | model = "ViT-L-14-336" 24 | prompting_depth = 4 25 | prompting_length = 5 26 | prompting_type = 'SD' 27 | prompting_branch = 'VL' 28 | use_hsf = True 29 | k_clusters = 20 30 | 31 | config_path = os.path.join('./model_configs', f'{model}.json') 32 | 33 | # Prepare model 34 | with open(config_path, 'r') as f: 35 | model_configs = json.load(f) 36 | 37 | # Set up the feature hierarchy 38 | n_layers = model_configs['vision_cfg']['layers'] 39 | substage = n_layers // 4 40 | features_list = [substage, substage * 2, substage * 3, substage * 4] 41 | 42 | model = AdaCLIP_Trainer( 43 | backbone=model, 44 | feat_list=features_list, 45 | input_dim=model_configs['vision_cfg']['width'], 46 | output_dim=model_configs['embed_dim'], 47 | learning_rate=0., 48 | device=device, 49 | image_size=image_size, 50 | prompting_depth=prompting_depth, 51 | prompting_length=prompting_length, 52 | prompting_branch=prompting_branch, 53 | prompting_type=prompting_type, 54 | use_hsf=use_hsf, 55 | k_clusters=k_clusters 56 | ).to(device) 57 | 58 | 59 | def process_image(image, text, options): 60 | # Load the model based on selected options 61 | if 'MVTec AD+Colondb' in options: 62 | model.load(ckt_path1) 63 | elif 'VisA+Clinicdb' in options: 64 | model.load(ckt_path2) 65 | elif 'All' in options: 66 | model.load(ckt_path3) 67 | else: 68 | # Default to 'All' if no valid option is provided 69 | model.load(ckt_path3) 70 | print('Invalid option. Defaulting to All.') 71 | 72 | # Ensure image is in RGB mode 73 | image = image.convert('RGB') 74 | 75 | # Convert PIL image to NumPy array 76 | np_image = np.array(image) 77 | 78 | # Convert RGB to BGR for OpenCV 79 | np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR) 80 | np_image = cv2.resize(np_image, (image_size, image_size)) 81 | # Preprocess the image and run the model 82 | img_input = model.preprocess(image).unsqueeze(0) 83 | img_input = img_input.to(model.device) 84 | 85 | with torch.no_grad(): 86 | anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True) 87 | 88 | # Process anomaly map 89 | anomaly_map = anomaly_map[0, :, :].cpu().numpy() 90 | anomaly_score = anomaly_score[0].cpu().numpy() 91 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 92 | anomaly_map = (anomaly_map * 255).astype(np.uint8) 93 | 94 | # Apply color map and blend with original image 95 | heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) 96 | vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0) 97 | 98 | # Convert OpenCV image back to PIL image for Gradio 99 | vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB)) 100 | 101 | return vis_map_pil, f'{anomaly_score:.3f}' 102 | 103 | # Define examples 104 | examples = [ 105 | ["asset/img.png", "candle", "MVTec AD+Colondb"], 106 | ["asset/img2.png", "bottle", "VisA+Clinicdb"], 107 | ["asset/img3.png", "button", "All"], 108 | ] 109 | 110 | # Gradio interface layout 111 | demo = gr.Interface( 112 | fn=process_image, 113 | inputs=[ 114 | gr.Image(type="pil", label="Upload Image"), 115 | gr.Textbox(label="Class Name"), 116 | gr.Radio(["MVTec AD+Colondb", 117 | "VisA+Clinicdb", 118 | "All"], 119 | label="Pre-trained Datasets") 120 | ], 121 | outputs=[ 122 | gr.Image(type="pil", label="Output Image"), 123 | gr.Textbox(label="Anomaly Score"), 124 | ], 125 | examples=examples, 126 | title="AdaCLIP -- Zero-shot Anomaly Detection", 127 | description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection" 128 | ) 129 | 130 | # Launch the demo 131 | demo.launch() 132 | # demo.launch(server_name="0.0.0.0", server_port=10002) 133 | 134 | -------------------------------------------------------------------------------- /asset/Fig_app.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f71ab8be0e45353c1660526ff450754e82ddf4a2b7f18bb5a33ac3b704b0d76b 3 | size 268551 4 | -------------------------------------------------------------------------------- /asset/Fig_detection_results.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c00bd303a99d981d964b12e981bd1f2954d469766839523e76f7d7162fbb24cb 3 | size 363123 4 | -------------------------------------------------------------------------------- /asset/Table_industrial.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5fa4d9ab1ff1b3ca90b45f4b92ee7b12a89e5327cb22621d4081fb5f160d3d68 3 | size 401841 4 | -------------------------------------------------------------------------------- /asset/Table_medical.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d2424190619dbbd134b943ef9e38a6523635ab0d279f2445da6bdd266d3dafac 3 | size 291004 4 | -------------------------------------------------------------------------------- /asset/framework.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3804c7f5ae141257dbe5dd43cb20f4216a1061051fd8754d6f0c730dd085ad7d 3 | size 439936 4 | -------------------------------------------------------------------------------- /asset/img.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3eaff97d07132f9b06998737b976d4a0e0a3a2168b40aee43aad6e62d040f87e 3 | size 1421232 4 | -------------------------------------------------------------------------------- /asset/img2.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a3918b94553a8922b3c16d064ef73e9062710b35639a949c56d926037e4c0d0a 3 | size 547657 4 | -------------------------------------------------------------------------------- /asset/img3.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9394757293585aa9de542f3e70025788e5a3e1ad5a1277a8648f8050f8d7e868 3 | size 624200 4 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | DATA_ROOT = '../datasets' -------------------------------------------------------------------------------- /data_preprocess/br35h.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection') 7 | class Br35hSolver(object): 8 | CLSNAMES = [ 9 | 'br35h', 10 | ] 11 | 12 | def __init__(self, root=Br35h_ROOT, train_ratio=0.5): 13 | self.root = root 14 | self.meta_path = f'{root}/meta.json' 15 | self.train_ratio = train_ratio 16 | 17 | def run(self): 18 | self.generate_meta_info() 19 | 20 | def generate_meta_info(self): 21 | info = dict(train={}, test={}) 22 | for cls_name in self.CLSNAMES: 23 | cls_dir = f'{self.root}/{cls_name}' 24 | for phase in ['train', 'test']: 25 | cls_info = [] 26 | species = os.listdir(f'{cls_dir}/{phase}') 27 | for specie in species: 28 | is_abnormal = True if specie not in ['good'] else False 29 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 30 | img_names.sort() 31 | 32 | for idx, img_name in enumerate(img_names): 33 | info_img = dict( 34 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 35 | mask_path=f'', 36 | cls_name=cls_name, 37 | specie_name=specie, 38 | anomaly=1 if is_abnormal else 0, 39 | ) 40 | cls_info.append(info_img) 41 | 42 | info[phase][cls_name] = cls_info 43 | 44 | with open(self.meta_path, 'w') as f: 45 | f.write(json.dumps(info, indent=4) + "\n") 46 | 47 | 48 | if __name__ == '__main__': 49 | runner = Br35hSolver(root=Br35h_ROOT) 50 | runner.run() 51 | -------------------------------------------------------------------------------- /data_preprocess/brain_mri.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | BrainMRI_ROOT = os.path.join(DATA_ROOT, 'BrainMRI') 7 | 8 | class BrainMRISolver(object): 9 | CLSNAMES = [ 10 | 'brain_mri', 11 | ] 12 | 13 | def __init__(self, root=BrainMRI_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | img_names.sort() 32 | 33 | for idx, img_name in enumerate(img_names): 34 | info_img = dict( 35 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 36 | mask_path=f'', 37 | cls_name=cls_name, 38 | specie_name=specie, 39 | anomaly=1 if is_abnormal else 0, 40 | ) 41 | cls_info.append(info_img) 42 | 43 | info[phase][cls_name] = cls_info 44 | 45 | with open(self.meta_path, 'w') as f: 46 | f.write(json.dumps(info, indent=4) + "\n") 47 | 48 | 49 | if __name__ == '__main__': 50 | runner = BrainMRISolver(root=BrainMRI_ROOT) 51 | runner.run() 52 | -------------------------------------------------------------------------------- /data_preprocess/btad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | BTAD_ROOT = os.path.join(DATA_ROOT, 'BTech_Dataset_transformed') 7 | 8 | class BTADSolver(object): 9 | CLSNAMES = [ 10 | '01', '02', '03', 11 | ] 12 | 13 | def __init__(self, root=BTAD_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['ok'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = BTADSolver(root=BTAD_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/clinicdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | ClinicDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ClinicDB') 7 | 8 | class ClinicDBSolver(object): 9 | CLSNAMES = [ 10 | 'ClinicDB', 11 | ] 12 | 13 | def __init__(self, root=ClinicDB_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = ClinicDBSolver(root=ClinicDB_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/colondb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | ColonDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ColonDB') 7 | 8 | class ColonDBSolver(object): 9 | CLSNAMES = [ 10 | 'ColonDB', 11 | ] 12 | 13 | def __init__(self, root=ColonDB_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = ColonDBSolver(root=ColonDB_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/dagm-pre.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import cv2 5 | import argparse 6 | from config import DATA_ROOT 7 | 8 | dataset_root = os.path.join(DATA_ROOT, 'DAGM2007') 9 | 10 | class_names = os.listdir(dataset_root) 11 | 12 | 13 | for class_name in class_names: 14 | states = os.listdir(os.path.join(dataset_root, class_name)) 15 | for state in states: 16 | images = list() 17 | mask = list() 18 | files = os.listdir(os.path.join(dataset_root, class_name,state)) 19 | for f in files: 20 | if 'PNG' in f[-3:]: 21 | images.append(f) 22 | files = os.listdir(os.path.join(dataset_root, class_name, state,'Label')) 23 | for f in files: 24 | if 'PNG' in f[-3:]: 25 | mask.append(f) 26 | normal_image_path_train = list() 27 | normal_image_path_test = list() 28 | normal_image_path = list() 29 | abnormal_image_path = list() 30 | abnormal_image_label = list() 31 | for f in images: 32 | id = f[-8:-4] 33 | flag = 0 34 | for y in mask: 35 | if id in y: 36 | abnormal_image_path.append(f) 37 | abnormal_image_label.append(y) 38 | flag = 1 39 | break 40 | if flag == 0: 41 | normal_image_path.append(f) 42 | 43 | if len(abnormal_image_path) != len(abnormal_image_label): 44 | raise ValueError 45 | length = len(abnormal_image_path) 46 | 47 | normal_image_path_test = normal_image_path[:length] 48 | normal_image_path_train = normal_image_path[length:] 49 | 50 | target_root = '../datasets/DAGM_anomaly_detection' 51 | 52 | train_root = os.path.join(target_root, class_name, 'train','good') 53 | if not os.path.exists(train_root): 54 | os.makedirs(train_root) 55 | for f in normal_image_path_train: 56 | image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) 57 | cv2.imwrite(os.path.join(train_root,f), image_data) 58 | 59 | test_root = os.path.join(target_root, class_name, 'test','good') 60 | if not os.path.exists(test_root): 61 | os.makedirs(test_root) 62 | for f in normal_image_path_test: 63 | image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) 64 | cv2.imwrite(os.path.join(test_root,f), image_data) 65 | 66 | test_root = os.path.join(target_root, class_name, 'test','defect') 67 | if not os.path.exists(test_root): 68 | os.makedirs(test_root) 69 | for f in abnormal_image_path: 70 | image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) 71 | cv2.imwrite(os.path.join(test_root,f), image_data) 72 | 73 | test_root = os.path.join(target_root, class_name, 'ground_truth','defect') 74 | if not os.path.exists(test_root): 75 | os.makedirs(test_root) 76 | for f in mask: 77 | image_data = cv2.imread(os.path.join(dataset_root, class_name, state,'Label',f)) 78 | cv2.imwrite(os.path.join(test_root,f), image_data) 79 | 80 | 81 | 82 | print("Done") -------------------------------------------------------------------------------- /data_preprocess/dagm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | DAGM_ROOT = os.path.join(DATA_ROOT, 'DAGM_anomaly_detection') 7 | 8 | class DAGMSolver(object): 9 | CLSNAMES = [ 10 | 'Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6','Class7','Class8','Class9','Class10', 11 | ] 12 | 13 | def __init__(self, root=DAGM_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = DAGMSolver(root=DAGM_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | DTD_ROOT = os.path.join(DATA_ROOT, 'DTD-Synthetic') 7 | 8 | class DTDSolver(object): 9 | CLSNAMES = [ 10 | 'Blotchy_099', 'Fibrous_183', 'Marbled_078', 'Matted_069', 'Mesh_114','Perforated_037','Stratified_154','Woven_001','Woven_068','Woven_104','Woven_125','Woven_127', 11 | ] 12 | 13 | def __init__(self, root=DTD_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = DTDSolver(root=DTD_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/endo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | ENDO_ROOT = os.path.join(DATA_ROOT, 'EndoTect') 7 | 8 | class ENDOSolver(object): 9 | CLSNAMES = [ 10 | 'endo', 11 | ] 12 | 13 | def __init__(self, root=ENDO_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = ENDOSolver(root=ENDO_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/headct-pre.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import shutil 5 | import argparse 6 | 7 | from config import DATA_ROOT 8 | 9 | dataset_root = os.path.join(DATA_ROOT, 'head_ct') 10 | 11 | label_file = os.path.join(dataset_root, 'labels.csv') 12 | 13 | data = np.loadtxt(label_file, dtype=int, delimiter=',', skiprows=1) 14 | 15 | fnames = data[:, 0] 16 | label = data[:, 1] 17 | 18 | normal_fnames = fnames[label==0] 19 | outlier_fnames = fnames[label==1] 20 | 21 | 22 | target_root = '../datasets/HeadCT_anomaly_detection/headct' 23 | train_root = os.path.join(target_root, 'train/good') 24 | if not os.path.exists(train_root): 25 | os.makedirs(train_root) 26 | 27 | test_normal_root = os.path.join(target_root, 'test/good') 28 | if not os.path.exists(test_normal_root): 29 | os.makedirs(test_normal_root) 30 | for f in normal_fnames: 31 | source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f)) 32 | shutil.copy(source, test_normal_root) 33 | 34 | test_outlier_root = os.path.join(target_root, 'test/defect') 35 | if not os.path.exists(test_outlier_root): 36 | os.makedirs(test_outlier_root) 37 | for f in outlier_fnames: 38 | source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f)) 39 | shutil.copy(source, test_outlier_root) 40 | 41 | print('Done') -------------------------------------------------------------------------------- /data_preprocess/headct.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | # from dataset import MPDD_ROOT 5 | # from dataset.mpdd import MPDD_ROOT 6 | 7 | 8 | HEADCT_ROOT = '../datasets/HeadCT_anomaly_detection' 9 | class HEADCTSolver(object): 10 | CLSNAMES = [ 11 | 'headct', 12 | ] 13 | 14 | def __init__(self, root=HEADCT_ROOT, train_ratio=0.5): 15 | self.root = root 16 | self.meta_path = f'{root}/meta.json' 17 | self.train_ratio = train_ratio 18 | 19 | def run(self): 20 | self.generate_meta_info() 21 | 22 | def generate_meta_info(self): 23 | info = dict(train={}, test={}) 24 | for cls_name in self.CLSNAMES: 25 | cls_dir = f'{self.root}/{cls_name}' 26 | for phase in ['train', 'test']: 27 | cls_info = [] 28 | species = os.listdir(f'{cls_dir}/{phase}') 29 | for specie in species: 30 | is_abnormal = True if specie not in ['good'] else False 31 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 32 | img_names.sort() 33 | 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = HEADCTSolver(root=HEADCT_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/isic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | ISIC_ROOT = os.path.join(DATA_ROOT, 'ISIC') 7 | 8 | class ISICSolver(object): 9 | CLSNAMES = [ 10 | 'isic', 11 | ] 12 | 13 | def __init__(self, root=ISIC_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = ISICSolver(root=ISIC_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/mpdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | MPDD_ROOT = os.path.join(DATA_ROOT, 'MPDD') 7 | 8 | class MPDDSolver(object): 9 | CLSNAMES = [ 10 | 'bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate','tubes', 11 | ] 12 | 13 | def __init__(self, root=MPDD_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = MPDDSolver(root=MPDD_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from dataset import MVTEC_ROOT 5 | 6 | class MVTecSolver(object): 7 | CLSNAMES = [ 8 | 'bottle', 'cable', 'capsule', 'carpet', 'grid', 9 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 10 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 11 | ] 12 | 13 | def __init__(self, root=MVTEC_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = MVTecSolver(root=MVTEC_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/sdd-pre.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import cv2 5 | import argparse 6 | 7 | from config import DATA_ROOT 8 | 9 | dataset_root = os.path.join(DATA_ROOT, 'KolektorSDD') 10 | 11 | dirs = os.listdir(dataset_root) 12 | normal_images = list() 13 | normal_labels = list() 14 | normal_fname = list() 15 | outlier_images = list() 16 | outlier_labels = list() 17 | outlier_fname = list() 18 | for d in dirs: 19 | files = os.listdir(os.path.join(dataset_root, d)) 20 | images = list() 21 | for f in files: 22 | if 'jpg' in f[-3:]: 23 | images.append(f) 24 | 25 | for image in images: 26 | split_images = list() 27 | split_labels = list() 28 | image_name = image.split('.')[0] 29 | image_data = cv2.imread(os.path.join(dataset_root, d, image)) 30 | label_data = cv2.imread(os.path.join(dataset_root, d, image_name + '_label.bmp')) 31 | if image_data.shape != label_data.shape: 32 | raise ValueError 33 | image_length = image_data.shape[0] 34 | split_images.append(image_data[:image_length // 3, :, :]) 35 | split_images.append(image_data[image_length // 3:image_length * 2 // 3, :, :]) 36 | split_images.append(image_data[image_length * 2 // 3:, :, :]) 37 | split_labels.append(label_data[:image_length // 3, :, :]) 38 | split_labels.append(label_data[image_length // 3:image_length * 2 // 3, :, :]) 39 | split_labels.append(label_data[image_length * 2 // 3:, :, :]) 40 | for i, (im, la) in enumerate(zip(split_images, split_labels)): 41 | if np.max(la) != 0: 42 | outlier_images.append(im) 43 | outlier_labels.append(la) 44 | outlier_fname.append(d + '_' + image_name + '_' + str(i)) 45 | else: 46 | normal_images.append(im) 47 | normal_labels.append(la) 48 | normal_fname.append(d + '_' + image_name + '_' + str(i)) 49 | 50 | normal_train, normal_test, normal_name_train, normal_name_test = train_test_split(normal_images, normal_fname, test_size=0.25, random_state=42) 51 | 52 | target_root = '../datasets/SDD_anomaly_detection/SDD' 53 | train_root = os.path.join(target_root, 'train/good') 54 | if not os.path.exists(train_root): 55 | os.makedirs(train_root) 56 | for image, name in zip(normal_train, normal_name_train): 57 | cv2.imwrite(os.path.join(train_root, name + '.png'), image) 58 | 59 | test_root = os.path.join(target_root, 'test/good') 60 | if not os.path.exists(test_root): 61 | os.makedirs(test_root) 62 | for image, name in zip(normal_test, normal_name_test): 63 | cv2.imwrite(os.path.join(test_root, name + '.png'), image) 64 | 65 | defect_root = os.path.join(target_root, 'test/defect') 66 | label_root = os.path.join(target_root, 'ground_truth/defect') 67 | if not os.path.exists(defect_root): 68 | os.makedirs(defect_root) 69 | if not os.path.exists(label_root): 70 | os.makedirs(label_root) 71 | for image, label, name in zip(outlier_images, outlier_labels, outlier_fname): 72 | cv2.imwrite(os.path.join(defect_root, name + '.png'), image) 73 | cv2.imwrite(os.path.join(label_root, name + '_mask.png'), label) 74 | 75 | print("Done") -------------------------------------------------------------------------------- /data_preprocess/sdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection') 7 | 8 | class SDDSolver(object): 9 | CLSNAMES = [ 10 | 'SDD', 11 | ] 12 | 13 | def __init__(self, root=SDD_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = SDDSolver(root=SDD_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/tn3k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from config import DATA_ROOT 5 | 6 | TN3K_ROOT = os.path.join(DATA_ROOT, 'TN3K') 7 | 8 | class TN3KSolver(object): 9 | CLSNAMES = [ 10 | 'tn3k', 11 | ] 12 | 13 | def __init__(self, root=TN3K_ROOT, train_ratio=0.5): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.train_ratio = train_ratio 17 | 18 | def run(self): 19 | self.generate_meta_info() 20 | 21 | def generate_meta_info(self): 22 | info = dict(train={}, test={}) 23 | for cls_name in self.CLSNAMES: 24 | cls_dir = f'{self.root}/{cls_name}' 25 | for phase in ['train', 'test']: 26 | cls_info = [] 27 | species = os.listdir(f'{cls_dir}/{phase}') 28 | for specie in species: 29 | is_abnormal = True if specie not in ['good'] else False 30 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 31 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 32 | img_names.sort() 33 | mask_names.sort() if mask_names is not None else None 34 | for idx, img_name in enumerate(img_names): 35 | info_img = dict( 36 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 37 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name=specie, 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | 44 | info[phase][cls_name] = cls_info 45 | 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = TN3KSolver(root=TN3K_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /data_preprocess/visa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import random 5 | from dataset import VISA_ROOT 6 | 7 | class VisASolver(object): 8 | CLSNAMES = [ 9 | 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 10 | 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', 11 | 'pcb4', 'pipe_fryum', 12 | ] 13 | 14 | def __init__(self, root=VISA_ROOT, train_ratio=0.5): 15 | self.root = root 16 | self.meta_path = f'{root}/meta.json' 17 | self.phases = ['train', 'test'] 18 | self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) 19 | self.train_ratio = train_ratio 20 | 21 | def run(self): 22 | self.generate_meta_info() 23 | 24 | def generate_meta_info(self): 25 | columns = self.csv_data.columns # [object, split, label, image, mask] 26 | info = {phase: {} for phase in self.phases} 27 | for cls_name in self.CLSNAMES: 28 | cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name] 29 | for phase in self.phases: 30 | cls_info = [] 31 | cls_data_phase = cls_data[cls_data[columns[1]] == phase] 32 | cls_data_phase.index = list(range(len(cls_data_phase))) 33 | for idx in range(cls_data_phase.shape[0]): 34 | data = cls_data_phase.loc[idx] 35 | is_abnormal = True if data[2] == 'anomaly' else False 36 | info_img = dict( 37 | img_path=data[3], 38 | mask_path=data[4] if is_abnormal else '', 39 | cls_name=cls_name, 40 | specie_name='', 41 | anomaly=1 if is_abnormal else 0, 42 | ) 43 | cls_info.append(info_img) 44 | info[phase][cls_name] = cls_info 45 | with open(self.meta_path, 'w') as f: 46 | f.write(json.dumps(info, indent=4) + "\n") 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = VisASolver(root=VISA_ROOT) 52 | runner.run() 53 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .mvtec import MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT 2 | from .visa import VISA_CLS_NAMES, VisaDataset, VISA_ROOT 3 | from .mpdd import MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT 4 | from .btad import BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT 5 | from .sdd import SDD_CLS_NAMES, SDDDataset, SDD_ROOT 6 | from .dagm import DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT 7 | from .dtd import DTD_CLS_NAMES,DTDDataset,DTD_ROOT 8 | from .isic import ISIC_CLS_NAMES,ISICDataset,ISIC_ROOT 9 | from .colondb import ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT 10 | from .clinicdb import ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT 11 | from .tn3k import TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT 12 | from .headct import HEADCT_CLS_NAMES,HEADCTDataset,HEADCT_ROOT 13 | from .brain_mri import BrainMRI_CLS_NAMES,BrainMRIDataset,BrainMRI_ROOT 14 | from .br35h import Br35h_CLS_NAMES,Br35hDataset,Br35h_ROOT 15 | from torch.utils.data import ConcatDataset 16 | 17 | dataset_dict = { 18 | 'br35h': (Br35h_CLS_NAMES, Br35hDataset, Br35h_ROOT), 19 | 'brain_mri': (BrainMRI_CLS_NAMES, BrainMRIDataset, BrainMRI_ROOT), 20 | 'btad': (BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT), 21 | 'clinicdb': (ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT), 22 | 'colondb': (ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT), 23 | 'dagm': (DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT), 24 | 'dtd': (DTD_CLS_NAMES, DTDDataset, DTD_ROOT), 25 | 'headct': (HEADCT_CLS_NAMES, HEADCTDataset, HEADCT_ROOT), 26 | 'isic': (ISIC_CLS_NAMES, ISICDataset, ISIC_ROOT), 27 | 'mpdd': (MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT), 28 | 'mvtec': (MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT), 29 | 'sdd': (SDD_CLS_NAMES, SDDDataset, SDD_ROOT), 30 | 'tn3k': (TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT), 31 | 'visa': (VISA_CLS_NAMES, VisaDataset, VISA_ROOT), 32 | } 33 | 34 | def get_data(dataset_type_list, transform, target_transform, training): 35 | if not isinstance(dataset_type_list, list): 36 | dataset_type_list = [dataset_type_list] 37 | 38 | dataset_cls_names_list = [] 39 | dataset_instance_list = [] 40 | dataset_root_list = [] 41 | for dataset_type in dataset_type_list: 42 | if dataset_dict.get(dataset_type, ''): 43 | dataset_cls_names, dataset_instance, dataset_root = dataset_dict[dataset_type] 44 | dataset_instance = dataset_instance( 45 | clsnames=dataset_cls_names, 46 | transform=transform, 47 | target_transform=target_transform, 48 | training=training 49 | ) 50 | 51 | dataset_cls_names_list.append(dataset_cls_names) 52 | dataset_instance_list.append(dataset_instance) 53 | dataset_root_list.append(dataset_root) 54 | 55 | else: 56 | print(f'Only support {list(dataset_dict.keys())}, but entered {dataset_type}...') 57 | raise NotImplementedError 58 | 59 | if len(dataset_type_list) > 1: 60 | dataset_instance = ConcatDataset(dataset_instance_list) 61 | dataset_cls_names = dataset_cls_names_list 62 | dataset_root = dataset_root_list 63 | else: 64 | dataset_instance = dataset_instance_list[0] 65 | dataset_cls_names = dataset_cls_names_list[0] 66 | dataset_root = dataset_root_list[0] 67 | 68 | return dataset_cls_names, dataset_instance, dataset_root -------------------------------------------------------------------------------- /dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for our zero-shot anomaly detection dataset 3 | """ 4 | import json 5 | import os 6 | import random 7 | import numpy as np 8 | import torch.utils.data as data 9 | from PIL import Image 10 | import cv2 11 | from config import DATA_ROOT 12 | 13 | 14 | class DataSolver: 15 | def __init__(self, root, clsnames): 16 | self.root = root 17 | self.clsnames = clsnames 18 | self.path = os.path.join(root, 'meta.json') 19 | 20 | def run(self): 21 | with open(self.path, 'r') as f: 22 | info = json.load(f) 23 | 24 | info_required = dict(train={}, test={}) 25 | for cls in self.clsnames: 26 | for k in info.keys(): 27 | info_required[k][cls] = info[k][cls] 28 | 29 | return info_required 30 | 31 | 32 | class BaseDataset(data.Dataset): 33 | def __init__(self, clsnames, transform, target_transform, root, aug_rate=0., training=True): 34 | self.root = root 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | self.aug_rate = aug_rate 38 | self.training = training 39 | self.data_all = [] 40 | self.cls_names = clsnames 41 | 42 | solver = DataSolver(root, clsnames) 43 | meta_info = solver.run() 44 | 45 | self.meta_info = meta_info['test'] # Only utilize the test dataset for both training and testing 46 | for cls_name in self.cls_names: 47 | self.data_all.extend(self.meta_info[cls_name]) 48 | 49 | self.length = len(self.data_all) 50 | 51 | def __len__(self): 52 | return self.length 53 | 54 | def combine_img(self, cls_name): 55 | """ 56 | From April-GAN: https://github.com/ByChelsea/VAND-APRIL-GAN 57 | Here we combine four images into a single image for data augmentation. 58 | """ 59 | img_info = random.sample(self.meta_info[cls_name], 4) 60 | 61 | img_ls = [] 62 | mask_ls = [] 63 | 64 | for data in img_info: 65 | img_path = os.path.join(self.root, data['img_path']) 66 | mask_path = os.path.join(self.root, data['mask_path']) 67 | 68 | img = Image.open(img_path).convert('RGB') 69 | img_ls.append(img) 70 | 71 | if not data['anomaly']: 72 | img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') 73 | else: 74 | img_mask = np.array(Image.open(mask_path).convert('L')) > 0 75 | img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') 76 | 77 | mask_ls.append(img_mask) 78 | 79 | # Image 80 | image_width, image_height = img_ls[0].size 81 | result_image = Image.new("RGB", (2 * image_width, 2 * image_height)) 82 | for i, img in enumerate(img_ls): 83 | row = i // 2 84 | col = i % 2 85 | x = col * image_width 86 | y = row * image_height 87 | result_image.paste(img, (x, y)) 88 | 89 | # Mask 90 | result_mask = Image.new("L", (2 * image_width, 2 * image_height)) 91 | for i, img in enumerate(mask_ls): 92 | row = i // 2 93 | col = i % 2 94 | x = col * image_width 95 | y = row * image_height 96 | result_mask.paste(img, (x, y)) 97 | 98 | return result_image, result_mask 99 | 100 | def __getitem__(self, index): 101 | data = self.data_all[index] 102 | img_path = os.path.join(self.root, data['img_path']) 103 | mask_path = os.path.join(self.root, data['mask_path']) 104 | cls_name = data['cls_name'] 105 | anomaly = data['anomaly'] 106 | random_number = random.random() 107 | 108 | if self.training and random_number < self.aug_rate: 109 | img, img_mask = self.combine_img(cls_name) 110 | else: 111 | if img_path.endswith('.tif'): 112 | img = cv2.imread(img_path) 113 | img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 114 | else: 115 | img = Image.open(img_path).convert('RGB') 116 | if anomaly == 0: 117 | img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') 118 | else: 119 | if data['mask_path']: 120 | img_mask = np.array(Image.open(mask_path).convert('L')) > 0 121 | img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') 122 | else: 123 | img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') 124 | # Transforms 125 | if self.transform is not None: 126 | img = self.transform(img) 127 | if self.target_transform is not None and img_mask is not None: 128 | img_mask = self.target_transform(img_mask) 129 | if img_mask is None: 130 | img_mask = [] 131 | 132 | return { 133 | 'img': img, 134 | 'img_mask': img_mask, 135 | 'cls_name': cls_name, 136 | 'anomaly': anomaly, 137 | 'img_path': img_path 138 | } 139 | -------------------------------------------------------------------------------- /dataset/br35h.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection''' 6 | 7 | Br35h_CLS_NAMES = [ 8 | 'br35h', 9 | ] 10 | Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection') 11 | 12 | class Br35hDataset(BaseDataset): 13 | def __init__(self, transform, target_transform, clsnames=Br35h_CLS_NAMES, aug_rate=0.0, root=Br35h_ROOT, training=True): 14 | super(Br35hDataset, self).__init__( 15 | clsnames=clsnames, transform=transform, target_transform=target_transform, 16 | root=root, aug_rate=aug_rate, training=training 17 | ) 18 | 19 | -------------------------------------------------------------------------------- /dataset/brain_mri.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection''' 6 | BrainMRI_CLS_NAMES = [ 7 | 'brain_mri', 8 | ] 9 | BrainMRI_ROOT = os.path.join(DATA_ROOT, 'BrainMRI') 10 | 11 | class BrainMRIDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=BrainMRI_CLS_NAMES, aug_rate=0.0, root=BrainMRI_ROOT, training=True): 13 | super(BrainMRIDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | -------------------------------------------------------------------------------- /dataset/btad.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://avires.dimi.uniud.it/papers/btad/btad.zip''' 6 | BTAD_CLS_NAMES = [ 7 | '01', '02', '03', 8 | ] 9 | BTAD_ROOT = os.path.join(DATA_ROOT, 'BTech_Dataset_transformed') 10 | 11 | class BTADDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=BTAD_CLS_NAMES, aug_rate=0.0, root=BTAD_ROOT, training=True): 13 | super(BTADDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | -------------------------------------------------------------------------------- /dataset/clinicdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://paperswithcode.com/dataset/cvc-clinicdb''' 6 | ClinicDB_CLS_NAMES = [ 7 | 'ClinicDB', 8 | ] 9 | ClinicDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ClinicDB') 10 | 11 | class ClinicDBDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=ClinicDB_CLS_NAMES, aug_rate=0.0, root=ClinicDB_ROOT, training=True): 13 | super(ClinicDBDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | -------------------------------------------------------------------------------- /dataset/colondb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: http://mv.cvc.uab.es/projects/colon-qa/cvccolondb''' 6 | ColonDB_CLS_NAMES = [ 7 | 'ColonDB', 8 | ] 9 | ColonDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ColonDB') 10 | 11 | class ColonDBDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=ColonDB_CLS_NAMES, aug_rate=0.0, root=ColonDB_ROOT, training=True): 13 | super(ColonDBDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | 18 | 19 | -------------------------------------------------------------------------------- /dataset/dagm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://hci.iwr.uni-heidelberg.de/content/weakly-supervised-learning-industrial-optical-inspection''' 6 | DAGM_CLS_NAMES = [ 7 | 'Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6','Class7','Class8','Class9','Class10', 8 | ] 9 | DAGM_ROOT = os.path.join(DATA_ROOT, 'DAGM_anomaly_detection') 10 | 11 | class DAGMDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=DAGM_CLS_NAMES, aug_rate=0.0, root=DAGM_ROOT, training=True): 13 | super(DAGMDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | -------------------------------------------------------------------------------- /dataset/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1''' 6 | DTD_CLS_NAMES = [ 7 | 'Blotchy_099', 'Fibrous_183', 'Marbled_078', 'Matted_069', 'Mesh_114','Perforated_037','Stratified_154','Woven_001','Woven_068','Woven_104','Woven_125','Woven_127', 8 | ] 9 | DTD_ROOT = os.path.join(DATA_ROOT, 'DTD-Synthetic') 10 | 11 | class DTDDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=DTD_CLS_NAMES, aug_rate=0.0, root=DTD_ROOT, training=True): 13 | super(DTDDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | -------------------------------------------------------------------------------- /dataset/headct.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://www.kaggle.com/datasets/felipekitamura/head-ct-hemorrhage''' 6 | HEADCT_CLS_NAMES = [ 7 | 'headct', 8 | ] 9 | HEADCT_ROOT = os.path.join(DATA_ROOT, 'HeadCT_anomaly_detection') 10 | 11 | class HEADCTDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=HEADCT_CLS_NAMES, aug_rate=0.0, root=HEADCT_ROOT, training=True): 13 | super(HEADCTDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | 18 | 19 | -------------------------------------------------------------------------------- /dataset/isic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://challenge.isic-archive.com/data/''' 6 | ISIC_CLS_NAMES = [ 7 | 'isic', 8 | ] 9 | ISIC_ROOT = os.path.join(DATA_ROOT, 'ISIC') 10 | 11 | class ISICDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=ISIC_CLS_NAMES, aug_rate=0.0, root=ISIC_ROOT, training=True): 13 | super(ISICDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | 18 | 19 | -------------------------------------------------------------------------------- /dataset/mpdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://github.com/stepanje/MPDD''' 6 | MPDD_CLS_NAMES = [ 7 | 'bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate','tubes', 8 | ] 9 | MPDD_ROOT = os.path.join(DATA_ROOT, 'MPDD') 10 | 11 | class MPDDDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=MPDD_CLS_NAMES, aug_rate=0.0, root=MPDD_ROOT, training=True): 13 | super(MPDDDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | 18 | -------------------------------------------------------------------------------- /dataset/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://paperswithcode.com/dataset/mvtecad''' 6 | 7 | MVTEC_CLS_NAMES = [ 8 | 'bottle', 'cable', 'capsule', 'carpet', 'grid', 9 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 10 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 11 | ] 12 | MVTEC_ROOT = os.path.join(DATA_ROOT, 'mvtec_anomaly_detection') 13 | 14 | class MVTecDataset(BaseDataset): 15 | def __init__(self, transform, target_transform, clsnames=MVTEC_CLS_NAMES, aug_rate=0.2, root=MVTEC_ROOT, training=True): 16 | super(MVTecDataset, self).__init__( 17 | clsnames=clsnames, transform=transform, target_transform=target_transform, 18 | root=root, aug_rate=aug_rate, training=training 19 | ) 20 | -------------------------------------------------------------------------------- /dataset/sdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://data.vicos.si/datasets/KSDD/KolektorSDD.zip''' 6 | SDD_CLS_NAMES = [ 7 | 'SDD', 8 | ] 9 | SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection') 10 | 11 | 12 | class SDDDataset(BaseDataset): 13 | def __init__(self, transform, target_transform, clsnames=SDD_CLS_NAMES, aug_rate=0.0, root=SDD_ROOT, training=True): 14 | super(SDDDataset, self).__init__( 15 | clsnames=clsnames, transform=transform, target_transform=target_transform, 16 | root=root, aug_rate=aug_rate, training=training 17 | ) 18 | 19 | -------------------------------------------------------------------------------- /dataset/tn3k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://ieeexplore.ieee.org/document/9434087/references#references''' 6 | TN3K_CLS_NAMES = [ 7 | 'tn3k', 8 | ] 9 | TN3K_ROOT = os.path.join(DATA_ROOT, 'TN3K') 10 | 11 | class TN3KDataset(BaseDataset): 12 | def __init__(self, transform, target_transform, clsnames=TN3K_CLS_NAMES, aug_rate=0.0, root=TN3K_ROOT, training=True): 13 | super(TN3KDataset, self).__init__( 14 | clsnames=clsnames, transform=transform, target_transform=target_transform, 15 | root=root, aug_rate=aug_rate, training=training 16 | ) 17 | 18 | 19 | -------------------------------------------------------------------------------- /dataset/visa.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from config import DATA_ROOT 4 | 5 | '''dataset source: https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar''' 6 | VISA_CLS_NAMES = [ 7 | 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 8 | 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', 9 | 'pcb4', 'pipe_fryum', 10 | ] 11 | 12 | VISA_ROOT = os.path.join(DATA_ROOT, 'VisA_20220922') 13 | 14 | class VisaDataset(BaseDataset): 15 | def __init__(self, transform, target_transform, clsnames=VISA_CLS_NAMES, aug_rate=0.0, root=VISA_ROOT, training=True): 16 | super(VisaDataset, self).__init__( 17 | clsnames=clsnames, transform=transform, target_transform=target_transform, 18 | root=root, aug_rate=aug_rate, training=training 19 | ) 20 | 21 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # add dependencies 2 | # python395_cuda113_pytorch1101 3 | # please change dataset root in ./config.py according to your specifications 4 | 5 | conda create -n AdaCLIP python=3.9.5 -y 6 | conda activate AdaCLIP 7 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html 8 | pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4 9 | pip install gradio 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from math import exp 6 | 7 | class FocalLoss(nn.Module): 8 | """ 9 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 10 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 11 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 12 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 13 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 14 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 15 | focus on hard misclassified example 16 | :param smooth: (float,double) smooth value when cross entropy 17 | :param balance_index: (int) balance class index, should be specific when alpha is float 18 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 19 | """ 20 | 21 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): 22 | super(FocalLoss, self).__init__() 23 | self.apply_nonlin = apply_nonlin 24 | self.alpha = alpha 25 | self.gamma = gamma 26 | self.balance_index = balance_index 27 | self.smooth = smooth 28 | self.size_average = size_average 29 | 30 | if self.smooth is not None: 31 | if self.smooth < 0 or self.smooth > 1.0: 32 | raise ValueError('smooth value should be in [0,1]') 33 | 34 | def forward(self, logit, target): 35 | if self.apply_nonlin is not None: 36 | logit = self.apply_nonlin(logit) 37 | num_class = logit.shape[1] 38 | 39 | if logit.dim() > 2: 40 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 41 | logit = logit.view(logit.size(0), logit.size(1), -1) 42 | logit = logit.permute(0, 2, 1).contiguous() 43 | logit = logit.view(-1, logit.size(-1)) 44 | target = torch.squeeze(target, 1) 45 | target = target.view(-1, 1) 46 | alpha = self.alpha 47 | 48 | if alpha is None: 49 | alpha = torch.ones(num_class, 1) 50 | elif isinstance(alpha, (list, np.ndarray)): 51 | assert len(alpha) == num_class 52 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 53 | alpha = alpha / alpha.sum() 54 | elif isinstance(alpha, float): 55 | alpha = torch.ones(num_class, 1) 56 | alpha = alpha * (1 - self.alpha) 57 | alpha[self.balance_index] = self.alpha 58 | 59 | else: 60 | raise TypeError('Not support alpha type') 61 | 62 | if alpha.device != logit.device: 63 | alpha = alpha.to(logit.device) 64 | 65 | idx = target.cpu().long() 66 | 67 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 68 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 69 | if one_hot_key.device != logit.device: 70 | one_hot_key = one_hot_key.to(logit.device) 71 | 72 | if self.smooth: 73 | one_hot_key = torch.clamp( 74 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) 75 | pt = (one_hot_key * logit).sum(1) + self.smooth 76 | logpt = pt.log() 77 | 78 | gamma = self.gamma 79 | 80 | alpha = alpha[idx] 81 | alpha = torch.squeeze(alpha) 82 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 83 | 84 | if self.size_average: 85 | loss = loss.mean() 86 | return loss 87 | 88 | 89 | class BinaryDiceLoss(nn.Module): 90 | def __init__(self): 91 | super(BinaryDiceLoss, self).__init__() 92 | 93 | def forward(self, input, targets): 94 | # 获取每个批次的大小 N 95 | N = targets.size()[0] 96 | # 平滑变量 97 | smooth = 1 98 | # 将宽高 reshape 到同一纬度 99 | input_flat = input.view(N, -1) 100 | targets_flat = targets.view(N, -1) 101 | 102 | # 计算交集 103 | intersection = input_flat * targets_flat 104 | N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth) 105 | # 计算一个批次中平均每张图的损失 106 | loss = 1 - N_dice_eff.sum() / N 107 | return loss 108 | 109 | 110 | 111 | 112 | class ConADLoss(nn.Module): 113 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 114 | It also supports the unsupervised contrastive loss in SimCLR""" 115 | def __init__(self, contrast_mode='all',random_anchors=10): 116 | super(ConADLoss, self).__init__() 117 | assert contrast_mode in ['all', 'mean', 'random'] 118 | self.contrast_mode = contrast_mode 119 | self.random_anchors = random_anchors 120 | def forward(self, features, labels): 121 | """Compute loss for model. If both `labels` and `mask` are None, 122 | it degenerates to SimCLR unsupervised loss: 123 | https://arxiv.org/pdf/2002.05709.pdf 124 | 125 | Args: 126 | features: hidden vector of shape [bsz, C, ...]. 127 | labels: ground truth of shape [bsz, 1, ...]., where 1 denotes to abnormal, and 0 denotes to normal 128 | Returns: 129 | A loss scalar. 130 | """ 131 | device = (torch.device('cuda') 132 | if features.is_cuda 133 | else torch.device('cpu')) 134 | if len(features.shape) != len(labels.shape): 135 | raise ValueError('`features` needs to have the same dimensions with labels') 136 | 137 | if len(features.shape) < 3: 138 | raise ValueError('`features` needs to be [bsz, C, ...],' 139 | 'at least 3 dimensions are required') 140 | 141 | if len(features.shape) > 3: 142 | features = features.view(features.shape[0], features.shape[1], -1) 143 | labels = labels.view(labels.shape[0], labels.shape[1], -1) 144 | 145 | labels = labels.squeeze() 146 | batch_size = features.shape[0] 147 | 148 | C = features.shape[1] 149 | normal_feats = features[:, :, labels == 0] 150 | abnormal_feats = features[:, :, labels == 1] 151 | 152 | normal_feats = normal_feats.permute((1, 0, 2)).contiguous().view(C, -1) 153 | abnormal_feats = abnormal_feats.permute((1, 0, 2)).contiguous().view(C, -1) 154 | 155 | contrast_count = normal_feats.shape[1] 156 | contrast_feature = normal_feats 157 | 158 | if self.contrast_mode == 'mean': 159 | anchor_feature = torch.mean(normal_feats, dim=1) 160 | anchor_feature = F.normalize(anchor_feature, dim=0, p=2) 161 | anchor_count = 1 162 | elif self.contrast_mode == 'all': 163 | anchor_feature = contrast_feature 164 | anchor_count = contrast_count 165 | elif self.contrast_mode == 'random': 166 | dim_to_sample = 1 167 | num_samples = min(self.random_anchors, contrast_count) 168 | permuted_indices = torch.randperm(normal_feats.size(dim_to_sample)).to(normal_feats.device) 169 | selected_indices = permuted_indices[:num_samples] 170 | anchor_feature = normal_feats.index_select(dim_to_sample, selected_indices) 171 | else: 172 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 173 | 174 | # compute logits 175 | # maximize similarity 176 | anchor_dot_normal = torch.matmul(anchor_feature.T, normal_feats).mean() 177 | 178 | # minimize similarity 179 | anchor_dot_abnormal = torch.matmul(anchor_feature.T, abnormal_feats).mean() 180 | 181 | loss = 0 182 | if normal_feats.shape[1] > 0: 183 | loss -= anchor_dot_normal 184 | if abnormal_feats.shape[1] > 0: 185 | loss += anchor_dot_abnormal 186 | 187 | loss = torch.exp(loss) 188 | 189 | return loss 190 | -------------------------------------------------------------------------------- /method/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaclip import AdaCLIP 2 | from .trainer import AdaCLIP_Trainer -------------------------------------------------------------------------------- /method/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a 3 | size 1356917 4 | -------------------------------------------------------------------------------- /method/clip_model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from dataclasses import dataclass 6 | import logging 7 | import math 8 | from typing import Optional, Tuple, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer 15 | from .utils import to_2tuple 16 | 17 | @dataclass 18 | class CLIPVisionCfg: 19 | layers: Union[Tuple[int, int, int, int], int] = 12 20 | width: int = 768 21 | head_width: int = 64 22 | mlp_ratio: float = 4.0 23 | patch_size: int = 16 24 | image_size: Union[Tuple[int, int], int] = 224 25 | ls_init_value: Optional[float] = None # layer scale initial value 26 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 27 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design 28 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) 29 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer 30 | n_queries: int = 256 # n_queries for attentional pooler 31 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 32 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 33 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 34 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 35 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 36 | timm_proj_bias: bool = False # enable bias final projection 37 | timm_drop: float = 0. # head dropout 38 | timm_drop_path: Optional[float] = None # backbone stochastic depth 39 | output_tokens: bool = False 40 | 41 | 42 | @dataclass 43 | class CLIPTextCfg: 44 | context_length: int = 77 45 | vocab_size: int = 49408 46 | width: int = 512 47 | heads: int = 8 48 | layers: int = 12 49 | ls_init_value: Optional[float] = None # layer scale initial value 50 | hf_model_name: str = None 51 | hf_tokenizer_name: str = None 52 | hf_model_pretrained: bool = True 53 | proj: str = 'mlp' 54 | pooler_type: str = 'mean_pooler' 55 | embed_cls: bool = False 56 | pad_id: int = 0 57 | output_tokens: bool = False 58 | 59 | 60 | def get_cast_dtype(precision: str): 61 | cast_dtype = None 62 | if precision == 'bf16': 63 | cast_dtype = torch.bfloat16 64 | elif precision == 'fp16': 65 | cast_dtype = torch.float16 66 | return cast_dtype 67 | 68 | 69 | def _build_vision_tower( 70 | embed_dim: int, 71 | vision_cfg: CLIPVisionCfg, 72 | quick_gelu: bool = False, 73 | cast_dtype: Optional[torch.dtype] = None, 74 | ): 75 | if isinstance(vision_cfg, dict): 76 | vision_cfg = CLIPVisionCfg(**vision_cfg) 77 | 78 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 79 | # memory efficient in recent PyTorch releases (>= 1.10). 80 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 81 | act_layer = QuickGELU if quick_gelu else nn.GELU 82 | 83 | vision_heads = vision_cfg.width // vision_cfg.head_width 84 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 85 | visual = VisionTransformer( 86 | image_size=vision_cfg.image_size, 87 | patch_size=vision_cfg.patch_size, 88 | width=vision_cfg.width, 89 | layers=vision_cfg.layers, 90 | heads=vision_heads, 91 | mlp_ratio=vision_cfg.mlp_ratio, 92 | ls_init_value=vision_cfg.ls_init_value, 93 | patch_dropout=vision_cfg.patch_dropout, 94 | input_patchnorm=vision_cfg.input_patchnorm, 95 | global_average_pool=vision_cfg.global_average_pool, 96 | attentional_pool=vision_cfg.attentional_pool, 97 | n_queries=vision_cfg.n_queries, 98 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 99 | output_tokens=vision_cfg.output_tokens, 100 | output_dim=embed_dim, 101 | act_layer=act_layer, 102 | norm_layer=norm_layer 103 | ) 104 | 105 | return visual 106 | 107 | 108 | def _build_text_tower( 109 | embed_dim: int, 110 | text_cfg: CLIPTextCfg, 111 | quick_gelu: bool = False, 112 | cast_dtype: Optional[torch.dtype] = None, 113 | ): 114 | if isinstance(text_cfg, dict): 115 | text_cfg = CLIPTextCfg(**text_cfg) 116 | 117 | act_layer = QuickGELU if quick_gelu else nn.GELU 118 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 119 | 120 | text = TextTransformer( 121 | context_length=text_cfg.context_length, 122 | vocab_size=text_cfg.vocab_size, 123 | width=text_cfg.width, 124 | heads=text_cfg.heads, 125 | layers=text_cfg.layers, 126 | ls_init_value=text_cfg.ls_init_value, 127 | output_dim=embed_dim, 128 | embed_cls=text_cfg.embed_cls, 129 | output_tokens=text_cfg.output_tokens, 130 | pad_id=text_cfg.pad_id, 131 | act_layer=act_layer, 132 | norm_layer=norm_layer 133 | ) 134 | 135 | return text 136 | 137 | 138 | class CLIP(nn.Module): 139 | output_dict: torch.jit.Final[bool] 140 | 141 | def __init__( 142 | self, 143 | embed_dim: int, 144 | vision_cfg: CLIPVisionCfg, 145 | text_cfg: CLIPTextCfg, 146 | quick_gelu: bool = False, 147 | cast_dtype: Optional[torch.dtype] = None, 148 | output_dict: bool = False, 149 | ): 150 | super().__init__() 151 | self.output_dict = output_dict 152 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 153 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 154 | self.transformer = text.transformer 155 | self.vocab_size = text.vocab_size 156 | self.token_embedding = text.token_embedding 157 | self.positional_embedding = text.positional_embedding 158 | self.ln_final = text.ln_final 159 | self.text_projection = text.text_projection 160 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 161 | 162 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 163 | 164 | 165 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 166 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 167 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 168 | 169 | @torch.jit.ignore 170 | def set_grad_checkpointing(self, enable=True): 171 | self.visual.set_grad_checkpointing(enable) 172 | self.transformer.grad_checkpointing = enable 173 | 174 | def encode_image(self, image, out_layers): 175 | 176 | x = image 177 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 178 | if self.visual.input_patchnorm: 179 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 180 | x = x.reshape(x.shape[0], x.shape[1], 181 | self.visual.grid_size[0], 182 | self.visual.patch_size[0], 183 | self.visual.grid_size[1], 184 | self.visual.patch_size[1]) 185 | x = x.permute(0, 2, 4, 1, 3, 5) 186 | x = x.reshape(x.shape[0], self.visual.grid_size[0] * self.visual.grid_size[1], -1) 187 | x = self.visual.patchnorm_pre_ln(x) 188 | x = self.visual.conv1(x) 189 | else: 190 | x = self.visual.conv1(x) # shape = [*, width, grid, grid] 191 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 192 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 193 | 194 | 195 | # class embeddings and positional embeddings 196 | x = torch.cat( 197 | [self.visual.class_embedding.to(x.dtype) + 198 | torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 199 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 200 | 201 | x = x + self.visual.positional_embedding.to(x.dtype) 202 | 203 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 204 | x = self.visual.patch_dropout(x) 205 | x = self.visual.ln_pre(x) 206 | 207 | x = x.permute(1, 0, 2) # NLD -> LND 208 | 209 | patch_tokens = [] 210 | 211 | idx = 0 212 | for r in self.visual.transformer.resblocks: 213 | idx += 1 214 | # add prompt here 215 | x, attn_tmp = r(x, attn_mask=None) 216 | if idx in out_layers: 217 | patch_tokens.append(x) 218 | 219 | x = x.permute(1, 0, 2) # LND -> NLD 220 | patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD 221 | 222 | if self.visual.attn_pool is not None: 223 | x = self.visual.attn_pool(x) 224 | x = self.visual.ln_post(x) 225 | pooled, tokens = self.visual._global_pool(x) 226 | else: 227 | pooled, tokens = self.visual._global_pool(x) 228 | pooled = self.visual.ln_post(pooled) 229 | 230 | if self.visual.proj is not None: 231 | pooled = pooled @ self.visual.proj 232 | 233 | if self.visual.output_tokens: 234 | return pooled, patch_tokens 235 | 236 | return pooled, patch_tokens 237 | 238 | def encode_text(self, text): 239 | cast_dtype = self.transformer.get_cast_dtype() 240 | 241 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 242 | 243 | x = x + self.positional_embedding.to(cast_dtype) 244 | x = x.permute(1, 0, 2) # NLD -> LND 245 | 246 | for r in self.visual.transformer.resblocks: 247 | # add prompt here 248 | 249 | x, attn_tmp = r(x, attn_mask=self.attn_mask) 250 | 251 | x = x.permute(1, 0, 2) # LND -> NLD 252 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 253 | 254 | # take features from the eot embedding (eot_token is the highest number in each sequence) 255 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 256 | return x 257 | 258 | 259 | 260 | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): 261 | """Convert applicable model parameters to low-precision (bf16 or fp16)""" 262 | 263 | def _convert_weights(l): 264 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 265 | l.weight.data = l.weight.data.to(dtype) 266 | if l.bias is not None: 267 | l.bias.data = l.bias.data.to(dtype) 268 | 269 | if isinstance(l, (nn.MultiheadAttention, Attention)): 270 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 271 | tensor = getattr(l, attr) 272 | if tensor is not None: 273 | tensor.data = tensor.data.to(dtype) 274 | 275 | for name in ["text_projection", "proj"]: 276 | if hasattr(l, name): 277 | attr = getattr(l, name) 278 | if attr is not None: 279 | attr.data = attr.data.to(dtype) 280 | 281 | model.apply(_convert_weights) 282 | 283 | 284 | convert_weights_to_fp16 = convert_weights_to_lp # backwards compat 285 | 286 | 287 | # used to maintain checkpoint compatibility 288 | def convert_to_custom_text_state_dict(state_dict: dict): 289 | if 'text_projection' in state_dict: 290 | # old format state_dict, move text tower -> .text 291 | new_state_dict = {} 292 | for k, v in state_dict.items(): 293 | if any(k.startswith(p) for p in ( 294 | 'text_projection', 295 | 'positional_embedding', 296 | 'token_embedding', 297 | 'transformer', 298 | 'ln_final', 299 | )): 300 | k = 'text.' + k 301 | new_state_dict[k] = v 302 | return new_state_dict 303 | return state_dict 304 | 305 | 306 | def build_model_from_openai_state_dict( 307 | state_dict: dict, 308 | quick_gelu=True, 309 | cast_dtype=torch.float16, 310 | ): 311 | vit = "visual.proj" in state_dict 312 | 313 | if vit: 314 | vision_width = state_dict["visual.conv1.weight"].shape[0] 315 | vision_layers = len( 316 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 317 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 318 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 319 | image_size = vision_patch_size * grid_size 320 | else: 321 | counts: list = [ 322 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 323 | vision_layers = tuple(counts) 324 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 325 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 326 | vision_patch_size = None 327 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 328 | image_size = output_width * 32 329 | 330 | embed_dim = state_dict["text_projection"].shape[1] 331 | context_length = state_dict["positional_embedding"].shape[0] 332 | vocab_size = state_dict["token_embedding.weight"].shape[0] 333 | transformer_width = state_dict["ln_final.weight"].shape[0] 334 | transformer_heads = transformer_width // 64 335 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 336 | 337 | vision_cfg = CLIPVisionCfg( 338 | layers=vision_layers, 339 | width=vision_width, 340 | patch_size=vision_patch_size, 341 | image_size=image_size, 342 | ) 343 | text_cfg = CLIPTextCfg( 344 | context_length=context_length, 345 | vocab_size=vocab_size, 346 | width=transformer_width, 347 | heads=transformer_heads, 348 | layers=transformer_layers, 349 | ) 350 | model = CLIP( 351 | embed_dim, 352 | vision_cfg=vision_cfg, 353 | text_cfg=text_cfg, 354 | quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU 355 | cast_dtype=cast_dtype, 356 | ) 357 | 358 | for key in ["input_resolution", "context_length", "vocab_size"]: 359 | state_dict.pop(key, None) 360 | 361 | convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 362 | model.load_state_dict(state_dict) 363 | return model.eval() 364 | 365 | 366 | def trace_model(model, batch_size=256, device=torch.device('cpu')): 367 | model.eval() 368 | image_size = model.visual.image_size 369 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) 370 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) 371 | model = torch.jit.trace_module( 372 | model, 373 | inputs=dict( 374 | forward=(example_images, example_text), 375 | encode_text=(example_text,), 376 | encode_image=(example_images,) 377 | )) 378 | model.visual.image_size = image_size 379 | return model 380 | 381 | 382 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic'): 383 | # Rescale the grid of position embeddings when loading from state_dict 384 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 385 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 386 | return 387 | grid_size = to_2tuple(model.visual.grid_size) 388 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 389 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 390 | if new_seq_len == old_pos_embed.shape[0]: 391 | return 392 | 393 | if extra_tokens: 394 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 395 | else: 396 | pos_emb_tok, pos_emb_img = None, old_pos_embed 397 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 398 | 399 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 400 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 401 | pos_emb_img = F.interpolate( 402 | pos_emb_img, 403 | size=grid_size, 404 | mode=interpolation, 405 | align_corners=False, 406 | ) 407 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 408 | if pos_emb_tok is not None: 409 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 410 | else: 411 | new_pos_embed = pos_emb_img 412 | state_dict['visual.positional_embedding'] = new_pos_embed 413 | -------------------------------------------------------------------------------- /method/custom_clip.py: -------------------------------------------------------------------------------- 1 | # This file is largely borrowed from open clip 2 | import hashlib 3 | import json 4 | import logging 5 | import os 6 | import re 7 | import urllib 8 | import warnings 9 | from copy import deepcopy 10 | from dataclasses import dataclass, asdict 11 | from functools import partial 12 | from pathlib import Path 13 | from typing import Any, Optional, Tuple 14 | from typing import Dict, Union 15 | from typing import List 16 | import torch 17 | import torch.nn as nn 18 | import torchvision.transforms.functional as F 19 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 20 | CenterCrop 21 | from tqdm import tqdm 22 | from .clip_model import CLIP, convert_to_custom_text_state_dict, \ 23 | resize_pos_embed 24 | from .clip_model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 25 | from .tokenizer import HFTokenizer, tokenize 26 | 27 | __version__ = '2.16.0' 28 | 29 | try: 30 | from huggingface_hub import hf_hub_download 31 | 32 | hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) 33 | _has_hf_hub = True 34 | except ImportError: 35 | hf_hub_download = None 36 | _has_hf_hub = False 37 | 38 | 39 | def _pcfg(url='', hf_hub='', mean=None, std=None): 40 | return dict( 41 | url=url, 42 | hf_hub=hf_hub, 43 | mean=mean, 44 | std=std, 45 | ) 46 | 47 | 48 | _VITB32 = dict( 49 | openai=_pcfg( 50 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 51 | laion400m_e31=_pcfg( 52 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 53 | laion400m_e32=_pcfg( 54 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 55 | laion2b_e16=_pcfg( 56 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), 57 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') 58 | ) 59 | 60 | 61 | _VITB16 = dict( 62 | openai=_pcfg( 63 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), 64 | laion400m_e31=_pcfg( 65 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), 66 | laion400m_e32=_pcfg( 67 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), 68 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), 69 | ) 70 | 71 | _VITL14 = dict( 72 | openai=_pcfg( 73 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), 74 | laion400m_e31=_pcfg( 75 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), 76 | laion400m_e32=_pcfg( 77 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), 78 | laion2b_s32b_b82k=_pcfg( 79 | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', 80 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 81 | ) 82 | 83 | _VITL14_336 = dict( 84 | openai=_pcfg( 85 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), 86 | ) 87 | 88 | _VITH14 = dict( 89 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), 90 | ) 91 | 92 | _VITg14 = dict( 93 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), 94 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), 95 | ) 96 | 97 | _VITbigG14 = dict( 98 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), 99 | ) 100 | 101 | 102 | 103 | _PRETRAINED = { 104 | "ViT-B-32": _VITB32, 105 | "ViT-B-16": _VITB16, 106 | "ViT-L-14": _VITL14, 107 | "ViT-L-14-336": _VITL14_336, 108 | "ViT-H-14": _VITH14, 109 | "ViT-g-14": _VITg14, 110 | "ViT-bigG-14": _VITbigG14, 111 | } 112 | 113 | 114 | def _clean_tag(tag: str): 115 | # normalize pretrained tags 116 | return tag.lower().replace('-', '_') 117 | 118 | 119 | def list_pretrained(as_str: bool = False): 120 | """ returns list of pretrained models 121 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 122 | """ 123 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 124 | 125 | 126 | def list_pretrained_models_by_tag(tag: str): 127 | """ return all models having the specified pretrain tag """ 128 | models = [] 129 | tag = _clean_tag(tag) 130 | for k in _PRETRAINED.keys(): 131 | if tag in _PRETRAINED[k]: 132 | models.append(k) 133 | return models 134 | 135 | 136 | def list_pretrained_tags_by_model(model: str): 137 | """ return all pretrain tags for the specified model architecture """ 138 | tags = [] 139 | if model in _PRETRAINED: 140 | tags.extend(_PRETRAINED[model].keys()) 141 | return tags 142 | 143 | 144 | def is_pretrained_cfg(model: str, tag: str): 145 | if model not in _PRETRAINED: 146 | return False 147 | return _clean_tag(tag) in _PRETRAINED[model] 148 | 149 | 150 | def get_pretrained_cfg(model: str, tag: str): 151 | if model not in _PRETRAINED: 152 | return {} 153 | model_pretrained = _PRETRAINED[model] 154 | if 'openai' in model_pretrained.keys(): 155 | tag = 'openai' 156 | else: 157 | tag = list(model_pretrained.keys())[0] 158 | print('*' * 50) 159 | print(f'Use pretrained model from {tag}...') 160 | print('*' * 50) 161 | return model_pretrained.get(_clean_tag(tag), {}) 162 | 163 | 164 | def get_pretrained_url(model: str, tag: str): 165 | cfg = get_pretrained_cfg(model, _clean_tag(tag)) 166 | return cfg.get('url', '') 167 | 168 | 169 | def download_pretrained_from_url( 170 | url: str, 171 | cache_dir: Union[str, None] = None, 172 | ): 173 | if not cache_dir: 174 | cache_dir = os.path.expanduser("~/.cache/clip") 175 | os.makedirs(cache_dir, exist_ok=True) 176 | filename = os.path.basename(url) 177 | 178 | if 'openaipublic' in url: 179 | expected_sha256 = url.split("/")[-2] 180 | elif 'mlfoundations' in url: 181 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 182 | else: 183 | expected_sha256 = '' 184 | 185 | download_target = os.path.join(cache_dir, filename) 186 | 187 | if os.path.exists(download_target) and not os.path.isfile(download_target): 188 | raise RuntimeError(f"{download_target} exists and is not a regular file") 189 | 190 | if os.path.isfile(download_target): 191 | if expected_sha256: 192 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 193 | return download_target 194 | else: 195 | warnings.warn( 196 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 197 | else: 198 | return download_target 199 | 200 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 201 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 202 | while True: 203 | buffer = source.read(8192) 204 | if not buffer: 205 | break 206 | 207 | output.write(buffer) 208 | loop.update(len(buffer)) 209 | 210 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith( 211 | expected_sha256): 212 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 213 | 214 | return download_target 215 | 216 | 217 | def has_hf_hub(necessary=False): 218 | if not _has_hf_hub and necessary: 219 | # if no HF Hub module installed, and it is necessary to continue, raise error 220 | raise RuntimeError( 221 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 222 | return _has_hf_hub 223 | 224 | 225 | def download_pretrained_from_hf( 226 | model_id: str, 227 | filename: str = 'open_clip_pytorch_model.bin', 228 | revision=None, 229 | cache_dir: Union[str, None] = None, 230 | ): 231 | has_hf_hub(True) 232 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) 233 | return cached_file 234 | 235 | 236 | def download_pretrained( 237 | cfg: Dict, 238 | force_hf_hub: bool = False, 239 | cache_dir: Union[str, None] = None, 240 | ): 241 | target = '' 242 | if not cfg: 243 | return target 244 | 245 | download_url = cfg.get('url', '') 246 | download_hf_hub = cfg.get('hf_hub', '') 247 | if download_hf_hub and force_hf_hub: 248 | # use HF hub even if url exists 249 | download_url = '' 250 | 251 | if download_url: 252 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) 253 | elif download_hf_hub: 254 | has_hf_hub(True) 255 | # we assume the hf_hub entries in pretrained config combine model_id + filename in 256 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and 257 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. 258 | model_id, filename = os.path.split(download_hf_hub) 259 | if filename: 260 | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) 261 | else: 262 | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 263 | 264 | return target 265 | 266 | 267 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 268 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 269 | 270 | 271 | @dataclass 272 | class AugmentationCfg: 273 | scale: Tuple[float, float] = (0.9, 1.0) 274 | ratio: Optional[Tuple[float, float]] = None 275 | color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None 276 | interpolation: Optional[str] = None 277 | re_prob: Optional[float] = None 278 | re_count: Optional[int] = None 279 | use_timm: bool = False 280 | 281 | 282 | class ResizeMaxSize(nn.Module): 283 | 284 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 285 | super().__init__() 286 | if not isinstance(max_size, int): 287 | raise TypeError(f"Size should be int. Got {type(max_size)}") 288 | self.max_size = max_size 289 | self.interpolation = interpolation 290 | self.fn = min if fn == 'min' else min 291 | self.fill = fill 292 | 293 | def forward(self, img): 294 | if isinstance(img, torch.Tensor): 295 | height, width = img.shape[:2] 296 | else: 297 | width, height = img.size 298 | scale = self.max_size / float(max(height, width)) 299 | if scale != 1.0: 300 | new_size = tuple(round(dim * scale) for dim in (height, width)) 301 | img = F.resize(img, new_size, self.interpolation) 302 | pad_h = self.max_size - new_size[0] 303 | pad_w = self.max_size - new_size[1] 304 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 305 | return img 306 | 307 | 308 | def _convert_to_rgb(image): 309 | return image.convert('RGB') 310 | 311 | 312 | def image_transform( 313 | image_size: int, 314 | is_train: bool, 315 | mean: Optional[Tuple[float, ...]] = None, 316 | std: Optional[Tuple[float, ...]] = None, 317 | resize_longest_max: bool = False, 318 | fill_color: int = 0, 319 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 320 | ): 321 | mean = mean or OPENAI_DATASET_MEAN 322 | if not isinstance(mean, (list, tuple)): 323 | mean = (mean,) * 3 324 | 325 | std = std or OPENAI_DATASET_STD 326 | if not isinstance(std, (list, tuple)): 327 | std = (std,) * 3 328 | 329 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 330 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 331 | image_size = image_size[0] 332 | 333 | if isinstance(aug_cfg, dict): 334 | aug_cfg = AugmentationCfg(**aug_cfg) 335 | else: 336 | aug_cfg = aug_cfg or AugmentationCfg() 337 | normalize = Normalize(mean=mean, std=std) 338 | if is_train: 339 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 340 | use_timm = aug_cfg_dict.pop('use_timm', False) 341 | if use_timm: 342 | from timm.data import create_transform # timm can still be optional 343 | if isinstance(image_size, (tuple, list)): 344 | assert len(image_size) >= 2 345 | input_size = (3,) + image_size[-2:] 346 | else: 347 | input_size = (3, image_size, image_size) 348 | # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time 349 | aug_cfg_dict.setdefault('interpolation', 'random') 350 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 351 | train_transform = create_transform( 352 | input_size=input_size, 353 | is_training=True, 354 | hflip=0., 355 | mean=mean, 356 | std=std, 357 | re_mode='pixel', 358 | **aug_cfg_dict, 359 | ) 360 | else: 361 | train_transform = Compose([ 362 | RandomResizedCrop( 363 | image_size, 364 | scale=aug_cfg_dict.pop('scale'), 365 | interpolation=InterpolationMode.BICUBIC, 366 | ), 367 | _convert_to_rgb, 368 | ToTensor(), 369 | normalize, 370 | ]) 371 | if aug_cfg_dict: 372 | warnings.warn( 373 | f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') 374 | return train_transform 375 | else: 376 | if resize_longest_max: 377 | transforms = [ 378 | ResizeMaxSize(image_size, fill=fill_color) 379 | ] 380 | else: 381 | transforms = [ 382 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 383 | CenterCrop(image_size), 384 | ] 385 | transforms.extend([ 386 | _convert_to_rgb, 387 | ToTensor(), 388 | normalize, 389 | ]) 390 | return Compose(transforms) 391 | 392 | 393 | def list_openai_models() -> List[str]: 394 | """Returns the names of available CLIP models""" 395 | return list_pretrained_models_by_tag('openai') 396 | 397 | 398 | def load_openai_model( 399 | name: str, 400 | precision: Optional[str] = None, 401 | device: Optional[Union[str, torch.device]] = None, 402 | jit: bool = True, 403 | cache_dir: Optional[str] = None, 404 | ): 405 | """Load a CLIP model 406 | 407 | Parameters 408 | ---------- 409 | name : str 410 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 411 | precision: str 412 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 413 | device : Union[str, torch.device] 414 | The device to put the loaded model 415 | jit : bool 416 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 417 | cache_dir : Optional[str] 418 | The directory to cache the downloaded model weights 419 | 420 | Returns 421 | ------- 422 | model : torch.nn.Module 423 | The CLIP model 424 | preprocess : Callable[[PIL.Image], torch.Tensor] 425 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 426 | """ 427 | if device is None: 428 | device = "cuda" if torch.cuda.is_available() else "cpu" 429 | if precision is None: 430 | precision = 'fp32' if device == 'cpu' else 'fp16' 431 | 432 | cfg = get_pretrained_cfg(name, 'openai') 433 | if cfg: 434 | model_path = download_pretrained(cfg, cache_dir=cache_dir) 435 | elif os.path.isfile(name): 436 | model_path = name 437 | else: 438 | raise RuntimeError(f"Model {name} not found; available models = {list_pretrained()}") 439 | 440 | try: 441 | # loading JIT archive 442 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 443 | state_dict = None 444 | except RuntimeError: 445 | # loading saved state dict 446 | if jit: 447 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 448 | jit = False 449 | state_dict = torch.load(model_path, map_location="cpu") 450 | 451 | # JIT -> Just In Time 452 | if not jit: 453 | # Build a non-jit model from the OpenAI jitted model state dict 454 | cast_dtype = get_cast_dtype(precision) 455 | try: 456 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 457 | except KeyError: 458 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 459 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 460 | 461 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 462 | model = model.to(device) 463 | if precision.startswith('amp') or precision == 'fp32': 464 | model.float() 465 | elif precision == 'bf16': 466 | convert_weights_to_lp(model, dtype=torch.bfloat16) 467 | 468 | return model 469 | 470 | # patch the device names 471 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 472 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 473 | 474 | def patch_device(module): 475 | try: 476 | graphs = [module.graph] if hasattr(module, "graph") else [] 477 | except RuntimeError: 478 | graphs = [] 479 | 480 | if hasattr(module, "forward1"): 481 | graphs.append(module.forward1.graph) 482 | 483 | for graph in graphs: 484 | for node in graph.findAllNodes("prim::Constant"): 485 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 486 | node.copyAttributes(device_node) 487 | 488 | model.apply(patch_device) 489 | patch_device(model.encode_image) 490 | patch_device(model.encode_text) 491 | 492 | # patch dtype to float32 (typically for CPU) 493 | if precision == 'fp32': 494 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 495 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 496 | float_node = float_input.node() 497 | 498 | def patch_float(module): 499 | try: 500 | graphs = [module.graph] if hasattr(module, "graph") else [] 501 | except RuntimeError: 502 | graphs = [] 503 | 504 | if hasattr(module, "forward1"): 505 | graphs.append(module.forward1.graph) 506 | 507 | for graph in graphs: 508 | for node in graph.findAllNodes("aten::to"): 509 | inputs = list(node.inputs()) 510 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 511 | if inputs[i].node()["value"] == 5: 512 | inputs[i].node().copyAttributes(float_node) 513 | 514 | model.apply(patch_float) 515 | patch_float(model.encode_image) 516 | patch_float(model.encode_text) 517 | model.float() 518 | 519 | # ensure image_size attr available at consistent location for both jit and non-jit 520 | model.visual.image_size = model.input_resolution.item() 521 | return model 522 | 523 | 524 | HF_HUB_PREFIX = 'hf-hub:' 525 | _MODEL_CONFIG_PATHS = [Path(__file__).parent.parent / f"./model_configs/"] 526 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 527 | 528 | 529 | def _natural_key(string_): 530 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 531 | 532 | 533 | def _rescan_model_configs(): 534 | global _MODEL_CONFIGS 535 | 536 | config_ext = ('.json',) 537 | config_files = [] 538 | for config_path in _MODEL_CONFIG_PATHS: 539 | if config_path.is_file() and config_path.suffix in config_ext: 540 | config_files.append(config_path) 541 | elif config_path.is_dir(): 542 | for ext in config_ext: 543 | config_files.extend(config_path.glob(f'*{ext}')) 544 | 545 | for cf in config_files: 546 | with open(cf, 'r') as f: 547 | model_cfg = json.load(f) 548 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 549 | _MODEL_CONFIGS[cf.stem] = model_cfg 550 | 551 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 552 | 553 | 554 | _rescan_model_configs() # initial populate of model config registry 555 | 556 | 557 | def list_models(): 558 | """ enumerate available model architectures based on config files """ 559 | return list(_MODEL_CONFIGS.keys()) 560 | 561 | 562 | def add_model_config(path): 563 | """ add model config path or file and update registry """ 564 | if not isinstance(path, Path): 565 | path = Path(path) 566 | _MODEL_CONFIG_PATHS.append(path) 567 | _rescan_model_configs() 568 | 569 | 570 | def get_model_config(model_name): 571 | if model_name in _MODEL_CONFIGS: 572 | return deepcopy(_MODEL_CONFIGS[model_name]) 573 | else: 574 | return None 575 | 576 | 577 | def get_tokenizer(model_name): 578 | if model_name.startswith(HF_HUB_PREFIX): 579 | tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) 580 | else: 581 | config = get_model_config(model_name) 582 | tokenizer = HFTokenizer( 583 | config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize 584 | return tokenizer 585 | 586 | 587 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 588 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 589 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 590 | state_dict = checkpoint['state_dict'] 591 | else: 592 | state_dict = checkpoint 593 | if next(iter(state_dict.items()))[0].startswith('module'): 594 | state_dict = {k[7:]: v for k, v in state_dict.items()} 595 | return state_dict 596 | 597 | 598 | def load_checkpoint(model, checkpoint_path, strict=True): 599 | state_dict = load_state_dict(checkpoint_path) 600 | # detect old format and make compatible with new format 601 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 602 | state_dict = convert_to_custom_text_state_dict(state_dict) 603 | resize_pos_embed(state_dict, model) 604 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 605 | return incompatible_keys 606 | 607 | 608 | def create_model( 609 | model_name: str, 610 | img_size: int, 611 | pretrained: Optional[str] = None, 612 | precision: str = 'fp32', 613 | device: Union[str, torch.device] = 'cpu', 614 | jit: bool = False, 615 | cache_dir: Optional[str] = None, 616 | output_dict: Optional[bool] = None, 617 | ): 618 | if model_name.count('ViT') < 1: 619 | print('only support ViT model..') 620 | raise NotImplementedError 621 | 622 | # in which means, we can also use old naming rules. 623 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 624 | checkpoint_path = None 625 | pretrained_cfg = {} 626 | model_cfg = None 627 | 628 | if isinstance(device, str): 629 | device = torch.device(device) 630 | 631 | # our default version are borrowed from openai 632 | assert pretrained and pretrained.lower() == 'openai', 'only support openai module.' 633 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 634 | model_cfg = model_cfg or get_model_config(model_name) 635 | 636 | model_cfg['vision_cfg']['image_size'] = img_size 637 | cast_dtype = get_cast_dtype(precision) 638 | 639 | model_pre = load_openai_model( 640 | model_name, 641 | precision=precision, 642 | device=device, 643 | jit=jit, 644 | cache_dir=cache_dir, 645 | ) 646 | state_dict = model_pre.state_dict() 647 | 648 | # to always output dict even if it is clip 649 | if output_dict and hasattr(model_pre, "output_dict"): 650 | model_pre.output_dict = True 651 | 652 | model = CLIP(**model_cfg, cast_dtype=cast_dtype) 653 | 654 | # mainly need to resize the position embeddings 655 | resize_pos_embed(state_dict, model) 656 | incompatible_keys = model.load_state_dict(state_dict, strict=True) 657 | model.to(device=device) 658 | if precision in ("fp16", "bf16"): 659 | convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) 660 | 661 | # set image / mean metadata from pretrained_cfg if available, or use default 662 | model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN 663 | model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD 664 | 665 | # to always output dict even if it is clip 666 | if output_dict and hasattr(model, "output_dict"): 667 | model.output_dict = True 668 | 669 | if jit: 670 | model = torch.jit.script(model) 671 | 672 | return model 673 | 674 | 675 | def create_model_and_transforms( 676 | model_name: str, 677 | img_size: int, 678 | pretrained: Optional[str] = None, 679 | precision: str = 'fp32', 680 | device: Union[str, torch.device] = 'cpu', 681 | jit: bool = False, 682 | image_mean: Optional[Tuple[float, ...]] = None, 683 | image_std: Optional[Tuple[float, ...]] = None, 684 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 685 | cache_dir: Optional[str] = None, 686 | output_dict: Optional[bool] = None, 687 | ): 688 | ######### create the clip model 689 | model = create_model( 690 | model_name, 691 | img_size, 692 | pretrained, 693 | precision=precision, 694 | device=device, 695 | jit=jit, 696 | cache_dir=cache_dir, 697 | output_dict=output_dict, 698 | ) 699 | 700 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 701 | image_std = image_std or getattr(model.visual, 'image_std', None) 702 | preprocess_train = image_transform( 703 | model.visual.image_size, 704 | is_train=True, 705 | mean=image_mean, 706 | std=image_std, 707 | aug_cfg=aug_cfg, 708 | ) 709 | preprocess_val = image_transform( 710 | model.visual.image_size, 711 | is_train=False, 712 | mean=image_mean, 713 | std=image_std, 714 | ) 715 | 716 | return model, preprocess_train, preprocess_val 717 | -------------------------------------------------------------------------------- /method/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /method/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a significant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | def decode(output_ids: torch.Tensor): 156 | output_ids = output_ids.cpu().numpy() 157 | return _tokenizer.decode(output_ids) 158 | 159 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 160 | """ 161 | Returns the tokenized representation of given input string(s) 162 | 163 | Parameters 164 | ---------- 165 | texts : Union[str, List[str]] 166 | An input string or a list of input strings to tokenize 167 | context_length : int 168 | The context length to use; all CLIP models use 77 as the context length 169 | 170 | Returns 171 | ------- 172 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 173 | """ 174 | if isinstance(texts, str): 175 | texts = [texts] 176 | 177 | sot_token = _tokenizer.encoder[""] 178 | eot_token = _tokenizer.encoder[""] 179 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 180 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 181 | 182 | for i, tokens in enumerate(all_tokens): 183 | if len(tokens) > context_length: 184 | tokens = tokens[:context_length] # Truncate 185 | tokens[-1] = eot_token 186 | result[i, :len(tokens)] = torch.tensor(tokens) 187 | 188 | return result 189 | 190 | 191 | class HFTokenizer: 192 | """HuggingFace tokenizer wrapper""" 193 | 194 | def __init__(self, tokenizer_name: str): 195 | from transformers import AutoTokenizer 196 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 197 | 198 | def save_pretrained(self, dest): 199 | self.tokenizer.save_pretrained(dest) 200 | 201 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 202 | # same cleaning as for default tokenizer, except lowercasing 203 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 207 | input_ids = self.tokenizer( 208 | texts, 209 | return_tensors='pt', 210 | max_length=context_length, 211 | padding='max_length', 212 | truncation=True, 213 | ).input_ids 214 | return input_ids 215 | -------------------------------------------------------------------------------- /method/trainer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torchvision.transforms as transforms 3 | from scipy.ndimage import gaussian_filter 4 | 5 | from loss import FocalLoss, BinaryDiceLoss 6 | from tools import visualization, calculate_metric, calculate_average_metric 7 | from .adaclip import * 8 | from .custom_clip import create_model_and_transforms 9 | 10 | 11 | class AdaCLIP_Trainer(nn.Module): 12 | def __init__( 13 | self, 14 | # clip-related 15 | backbone, feat_list, input_dim, output_dim, 16 | 17 | # learning-related 18 | learning_rate, device, image_size, 19 | 20 | # model settings 21 | prompting_depth=3, prompting_length=2, 22 | prompting_branch='VL', prompting_type='SD', 23 | use_hsf=True, k_clusters=20, 24 | ): 25 | 26 | super(AdaCLIP_Trainer, self).__init__() 27 | 28 | self.device = device 29 | self.feat_list = feat_list 30 | self.image_size = image_size 31 | self.prompting_branch = prompting_branch 32 | self.prompting_type = prompting_type 33 | 34 | self.loss_focal = FocalLoss() 35 | self.loss_dice = BinaryDiceLoss() 36 | 37 | ########### different model choices 38 | freeze_clip, _, self.preprocess = create_model_and_transforms(backbone, image_size, 39 | pretrained='openai') 40 | freeze_clip = freeze_clip.to(device) 41 | freeze_clip.eval() 42 | 43 | self.clip_model = AdaCLIP(freeze_clip=freeze_clip, 44 | text_channel=output_dim, 45 | visual_channel=input_dim, 46 | prompting_length=prompting_length, 47 | prompting_depth=prompting_depth, 48 | prompting_branch=prompting_branch, 49 | prompting_type=prompting_type, 50 | use_hsf=use_hsf, 51 | k_clusters=k_clusters, 52 | output_layers=feat_list, 53 | device=device, 54 | image_size=image_size).to(device) 55 | 56 | self.transform = transforms.Compose([ 57 | transforms.Resize((image_size, image_size)), 58 | transforms.CenterCrop(image_size), 59 | transforms.ToTensor() 60 | ]) 61 | 62 | self.preprocess.transforms[0] = transforms.Resize(size=(image_size, image_size), 63 | interpolation=transforms.InterpolationMode.BICUBIC, 64 | max_size=None) 65 | 66 | self.preprocess.transforms[1] = transforms.CenterCrop(size=(image_size, image_size)) 67 | 68 | # update parameters 69 | self.learnable_paramter_list = [ 70 | 'text_prompter', 71 | 'visual_prompter', 72 | 'patch_token_layer', 73 | 'cls_token_layer', 74 | 'dynamic_visual_prompt_generator', 75 | 'dynamic_text_prompt_generator' 76 | ] 77 | 78 | self.params_to_update = [] 79 | for name, param in self.clip_model.named_parameters(): 80 | # print(name) 81 | for update_name in self.learnable_paramter_list: 82 | if update_name in name: 83 | # print(f'updated parameters--{name}: {update_name}') 84 | self.params_to_update.append(param) 85 | 86 | # build the optimizer 87 | self.optimizer = torch.optim.AdamW(self.params_to_update, lr=learning_rate, betas=(0.5, 0.999)) 88 | 89 | def save(self, path): 90 | self.save_dict = {} 91 | for param, value in self.state_dict().items(): 92 | for update_name in self.learnable_paramter_list: 93 | if update_name in param: 94 | # print(f'{param}: {update_name}') 95 | self.save_dict[param] = value 96 | break 97 | 98 | torch.save(self.save_dict, path) 99 | 100 | def load(self, path): 101 | self.load_state_dict(torch.load(path, map_location=self.device), strict=False) 102 | 103 | def train_one_batch(self, items): 104 | image = items['img'].to(self.device) 105 | cls_name = items['cls_name'] 106 | 107 | # pixel level 108 | anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=False) 109 | 110 | if not isinstance(anomaly_map, list): 111 | anomaly_map = [anomaly_map] 112 | 113 | # losses 114 | gt = items['img_mask'].to(self.device) 115 | gt = gt.squeeze() 116 | 117 | gt[gt > 0.5] = 1 118 | gt[gt <= 0.5] = 0 119 | 120 | is_anomaly = items['anomaly'].to(self.device) 121 | is_anomaly[is_anomaly > 0.5] = 1 122 | is_anomaly[is_anomaly <= 0.5] = 0 123 | loss = 0 124 | 125 | # classification loss 126 | classification_loss = self.loss_focal(anomaly_score, is_anomaly.unsqueeze(1)) 127 | loss += classification_loss 128 | 129 | # seg loss 130 | seg_loss = 0 131 | for am, in zip(anomaly_map): 132 | seg_loss += (self.loss_focal(am, gt) + self.loss_dice(am[:, 1, :, :], gt) + 133 | self.loss_dice(am[:, 0, :, :], 1-gt)) 134 | 135 | loss += seg_loss 136 | 137 | self.optimizer.zero_grad() 138 | loss.backward() 139 | self.optimizer.step() 140 | 141 | return loss 142 | 143 | def train_epoch(self, loader): 144 | self.clip_model.train() 145 | loss_list = [] 146 | for items in loader: 147 | loss = self.train_one_batch(items) 148 | loss_list.append(loss.item()) 149 | 150 | return np.mean(loss_list) 151 | 152 | @torch.no_grad() 153 | def evaluation(self, dataloader, obj_list, save_fig, save_fig_dir=None): 154 | self.clip_model.eval() 155 | 156 | results = {} 157 | results['cls_names'] = [] 158 | results['imgs_gts'] = [] 159 | results['anomaly_scores'] = [] 160 | results['imgs_masks'] = [] 161 | results['anomaly_maps'] = [] 162 | results['imgs'] = [] 163 | results['names'] = [] 164 | 165 | with torch.no_grad(), torch.cuda.amp.autocast(): 166 | image_indx = 0 167 | for indx, items in enumerate(dataloader): 168 | if save_fig: 169 | path = items['img_path'] 170 | for _path in path: 171 | vis_image = cv2.resize(cv2.imread(_path), (self.image_size, self.image_size)) 172 | results['imgs'].append(vis_image) 173 | cls_name = items['cls_name'] 174 | for _cls_name in cls_name: 175 | image_indx += 1 176 | results['names'].append('{:}-{:03d}'.format(_cls_name, image_indx)) 177 | 178 | image = items['img'].to(self.device) 179 | cls_name = items['cls_name'] 180 | results['cls_names'].extend(cls_name) 181 | gt_mask = items['img_mask'] 182 | gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 183 | 184 | for _gt_mask in gt_mask: 185 | results['imgs_masks'].append(_gt_mask.squeeze(0).numpy()) # px 186 | 187 | # pixel level 188 | anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=True) 189 | 190 | anomaly_map = anomaly_map.cpu().numpy() 191 | anomaly_score = anomaly_score.cpu().numpy() 192 | 193 | for _anomaly_map, _anomaly_score in zip(anomaly_map, anomaly_score): 194 | _anomaly_map = gaussian_filter(_anomaly_map, sigma=4) 195 | results['anomaly_maps'].append(_anomaly_map) 196 | results['anomaly_scores'].append(_anomaly_score) 197 | 198 | is_anomaly = np.array(items['anomaly']) 199 | for _is_anomaly in is_anomaly: 200 | results['imgs_gts'].append(_is_anomaly) 201 | 202 | # visualization 203 | if save_fig: 204 | print('saving fig.....') 205 | visualization.plot_sample_cv2( 206 | results['names'], 207 | results['imgs'], 208 | {'AdaCLIP': results['anomaly_maps']}, 209 | results['imgs_masks'], 210 | save_fig_dir 211 | ) 212 | 213 | metric_dict = dict() 214 | for obj in obj_list: 215 | metric_dict[obj] = dict() 216 | 217 | for obj in obj_list: 218 | metric = calculate_metric(results, obj) 219 | obj_full_name = f'{obj}' 220 | metric_dict[obj_full_name] = metric 221 | 222 | metric_dict['Average'] = calculate_average_metric(metric_dict) 223 | 224 | return metric_dict 225 | 226 | -------------------------------------------------------------------------------- /method/transformer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | from typing import Callable, Optional, Sequence, Tuple 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.utils.checkpoint import checkpoint 9 | 10 | from .utils import to_2tuple 11 | import numpy as np 12 | 13 | 14 | class LayerNormFp32(nn.LayerNorm): 15 | """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" 16 | 17 | def forward(self, x: torch.Tensor): 18 | orig_type = x.dtype 19 | x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) 20 | return x.to(orig_type) 21 | 22 | 23 | class LayerNorm(nn.LayerNorm): 24 | """Subclass torch's LayerNorm (with cast back to input dtype).""" 25 | 26 | def forward(self, x: torch.Tensor): 27 | orig_type = x.dtype 28 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 29 | return x.to(orig_type) 30 | 31 | 32 | class QuickGELU(nn.Module): 33 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 34 | def forward(self, x: torch.Tensor): 35 | return x * torch.sigmoid(1.702 * x) 36 | 37 | 38 | class LayerScale(nn.Module): 39 | def __init__(self, dim, init_values=1e-5, inplace=False): 40 | super().__init__() 41 | self.inplace = inplace 42 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 43 | 44 | def forward(self, x): 45 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 46 | 47 | 48 | class PatchDropout(nn.Module): 49 | """ 50 | https://arxiv.org/abs/2212.00794 51 | """ 52 | 53 | def __init__(self, prob, exclude_first_token=True): 54 | super().__init__() 55 | assert 0 <= prob < 1. 56 | self.prob = prob 57 | self.exclude_first_token = exclude_first_token # exclude CLS token 58 | 59 | def forward(self, x): 60 | if not self.training or self.prob == 0.: 61 | return x 62 | 63 | if self.exclude_first_token: 64 | cls_tokens, x = x[:, :1], x[:, 1:] 65 | else: 66 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) 67 | 68 | batch = x.size()[0] 69 | num_tokens = x.size()[1] 70 | 71 | batch_indices = torch.arange(batch) 72 | batch_indices = batch_indices[..., None] 73 | 74 | keep_prob = 1 - self.prob 75 | num_patches_keep = max(1, int(num_tokens * keep_prob)) 76 | 77 | rand = torch.randn(batch, num_tokens) 78 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices 79 | 80 | x = x[batch_indices, patch_indices_keep] 81 | 82 | if self.exclude_first_token: 83 | x = torch.cat((cls_tokens, x), dim=1) 84 | 85 | return x 86 | 87 | 88 | class Attention(nn.Module): 89 | def __init__( 90 | self, 91 | dim, 92 | num_heads=8, 93 | qkv_bias=True, 94 | scaled_cosine=False, 95 | scale_heads=False, 96 | logit_scale_max=math.log(1. / 0.01), 97 | attn_drop=0., 98 | proj_drop=0. 99 | ): 100 | super().__init__() 101 | self.scaled_cosine = scaled_cosine 102 | self.scale_heads = scale_heads 103 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 104 | self.num_heads = num_heads 105 | self.head_dim = dim // num_heads 106 | self.scale = self.head_dim ** -0.5 107 | self.logit_scale_max = logit_scale_max 108 | 109 | # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original 110 | self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) 111 | if qkv_bias: 112 | self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) 113 | else: 114 | self.in_proj_bias = None 115 | 116 | if self.scaled_cosine: 117 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) 118 | else: 119 | self.logit_scale = None 120 | self.attn_drop = nn.Dropout(attn_drop) 121 | if self.scale_heads: 122 | self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) 123 | else: 124 | self.head_scale = None 125 | self.out_proj = nn.Linear(dim, dim) 126 | self.out_drop = nn.Dropout(proj_drop) 127 | 128 | def forward(self, x, attn_mask: Optional[torch.Tensor] = None): 129 | L, N, C = x.shape 130 | q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) 131 | q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 132 | k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 133 | v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 134 | 135 | if self.logit_scale is not None: 136 | attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) 137 | logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() 138 | attn = attn.view(N, self.num_heads, L, L) * logit_scale 139 | attn = attn.view(-1, L, L) 140 | else: 141 | q = q * self.scale 142 | attn = torch.bmm(q, k.transpose(-1, -2)) 143 | 144 | if attn_mask is not None: 145 | if attn_mask.dtype == torch.bool: 146 | new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) 147 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 148 | attn_mask = new_attn_mask 149 | attn += attn_mask 150 | 151 | attn = attn.softmax(dim=-1) 152 | attn = self.attn_drop(attn) 153 | 154 | x = torch.bmm(attn, v) 155 | if self.head_scale is not None: 156 | x = x.view(N, self.num_heads, L, C) * self.head_scale 157 | x = x.view(-1, L, C) 158 | x = x.transpose(0, 1).reshape(L, N, C) 159 | x = self.out_proj(x) 160 | x = self.out_drop(x) 161 | return x 162 | 163 | 164 | class AttentionalPooler(nn.Module): 165 | def __init__( 166 | self, 167 | d_model: int, 168 | context_dim: int, 169 | n_head: int = 8, 170 | n_queries: int = 256, 171 | norm_layer: Callable = LayerNorm 172 | ): 173 | super().__init__() 174 | self.query = nn.Parameter(torch.randn(n_queries, d_model)) 175 | self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) 176 | self.ln_q = norm_layer(d_model) 177 | self.ln_k = norm_layer(context_dim) 178 | 179 | def forward(self, x: torch.Tensor): 180 | x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND 181 | N = x.shape[1] 182 | q = self.ln_q(self.query) 183 | out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] 184 | return out.permute(1, 0, 2) # LND -> NLD 185 | 186 | def _repeat(self, query, N: int): 187 | return query.unsqueeze(1).repeat(1, N, 1) 188 | 189 | 190 | class ResidualAttentionBlock(nn.Module): 191 | def __init__( 192 | self, 193 | d_model: int, 194 | n_head: int, 195 | mlp_ratio: float = 4.0, 196 | ls_init_value: float = None, 197 | act_layer: Callable = nn.GELU, 198 | norm_layer: Callable = LayerNorm, 199 | is_cross_attention: bool = False, 200 | idx: int = 12, 201 | ): 202 | super().__init__() 203 | 204 | self.idx = idx 205 | 206 | self.ln_1 = norm_layer(d_model) 207 | self.attn = nn.MultiheadAttention(d_model, n_head) 208 | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 209 | if is_cross_attention: 210 | self.ln_1_kv = norm_layer(d_model) 211 | 212 | self.ln_2 = norm_layer(d_model) 213 | mlp_width = int(d_model * mlp_ratio) 214 | self.mlp = nn.Sequential(OrderedDict([ 215 | ("c_fc", nn.Linear(d_model, mlp_width)), 216 | ("gelu", act_layer()), 217 | ("c_proj", nn.Linear(mlp_width, d_model)) 218 | ])) 219 | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 220 | 221 | def attention( 222 | self, 223 | q_x: torch.Tensor, 224 | k_x: Optional[torch.Tensor] = None, 225 | v_x: Optional[torch.Tensor] = None, 226 | attn_mask: Optional[torch.Tensor] = None, 227 | ): 228 | k_x = k_x if k_x is not None else q_x 229 | v_x = v_x if v_x is not None else q_x 230 | 231 | attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None 232 | return self.attn( 233 | q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask 234 | ) 235 | 236 | def forward( 237 | self, 238 | q_x: torch.Tensor, 239 | k_x: Optional[torch.Tensor] = None, 240 | v_x: Optional[torch.Tensor] = None, 241 | attn_mask: Optional[torch.Tensor] = None, 242 | ): 243 | k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None 244 | v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None 245 | 246 | tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) 247 | x = q_x + self.ls_1(tmp) 248 | x = x + self.ls_2(self.mlp(self.ln_2(x))) 249 | return x, attn 250 | 251 | 252 | 253 | class Transformer(nn.Module): 254 | def __init__( 255 | self, 256 | width: int, 257 | layers: int, 258 | heads: int, 259 | mlp_ratio: float = 4.0, 260 | ls_init_value: float = None, 261 | act_layer: Callable = nn.GELU, 262 | norm_layer: Callable = LayerNorm, 263 | ): 264 | super().__init__() 265 | self.width = width 266 | self.layers = layers 267 | self.grad_checkpointing = False 268 | 269 | self.resblocks = nn.ModuleList([ 270 | ResidualAttentionBlock( 271 | width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, 272 | idx=idx) 273 | for idx in range(layers) 274 | ]) 275 | 276 | def get_cast_dtype(self) -> torch.dtype: 277 | return self.resblocks[0].mlp.c_fc.weight.dtype 278 | 279 | def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9], 280 | attn_mask: Optional[torch.Tensor] = None): 281 | idx = 0 282 | out_tokens = [] 283 | for r in self.resblocks: 284 | idx += 1 285 | if self.grad_checkpointing and not torch.jit.is_scripting(): 286 | # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 287 | x = checkpoint(r, x, None, None, attn_mask) 288 | else: 289 | x, attn_tmp = r(x, attn_mask=attn_mask) 290 | if idx in out_layers: 291 | out_tokens.append(x) 292 | return x, out_tokens 293 | 294 | 295 | 296 | class VisionTransformer(nn.Module): 297 | output_tokens: torch.jit.Final[bool] 298 | 299 | def __init__( 300 | self, 301 | image_size: int, 302 | patch_size: int, 303 | width: int, 304 | layers: int, 305 | heads: int, 306 | mlp_ratio: float, 307 | ls_init_value: float = None, 308 | global_average_pool: bool = False, 309 | attentional_pool: bool = False, 310 | n_queries: int = 256, 311 | attn_pooler_heads: int = 8, 312 | output_dim: int = 512, 313 | patch_dropout: float = 0., 314 | input_patchnorm: bool = False, 315 | act_layer: Callable = nn.GELU, 316 | norm_layer: Callable = LayerNorm, 317 | output_tokens: bool = False, 318 | ): 319 | super().__init__() 320 | self.output_tokens = output_tokens 321 | image_height, image_width = self.image_size = to_2tuple(image_size) 322 | patch_height, patch_width = self.patch_size = to_2tuple(patch_size) 323 | self.grid_size = (image_height // patch_height, image_width // patch_width) 324 | self.output_dim = output_dim 325 | 326 | # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 327 | self.input_patchnorm = input_patchnorm 328 | 329 | if input_patchnorm: 330 | patch_input_dim = patch_height * patch_width * 3 331 | self.patchnorm_pre_ln = LayerNorm(patch_input_dim) 332 | self.conv1 = nn.Linear(patch_input_dim, width) 333 | else: 334 | self.patchnorm_pre_ln = nn.Identity() 335 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, 336 | bias=False) 337 | 338 | # class embeddings and positional embeddings 339 | scale = width ** -0.5 340 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 341 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) 342 | 343 | # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn 344 | self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() 345 | 346 | self.ln_pre = norm_layer(width) 347 | 348 | self.transformer = Transformer( 349 | width, 350 | layers, 351 | heads, 352 | mlp_ratio, 353 | ls_init_value=ls_init_value, 354 | act_layer=act_layer, 355 | norm_layer=norm_layer, 356 | ) 357 | 358 | self.global_average_pool = global_average_pool 359 | if attentional_pool: 360 | self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) 361 | self.ln_post = norm_layer(output_dim) 362 | self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) 363 | else: 364 | self.attn_pool = None 365 | self.ln_post = norm_layer(width) 366 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 367 | 368 | self.init_parameters() 369 | 370 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 371 | for param in self.parameters(): 372 | param.requires_grad = False 373 | 374 | if unlocked_groups != 0: 375 | groups = [ 376 | [ 377 | self.conv1, 378 | self.class_embedding, 379 | self.positional_embedding, 380 | self.ln_pre, 381 | ], 382 | *self.transformer.resblocks[:-1], 383 | [ 384 | self.transformer.resblocks[-1], 385 | self.ln_post, 386 | ], 387 | self.proj, 388 | ] 389 | 390 | def _unlock(x): 391 | if isinstance(x, Sequence): 392 | for g in x: 393 | _unlock(g) 394 | else: 395 | if isinstance(x, torch.nn.Parameter): 396 | x.requires_grad = True 397 | else: 398 | for p in x.parameters(): 399 | p.requires_grad = True 400 | 401 | _unlock(groups[-unlocked_groups:]) 402 | 403 | def init_parameters(self): 404 | # FIXME OpenAI CLIP did not define an init for the VisualTransformer 405 | # TODO experiment if default PyTorch init, below, or alternate init is best. 406 | 407 | # nn.init.normal_(self.class_embedding, std=self.scale) 408 | # nn.init.normal_(self.positional_embedding, std=self.scale) 409 | # 410 | # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 411 | # attn_std = self.transformer.width ** -0.5 412 | # fc_std = (2 * self.transformer.width) ** -0.5 413 | # for block in self.transformer.resblocks: 414 | # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 415 | # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 416 | # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 417 | # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 418 | # 419 | # if self.text_projection is not None: 420 | # nn.init.normal_(self.text_projection, std=self.scale) 421 | pass 422 | 423 | @torch.jit.ignore 424 | def set_grad_checkpointing(self, enable=True): 425 | self.transformer.grad_checkpointing = enable 426 | 427 | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 428 | if self.global_average_pool: 429 | return x.mean(dim=1), x 430 | else: 431 | return x[:, 0], x[:, 1:] 432 | 433 | def forward(self, x: torch.Tensor, out_layers: list): 434 | 435 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 436 | if self.input_patchnorm: 437 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 438 | x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], 439 | self.patch_size[1]) 440 | x = x.permute(0, 2, 4, 1, 3, 5) 441 | x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) 442 | x = self.patchnorm_pre_ln(x) 443 | x = self.conv1(x) 444 | else: 445 | x = self.conv1(x) # shape = [*, width, grid, grid] 446 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 447 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 448 | 449 | # class embeddings and positional embeddings 450 | x = torch.cat( 451 | [self.class_embedding.to(x.dtype) + 452 | torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 453 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 454 | x = x + self.positional_embedding.to(x.dtype) 455 | 456 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 457 | x = self.patch_dropout(x) 458 | x = self.ln_pre(x) 459 | 460 | x = x.permute(1, 0, 2) # NLD -> LND 461 | x, patch_tokens = self.transformer(x, out_layers) 462 | x = x.permute(1, 0, 2) # LND -> NLD 463 | patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD 464 | # patch_tokens = patch_tokens.permute(1, 0, 2) # LND -> NLD 465 | 466 | if self.attn_pool is not None: 467 | x = self.attn_pool(x) 468 | x = self.ln_post(x) 469 | pooled, tokens = self._global_pool(x) 470 | else: 471 | pooled, tokens = self._global_pool(x) 472 | pooled = self.ln_post(pooled) 473 | # patch_pooled, patch_tokens = self._global_pool(patch_tokens) 474 | # tokens = self.ln_post(tokens) 475 | 476 | if self.proj is not None: 477 | pooled = pooled @ self.proj 478 | # patch_tokens = patch_tokens @ self.proj # 不知道能不能行 479 | # tokens = tokens @ self.proj 480 | 481 | if self.output_tokens: 482 | return pooled, patch_tokens 483 | 484 | return pooled, patch_tokens 485 | 486 | 487 | class TextTransformer(nn.Module): 488 | output_tokens: torch.jit.Final[bool] 489 | 490 | def __init__( 491 | self, 492 | context_length: int = 77, 493 | vocab_size: int = 49408, 494 | width: int = 512, 495 | heads: int = 8, 496 | layers: int = 12, 497 | ls_init_value: float = None, 498 | output_dim: int = 512, 499 | act_layer: Callable = nn.GELU, 500 | norm_layer: Callable = LayerNorm, 501 | embed_cls: bool = False, 502 | pad_id: int = 0, 503 | output_tokens: bool = False, 504 | ): 505 | super().__init__() 506 | self.output_tokens = output_tokens 507 | self.num_pos = self.context_length = context_length 508 | self.vocab_size = vocab_size 509 | self.width = width 510 | self.output_dim = output_dim 511 | self.heads = heads 512 | self.pad_id = pad_id 513 | 514 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 515 | 516 | if embed_cls: 517 | self.cls_emb = nn.Parameter(torch.empty(width)) 518 | self.num_pos += 1 519 | else: 520 | self.cls_emb = None 521 | 522 | self.token_embedding = nn.Embedding(vocab_size, width) 523 | self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) 524 | 525 | self.transformer = Transformer( 526 | width=width, 527 | layers=layers, 528 | heads=heads, 529 | ls_init_value=ls_init_value, 530 | act_layer=act_layer, 531 | norm_layer=norm_layer, 532 | ) 533 | 534 | self.ln_final = norm_layer(width) 535 | 536 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) 537 | 538 | self.init_parameters() 539 | 540 | def init_parameters(self): 541 | nn.init.normal_(self.token_embedding.weight, std=0.02) 542 | nn.init.normal_(self.positional_embedding, std=0.01) 543 | if self.cls_emb is not None: 544 | nn.init.normal_(self.cls_emb, std=0.01) 545 | 546 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 547 | attn_std = self.transformer.width ** -0.5 548 | fc_std = (2 * self.transformer.width) ** -0.5 549 | for block in self.transformer.resblocks: 550 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 551 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 552 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 553 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 554 | 555 | if self.text_projection is not None: 556 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 557 | 558 | @torch.jit.ignore 559 | def set_grad_checkpointing(self, enable=True): 560 | self.transformer.grad_checkpointing = enable 561 | 562 | def build_attention_mask(self): 563 | # lazily create causal attention mask, with full attention between the tokens 564 | # pytorch uses additive attention mask; fill with -inf 565 | mask = torch.empty(self.num_pos, self.num_pos) 566 | mask.fill_(float("-inf")) 567 | mask.triu_(1) # zero out the lower diagonal 568 | return mask 569 | 570 | def build_cls_mask(self, text, cast_dtype: torch.dtype): 571 | cls_mask = (text != self.pad_id).unsqueeze(1) 572 | cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) 573 | additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) 574 | additive_mask.fill_(0) 575 | additive_mask.masked_fill_(~cls_mask, float("-inf")) 576 | additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) 577 | return additive_mask 578 | 579 | def _repeat(self, t, N: int): 580 | return t.reshape(1, 1, -1).repeat(N, 1, 1) 581 | 582 | def forward(self, text): 583 | cast_dtype = self.transformer.get_cast_dtype() 584 | seq_len = text.shape[1] 585 | 586 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 587 | attn_mask = self.attn_mask 588 | if self.cls_emb is not None: 589 | seq_len += 1 590 | x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) 591 | cls_mask = self.build_cls_mask(text, cast_dtype) 592 | attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] 593 | 594 | x = x + self.positional_embedding[:seq_len].to(cast_dtype) 595 | x = x.permute(1, 0, 2) # NLD -> LND 596 | x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask) 597 | x = x.permute(1, 0, 2) # LND -> NLD 598 | 599 | # x.shape = [batch_size, n_ctx, transformer.width] 600 | # take features from the eot embedding (eot_token is the highest number in each sequence) 601 | if self.cls_emb is not None: 602 | pooled, tokens = x[:, -1], x[:, :-1] 603 | pooled = self.ln_final(pooled) 604 | else: 605 | x = self.ln_final(x) 606 | pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x 607 | 608 | if self.text_projection is not None: 609 | pooled = pooled @ self.text_projection 610 | 611 | if self.output_tokens: 612 | return pooled, tokens 613 | 614 | return pooled 615 | 616 | -------------------------------------------------------------------------------- /method/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchvision==0.11.2 3 | torchaudio==0.10.1 4 | tqdm 5 | tensorboard 6 | setuptools==58.0.4 7 | opencv-python 8 | scikit-image 9 | scikit-learn 10 | matplotlib 11 | seaborn 12 | ftfy 13 | regex 14 | numpy==1.26.4 15 | gradio 16 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore", category=RuntimeWarning) 3 | import os 4 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | import argparse 8 | import json 9 | import os 10 | import torch 11 | from scipy.ndimage import gaussian_filter 12 | import cv2 13 | 14 | # Importing from local modules 15 | from tools import write2csv, setup_seed, Logger 16 | from dataset import get_data, dataset_dict 17 | from method import AdaCLIP_Trainer 18 | from PIL import Image 19 | import numpy as np 20 | 21 | setup_seed(111) 22 | 23 | def train(args): 24 | assert os.path.isfile(args.ckt_path), f"Please check the path of pre-trained model, {args.ckt_path} is not valid." 25 | 26 | # Configurations 27 | batch_size = args.batch_size 28 | image_size = args.image_size 29 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | 31 | save_fig = args.save_fig 32 | 33 | # Logger 34 | logger = Logger('log.txt') 35 | 36 | # Print basic information 37 | for key, value in sorted(vars(args).items()): 38 | logger.info(f'{key} = {value}') 39 | 40 | 41 | config_path = os.path.join('./model_configs', f'{args.model}.json') 42 | 43 | # Prepare model 44 | with open(config_path, 'r') as f: 45 | model_configs = json.load(f) 46 | 47 | # Set up the feature hierarchy 48 | n_layers = model_configs['vision_cfg']['layers'] 49 | substage = n_layers // 4 50 | features_list = [substage, substage * 2, substage * 3, substage * 4] 51 | 52 | model = AdaCLIP_Trainer( 53 | backbone=args.model, 54 | feat_list=features_list, 55 | input_dim=model_configs['vision_cfg']['width'], 56 | output_dim=model_configs['embed_dim'], 57 | learning_rate=0., 58 | device=device, 59 | image_size=image_size, 60 | prompting_depth=args.prompting_depth, 61 | prompting_length=args.prompting_length, 62 | prompting_branch=args.prompting_branch, 63 | prompting_type=args.prompting_type, 64 | use_hsf=args.use_hsf, 65 | k_clusters=args.k_clusters 66 | ).to(device) 67 | 68 | model.load(args.ckt_path) 69 | 70 | if args.testing_model == 'dataset': 71 | assert args.testing_data in dataset_dict.keys(), f"You entered {args.testing_data}, but we only support " \ 72 | f"{dataset_dict.keys()}" 73 | 74 | save_root = args.save_path 75 | csv_root = os.path.join(save_root, 'csvs') 76 | image_root = os.path.join(save_root, 'images') 77 | csv_path = os.path.join(csv_root, f'{args.testing_data}.csv') 78 | image_dir = os.path.join(image_root, f'{args.testing_data}') 79 | os.makedirs(image_dir, exist_ok=True) 80 | 81 | test_data_cls_names, test_data, test_data_root = get_data( 82 | dataset_type_list=args.testing_data, 83 | transform=model.preprocess, 84 | target_transform=model.transform, 85 | training=False) 86 | 87 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False) 88 | save_fig_flag = save_fig 89 | 90 | metric_dict = model.evaluation( 91 | test_dataloader, 92 | test_data_cls_names, 93 | save_fig_flag, 94 | image_dir, 95 | ) 96 | 97 | for tag, data in metric_dict.items(): 98 | logger.info( 99 | '{:>15} \t\tI-Auroc:{:.2f} \tI-F1:{:.2f} \tI-AP:{:.2f} \tP-Auroc:{:.2f} \tP-F1:{:.2f} \tP-AP:{:.2f}'. 100 | format(tag, 101 | data['auroc_im'], 102 | data['f1_im'], 103 | data['ap_im'], 104 | data['auroc_px'], 105 | data['f1_px'], 106 | data['ap_px']) 107 | ) 108 | 109 | 110 | for k in metric_dict.keys(): 111 | write2csv(metric_dict[k], test_data_cls_names, k, csv_path) 112 | 113 | elif args.testing_model == 'image': 114 | assert os.path.isfile(args.image_path), f"Please verify the input image path: {args.image_path}" 115 | ori_image = cv2.resize(cv2.imread(args.image_path), (args.image_size, args.image_size)) 116 | pil_img = Image.open(args.image_path).convert('RGB') 117 | 118 | img_input = model.preprocess(pil_img).unsqueeze(0) 119 | img_input = img_input.to(model.device) 120 | 121 | with torch.no_grad(): 122 | anomaly_map, anomaly_score = model.clip_model(img_input, [args.class_name], aggregation=True) 123 | 124 | anomaly_map = anomaly_map[0, :, :] 125 | anomaly_score = anomaly_score[0] 126 | anomaly_map = anomaly_map.cpu().numpy() 127 | anomaly_score = anomaly_score.cpu().numpy() 128 | 129 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 130 | anomaly_map = anomaly_map * 255 131 | anomaly_map = anomaly_map.astype(np.uint8) 132 | 133 | heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) 134 | vis_map = cv2.addWeighted(heat_map, 0.5, ori_image, 0.5, 0) 135 | 136 | vis_map = cv2.hconcat([ori_image, vis_map]) 137 | save_path = os.path.join(args.save_path, args.save_name) 138 | print(f"Anomaly detection results are saved in {save_path}, with an anomaly of {anomaly_score:.3f} ") 139 | cv2.imwrite(save_path, vis_map) 140 | 141 | def str2bool(v): 142 | return v.lower() in ("yes", "true", "t", "1") 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser("AdaCLIP", add_help=True) 146 | 147 | # Paths and configurations 148 | parser.add_argument("--ckt_path", type=str, default='weights/pretrained_mvtec_colondb.pth', 149 | help="Path to the pre-trained model (default: weights/pretrained_mvtec_colondb.pth)") 150 | 151 | parser.add_argument("--testing_model", type=str, default="dataset", choices=["dataset", "image"], 152 | help="Model for testing (default: 'dataset')") 153 | 154 | # for the dataset model 155 | parser.add_argument("--testing_data", type=str, default="visa", help="Dataset for testing (default: 'visa')") 156 | 157 | # for the image model 158 | parser.add_argument("--image_path", type=str, default="asset/img.png", 159 | help="Model for testing (default: 'asset/img.png')") 160 | parser.add_argument("--class_name", type=str, default="candle", 161 | help="The class name of the testing image (default: 'candle')") 162 | parser.add_argument("--save_name", type=str, default="test.png", 163 | help="Model for testing (default: 'dataset')") 164 | 165 | 166 | parser.add_argument("--save_path", type=str, default='./workspaces', 167 | help="Directory to save results (default: './workspaces')") 168 | 169 | parser.add_argument("--model", type=str, default="ViT-L-14-336", 170 | choices=["ViT-B-16", "ViT-B-32", "ViT-L-14", "ViT-L-14-336"], 171 | help="The CLIP model to be used (default: 'ViT-L-14-336')") 172 | 173 | parser.add_argument("--save_fig", type=str2bool, default=False, 174 | help="Save figures for visualizations (default: False)") 175 | 176 | # Hyper-parameters 177 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size (default: 1)") 178 | parser.add_argument("--image_size", type=int, default=518, help="Size of the input images (default: 518)") 179 | 180 | # Prompting parameters 181 | parser.add_argument("--prompting_depth", type=int, default=4, help="Depth of prompting (default: 4)") 182 | parser.add_argument("--prompting_length", type=int, default=5, help="Length of prompting (default: 5)") 183 | parser.add_argument("--prompting_type", type=str, default='SD', choices=['', 'S', 'D', 'SD'], 184 | help="Type of prompting. 'S' for Static, 'D' for Dynamic, 'SD' for both (default: 'SD')") 185 | parser.add_argument("--prompting_branch", type=str, default='VL', choices=['', 'V', 'L', 'VL'], 186 | help="Branch of prompting. 'V' for Visual, 'L' for Language, 'VL' for both (default: 'VL')") 187 | 188 | parser.add_argument("--use_hsf", type=str2bool, default=True, 189 | help="Use HSF for aggregation. If False, original class embedding is used (default: True)") 190 | parser.add_argument("--k_clusters", type=int, default=20, help="Number of clusters (default: 20)") 191 | 192 | args = parser.parse_args() 193 | 194 | if args.batch_size != 1: 195 | raise NotImplementedError( 196 | "Currently, only batch size of 1 is supported due to unresolved bugs. Please set --batch_size to 1.") 197 | 198 | train(args) 199 | 200 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # pre-trained from MVTec and ColonDB 2 | ckt_path="weights/pretrained_mvtec_colondb.pth" 3 | gpu_id=0 4 | 5 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data br35h 6 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data brain_mri 7 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data btad 8 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data clinicdb 9 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data dagm 10 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data dtd 11 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data headct 12 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data isic 13 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data mpdd 14 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data sdd 15 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data tn3k 16 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data visa 17 | 18 | # pre-trained from Visa and Clinicdb 19 | ckt_path="weights/pretrained_visa_clinicdb.pth" 20 | gpu_id=0 21 | 22 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data colondb 23 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data mvtec 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /test_single_image.sh: -------------------------------------------------------------------------------- 1 | ckt_path="weights/pretrained_all.pth" 2 | gpu_id=0 3 | 4 | # demo: do zero-shot anomaly detection for a single image 5 | CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model image --ckt_path $ckt_path --save_fig True \ 6 | --image_path asset/img.png --class_name candle --save_name test.png -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .csv_tools import write2csv 2 | from .logger import Logger, log_metrics 3 | from .metrics import calculate_metric, calculate_average_metric 4 | from .training_tools import setup_seed, setup_paths 5 | from .visualization import plot_sample_cv2 -------------------------------------------------------------------------------- /tools/csv_tools.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | 4 | def write2csv(results:dict, total_classes, cur_class, csv_path): 5 | keys = list(results.keys()) 6 | 7 | if not os.path.exists(csv_path): 8 | df_all = None 9 | for class_name in total_classes: 10 | r = dict() 11 | for k in keys: 12 | r[k] = 0.00 13 | df_temp = pd.DataFrame(r, index=[f'{class_name}']) 14 | 15 | if df_all is None: 16 | df_all = df_temp 17 | else: 18 | df_all = pd.concat([df_all, df_temp], axis=0) 19 | 20 | df_all.to_csv(csv_path, header=True, float_format='%.2f') 21 | 22 | df = pd.read_csv(csv_path, index_col=0) 23 | 24 | for k in keys: 25 | df.loc[f'{cur_class}', k] = results[k] 26 | 27 | df.to_csv(csv_path, header=True, float_format='%.2f') 28 | 29 | -------------------------------------------------------------------------------- /tools/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | class Logger(object): 4 | def __init__(self, txt_path): 5 | root_logger = logging.getLogger() 6 | for handler in root_logger.handlers[:]: 7 | root_logger.removeHandler(handler) 8 | root_logger.setLevel(logging.WARNING) 9 | self.txt_path = txt_path 10 | self.logger = logging.getLogger('train') 11 | self.formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 12 | self.logger.setLevel(logging.INFO) 13 | 14 | def __console(self, level, message): 15 | root_logger = logging.getLogger() 16 | for handler in root_logger.handlers[:]: 17 | root_logger.removeHandler(handler) 18 | 19 | file_handler = logging.FileHandler(self.txt_path, mode='a') 20 | console_handler = logging.StreamHandler() 21 | 22 | file_handler.setFormatter(self.formatter) 23 | console_handler.setFormatter(self.formatter) 24 | 25 | self.logger.addHandler(file_handler) 26 | self.logger.addHandler(console_handler) 27 | 28 | if level == 'info': 29 | self.logger.info(message) 30 | elif level == 'debug': 31 | self.logger.debug(message) 32 | elif level == 'warning': 33 | self.logger.warning(message) 34 | elif level == 'error': 35 | self.logger.error(message) 36 | 37 | self.logger.removeHandler(file_handler) 38 | self.logger.removeHandler(console_handler) 39 | 40 | file_handler.close() 41 | 42 | def debug(self, message): 43 | self.__console('debug', message) 44 | 45 | def info(self, message): 46 | self.__console('info', message) 47 | 48 | def warning(self, message): 49 | self.__console('warning', message) 50 | 51 | def error(self, message): 52 | self.__console('error', message) 53 | 54 | def log_metrics(metrics, logger, tensorboard_logger, epoch): 55 | def log_single_class(data, tag): 56 | logger.info( 57 | '{:>15} \t\tI-Auroc:{:.2f} \tI-F1:{:.2f} \tI-AP:{:.2f} \tP-Auroc:{:.2f} \tP-F1:{:.2f} \tP-AP:{:.2f}'. 58 | format(tag, 59 | data['auroc_im'], 60 | data['f1_im'], 61 | data['ap_im'], 62 | data['auroc_px'], 63 | data['f1_px'], 64 | data['ap_px']) 65 | ) 66 | # Adding scalar metrics to TensorBoard 67 | for metric_name in ['auroc_im', 'f1_im', 'ap_im', 'auroc_px', 'f1_px', 'ap_px']: 68 | tensorboard_logger.add_scalar(f'{tag}-{metric_name}', data[metric_name], epoch) 69 | 70 | for tag, data in metrics.items(): 71 | log_single_class(data, tag) 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /tools/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import auc, roc_auc_score, precision_recall_curve, average_precision_score 3 | 4 | 5 | def rescale(x): 6 | return (x - x.min()) / (x.max() - x.min()) 7 | 8 | 9 | def is_one_class(gt: np.ndarray): 10 | gt_ravel = gt.ravel() 11 | return gt_ravel.sum() == 0 or gt_ravel.sum() == gt_ravel.shape[0] 12 | 13 | 14 | def calculate_px_metrics(gt_px, pr_px): 15 | if is_one_class(gt_px): # In case there are only normal pixels or no pixel-level labels 16 | return 0, 0, 0 17 | 18 | auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel()) 19 | precisions, recalls, _ = precision_recall_curve(gt_px.ravel(), pr_px.ravel()) 20 | f1_scores = (2 * precisions * recalls) / (precisions + recalls) 21 | f1_px = np.max(f1_scores[np.isfinite(f1_scores)]) 22 | ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel()) 23 | 24 | return auroc_px * 100, f1_px * 100, ap_px * 100 25 | 26 | 27 | def calculate_im_metrics(gt_im, pr_im): 28 | if is_one_class(gt_im): # In case there are only normal samples or no image-level labels 29 | return 0, 0, 0 30 | 31 | auroc_im = roc_auc_score(gt_im.ravel(), pr_im.ravel()) 32 | precisions, recalls, _ = precision_recall_curve(gt_im.ravel(), pr_im.ravel()) 33 | f1_scores = (2 * precisions * recalls) / (precisions + recalls) 34 | f1_im = np.max(f1_scores[np.isfinite(f1_scores)]) 35 | ap_im = average_precision_score(gt_im, pr_im) 36 | 37 | return ap_im * 100, auroc_im * 100, f1_im * 100 38 | 39 | 40 | def calculate_average_metric(metrics: dict): 41 | average = {} 42 | for obj, metric in metrics.items(): 43 | for k, v in metric.items(): 44 | if k not in average: 45 | average[k] = [] 46 | average[k].append(v) 47 | 48 | for k, v in average.items(): 49 | average[k] = np.mean(v) 50 | 51 | return average 52 | 53 | 54 | def calculate_metric(results, obj): 55 | gt_px = [] 56 | pr_px = [] 57 | 58 | gt_im = [] 59 | pr_im = [] 60 | 61 | for idx in range(len(results['cls_names'])): 62 | if results['cls_names'][idx] == obj: 63 | gt_px.append(results['imgs_masks'][idx]) 64 | pr_px.append(results['anomaly_maps'][idx]) 65 | 66 | gt_im.append(results['imgs_gts'][idx]) 67 | pr_im.append(results['anomaly_scores'][idx]) 68 | 69 | gt_px = np.array(gt_px) 70 | pr_px = np.array(pr_px) 71 | 72 | gt_im = np.array(gt_im) 73 | pr_im = np.array(pr_im) 74 | 75 | auroc_px, f1_px, ap_px = calculate_px_metrics(gt_px, pr_px) 76 | ap_im, auroc_im, f1_im = calculate_im_metrics(gt_im, pr_im) 77 | 78 | metric = { 79 | 'auroc_px': auroc_px, 80 | 'auroc_im': auroc_im, 81 | 'f1_px': f1_px, 82 | 'f1_im': f1_im, 83 | 'ap_px': ap_px, 84 | 'ap_im': ap_im, 85 | } 86 | 87 | return metric 88 | -------------------------------------------------------------------------------- /tools/training_tools.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | from torch.utils.tensorboard import SummaryWriter 3 | import os 4 | import random 5 | import torch 6 | import numpy as np 7 | 8 | 9 | def setup_seed(seed): 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | 19 | def setup_paths(args): 20 | save_root = args.save_path 21 | model_root = os.path.join(save_root, 'models') 22 | log_root = os.path.join(save_root, 'logs') 23 | csv_root = os.path.join(save_root, 'csvs') 24 | image_root = os.path.join(save_root, 'images') 25 | tensorboard_root = os.path.join(save_root, 'tensorboard') 26 | 27 | os.makedirs(model_root, exist_ok=True) 28 | os.makedirs(log_root, exist_ok=True) 29 | os.makedirs(csv_root, exist_ok=True) 30 | os.makedirs(image_root, exist_ok=True) 31 | os.makedirs(tensorboard_root, exist_ok=True) 32 | 33 | if args.use_hsf: 34 | # prepare model name 35 | model_name = f'{args.exp_indx}s-pretrained-{args.training_data}-{args.model}-' \ 36 | f'{args.prompting_type}-{args.prompting_branch}-' \ 37 | f'D{args.prompting_depth}-L{args.prompting_length}-HSF-K{args.k_clusters}' 38 | else: 39 | # prepare model name 40 | model_name = f'{args.exp_indx}s-pretrained-{args.training_data}-{args.model}-' \ 41 | f'{args.prompting_type}-{args.prompting_branch}-' \ 42 | f'D{args.prompting_depth}-L{args.prompting_length}-WO-HSF' 43 | 44 | 45 | # prepare model path 46 | ckp_path = os.path.join(model_root, model_name) 47 | 48 | # prepare tensorboard dir 49 | tensorboard_dir = os.path.join(tensorboard_root, f'{model_name}-{args.testing_data}') 50 | if os.path.exists(tensorboard_dir): 51 | import shutil 52 | shutil.rmtree(tensorboard_dir) 53 | tensorboard_logger = SummaryWriter(log_dir=tensorboard_dir) 54 | 55 | # prepare csv path 56 | csv_path = os.path.join(csv_root, f'{model_name}-{args.testing_data}.csv') 57 | 58 | # prepare image path 59 | image_dir = os.path.join(image_root, f'{model_name}-{args.testing_data}') 60 | os.makedirs(image_dir, exist_ok=True) 61 | 62 | # prepare log path 63 | log_path = os.path.join(log_root, f'{model_name}-{args.testing_data}.txt') 64 | 65 | return model_name, image_dir, csv_path, log_path, ckp_path, tensorboard_logger 66 | 67 | 68 | -------------------------------------------------------------------------------- /tools/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib 3 | 4 | matplotlib.use("Agg") 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | import seaborn as sns 9 | 10 | ## 11 | from sklearn.manifold import TSNE 12 | from sklearn.decomposition import PCA 13 | 14 | ## 15 | import matplotlib.ticker as mtick 16 | 17 | 18 | def plot_sample_cv2(names, imgs, scores_: dict, gts, save_folder=None): 19 | os.makedirs(save_folder, exist_ok=True) 20 | 21 | # get subplot number 22 | total_number = len(imgs) 23 | 24 | scores = scores_.copy() 25 | # normarlisze anomalies 26 | for k, v in scores.items(): 27 | max_value = np.max(v) 28 | min_value = np.min(v) 29 | 30 | scores[k] = (scores[k] - min_value) / max_value * 255 31 | scores[k] = scores[k].astype(np.uint8) 32 | # draw gts 33 | mask_imgs = [] 34 | for idx in range(total_number): 35 | gts_ = gts[idx] 36 | mask_imgs_ = imgs[idx].copy() 37 | mask_imgs_[gts_ > 0.5] = (0, 0, 255) 38 | mask_imgs.append(mask_imgs_) 39 | 40 | # save imgs 41 | for idx in range(total_number): 42 | 43 | cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_ori.jpg'), imgs[idx]) 44 | cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_gt.jpg'), mask_imgs[idx]) 45 | 46 | for key in scores: 47 | heat_map = cv2.applyColorMap(scores[key][idx], cv2.COLORMAP_JET) 48 | visz_map = cv2.addWeighted(heat_map, 0.5, imgs[idx], 0.5, 0) 49 | cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_{key}.jpg'), 50 | visz_map) 51 | 52 | 53 | 54 | 55 | def plot_feat_cv2(names, feat, save_folder=None): 56 | # get subplot number 57 | total_number = len(feat) 58 | 59 | # save imgs 60 | for idx in range(total_number): 61 | feat[idx] = cv2.resize(feat[idx], (256, 256), interpolation=cv2.INTER_NEAREST) 62 | cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_feat.jpg'), feat[idx]) 63 | 64 | 65 | 66 | valid_feature_visualization_methods = ['TSNE', 'PCA'] 67 | 68 | def visualize_feature(features, labels, legends, n_components=3, method='TSNE'): 69 | assert method in valid_feature_visualization_methods 70 | assert n_components in [2, 3] 71 | 72 | if method == 'TSNE': 73 | model = TSNE(n_components=n_components) 74 | elif method == 'PCA': 75 | model = PCA(n_components=n_components) 76 | 77 | else: 78 | raise NotImplementedError 79 | 80 | feat_proj = model.fit_transform(features) 81 | 82 | if n_components == 2: 83 | ax = scatter_2d(feat_proj, labels) 84 | elif n_components == 3: 85 | ax = scatter_3d(feat_proj, labels) 86 | else: 87 | raise NotImplementedError 88 | 89 | plt.legend(legends) 90 | plt.axis('off') 91 | 92 | 93 | def scatter_3d(feat_proj, label): 94 | plt.clf() 95 | ax1 = plt.axes(projection='3d') 96 | 97 | label_unique = np.unique(label) 98 | 99 | for l in label_unique: 100 | ax1.scatter3D(feat_proj[label == l, 0], 101 | feat_proj[label == l, 1], 102 | feat_proj[label == l, 2], s=5) 103 | 104 | return ax1 105 | 106 | 107 | def scatter_2d(feat_proj, label): 108 | plt.clf() 109 | ax1 = plt.axes() 110 | 111 | label_unique = np.unique(label) 112 | 113 | for l in label_unique: 114 | ax1.scatter(feat_proj[label == l, 0], 115 | feat_proj[label == l, 1], s=5) 116 | 117 | return ax1 118 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore", category=RuntimeWarning) 3 | import os 4 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | import argparse 8 | import json 9 | import os 10 | import torch 11 | 12 | # Importing from local modules 13 | from tools import write2csv, setup_paths, setup_seed, log_metrics, Logger 14 | from dataset import get_data 15 | from method import AdaCLIP_Trainer 16 | 17 | setup_seed(111) 18 | 19 | def train(args): 20 | # Configurations 21 | epochs = args.epoch 22 | learning_rate = args.learning_rate 23 | batch_size = args.batch_size 24 | image_size = args.image_size 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | save_fig = args.save_fig 28 | 29 | # Set up paths 30 | model_name, image_dir, csv_path, log_path, ckp_path, tensorboard_logger = setup_paths(args) 31 | # Logger 32 | logger = Logger(log_path) 33 | 34 | # Print basic information 35 | for key, value in sorted(vars(args).items()): 36 | logger.info(f'{key} = {value}') 37 | 38 | logger.info('Model name: {:}'.format(model_name)) 39 | 40 | config_path = os.path.join('./model_configs', f'{args.model}.json') 41 | 42 | # Prepare model 43 | with open(config_path, 'r') as f: 44 | model_configs = json.load(f) 45 | 46 | # Set up the feature hierarchy 47 | n_layers = model_configs['vision_cfg']['layers'] 48 | substage = n_layers // 4 49 | features_list = [substage, substage * 2, substage * 3, substage * 4] 50 | 51 | model = AdaCLIP_Trainer( 52 | backbone=args.model, 53 | feat_list=features_list, 54 | input_dim=model_configs['vision_cfg']['width'], 55 | output_dim=model_configs['embed_dim'], 56 | learning_rate=learning_rate, 57 | device=device, 58 | image_size=image_size, 59 | prompting_depth=args.prompting_depth, 60 | prompting_length=args.prompting_length, 61 | prompting_branch=args.prompting_branch, 62 | prompting_type=args.prompting_type, 63 | use_hsf=args.use_hsf, 64 | k_clusters=args.k_clusters 65 | ).to(device) 66 | 67 | train_data_cls_names, train_data, train_data_root = get_data( 68 | dataset_type_list=args.training_data, 69 | transform=model.preprocess, 70 | target_transform=model.transform, 71 | training=True) 72 | 73 | test_data_cls_names, test_data, test_data_root = get_data( 74 | dataset_type_list=args.testing_data, 75 | transform=model.preprocess, 76 | target_transform=model.transform, 77 | training=False) 78 | 79 | logger.info('Data Root: training, {:}; testing, {:}'.format(train_data_root, test_data_root)) 80 | 81 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) 82 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False) 83 | 84 | # Typically, we use MVTec or VisA as the validation set. The best model from this validation 85 | # process is then used for zero-shot anomaly detection on novel categories. 86 | best_f1 = -1e1 87 | 88 | for epoch in tqdm(range(epochs)): 89 | loss = model.train_epoch(train_dataloader) 90 | 91 | # Logs 92 | if (epoch + 1) % args.print_freq == 0: 93 | logger.info('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, loss)) 94 | tensorboard_logger.add_scalar('loss', loss, epoch) 95 | 96 | # Validation 97 | if (epoch + 1) % args.valid_freq == 0 or (epoch == epochs - 1): 98 | if epoch == epochs - 1: 99 | save_fig_flag = save_fig 100 | else: 101 | save_fig_flag = False 102 | 103 | logger.info('=============================Testing ====================================') 104 | metric_dict = model.evaluation( 105 | test_dataloader, 106 | test_data_cls_names, 107 | save_fig_flag, 108 | image_dir, 109 | ) 110 | 111 | log_metrics( 112 | metric_dict, 113 | logger, 114 | tensorboard_logger, 115 | epoch 116 | ) 117 | 118 | f1_px = metric_dict['Average']['f1_px'] 119 | 120 | # Save best 121 | if f1_px > best_f1: 122 | for k in metric_dict.keys(): 123 | write2csv(metric_dict[k], test_data_cls_names, k, csv_path) 124 | 125 | ckp_path_best = ckp_path + '_best.pth' 126 | model.save(ckp_path_best) 127 | best_f1 = f1_px 128 | 129 | 130 | 131 | def str2bool(v): 132 | return v.lower() in ("yes", "true", "t", "1") 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser("AdaCLIP", add_help=True) 136 | 137 | # Paths and configurations 138 | parser.add_argument("--training_data", type=str, default=["mvtec", "colondb"], nargs='+', 139 | help="Datasets for training (default: ['mvtec', 'colondb'])") 140 | parser.add_argument("--testing_data", type=str, default="visa", help="Dataset for testing (default: 'visa')") 141 | 142 | parser.add_argument("--save_path", type=str, default='./workspaces', 143 | help="Directory to save results (default: './workspaces')") 144 | 145 | parser.add_argument("--model", type=str, default="ViT-L-14-336", 146 | choices=["ViT-B-16", "ViT-B-32", "ViT-L-14", "ViT-L-14-336"], 147 | help="The CLIP model to be used (default: 'ViT-L-14-336')") 148 | 149 | parser.add_argument("--save_fig", type=str2bool, default=False, 150 | help="Save figures for visualizations (default: False)") 151 | parser.add_argument("--ckt_path", type=str, default='', help="Path to the pre-trained model (default: '')") 152 | 153 | # Hyper-parameters 154 | parser.add_argument("--exp_indx", type=int, default=0, help="Index of the experiment (default: 0)") 155 | parser.add_argument("--epoch", type=int, default=5, help="Number of epochs (default: 5)") 156 | parser.add_argument("--learning_rate", type=float, default=0.01, help="Learning rate (default: 0.01)") 157 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size (default: 1)") 158 | 159 | parser.add_argument("--image_size", type=int, default=518, help="Size of the input images (default: 518)") 160 | parser.add_argument("--print_freq", type=int, default=1, help="Frequency of print statements (default: 1)") 161 | parser.add_argument("--valid_freq", type=int, default=1, help="Frequency of validation (default: 1)") 162 | 163 | # Prompting parameters 164 | parser.add_argument("--prompting_depth", type=int, default=4, help="Depth of prompting (default: 4)") 165 | parser.add_argument("--prompting_length", type=int, default=5, help="Length of prompting (default: 5)") 166 | parser.add_argument("--prompting_type", type=str, default='SD', choices=['', 'S', 'D', 'SD'], 167 | help="Type of prompting. 'S' for Static, 'D' for Dynamic, 'SD' for both (default: 'SD')") 168 | parser.add_argument("--prompting_branch", type=str, default='VL', choices=['', 'V', 'L', 'VL'], 169 | help="Branch of prompting. 'V' for Visual, 'L' for Language, 'VL' for both (default: 'VL')") 170 | 171 | parser.add_argument("--use_hsf", type=str2bool, default=True, 172 | help="Use HSF for aggregation. If False, original class embedding is used (default: True)") 173 | parser.add_argument("--k_clusters", type=int, default=20, help="Number of clusters (default: 20)") 174 | 175 | args = parser.parse_args() 176 | 177 | if args.batch_size != 1: 178 | raise NotImplementedError( 179 | "Currently, only batch size of 1 is supported due to unresolved bugs. Please set --batch_size to 1.") 180 | 181 | train(args) 182 | 183 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 2 | 3 | # Note: Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable. 4 | # It is recommended to run the training process multiple times and choose the best model based on performance 5 | # on the validation set as the final model. 6 | 7 | # pre-trained on MVtec and colondb 8 | CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True --training_data mvtec colondb --testing_data visa 9 | 10 | # pre-trained on Visa and Clinicdb 11 | CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec 12 | 13 | # This model is pre-trained on all available data to create a powerful Zero-Shot Anomaly Detection (ZSAD) model for demonstration purposes. 14 | CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True \ 15 | --training_data \ 16 | br35h brain_mri btad clinicdb colondb \ 17 | dagm dtd headct isic mpdd mvtec sdd tn3k visa \ 18 | --testing_data mvtec 19 | 20 | --------------------------------------------------------------------------------