├── .gitignore ├── 01_train_mvtec.sh ├── 02_eval_mvtec.sh ├── 03_train_eyecandies.sh ├── 04_eval_eyecandies.sh ├── LICENSE.txt ├── README.md ├── cfm_inference.py ├── cfm_training.py ├── images └── architecture.jpg ├── models ├── dataset.py ├── feature_transfer_nets.py ├── features.py └── full_models.py ├── processing ├── aggregate_results.py ├── preprocess_eyecandies.py └── preprocess_mvtec.py ├── requirements.txt └── utils ├── general_utils.py ├── metrics_utils.py ├── mvtec3d_utils.py └── pointnet2_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | models/__pycache__ 3 | utils/__pycache__ 4 | processing/__pycache__ 5 | datasets 6 | checkpoints 7 | wandb 8 | results -------------------------------------------------------------------------------- /01_train_mvtec.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | epochs=50 4 | batch_size=4 5 | 6 | class_names=("bagel" "cable_gland" "carrot" "cookie" "dowel" "foam" "peach" "potato" "rope" "tire") 7 | 8 | for class_name in "${class_names[@]}" 9 | do 10 | python cfm_training.py --class_name $class_name --epochs_no $epochs --batch_size $batch_size 11 | done -------------------------------------------------------------------------------- /02_eval_mvtec.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | epochs=250 4 | batch_size=4 5 | 6 | class_names=("bagel" "cable_gland" "carrot" "cookie" "dowel" "foam" "peach" "potato" "rope" "tire") 7 | 8 | for class_name in "${class_names[@]}" 9 | do 10 | python cfm_inference.py --class_name $class_name --epochs_no $epochs --batch_size $batch_size 11 | done -------------------------------------------------------------------------------- /03_train_eyecandies.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | dataset_path=datasets/eyecandies 4 | checkpoint_savepath=models/checkpoints_CFM_eyecandies 5 | epochs=50 6 | batch_size=2 7 | 8 | class_names=("CandyCane" "ChocolateCookie" "ChocolatePraline" "Confetto" "GummyBear" "HazelnutTruffle" "LicoriceSandwich" "Lollipop" "Marshmallow" "PeppermintCandy") 9 | 10 | for class_name in "${class_names[@]}" 11 | do 12 | python cfm_training.py --class_name $class_name --model_type $model --epochs_no $epochs --batch_size $batch_size --dataset_path $dataset_path --checkpoint_savepath $checkpoint_savepath 13 | done -------------------------------------------------------------------------------- /04_eval_eyecandies.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | dataset_path=datasets/eyecandies 4 | checkpoint_savepath=models/checkpoints_CFM_eyecandies 5 | epochs=50 6 | batch_size=2 7 | 8 | quantitative_folder=results/quantitatives_eyecandies 9 | 10 | class_names=("CandyCane" "ChocolateCookie" "ChocolatePraline" "Confetto" "GummyBear" "HazelnutTruffle" "LicoriceSandwich" "Lollipop" "Marshmallow" "PeppermintCandy") 11 | 12 | for class_name in "${class_names[@]}" 13 | do 14 | python cfm_inference.py --class_name $class_name --model_type $model --epochs_no $epochs --batch_size $batch_size --dataset_path $dataset_path --checkpoint_folder $checkpoint_savepath --quantitative_folder $quantitative_folder --produce_qualitatives 15 | done -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | This Software is licensed under the terms of the following license 2 | which allows for non-commercial use only. For any other use of the software not 3 | covered by the terms of this license, please contact Pierluigi Zama Ramirez (pierluigi.zama@unibo.it) 4 | 5 | Article 1 – Definitions. 6 | 1. The following terms, as used herein, shall have the following meanings: 7 | - "Licensee" shall mean the person or organization agreeing to use the Software 8 | - "Licensor" shall mean Alma Mater Studiorum – Università di Bologna, organized and existing under the laws of Italy, whose principal place of business is via Zamboni 33 – 40126 Bologna (Italy). 9 | - "Software" shall mean software Crossomodal Feature Mappings (hereinafter referred to as “CFM”) for neural networks trained to detect and segment anomalies on multimodal data. This SW is the implementation of the approach described in “A. Costanzino, P. Zama Ramirez, G. Lisanti, L. Di Stefano. Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping. The IEEE/CVF Conference on Computer Vision and Pattern Recognition 2024. as uploaded by Licensor to the github repository at https://github.com/CVLAB-Unibo/crossmodal-feature-mapping on May 22 2024, in source code or object code form and any accompanying documentation as well as any modifications or additions uploaded to the same github repository by Licensor. 10 | 11 | Article 2 - Grant and scope. 12 | 1. Licensor grants Licensee a non-exclusive license authorizing Licensee to reproduce and use the Software for research and teaching activities only, with exclusion of any directly or indirectly commercial activities or purposes, without territorial limitation 13 | (hereinafter referred to as "License"). 14 | 2. The License extends to the source code of software CFM. 15 | 3. The License is non-assignable even if Licensee were to cease all activities. 16 | 4. Except as specifically set forth in this Agreement, Licensee acknowledges that this Agreement does not grant Licensee any use or rights to the Software, including, but not limited to, any author’s economic rights on the Software. 17 | 6. Licensee shall be solely responsible towards any user, specifically with regard to the installation, maintenance and documentation relative to the Software as well as the training of its users. 18 | 7. Licensor will keep the faculty to develop new versions of Software with no obligations to provide these new versions to Licensee, as well to use Software for research activities. 19 | 8. Licensor is under no obligation to create any upgrades or enhancements to the Software. 20 | 21 | Article 3 – No Fee. 22 | 1. The License is free of charge and no payment will be due by Licensee to Licensor for research purposes. 23 | 24 | Article 4 – Markings. 25 | 1. Licensee shall place the following text in all documentation and information it publishes or distributes in connection with the Software: 26 | "Software developed by Alma Mater Studiorum - University of Bologna (Italy) and SACMI IMOLA S.C. (Italy). Alma Mater Studiorum - University of Bologna (Italy) and SACMI IMOLA S.C. (Italy), makes no warranties of any kind on the software and shall in no event be liable for damages of any kind in connection with the use and exploitation of the software.” 27 | 28 | Article 5 - Limited liability. 29 | 1. To the extent permitted by applicable law, Licensor shall not be held liable for any damages resulting from the use of the Software by Licensee. 30 | 4. Licensee shall be solely responsible for any claims of third parties in connection with use of Software. 31 | 5. Licensee shall indemnify, defend and hold Licensor harmless against all claims by third parties arising out of any damage caused by such use. 32 | 6. The License extends to the Software "as is" and as transferred by Licensor no later than on the date of signature of this Agreement and without warranties as to the proper functioning of the Software or its functions, its fitness for a particular purpose or its merchantability. 33 | 7. To the extent permitted by applicable law, Licensor makes no warranty that the practice of the License does not infringe any trademark, patent, copyright, or similar rights of any third party. 34 | 35 | Article 6 - Claims by third parties and Infringements. 36 | 1. Licensor shall not be obligated to take measures to defend the Software against any claims made by third parties nor in case of infringements of the Software by third parties. 37 | 2. Should claims by third parties result in a limitation of the rights granted under the License, then the parties shall negotiate to adapt the provisions of this Agreement accordingly, whereby Licensee hereby waives any further claim or remedy towards Licensor. 38 | 3. Should claims by third parties result in the prohibition to use the Software, then Licensor shall be entitled to terminate this Agreement without indemnification. 39 | 40 | Article 7 - Duration and termination. 41 | 1. This Agreement is effective as long as it is available for download from the repository. 42 | 2. Upon breach or default of this Agreement by either party, a minimum two months' notice will be given by the other party to cure such breach or default. If at the end of such period the breach or default has not been cured, the other party will be entitled to terminate this Agreement by giving a one-month's notice. 43 | 3. At the end of this Agreement, Licensee shall cease any use of the Software and delete any copy of it, unless the copyrights to the Software have expired. 44 | 45 | Article 8 – Warranties and maintenance. 46 | 1. Licensor does not warrant that the Software will operate error-free. 47 | 2. No debugging service or any maintenance assistance will be provided by Licensor to Licensee for the Software. 48 | 3. Licensor will not be responsible for maintaining Licensee-modified portions of the Software. 49 | 50 | 51 | Article 10 - Nature of the Agreement. 52 | 1. Nothing herein shall be deemed to constitute either party as the agent or representative of the other party, or both parties as joint venturers or partners for any purposes. 53 | 2. Neither party shall be responsible for the acts or omissions of the other party, and neither party will have authority to speak for, represent or obligate the other party in any way without prior written authority from the other party. 54 | 55 | Article 11 - Applicable law and Jurisdiction. 56 | 1. This Agreement is governed by Italian law. 57 | 2. Any claims or disputes arising in connection with this Agreement, whether before or after it has expired, shall be brought before the court of Bologna, Italy. 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping (CVPR 2024)

3 | 4 | 5 |
6 | 7 | :rotating_light: This repository contains download links to the datasets, code snippets, and checkpoints of our work "**Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping**", [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024) 8 | 9 | by [Alex Costanzino*](https://alex-costanzino.github.io/), [Pierluigi Zama Ramirez*](https://pierlui92.github.io/), [Giuseppe Lisanti](https://www.unibo.it/sitoweb/giuseppe.lisanti), and [Luigi Di Stefano](https://www.unibo.it/sitoweb/luigi.distefano). \* _Equal Contribution_ 10 | 11 | University of Bologna 12 | 13 | 14 |
15 | 16 | 17 |

18 | 19 | [Project Page](https://cvlab-unibo.github.io/CrossmodalFeatureMapping/) | [Paper](https://arxiv.org/abs/2312.04521) 20 |

21 | 22 | 23 | ## :bookmark_tabs: Table of Contents 24 | 25 | 1. [Introduction](#clapper-introduction) 26 | 2. [Datasets](#file_cabinet) 27 | 3. [Checkpoints](#inbox_tray) 28 | 4. [Code](#memo-code) 29 | 6. [Contacts](#envelope-contacts) 30 | 31 |
32 | 33 | ## :clapper: Introduction 34 | Recent advancements have shown the potential of leveraging both point clouds and images to localize anomalies. 35 | Nevertheless, their applicability in industrial manufacturing is often constrained by significant drawbacks, such as the use of memory banks, which lead to a substantial increase in terms of memory footprint and inference time. 36 | We propose a novel light and fast framework that learns to map features from one modality to the other on nominal samples and detect anomalies by pinpointing inconsistencies between observed and mapped features. 37 | Extensive experiments show that our approach achieves state-of-the-art detection and segmentation performance, in both the standard and few-shot settings, on the MVTec 3D-AD dataset while achieving faster inference and occupying less memory than previous multimodal AD methods. 38 | Furthermore, we propose a layer pruning technique to improve memory and time efficiency with a marginal sacrifice in performance. 39 | 40 |

41 | 42 |

43 | 44 | Alt text 45 | 46 | :fountain_pen: If you find this code useful in your research, please cite: 47 | 48 | ```bibtex 49 | @inproceedings{costanzino2024cross, 50 | title = {Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping}, 51 | author = {Costanzino, Alex and Zama Ramirez, Pierluigi and Lisanti, Giuseppe and Di Stefano, Luigi}, 52 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 53 | note = {CVPR}, 54 | year = {2024}, 55 | } 56 | ``` 57 | 58 |

:file_cabinet: Datasets

59 | 60 | In our experiments, we employed two datasets featuring rgb images and point clouds: [MVTec 3D-AD](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad) and [Eyecandies](https://eyecan-ai.github.io/eyecandies/). You can preprocess them with the scripts contained in `processing`. 61 | 62 | 63 |

:inbox_tray: Checkpoints

64 | 65 | Here, you can download the weights of **CFMs** employed in the results of Table 1 and Table 2 of our paper. 66 | 67 | To use these weights, please follow these steps: 68 | 69 | 1. Create a folder named `checkpoints/checkpoints_CFM_mvtec` in the project directory; 70 | 2. Download the weights [[Download]](https://t.ly/DZ-o1); 71 | 3. Copy the downloaded weights into the `checkpoints_CFM_mvtec` folder. 72 | 73 | 74 | ## :memo: Code 75 | 76 |
77 | 78 | **Warning**: 79 | - The code utilizes `wandb` during training to log results. Please be sure to have a wandb account. Otherwise, if you prefer to not use `wandb`, disable it in `cfm_training.py` with the `flag mode = 'disabled'`. 80 | 81 |
82 | 83 | 84 | ### :hammer_and_wrench: Setup Instructions 85 | 86 | **Dependencies**: Ensure that you have installed all the necessary dependencies. The list of dependencies can be found in the `./requirements.txt` file. 87 | 88 | 89 | ### :rocket: Inference CFMs 90 | 91 | The `cfm_inference.py` script test the CFMs. It can be used to generate anomaly maps. 92 | 93 | You can specify the following options: 94 | - `--dataset_path`: Path to the root directory of the dataset. 95 | - `--checkpoint_folder`: Path to the directory of the checkpoints, i.e., `checkpoints/checkpoints_CFM_mvtec`. 96 | - `--class_name`: Class on which the CFMs was trained. 97 | - `--epochs_no`: Number of epochs used in CFMs optimization. 98 | - `--batch_size`: Number of samples per batch employed for CFMs optimization. 99 | - `--qualitative_folder`: Folder on which the anomaly maps are saved. 100 | - `--quantitative_folder`: Folder on which the metrics are saved. 101 | - `--visualize_plot`: Flag to visualize qualitatived during inference. 102 | - `--produce_qualitatives`: Flag to save qualitatived during inference. 103 | 104 | You can reproduce the results of Table 1 and Table 2 of the paper by running `02_eval_mvtec.sh`. 105 | 106 | If you haven't downloaded the checkpoints yet, you can find the download links in the **Checkpoints** section above. 107 | 108 | 109 | ### :rocket: Train CFMs 110 | 111 | To train CFMs refer to the example in `01_train_mvtec.sh` and `03_train_eyecandies.sh`. 112 | 113 | The `cfm_training.py` script train the CFMs. 114 | 115 | You can specify the following options: 116 | - `--dataset_path`: Path to the root directory of the dataset. 117 | - `--checkpoint_savepath`: Path to the directory on which checkpoints will be saved, i.e., `checkpoints/checkpoints_CFM_mvtec`. 118 | - `--class_name`: Class on which the CFMs are trained. 119 | - `--epochs_no`: Number of epochs for CFMs optimization. 120 | - `--batch_size`: Number of samples per batch for CFMs optimization. 121 | 122 | 123 | ## :envelope: Contacts 124 | 125 | For questions, please send an email to alex.costanzino@unibo.it or pierluigi.zama@unibo.it. 126 | 127 | 128 | ## :pray: Acknowledgements 129 | 130 | We would like to extend our sincere appreciation to the authors of the following projects for making their code available, which we have utilized in our work: 131 | 132 | - We would like to thank the authors of [M3DM](https://github.com/nomewang/M3DM), [3D-ADS](https://github.com/eliahuhorwitz/3D-ADS) and [AST](https://github.com/marco-rudolph/AST) for providing their code, which has been instrumental in our experiments. -------------------------------------------------------------------------------- /cfm_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | 10 | from models.features import MultimodalFeatures 11 | from models.dataset import get_data_loader 12 | from models.feature_transfer_nets import FeatureProjectionMLP, FeatureProjectionMLP_big 13 | 14 | from utils.metrics_utils import calculate_au_pro 15 | from sklearn.metrics import roc_auc_score 16 | 17 | 18 | def set_seeds(sid=42): 19 | np.random.seed(sid) 20 | 21 | torch.manual_seed(sid) 22 | if torch.cuda.is_available(): 23 | torch.cuda.manual_seed(sid) 24 | torch.cuda.manual_seed_all(sid) 25 | 26 | 27 | def infer_CFM(args): 28 | 29 | set_seeds() 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | 32 | # Dataloaders. 33 | test_loader = get_data_loader("test", class_name = args.class_name, img_size = 224, dataset_path = args.dataset_path) 34 | 35 | # Feature extractors. 36 | feature_extractor = MultimodalFeatures() 37 | 38 | # Model instantiation. 39 | CFM_2Dto3D = FeatureProjectionMLP(in_features = 768, out_features = 1152) 40 | CFM_3Dto2D = FeatureProjectionMLP(in_features = 1152, out_features = 768) 41 | 42 | CFM_2Dto3D_path = rf'{args.checkpoint_folder}/{args.class_name}/CFM_2Dto3D_{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs.pth' 43 | CFM_3Dto2D_path = rf'{args.checkpoint_folder}/{args.class_name}/CFM_3Dto2D_{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs.pth' 44 | 45 | CFM_2Dto3D.load_state_dict(torch.load(CFM_2Dto3D_path)) 46 | CFM_3Dto2D.load_state_dict(torch.load(CFM_3Dto2D_path)) 47 | 48 | CFM_2Dto3D.to(device), CFM_3Dto2D.to(device) 49 | 50 | # Make CFMs non-trainable. 51 | CFM_2Dto3D.eval(), CFM_3Dto2D.eval() 52 | 53 | # Use box filters to approximate gaussian blur (https://www.peterkovesi.com/papers/FastGaussianSmoothing.pdf). 54 | w_l, w_u = 5, 7 55 | pad_l, pad_u = 2, 3 56 | weight_l = torch.ones(1, 1, w_l, w_l, device = device)/(w_l**2) 57 | weight_u = torch.ones(1, 1, w_u, w_u, device = device)/(w_u**2) 58 | 59 | predictions, gts = [], [] 60 | image_labels, pixel_labels = [], [] 61 | image_preds, pixel_preds = [], [] 62 | 63 | # ------------ [Testing Loop] ------------ # 64 | 65 | # * Return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path 66 | for (rgb, pc, depth), gt, label, rgb_path in tqdm(test_loader, desc = f'Extracting feature from class: {args.class_name}.'): 67 | 68 | rgb, pc, depth = rgb.to(device), pc.to(device), depth.to(device) 69 | 70 | with torch.no_grad(): 71 | rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb, pc) 72 | 73 | rgb_feat_pred = CFM_3Dto2D(xyz_patch) 74 | xyz_feat_pred = CFM_2Dto3D(rgb_patch) 75 | 76 | xyz_mask = (xyz_patch.sum(axis = -1) == 0) # Mask only the feature vectors that are 0 everywhere. 77 | 78 | cos_3d = (torch.nn.functional.normalize(xyz_feat_pred, dim = 1) - torch.nn.functional.normalize(xyz_patch, dim = 1)).pow(2).sum(1).sqrt() 79 | cos_3d[xyz_mask] = 0. 80 | cos_3d = cos_3d.reshape(224,224) 81 | 82 | cos_2d = (torch.nn.functional.normalize(rgb_feat_pred, dim = 1) - torch.nn.functional.normalize(rgb_patch, dim = 1)).pow(2).sum(1).sqrt() 83 | cos_2d[xyz_mask] = 0. 84 | cos_2d = cos_2d.reshape(224,224) 85 | 86 | cos_comb = (cos_2d * cos_3d) 87 | cos_comb.reshape(-1)[xyz_mask] = 0. 88 | 89 | # Repeated box filters to approximate a Gaussian blur. 90 | cos_comb = cos_comb.reshape(1, 1, 224, 224) 91 | 92 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l) 93 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l) 94 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l) 95 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l) 96 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l) 97 | 98 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_u, weight = weight_u) 99 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_u, weight = weight_u) 100 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_u, weight = weight_u) 101 | 102 | cos_comb = cos_comb.reshape(224,224) 103 | 104 | # Prediction and ground-truth accumulation. 105 | gts.append(gt.squeeze().cpu().detach().numpy()) # * (224,224) 106 | predictions.append((cos_comb / (cos_comb[cos_comb!=0].mean())).cpu().detach().numpy()) # * (224,224) 107 | 108 | # GTs. 109 | image_labels.append(label) # * (1,) 110 | pixel_labels.extend(gt.flatten().cpu().detach().numpy()) # * (50176,) 111 | 112 | # Predictions. 113 | image_preds.append((cos_comb / torch.sqrt(cos_comb[cos_comb!=0].mean())).cpu().detach().numpy().max()) # * number 114 | pixel_preds.extend((cos_comb / torch.sqrt(cos_comb.mean())).flatten().cpu().detach().numpy()) # * (224,224) 115 | 116 | if args.produce_qualitatives: 117 | 118 | defect_class_str = rgb_path[0].split('/')[-3] 119 | image_name_str = rgb_path[0].split('/')[-1] 120 | 121 | save_path = f'{args.qualitative_folder}/{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs/{defect_class_str}' 122 | 123 | if not os.path.exists(save_path): 124 | os.makedirs(save_path) 125 | 126 | fig, axs = plt.subplots(2,3, figsize = (7,7)) 127 | 128 | denormalize = transforms.Compose([ 129 | transforms.Normalize(mean = [0., 0., 0.], std = [1/0.229, 1/0.224, 1/0.225]), 130 | transforms.Normalize(mean = [-0.485, -0.456, -0.406], std = [1., 1., 1.]), 131 | ]) 132 | 133 | rgb = denormalize(rgb) 134 | 135 | os.path.join(save_path, image_name_str) 136 | 137 | axs[0, 0].imshow(rgb.squeeze().permute(1,2,0).cpu().detach().numpy()) 138 | axs[0, 0].set_title('RGB') 139 | 140 | axs[0, 1].imshow(gt.squeeze().cpu().detach().numpy()) 141 | axs[0, 1].set_title('Ground-truth') 142 | 143 | axs[0, 2].imshow(depth.squeeze().permute(1,2,0).mean(axis=-1).cpu().detach().numpy()) 144 | axs[0, 2].set_title('Depth') 145 | 146 | axs[1, 0].imshow(cos_3d.cpu().detach().numpy(), cmap=plt.cm.jet) 147 | axs[1, 0].set_title('3D Cosine Similarity') 148 | 149 | axs[1, 1].imshow(cos_2d.cpu().detach().numpy(), cmap=plt.cm.jet) 150 | axs[1, 1].set_title('2D Cosine Similarity') 151 | 152 | axs[1, 2].imshow(cos_comb.cpu().detach().numpy(), cmap=plt.cm.jet) 153 | axs[1, 2].set_title('Combined Cosine Similarity') 154 | 155 | # Remove ticks and labels from all subplots 156 | for ax in axs.flat: 157 | ax.set_xticks([]) 158 | ax.set_yticks([]) 159 | ax.set_xticklabels([]) 160 | ax.set_yticklabels([]) 161 | 162 | # Adjust the layout and spacing 163 | plt.tight_layout() 164 | 165 | plt.savefig(os.path.join(save_path, image_name_str), dpi = 256) 166 | 167 | if args.visualize_plot: 168 | plt.show() 169 | 170 | # Calculate AD&S metrics. 171 | au_pros, _ = calculate_au_pro(gts, predictions) 172 | pixel_rocauc = roc_auc_score(np.stack(pixel_labels), np.stack(pixel_preds)) 173 | image_rocauc = roc_auc_score(np.stack(image_labels), np.stack(image_preds)) 174 | 175 | result_file_name = f'{args.quantitative_folder}/{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs.md' 176 | 177 | title_string = f'Metrics for class {args.class_name} with {args.epochs_no}ep_{args.batch_size}bs' 178 | header_string = 'AUPRO@30% & AUPRO@10% & AUPRO@5% & AUPRO@1% & P-AUROC & I-AUROC' 179 | results_string = f'{au_pros[0]:.3f} & {au_pros[1]:.3f} & {au_pros[2]:.3f} & {au_pros[3]:.3f} & {pixel_rocauc:.3f} & {image_rocauc:.3f}' 180 | 181 | if not os.path.exists(args.quantitative_folder): 182 | os.makedirs(args.quantitative_folder) 183 | 184 | with open(result_file_name, "w") as markdown_file: 185 | markdown_file.write(title_string + '\n' + header_string + '\n' + results_string) 186 | 187 | # Print AD&S metrics. 188 | print(title_string) 189 | print("AUPRO@30% | AUPRO@10% | AUPRO@5% | AUPRO@1% | P-AUROC | I-AUROC") 190 | print(f' {au_pros[0]:.3f} | {au_pros[1]:.3f} | {au_pros[2]:.3f} | {au_pros[3]:.3f} | {pixel_rocauc:.3f} | {image_rocauc:.3f}', end = '\n') 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser(description = 'Make inference with Crossmodal Feature Networks (CFMs) on a dataset.') 194 | 195 | parser.add_argument('--dataset_path', default = './datasets/mvtec3d', type = str, 196 | help = 'Dataset path.') 197 | 198 | parser.add_argument('--class_name', default = None, type = str, choices = ["bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire", 199 | 'CandyCane', 'ChocolateCookie', 'ChocolatePraline', 'Confetto', 'GummyBear', 'HazelnutTruffle', 'LicoriceSandwich', 'Lollipop', 'Marshmallow', 'PeppermintCandy'], 200 | help = 'Category name.') 201 | 202 | parser.add_argument('--checkpoint_folder', default = './checkpoints/checkpoints_CFM_mvtec', type = str, 203 | help = 'Path to the folder containing CFMs checkpoints.') 204 | 205 | parser.add_argument('--qualitative_folder', default = './results/qualitatives_mvtec', type = str, 206 | help = 'Path to the folder in which to save the qualitatives.') 207 | 208 | parser.add_argument('--quantitative_folder', default = './results/quantitatives_mvtec', type = str, 209 | help = 'Path to the folder in which to save the quantitatives.') 210 | 211 | parser.add_argument('--epochs_no', default = 50, type = int, 212 | help = 'Number of epochs to train the CFMs.') 213 | 214 | parser.add_argument('--batch_size', default = 4, type = int, 215 | help = 'Batch dimension. Usually 16 is around the max.') 216 | 217 | parser.add_argument('--visualize_plot', default = False, action = 'store_true', 218 | help = 'Whether to show plot or not.') 219 | 220 | parser.add_argument('--produce_qualitatives', default = False, action = 'store_true', 221 | help = 'Whether to produce qualitatives or not.') 222 | 223 | args = parser.parse_args() 224 | 225 | infer_CFM(args) -------------------------------------------------------------------------------- /cfm_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import torch 5 | import wandb 6 | 7 | import numpy as np 8 | from itertools import chain 9 | 10 | from tqdm import tqdm, trange 11 | 12 | from models.features import MultimodalFeatures 13 | from models.dataset import get_data_loader 14 | from models.feature_transfer_nets import FeatureProjectionMLP, FeatureProjectionMLP_big 15 | 16 | 17 | def set_seeds(sid=115): 18 | np.random.seed(sid) 19 | 20 | torch.manual_seed(sid) 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed(sid) 23 | torch.cuda.manual_seed_all(sid) 24 | 25 | 26 | def train_CFM(args): 27 | 28 | set_seeds() 29 | device = "cuda" if torch.cuda.is_available() else "cpu" 30 | 31 | model_name = f'{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs' 32 | 33 | wandb.init( 34 | project = 'crossmodal-feature-mappings', 35 | name = model_name 36 | ) 37 | 38 | # Dataloader. 39 | train_loader = get_data_loader("train", class_name = args.class_name, img_size = 224, dataset_path = args.dataset_path, batch_size = args.batch_size, shuffle = True) 40 | 41 | # Feature extractors. 42 | feature_extractor = MultimodalFeatures() 43 | 44 | # Model instantiation. 45 | CFM_2Dto3D = FeatureProjectionMLP(in_features = 768, out_features = 1152) 46 | CFM_3Dto2D = FeatureProjectionMLP(in_features = 1152, out_features = 768) 47 | 48 | optimizer = torch.optim.Adam(params = chain(CFM_2Dto3D.parameters(), CFM_3Dto2D.parameters())) 49 | 50 | CFM_2Dto3D.to(device), CFM_3Dto2D.to(device) 51 | 52 | metric = torch.nn.CosineSimilarity(dim = -1, eps = 1e-06) 53 | 54 | for epoch in trange(args.epochs_no, desc = f'Training Feature Transfer Net.'): 55 | 56 | epoch_cos_sim_3Dto2D, epoch_cos_sim_2Dto3D = [], [] 57 | 58 | # ------------ [Trainig Loop] ------------ # 59 | # * Return (rgb_img, organized_pc, depth_map_3channel), globl_label 60 | for (rgb, pc, _), _ in tqdm(train_loader, desc = f'Extracting feature from class: {args.class_name}.'): 61 | rgb, pc = rgb.to(device), pc.to(device) 62 | 63 | # Make CFMs trainable. 64 | CFM_2Dto3D.train(), CFM_3Dto2D.train() 65 | 66 | if args.batch_size == 1: 67 | rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb, pc) 68 | else: 69 | rgb_patches = [] 70 | xyz_patches = [] 71 | 72 | for i in range(rgb.shape[0]): 73 | rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb[i].unsqueeze(dim=0), pc[i].unsqueeze(dim=0)) 74 | 75 | rgb_patches.append(rgb_patch) 76 | xyz_patches.append(xyz_patch) 77 | 78 | rgb_patch = torch.stack(rgb_patches, dim=0) 79 | xyz_patch = torch.stack(xyz_patches, dim=0) 80 | 81 | # Predictions. 82 | rgb_feat_pred = CFM_3Dto2D(xyz_patch) 83 | xyz_feat_pred = CFM_2Dto3D(rgb_patch) 84 | 85 | # Losses. 86 | xyz_mask = (xyz_patch.sum(axis = -1) == 0) # Mask only the feature vectors that are 0 everywhere. 87 | 88 | loss_3Dto2D = 1 - metric(xyz_feat_pred[~xyz_mask], xyz_patch[~xyz_mask]).mean() 89 | loss_2Dto3D = 1 - metric(rgb_feat_pred[~xyz_mask], rgb_patch[~xyz_mask]).mean() 90 | 91 | cos_sim_3Dto2D, cos_sim_2Dto3D = 1 - loss_3Dto2D.cpu(), 1 - loss_2Dto3D.cpu() 92 | 93 | epoch_cos_sim_3Dto2D.append(cos_sim_3Dto2D), epoch_cos_sim_2Dto3D.append(cos_sim_2Dto3D) 94 | 95 | # Logging. 96 | wandb.log({ 97 | "train/loss_3Dto2D" : loss_3Dto2D, 98 | "train/loss_2Dto3D" : loss_2Dto3D, 99 | "train/cosine_similarity_3Dto2D" : cos_sim_3Dto2D, 100 | "train/cosine_similarity_2Dto3D" : cos_sim_2Dto3D, 101 | }) 102 | 103 | if torch.isnan(loss_3Dto2D) or torch.isinf(loss_3Dto2D) or torch.isnan(loss_2Dto3D) or torch.isinf(loss_2Dto3D): 104 | exit() 105 | 106 | # Optimization. 107 | if not torch.isnan(loss_3Dto2D) and not torch.isinf(loss_3Dto2D) and not torch.isnan(loss_2Dto3D) and not torch.isinf(loss_2Dto3D): 108 | 109 | optimizer.zero_grad() 110 | 111 | loss_3Dto2D.backward(), loss_2Dto3D.backward() 112 | 113 | optimizer.step() 114 | 115 | # Global logging. 116 | wandb.log({ 117 | "global_train/cos_sim_3Dto2D" : torch.Tensor(epoch_cos_sim_3Dto2D, device = 'cpu').mean(), 118 | "global_train/cos_sim_2Dto3D" : torch.Tensor(epoch_cos_sim_2Dto3D, device = 'cpu').mean() 119 | }) 120 | 121 | # Model saving. 122 | directory = f'{args.checkpoint_savepath}/{args.class_name}' 123 | 124 | if not os.path.exists(directory): 125 | os.makedirs(directory) 126 | 127 | torch.save(CFM_2Dto3D.state_dict(), os.path.join(directory, 'CFM_2Dto3D_' + model_name + '.pth')) 128 | torch.save(CFM_3Dto2D.state_dict(), os.path.join(directory, 'CFM_3Dto2D_' + model_name + '.pth')) 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser(description = 'Train Crossmodal Feature Networks (CFMs) on a dataset.') 133 | 134 | parser.add_argument('--dataset_path', default = './datasets/mvtec3d', type = str, 135 | help = 'Dataset path.') 136 | 137 | parser.add_argument('--checkpoint_savepath', default = './checkpoints/checkpoints_CFM_mvtec', type = str, 138 | help = 'Where to save the model checkpoints.') 139 | 140 | parser.add_argument('--class_name', default = None, type = str, choices = ["bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire", 141 | 'CandyCane', 'ChocolateCookie', 'ChocolatePraline', 'Confetto', 'GummyBear', 'HazelnutTruffle', 'LicoriceSandwich', 'Lollipop', 'Marshmallow', 'PeppermintCandy'], 142 | help = 'Category name.') 143 | 144 | parser.add_argument('--epochs_no', default = 50, type = int, 145 | help = 'Number of epochs to train the CFMs.') 146 | 147 | parser.add_argument('--batch_size', default = 4, type = int, 148 | help = 'Batch dimension. Usually 16 is around the max.') 149 | 150 | args = parser.parse_args() 151 | train_CFM(args) -------------------------------------------------------------------------------- /images/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVLAB-Unibo/crossmodal-feature-mapping/10198716d65919dcc5257323af512907e63a4003/images/architecture.jpg -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torchvision import transforms 4 | import glob 5 | from torch.utils.data import Dataset 6 | from utils.mvtec3d_utils import * 7 | from torch.utils.data import DataLoader 8 | import numpy as np 9 | from utils.general_utils import SquarePad 10 | 11 | def eyecandies_classes(): 12 | return [ 13 | 'CandyCane', 14 | 'ChocolateCookie', 15 | 'ChocolatePraline', 16 | 'Confetto', 17 | 'GummyBear', 18 | 'HazelnutTruffle', 19 | 'LicoriceSandwich', 20 | 'Lollipop', 21 | 'Marshmallow', 22 | 'PeppermintCandy', 23 | ] 24 | 25 | def mvtec3d_classes(): 26 | return [ 27 | "bagel", 28 | "cable_gland", 29 | "carrot", 30 | "cookie", 31 | "dowel", 32 | "foam", 33 | "peach", 34 | "potato", 35 | "rope", 36 | "tire", 37 | ] 38 | 39 | RGB_SIZE = 224 40 | 41 | class BaseAnomalyDetectionDataset(Dataset): 42 | def __init__(self, split, class_name, img_size, dataset_path): 43 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406] 44 | self.IMAGENET_STD = [0.229, 0.224, 0.225] 45 | 46 | self.cls = class_name 47 | self.size = img_size 48 | self.img_path = os.path.join(dataset_path, self.cls, split) 49 | 50 | self.rgb_transform = transforms.Compose([ 51 | SquarePad(), 52 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation = transforms.InterpolationMode.BICUBIC), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean = self.IMAGENET_MEAN, std = self.IMAGENET_STD) 55 | ]) 56 | 57 | 58 | class TrainValDataset(BaseAnomalyDetectionDataset): 59 | def __init__(self, split, class_name, img_size, dataset_path): 60 | super().__init__(split = split, class_name = class_name, img_size = img_size, dataset_path = dataset_path) 61 | 62 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 63 | 64 | def load_dataset(self): 65 | img_tot_paths = [] 66 | tot_labels = [] 67 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png") 68 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff") 69 | rgb_paths.sort() 70 | tiff_paths.sort() 71 | sample_paths = list(zip(rgb_paths, tiff_paths)) 72 | img_tot_paths.extend(sample_paths) 73 | tot_labels.extend([0] * len(sample_paths)) 74 | return img_tot_paths, tot_labels 75 | 76 | def __len__(self): 77 | return len(self.img_paths) 78 | 79 | def __getitem__(self, idx): 80 | img_path, label = self.img_paths[idx], self.labels[idx] 81 | rgb_path = img_path[0] 82 | tiff_path = img_path[1] 83 | img = Image.open(rgb_path).convert('RGB') 84 | 85 | img = self.rgb_transform(img) 86 | organized_pc = read_tiff_organized_pc(tiff_path) 87 | 88 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis = 2) 89 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 90 | resized_organized_pc = resize_organized_pc(organized_pc, target_height = self.size, target_width = self.size) 91 | resized_organized_pc = resized_organized_pc.clone().detach().float() 92 | 93 | return (img, resized_organized_pc, resized_depth_map_3channel), label 94 | 95 | 96 | class TestDataset(BaseAnomalyDetectionDataset): 97 | def __init__(self, class_name, img_size, dataset_path): 98 | super().__init__(split = "test", class_name = class_name, img_size = img_size, dataset_path = dataset_path) 99 | 100 | self.gt_transform = transforms.Compose([ 101 | SquarePad(), 102 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST), 103 | transforms.ToTensor()]) 104 | 105 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 106 | 107 | def load_dataset(self): 108 | img_tot_paths = [] 109 | gt_tot_paths = [] 110 | tot_labels = [] 111 | defect_types = os.listdir(self.img_path) 112 | 113 | for defect_type in defect_types: 114 | if defect_type == 'good': 115 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") 116 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 117 | rgb_paths.sort() 118 | tiff_paths.sort() 119 | sample_paths = list(zip(rgb_paths, tiff_paths)) 120 | img_tot_paths.extend(sample_paths) 121 | gt_tot_paths.extend([0] * len(sample_paths)) 122 | tot_labels.extend([0] * len(sample_paths)) 123 | else: 124 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") 125 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 126 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png") 127 | rgb_paths.sort() 128 | tiff_paths.sort() 129 | gt_paths.sort() 130 | sample_paths = list(zip(rgb_paths, tiff_paths)) 131 | 132 | img_tot_paths.extend(sample_paths) 133 | gt_tot_paths.extend(gt_paths) 134 | tot_labels.extend([1] * len(sample_paths)) 135 | 136 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 137 | 138 | return img_tot_paths, gt_tot_paths, tot_labels 139 | 140 | def __len__(self): 141 | return len(self.img_paths) 142 | 143 | def __getitem__(self, idx): 144 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 145 | rgb_path = img_path[0] 146 | tiff_path = img_path[1] 147 | img_original = Image.open(rgb_path).convert('RGB') 148 | img = self.rgb_transform(img_original) 149 | 150 | organized_pc = read_tiff_organized_pc(tiff_path) 151 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) 152 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 153 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) 154 | resized_organized_pc = resized_organized_pc.clone().detach().float() 155 | 156 | if gt == 0: 157 | gt = torch.zeros( 158 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]]) 159 | else: 160 | gt = Image.open(gt).convert('L') 161 | gt = self.gt_transform(gt) 162 | gt = torch.where(gt > 0.5, 1., .0) 163 | 164 | return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path 165 | 166 | 167 | def get_data_loader(split, class_name, dataset_path, img_size = 224, batch_size = 1, shuffle = False): 168 | if split in ['train']: 169 | dataset = TrainValDataset(split = "train", class_name = class_name, img_size = img_size, dataset_path = dataset_path) 170 | elif split in ['validation']: 171 | dataset = TrainValDataset(split = "validation", class_name = class_name, img_size = img_size, dataset_path = dataset_path) 172 | elif split in ['test']: 173 | dataset = TestDataset(class_name = class_name, img_size = img_size, dataset_path = dataset_path) 174 | 175 | data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = shuffle, 176 | num_workers = 1, drop_last = False, pin_memory = True) 177 | 178 | return data_loader 179 | -------------------------------------------------------------------------------- /models/feature_transfer_nets.py: -------------------------------------------------------------------------------- 1 | # Alex Costanzino, CVLab 2 | # July 2023 3 | 4 | import torch 5 | 6 | class FeatureProjectionMLP(torch.nn.Module): 7 | def __init__(self, in_features = None, out_features = None, act_layer = torch.nn.GELU): 8 | super().__init__() 9 | 10 | self.act_fcn = act_layer() 11 | 12 | self.input = torch.nn.Linear(in_features, (in_features + out_features) // 2) 13 | self.projection = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2) 14 | self.output = torch.nn.Linear((in_features + out_features) // 2, out_features) 15 | 16 | def forward(self, x): 17 | x = self.input(x) 18 | x = self.act_fcn(x) 19 | 20 | x = self.projection(x) 21 | x = self.act_fcn(x) 22 | 23 | x = self.output(x) 24 | 25 | return x 26 | 27 | class FeatureProjectionMLP_big(torch.nn.Module): 28 | def __init__(self, in_features = None, out_features = None, act_layer = torch.nn.GELU): 29 | super().__init__() 30 | 31 | self.act_fcn = act_layer() 32 | 33 | self.input = torch.nn.Linear(in_features, (in_features + out_features) // 2) 34 | 35 | self.projection_a = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2) 36 | self.projection_b = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2) 37 | self.projection_c = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2) 38 | self.projection_d = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2) 39 | self.projection_e = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2) 40 | 41 | self.output = torch.nn.Linear((in_features + out_features) // 2, out_features) 42 | 43 | def forward(self, x): 44 | x = self.input(x) 45 | x = self.act_fcn(x) 46 | 47 | x = self.projection_a(x) 48 | x = self.act_fcn(x) 49 | x = self.projection_b(x) 50 | x = self.act_fcn(x) 51 | x = self.projection_c(x) 52 | x = self.act_fcn(x) 53 | x = self.projection_d(x) 54 | x = self.act_fcn(x) 55 | x = self.projection_e(x) 56 | x = self.act_fcn(x) 57 | 58 | x = self.output(x) 59 | 60 | return x -------------------------------------------------------------------------------- /models/features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from sklearn.metrics import roc_auc_score 5 | from utils.metrics_utils import calculate_au_pro 6 | from utils.pointnet2_utils import interpolating_points 7 | from models.full_models import FeatureExtractors 8 | 9 | 10 | dino_backbone_name = 'vit_base_patch8_224.dino' # 224/8 -> 28 patches. 11 | group_size = 128 12 | num_group = 1024 13 | 14 | class MultimodalFeatures(torch.nn.Module): 15 | def __init__(self, image_size = 224): 16 | super().__init__() 17 | 18 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 19 | 20 | self.deep_feature_extractor = FeatureExtractors(device = self.device, 21 | rgb_backbone_name = dino_backbone_name, 22 | group_size = group_size, num_group = num_group) 23 | 24 | self.deep_feature_extractor.to(self.device) 25 | 26 | self.image_size = image_size 27 | 28 | # * Applies a 2D adaptive average pooling over an input signal composed of several input planes. 29 | # * The output is of size H x W, for any input size. The number of output features is equal to the number of input planes. 30 | self.resize = torch.nn.AdaptiveAvgPool2d((224, 224)) 31 | 32 | self.average = torch.nn.AvgPool2d(kernel_size = 3, stride = 1) 33 | 34 | def __call__(self, rgb, xyz): 35 | rgb = rgb.to(self.device) 36 | xyz = xyz.to(self.device) 37 | 38 | with torch.no_grad(): 39 | rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz) 40 | 41 | 42 | interpolated_feature_maps = interpolating_points(xyz, center.permute(0,2,1), xyz_feature_maps) 43 | 44 | xyz_feature_maps = [fmap for fmap in [xyz_feature_maps]] 45 | rgb_feature_maps = [fmap for fmap in [rgb_feature_maps]] 46 | 47 | return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx, interpolated_feature_maps 48 | 49 | def calculate_metrics(self): 50 | self.image_preds = np.stack(self.image_preds) 51 | self.image_labels = np.stack(self.image_labels) 52 | self.pixel_preds = np.array(self.pixel_preds) 53 | 54 | self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds) 55 | self.pixel_rocauc = roc_auc_score(self.pixel_labels, self.pixel_preds) 56 | self.au_pro, _ = calculate_au_pro(self.gts, self.predictions) 57 | 58 | def get_features_maps(self, rgb, pc): 59 | 60 | unorganized_pc = pc.squeeze().permute(1, 2, 0).reshape(-1, pc.shape[1]) 61 | 62 | # Find nonzero indices. 63 | nonzero_indices = torch.nonzero(torch.all(unorganized_pc != 0, dim=1)).squeeze(dim=1) 64 | 65 | # Select nonzero indices and discard the others. 66 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :].unsqueeze(dim=0).permute(0, 2, 1) 67 | 68 | rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(rgb, unorganized_pc_no_zeros.contiguous()) 69 | 70 | # Interpolation to obtain a "full image" with point cloud features. 71 | xyz_patch = torch.cat(xyz_feature_maps, 1) 72 | 73 | xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size * self.image_size), dtype = xyz_patch.dtype, device = self.device) 74 | xyz_patch_full[..., nonzero_indices] = interpolated_pc 75 | 76 | xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) 77 | xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) 78 | xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T 79 | 80 | rgb_patch = torch.cat(rgb_feature_maps, 1) 81 | 82 | upsample_shape = xyz_patch_full_resized.shape[-2:] 83 | rgb_patch_upsample = torch.nn.functional.interpolate(rgb_patch, size = upsample_shape, mode = 'bilinear', align_corners = False) 84 | rgb_patch_upsample = rgb_patch_upsample.reshape(rgb_patch.shape[1], -1).T 85 | 86 | return rgb_patch_upsample, xyz_patch -------------------------------------------------------------------------------- /models/full_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | from timm.models.layers import DropPath 5 | from pointnet2_ops import pointnet2_utils 6 | 7 | class FeatureExtractors(torch.nn.Module): 8 | def __init__(self, device, 9 | rgb_backbone_name = 'vit_base_patch8_224_dino.dino', out_indices = None, 10 | group_size = 128, num_group = 1024): 11 | 12 | super().__init__() 13 | 14 | self.device = device 15 | 16 | kwargs = {'features_only': True if out_indices else False} 17 | 18 | if out_indices: 19 | kwargs.update({'out_indices': out_indices}) 20 | 21 | layers_keep = 12 22 | 23 | ## RGB backbone 24 | self.rgb_backbone = timm.create_model(model_name = rgb_backbone_name, pretrained = True, **kwargs) 25 | # ! Use only the first k blocks. 26 | self.rgb_backbone.blocks = torch.nn.Sequential(*self.rgb_backbone.blocks[:layers_keep]) # Remove Block(s) from 5 to 11. 27 | 28 | ## XYZ backbone 29 | self.xyz_backbone = PointTransformer(group_size = group_size, num_group = num_group) 30 | self.xyz_backbone.load_model_from_ckpt("checkpoints/feature_extractors/pointmae_pretrain.pth") 31 | # ! Use only the first k blocks. 32 | self.xyz_backbone.blocks.blocks = torch.nn.Sequential(*self.xyz_backbone.blocks.blocks[:layers_keep]) # Remove Block(s) from 5 to 11. 33 | 34 | 35 | def forward_rgb_features(self, x): 36 | x = self.rgb_backbone.patch_embed(x) 37 | x = self.rgb_backbone._pos_embed(x) 38 | x = self.rgb_backbone.norm_pre(x) 39 | x = self.rgb_backbone.blocks(x) 40 | x = self.rgb_backbone.norm(x) 41 | 42 | feat = x[:,1:].permute(0, 2, 1).view(1, -1, 28, 28) # view(1, -1, 14, 14) 43 | return feat 44 | 45 | 46 | def forward(self, rgb, xyz): 47 | rgb_features = self.forward_rgb_features(rgb) 48 | xyz_features, center, ori_idx, center_idx = self.xyz_backbone(xyz) 49 | 50 | return rgb_features, xyz_features, center, ori_idx, center_idx 51 | 52 | 53 | def fps(data, number): 54 | ''' 55 | data B N 3 56 | number int 57 | ''' 58 | fps_idx = pointnet2_utils.furthest_point_sample(data, number) 59 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() 60 | return fps_data, fps_idx 61 | 62 | 63 | class KNN(nn.Module): 64 | def __init__(self, k): 65 | super(KNN, self).__init__() 66 | self.k = k 67 | 68 | def forward(self, xyz, centers): 69 | assert xyz.size(0) == centers.size(0), "Batch size of xyz and centers should be the same" 70 | 71 | B, N_points, _ = xyz.size() 72 | K = centers.size(1) 73 | 74 | # Compute pairwise distances 75 | xyz = xyz.unsqueeze(2) # [B, N, 1, 3] 76 | centers = centers.unsqueeze(1) # [B, 1, K, 3] 77 | distances = torch.norm(xyz - centers, dim=-1) # [B, N, K] 78 | 79 | # Get the indices of the k nearest neighbors 80 | _, indices = torch.topk(distances, self.k, dim=1, largest=False, sorted=True) 81 | return indices 82 | 83 | 84 | class Group(nn.Module): 85 | def __init__(self, num_group, group_size): 86 | super().__init__() 87 | self.num_group = num_group 88 | self.group_size = group_size 89 | self.knn = KNN(k=self.group_size) 90 | 91 | def forward(self, xyz): 92 | ''' 93 | input: B N 3 94 | --------------------------- 95 | output: B G M 3 96 | center : B G 3 97 | ''' 98 | 99 | batch_size, num_points, _ = xyz.shape 100 | # fps the centers out 101 | center, center_idx = fps(xyz.contiguous(), self.num_group) # B G 3 102 | 103 | # knn to get the neighborhood 104 | # _, idx = self.knn(xyz, center) # B G M 105 | idx = self.knn(xyz, center).permute(0,2,1) # B G M 106 | 107 | assert idx.size(1) == self.num_group 108 | assert idx.size(2) == self.group_size 109 | ori_idx = idx 110 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 111 | idx = idx + idx_base 112 | idx = idx[-1] 113 | neighborhood = xyz.reshape(batch_size * num_points, -1)[idx, :] 114 | neighborhood = neighborhood.reshape(batch_size, self.num_group, self.group_size, 3).contiguous() 115 | # normalize 116 | neighborhood = neighborhood - center.unsqueeze(2) 117 | return neighborhood, center, ori_idx, center_idx 118 | 119 | 120 | class Encoder(nn.Module): 121 | def __init__(self, encoder_channel): 122 | super().__init__() 123 | self.encoder_channel = encoder_channel 124 | self.first_conv = nn.Sequential( 125 | nn.Conv1d(3, 128, 1), 126 | nn.BatchNorm1d(128), 127 | nn.ReLU(inplace=True), 128 | nn.Conv1d(128, 256, 1) 129 | ) 130 | self.second_conv = nn.Sequential( 131 | nn.Conv1d(512, 512, 1), 132 | nn.BatchNorm1d(512), 133 | nn.ReLU(inplace=True), 134 | nn.Conv1d(512, self.encoder_channel, 1) 135 | ) 136 | 137 | def forward(self, point_groups): 138 | ''' 139 | point_groups : B G N 3 140 | ----------------- 141 | feature_global : B G C 142 | ''' 143 | bs, g, n, _ = point_groups.shape 144 | point_groups = point_groups.reshape(bs * g, n, 3) 145 | # encoder 146 | feature = self.first_conv(point_groups.transpose(2, 1)) 147 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] 148 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) 149 | feature = self.second_conv(feature) 150 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] 151 | return feature_global.reshape(bs, g, self.encoder_channel) 152 | 153 | 154 | class MLP(nn.Module): 155 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 156 | super().__init__() 157 | out_features = out_features or in_features 158 | hidden_features = hidden_features or in_features 159 | self.fc1 = nn.Linear(in_features, hidden_features) 160 | self.act = act_layer() 161 | self.fc2 = nn.Linear(hidden_features, out_features) 162 | self.drop = nn.Dropout(drop) 163 | 164 | def forward(self, x): 165 | x = self.fc1(x) 166 | x = self.act(x) 167 | x = self.drop(x) 168 | x = self.fc2(x) 169 | x = self.drop(x) 170 | return x 171 | 172 | 173 | class Attention(nn.Module): 174 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 175 | super().__init__() 176 | self.num_heads = num_heads 177 | head_dim = dim // num_heads 178 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 179 | self.scale = qk_scale or head_dim ** -0.5 180 | 181 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 182 | self.attn_drop = nn.Dropout(attn_drop) 183 | self.proj = nn.Linear(dim, dim) 184 | self.proj_drop = nn.Dropout(proj_drop) 185 | 186 | def forward(self, x): 187 | B, N, C = x.shape 188 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 189 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 190 | 191 | attn = (q * self.scale) @ k.transpose(-2, -1) 192 | attn = attn.softmax(dim=-1) 193 | attn = self.attn_drop(attn) 194 | 195 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 196 | x = self.proj(x) 197 | x = self.proj_drop(x) 198 | return x 199 | 200 | 201 | class Block(nn.Module): 202 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 203 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 204 | super().__init__() 205 | self.norm1 = norm_layer(dim) 206 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 207 | self.norm2 = norm_layer(dim) 208 | mlp_hidden_dim = int(dim * mlp_ratio) 209 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 210 | 211 | self.attn = Attention( 212 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 213 | 214 | def forward(self, x): 215 | x = x + self.drop_path(self.attn(self.norm1(x))) 216 | x = x + self.drop_path(self.mlp(self.norm2(x))) 217 | return x 218 | 219 | 220 | class TransformerEncoder(nn.Module): 221 | """ 222 | Transformer Encoder without hierarchical structure 223 | """ 224 | 225 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 226 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): 227 | super().__init__() 228 | 229 | self.blocks = nn.ModuleList([ 230 | Block( 231 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 232 | drop=drop_rate, attn_drop=attn_drop_rate, 233 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 234 | ) 235 | for i in range(depth)]) 236 | 237 | def forward(self, x, pos): 238 | feature_list = [] 239 | # fetch_idx = [3, 7, 11] ### ! PD 240 | for i, block in enumerate(self.blocks): 241 | x = block(x + pos) 242 | # if i in fetch_idx: ### ! PD 243 | feature_list.append(x) 244 | return feature_list 245 | 246 | 247 | class PointTransformer(nn.Module): 248 | def __init__(self, group_size = 128, num_group = 1024, encoder_dims = 384): 249 | super().__init__() 250 | 251 | self.trans_dim = 384 252 | self.depth = 12 253 | self.drop_path_rate = 0.1 254 | self.num_heads = 6 255 | 256 | self.group_size = group_size 257 | self.num_group = num_group 258 | # grouper 259 | self.group_divider = Group(num_group = self.num_group, group_size = self.group_size) 260 | # define the encoder 261 | self.encoder_dims = encoder_dims 262 | if self.encoder_dims != self.trans_dim: 263 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 264 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) 265 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim) 266 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 267 | # bridge encoder and transformer 268 | 269 | self.pos_embed = nn.Sequential( 270 | nn.Linear(3, 128), 271 | nn.GELU(), 272 | nn.Linear(128, self.trans_dim) 273 | ) 274 | 275 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 276 | self.blocks = TransformerEncoder( 277 | embed_dim=self.trans_dim, 278 | depth=self.depth, 279 | drop_path_rate=dpr, 280 | num_heads=self.num_heads 281 | ) 282 | 283 | self.norm = nn.LayerNorm(self.trans_dim) 284 | 285 | def load_model_from_ckpt(self, bert_ckpt_path): 286 | if bert_ckpt_path is not None: 287 | device = "cuda" if torch.cuda.is_available() else "cpu" 288 | ckpt = torch.load(bert_ckpt_path, map_location=device) 289 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} 290 | 291 | for k in list(base_ckpt.keys()): 292 | if k.startswith('MAE_encoder'): 293 | base_ckpt[k[len('MAE_encoder.'):]] = base_ckpt[k] 294 | del base_ckpt[k] 295 | elif k.startswith('base_model'): 296 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k] 297 | del base_ckpt[k] 298 | 299 | incompatible = self.load_state_dict(base_ckpt, strict=False) 300 | 301 | def load_model_from_pb_ckpt(self, bert_ckpt_path): 302 | ckpt = torch.load(bert_ckpt_path) 303 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} 304 | for k in list(base_ckpt.keys()): 305 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'): 306 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k] 307 | elif k.startswith('base_model'): 308 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k] 309 | del base_ckpt[k] 310 | 311 | incompatible = self.load_state_dict(base_ckpt, strict=False) 312 | 313 | if incompatible.missing_keys: 314 | print('missing_keys') 315 | print( 316 | incompatible.missing_keys 317 | ) 318 | if incompatible.unexpected_keys: 319 | print('unexpected_keys') 320 | print( 321 | incompatible.unexpected_keys 322 | ) 323 | 324 | print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}') 325 | 326 | def forward(self, pts): 327 | if self.encoder_dims != self.trans_dim: 328 | B,C,N = pts.shape 329 | pts = pts.transpose(-1, -2) # B N 3 330 | # divide the point clo ud in the same form. This is important 331 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts) 332 | # # generate mask 333 | # bool_masked_pos = self._mask_center(center, no_mask = False) # B G 334 | # encoder the input cloud blocks 335 | group_input_tokens = self.encoder(neighborhood) # B G N 336 | group_input_tokens = self.reduce_dim(group_input_tokens) 337 | # prepare cls 338 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) 339 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) 340 | # add pos embedding 341 | pos = self.pos_embed(center) 342 | # final input 343 | x = torch.cat((cls_tokens, group_input_tokens), dim=1) 344 | pos = torch.cat((cls_pos, pos), dim=1) 345 | # transformer 346 | feature_list = self.blocks(x, pos) 347 | feature_list = [self.norm(x)[:,1:].transpose(-1, -2).contiguous() for x in feature_list] 348 | x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152 349 | return x, center, ori_idx, center_idx 350 | else: 351 | B, C, N = pts.shape 352 | pts = pts.transpose(-1, -2) # B N 3 353 | # divide the point clo ud in the same form. This is important 354 | 355 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts) 356 | group_input_tokens = self.encoder(neighborhood) # B G N 357 | 358 | pos = self.pos_embed(center) 359 | # final input 360 | x = group_input_tokens 361 | # transformer 362 | feature_list = self.blocks(x, pos) 363 | feature_list = [self.norm(x).transpose(-1, -2).contiguous() for x in feature_list] 364 | if len(feature_list) == 12: 365 | x = torch.cat((feature_list[3],feature_list[7],feature_list[11]), dim=1) 366 | elif len(feature_list) == 8: 367 | x = torch.cat((feature_list[1],feature_list[4],feature_list[7]), dim=1) 368 | elif len(feature_list) == 4: 369 | x = torch.cat((feature_list[1],feature_list[2],feature_list[3]), dim=1) 370 | else: 371 | x = feature_list[-1] 372 | return x, center, ori_idx, center_idx -------------------------------------------------------------------------------- /processing/aggregate_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | 6 | classes = ["bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire"] 7 | metrics = ["AUPRO@30\%", "AUPRO@10\%", "AUPRO@5\%", "AUPRO@1\%", "P-AUROC", "I-AUROC"] 8 | 9 | def read_md_files(directory): 10 | aggregated_results = [] 11 | md_files = sorted([filename for filename in os.listdir(directory) if filename.endswith(".md")]) 12 | 13 | for filename in md_files: 14 | filepath = os.path.join(directory, filename) 15 | with open(filepath, "r", encoding = "utf-8") as file: 16 | file_contents = file.read().split('\n')[-1].split('&') # Take only the last line (results) and split in string numbers. 17 | file_contents_num = [float(res) for res in file_contents] # Convert string numbers in float numbers. 18 | aggregated_results.append(file_contents_num) 19 | return np.array(aggregated_results) # Expected shape (10,6) 20 | 21 | def produce_table(args): 22 | data = read_md_files(args.quantitative_folder) 23 | results = pd.DataFrame(data, index=classes, columns=metrics).T 24 | results['mean'] = results.mean(axis=1) 25 | 26 | print(results.to_latex(float_format = "%.3f")) 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser(description = 'Create a LaTeX table with all the results.') 30 | 31 | parser.add_argument('--quantitative_folder', default = None, type = str, 32 | help = 'Path to the folder from which to fetch the quantitatives.') 33 | 34 | args = parser.parse_args() 35 | 36 | produce_table(args) -------------------------------------------------------------------------------- /processing/preprocess_eyecandies.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | import cv2 4 | import numpy as np 5 | import tifffile 6 | import yaml 7 | import imageio.v3 as iio 8 | import math 9 | import argparse 10 | 11 | # The same camera has been used for all the images 12 | FOCAL_LENGTH = 711.11 13 | 14 | def load_and_convert_depth(depth_img, info_depth): 15 | with open(info_depth) as f: 16 | data = yaml.safe_load(f) 17 | mind, maxd = data["normalization"]["min"], data["normalization"]["max"] 18 | 19 | dimg = iio.imread(depth_img) 20 | dimg = dimg.astype(np.float32) 21 | dimg = dimg / 65535.0 * (maxd - mind) + mind 22 | return dimg 23 | 24 | def depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length): 25 | # input depth map (in meters) --- cfr previous section 26 | depth_mt = load_and_convert_depth(depth_img, info_depth) 27 | 28 | # input pose 29 | pose = np.loadtxt(pose_txt) 30 | 31 | # camera intrinsics 32 | height, width = depth_mt.shape[:2] 33 | intrinsics_4x4 = np.array([ 34 | [focal_length, 0, width / 2, 0], 35 | [0, focal_length, height / 2, 0], 36 | [0, 0, 1, 0], 37 | [0, 0, 0, 1]] 38 | ) 39 | 40 | # build the camera projection matrix 41 | camera_proj = intrinsics_4x4 @ pose 42 | 43 | # build the (u, v, 1, 1/depth) vectors (non optimized version) 44 | camera_vectors = np.zeros((width * height, 4)) 45 | count=0 46 | for j in range(height): 47 | for i in range(width): 48 | camera_vectors[count, :] = np.array([i, j, 1, 1/depth_mt[j, i]]) 49 | count += 1 50 | 51 | # invert and apply to each 4-vector 52 | hom_3d_pts= np.linalg.inv(camera_proj) @ camera_vectors.T 53 | # print(hom_3d_pts.shape) 54 | # remove the homogeneous coordinate 55 | pcd = depth_mt.reshape(-1, 1) * hom_3d_pts.T 56 | return pcd[:, :3] 57 | 58 | def remove_point_cloud_background(pc): 59 | 60 | # The second dim is z 61 | dz = pc[256,1] - pc[-256,1] 62 | dy = pc[256,2] - pc[-256,2] 63 | 64 | norm = math.sqrt(dz**2 + dy**2) 65 | start_points = np.array([0, pc[-256, 1], pc[-256, 2]]) 66 | cos_theta = dy / norm 67 | sin_theta = dz / norm 68 | 69 | # Transform and rotation 70 | rotation_matrix = np.array([[1, 0, 0], [0, cos_theta, -sin_theta],[0, sin_theta, cos_theta]]) 71 | processed_pc = (rotation_matrix @ (pc - start_points).T).T 72 | 73 | # Remove background point 74 | for i in range(processed_pc.shape[0]): 75 | if processed_pc[i,1] > -0.02: 76 | processed_pc[i, :] = -start_points 77 | if processed_pc[i,2] > 1.8: 78 | processed_pc[i, :] = -start_points 79 | elif processed_pc[i,0] > 1 or processed_pc[i,0] < -1: 80 | processed_pc[i, :] = -start_points 81 | 82 | processed_pc = (rotation_matrix.T @ processed_pc.T).T + start_points 83 | 84 | index = [0, 2, 1] 85 | processed_pc = processed_pc[:,index] 86 | return processed_pc*[0.1, -0.1, 0.1] 87 | 88 | 89 | if __name__ == '__main__': 90 | 91 | parser = argparse.ArgumentParser(description='Process some integers.') 92 | parser.add_argument('--dataset_path', default='datasets/eyecandies', type=str, help="Original Eyecandies dataset path.") 93 | parser.add_argument('--target_dir', default='datasets/eyecandies_preprocessed', type=str, help="Processed Eyecandies dataset path") 94 | args = parser.parse_args() 95 | 96 | os.mkdir(args.target_dir) 97 | categories_list = os.listdir(args.dataset_path) 98 | 99 | for category_dir in categories_list: 100 | category_root_path = os.path.join(args.dataset_path, category_dir) 101 | 102 | category_train_path = os.path.join(category_root_path, '/train/data') 103 | category_test_path = os.path.join(category_root_path, '/test_public/data') 104 | 105 | category_target_path = os.path.join(args.target_dir, category_dir) 106 | os.mkdir(category_target_path) 107 | 108 | os.mkdir(os.path.join(category_target_path, 'train')) 109 | category_target_train_good_path = os.path.join(category_target_path, 'train/good') 110 | category_target_train_good_rgb_path = os.path.join(category_target_train_good_path, 'rgb') 111 | category_target_train_good_xyz_path = os.path.join(category_target_train_good_path, 'xyz') 112 | os.mkdir(category_target_train_good_path) 113 | os.mkdir(category_target_train_good_rgb_path) 114 | os.mkdir(category_target_train_good_xyz_path) 115 | 116 | os.mkdir(os.path.join(category_target_path, 'test')) 117 | category_target_test_good_path = os.path.join(category_target_path, 'test/good') 118 | category_target_test_good_rgb_path = os.path.join(category_target_test_good_path, 'rgb') 119 | category_target_test_good_xyz_path = os.path.join(category_target_test_good_path, 'xyz') 120 | category_target_test_good_gt_path = os.path.join(category_target_test_good_path, 'gt') 121 | os.mkdir(category_target_test_good_path) 122 | os.mkdir(category_target_test_good_rgb_path) 123 | os.mkdir(category_target_test_good_xyz_path) 124 | os.mkdir(category_target_test_good_gt_path) 125 | category_target_test_bad_path = os.path.join(category_target_path, 'test/bad') 126 | category_target_test_bad_rgb_path = os.path.join(category_target_test_bad_path, 'rgb') 127 | category_target_test_bad_xyz_path = os.path.join(category_target_test_bad_path, 'xyz') 128 | category_target_test_bad_gt_path = os.path.join(category_target_test_bad_path, 'gt') 129 | os.mkdir(category_target_test_bad_path) 130 | os.mkdir(category_target_test_bad_rgb_path) 131 | os.mkdir(category_target_test_bad_xyz_path) 132 | os.mkdir(category_target_test_bad_gt_path) 133 | 134 | category_train_files = os.listdir(category_train_path) 135 | num_train_files = len(category_train_files)//17 136 | for i in range(0, num_train_files): 137 | pc = depth_to_pointcloud( 138 | os.path.join(category_train_path,str(i).zfill(3)+'_depth.png'), 139 | os.path.join(category_train_path,str(i).zfill(3)+'_info_depth.yaml'), 140 | os.path.join(category_train_path,str(i).zfill(3)+'_pose.txt'), 141 | FOCAL_LENGTH, 142 | ) 143 | pc = remove_point_cloud_background(pc) 144 | pc = pc.reshape(512,512,3) 145 | tifffile.imwrite(os.path.join(category_target_train_good_xyz_path, str(i).zfill(3)+'.tiff'), pc) 146 | copyfile(os.path.join(category_train_path,str(i).zfill(3)+'_image_4.png'),os.path.join(category_target_train_good_rgb_path, str(i).zfill(3)+'.png')) 147 | 148 | 149 | category_test_files = os.listdir(category_test_path) 150 | num_test_files = len(category_test_files)//17 151 | for i in range(0, num_test_files): 152 | mask = cv2.imread(os.path.join(category_test_path,str(i).zfill(2)+'_mask.png')) 153 | if np.any(mask): 154 | pc = depth_to_pointcloud( 155 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'), 156 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'), 157 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'), 158 | FOCAL_LENGTH, 159 | ) 160 | pc = remove_point_cloud_background(pc) 161 | pc = pc.reshape(512,512,3) 162 | tifffile.imwrite(os.path.join(category_target_test_bad_xyz_path, str(i).zfill(3)+'.tiff'), pc) 163 | cv2.imwrite(os.path.join(category_target_test_bad_gt_path, str(i).zfill(3)+'.png'), mask) 164 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_bad_rgb_path, str(i).zfill(3)+'.png')) 165 | else: 166 | pc = depth_to_pointcloud( 167 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'), 168 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'), 169 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'), 170 | FOCAL_LENGTH, 171 | ) 172 | pc = remove_point_cloud_background(pc) 173 | pc = pc.reshape(512,512,3) 174 | tifffile.imwrite(os.path.join(category_target_test_good_xyz_path, str(i).zfill(3)+'.tiff'), pc) 175 | cv2.imwrite(os.path.join(category_target_test_good_gt_path, str(i).zfill(3)+'.png'), mask) 176 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_good_rgb_path, str(i).zfill(3)+'.png')) 177 | -------------------------------------------------------------------------------- /processing/preprocess_mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tifffile as tiff 4 | import open3d as o3d 5 | from pathlib import Path 6 | from PIL import Image 7 | import math 8 | import utils.mvtec3d_utils as mvt_util 9 | import argparse 10 | 11 | 12 | def get_edges_of_pc(organized_pc): 13 | unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2]) 14 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0) 15 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0) 16 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0) 17 | unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:] 18 | return unorganized_edges_pc 19 | 20 | def get_plane_eq(unorganized_pc,ransac_n_pts=50): 21 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc)) 22 | plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000) 23 | return plane_model 24 | 25 | def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005): 26 | # PREP PC 27 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean) 28 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) 29 | clean_planeless_unorganized_pc = unorganized_pc.copy() 30 | planeless_unorganized_rgb = unorganized_rgb.copy() 31 | 32 | # REMOVE PLANE 33 | plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean)) 34 | distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T)) 35 | plane_indices = np.argwhere(distances < distance_threshold) 36 | 37 | planeless_unorganized_rgb[plane_indices] = 0 38 | clean_planeless_unorganized_pc[plane_indices] = 0 39 | clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0], 40 | organized_pc_clean.shape[1], 41 | organized_pc_clean.shape[2]) 42 | planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0], 43 | organized_rgb.shape[1], 44 | organized_rgb.shape[2]) 45 | return clean_planeless_organized_pc, planeless_organized_rgb 46 | 47 | 48 | 49 | def connected_components_cleaning(organized_pc, organized_rgb, image_path): 50 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc) 51 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) 52 | 53 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] 54 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :] 55 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros)) 56 | labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False)) 57 | 58 | 59 | unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True) 60 | max_label = labels.max() 61 | if max_label>0: 62 | print("##########################################################################") 63 | print(f"Point cloud file {image_path} has {max_label + 1} clusters") 64 | print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}") 65 | print("##########################################################################\n\n") 66 | 67 | largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)] 68 | outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id) 69 | outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array] 70 | unorganized_pc[outlier_indices_original_pc_array] = 0 71 | unorganized_rgb[outlier_indices_original_pc_array] = 0 72 | organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0], 73 | organized_pc.shape[1], 74 | organized_pc.shape[2]) 75 | organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0], 76 | organized_rgb.shape[1], 77 | organized_rgb.shape[2]) 78 | return organized_clustered_pc, organized_clustered_rgb 79 | 80 | def roundup_next_100(x): 81 | return int(math.ceil(x / 100.0)) * 100 82 | 83 | def pad_cropped_pc(cropped_pc, single_channel=False): 84 | orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1] 85 | round_orig_h = roundup_next_100(orig_h) 86 | round_orig_w = roundup_next_100(orig_w) 87 | large_side = max(round_orig_h, round_orig_w) 88 | 89 | a = (large_side - orig_h) // 2 90 | aa = large_side - a - orig_h 91 | 92 | b = (large_side - orig_w) // 2 93 | bb = large_side - b - orig_w 94 | if single_channel: 95 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant') 96 | else: 97 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant') 98 | 99 | def preprocess_pc(tiff_path): 100 | # READ FILES 101 | organized_pc = mvt_util.read_tiff_organized_pc(tiff_path) 102 | rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "png") 103 | gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png") 104 | organized_rgb = np.array(Image.open(rgb_path)) 105 | 106 | organized_gt = None 107 | gt_exists = os.path.isfile(gt_path) 108 | if gt_exists: 109 | organized_gt = np.array(Image.open(gt_path)) 110 | 111 | # REMOVE PLANE 112 | planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb) 113 | 114 | 115 | # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE) 116 | padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False) 117 | padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False) 118 | #if gt_exists: 119 | # padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True) 120 | 121 | organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path) 122 | # SAVE PREPROCESSED FILES 123 | tiff.imsave(tiff_path, organized_clustered_pc) 124 | #Image.fromarray(organized_clustered_rgb).save(rgb_path) 125 | #if gt_exists: 126 | # Image.fromarray(padded_organized_gt).save(gt_path) 127 | 128 | 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD') 132 | parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)') 133 | args = parser.parse_args() 134 | 135 | 136 | root_path = args.dataset_path 137 | paths = Path(root_path).rglob('*.tiff') 138 | print(f"Found {len(list(paths))} tiff files in {root_path}") 139 | processed_files = 0 140 | for path in Path(root_path).rglob('*.tiff'): 141 | preprocess_pc(path) 142 | processed_files += 1 143 | if processed_files % 50 == 0: 144 | print(f"Processed {processed_files} tiff files...") 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.26.0 2 | matplotlib==3.8.3 3 | numpy==1.23.1 4 | open3d==0.18.0 5 | opencv_contrib_python==4.9.0.80 6 | opencv_python==4.7.0.72 7 | opencv_python_headless==4.9.0.80 8 | pandas==1.5.3 9 | Pillow==9.4.0 10 | Pillow==10.3.0 11 | pointnet2_ops==3.0.0 12 | PyYAML==6.0.1 13 | PyYAML==6.0.1 14 | scikit_learn==1.2.1 15 | scipy==1.9.1 16 | tifffile==2023.4.12 17 | timm==1.0.3 18 | torch==1.13.1 19 | torchvision==0.14.1 20 | tqdm==4.65.0 21 | wandb==0.15.5 22 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from torchvision import transforms 5 | from PIL import ImageFilter 6 | 7 | def set_seeds(seed: int = 0) -> None: 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | torch.manual_seed(seed) 11 | 12 | class KNNGaussianBlur(torch.nn.Module): 13 | def __init__(self, radius : int = 4): 14 | super().__init__() 15 | self.radius = radius 16 | self.unload = transforms.ToPILImage() 17 | self.load = transforms.ToTensor() 18 | self.blur_kernel = ImageFilter.GaussianBlur(radius=self.radius) 19 | 20 | def __call__(self, img): 21 | img = img.unsqueeze(dim=0) 22 | map_max = img.max() 23 | final_map = self.load(self.unload(img[0] / map_max).filter(self.blur_kernel)).to('cuda') * map_max 24 | return final_map.squeeze() 25 | 26 | class SquarePad: 27 | def __call__(self, image): 28 | max_wh = max(image.size) 29 | p_left, p_top = [(max_wh - s) // 2 for s in image.size] 30 | p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])] 31 | padding = (p_left, p_top, p_right, p_bottom) 32 | return transforms.functional.pad(image, padding, padding_mode = 'edge') -------------------------------------------------------------------------------- /utils/metrics_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code based on the official MVTec 3D-AD evaluation code found at 3 | https://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz 4 | 5 | Utility functions that compute a PRO curve and its definite integral, given 6 | pairs of anomaly and ground truth maps. 7 | 8 | The PRO curve can also be integrated up to a constant integration limit. 9 | """ 10 | import numpy as np 11 | import sklearn 12 | from scipy.ndimage.measurements import label 13 | from bisect import bisect 14 | 15 | 16 | class GroundTruthComponent: 17 | """ 18 | Stores sorted anomaly scores of a single ground truth component. 19 | Used to efficiently compute the region overlap for many increasing thresholds. 20 | """ 21 | 22 | def __init__(self, anomaly_scores): 23 | """ 24 | Initialize the module. 25 | 26 | Args: 27 | anomaly_scores: List of all anomaly scores within the ground truth 28 | component as numpy array. 29 | """ 30 | # Keep a sorted list of all anomaly scores within the component. 31 | self.anomaly_scores = anomaly_scores.copy() 32 | self.anomaly_scores.sort() 33 | 34 | # Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels. 35 | self.index = 0 36 | 37 | # The last evaluated threshold. 38 | self.last_threshold = None 39 | 40 | def compute_overlap(self, threshold): 41 | """ 42 | Compute the region overlap for a specific threshold. 43 | Thresholds must be passed in increasing order. 44 | 45 | Args: 46 | threshold: Threshold to compute the region overlap. 47 | 48 | Returns: 49 | Region overlap for the specified threshold. 50 | """ 51 | if self.last_threshold is not None: 52 | assert self.last_threshold <= threshold 53 | 54 | # Increase the index until it points to an anomaly score that is just above the specified threshold. 55 | while (self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold): 56 | self.index += 1 57 | 58 | # Compute the fraction of component pixels that are correctly segmented as anomalous. 59 | return 1.0 - self.index / len(self.anomaly_scores) 60 | 61 | 62 | def trapezoid(x, y, x_max=None): 63 | """ 64 | This function calculates the definit integral of a curve given by x- and corresponding y-values. 65 | In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by 66 | setting a value x_max. 67 | 68 | Points that do not have a finite x or y value will be ignored with a warning. 69 | 70 | Args: 71 | x: Samples from the domain of the function to integrate need to be sorted in ascending order. May contain 72 | the same value multiple times. In that case, the order of the corresponding y values will affect the 73 | integration with the trapezoidal rule. 74 | y: Values of the function corresponding to x values. 75 | x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its 76 | neighbors. Must not lie outside of the range of x. 77 | 78 | Returns: 79 | Area under the curve. 80 | """ 81 | 82 | x = np.array(x) 83 | y = np.array(y) 84 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y)) 85 | if not finite_mask.all(): 86 | print( 87 | """WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.""") 88 | x = x[finite_mask] 89 | y = y[finite_mask] 90 | 91 | # Introduce a correction term if max_x is not an element of x. 92 | correction = 0. 93 | if x_max is not None: 94 | if x_max not in x: 95 | # Get the insertion index that would keep x sorted after np.insert(x, ins, x_max). 96 | ins = bisect(x, x_max) 97 | # x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x). 98 | assert 0 < ins < len(x) 99 | 100 | # Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not 101 | # know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1]. 102 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1])) 103 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1]) 104 | 105 | # Cut off at x_max. 106 | mask = x <= x_max 107 | x = x[mask] 108 | y = y[mask] 109 | 110 | # Return area under the curve using the trapezoidal rule. 111 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction 112 | 113 | 114 | def collect_anomaly_scores(anomaly_maps, ground_truth_maps): 115 | """ 116 | Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false 117 | positive pixel from anomaly maps. 118 | 119 | Args: 120 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel. 121 | 122 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels 123 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains 124 | an anomaly. 125 | 126 | Returns: 127 | ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each 128 | component, a sorted list of its anomaly scores is stored. 129 | 130 | anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list 131 | can be used to quickly select thresholds that fix a certain false positive rate. 132 | """ 133 | # Make sure an anomaly map is present for each ground truth map. 134 | assert len(anomaly_maps) == len(ground_truth_maps) 135 | 136 | # Initialize ground truth components and scores of potential fp pixels. 137 | ground_truth_components = [] 138 | anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size) 139 | 140 | # Structuring element for computing connected components. 141 | structure = np.ones((3, 3), dtype=int) 142 | 143 | # Collect anomaly scores within each ground truth region and for all potential fp pixels. 144 | ok_index = 0 145 | for gt_map, prediction in zip(ground_truth_maps, anomaly_maps): 146 | 147 | # Compute the connected components in the ground truth map. 148 | labeled, n_components = label(gt_map, structure) 149 | 150 | # Store all potential fp scores. 151 | num_ok_pixels = len(prediction[labeled == 0]) 152 | anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy() 153 | ok_index += num_ok_pixels 154 | 155 | # Fetch anomaly scores within each GT component. 156 | for k in range(n_components): 157 | component_scores = prediction[labeled == (k + 1)] 158 | ground_truth_components.append(GroundTruthComponent(component_scores)) 159 | 160 | # Sort all potential false positive scores. 161 | anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index) 162 | anomaly_scores_ok_pixels.sort() 163 | 164 | return ground_truth_components, anomaly_scores_ok_pixels 165 | 166 | 167 | def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds): 168 | """ 169 | Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground 170 | truth maps. The number of interpolation points can be set manually. 171 | 172 | Args: 173 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel. 174 | 175 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels 176 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains 177 | an anomaly. 178 | 179 | num_thresholds: Number of thresholds to compute the PRO curve. 180 | Returns: 181 | fprs: List of false positive rates. 182 | pros: List of correspoding PRO values. 183 | """ 184 | # Fetch sorted anomaly scores. 185 | ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps) 186 | 187 | # Select equidistant thresholds. 188 | threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int) 189 | 190 | fprs = [1.0] 191 | pros = [1.0] 192 | thr = [0.0] 193 | 194 | for pos in threshold_positions: 195 | threshold = anomaly_scores_ok_pixels[pos] 196 | 197 | # Compute the false positive rate for this threshold. 198 | fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels) 199 | 200 | # Compute the PRO value for this threshold. 201 | pro = 0.0 202 | for component in ground_truth_components: 203 | pro += component.compute_overlap(threshold) 204 | pro /= len(ground_truth_components) 205 | 206 | fprs.append(fpr) 207 | pros.append(pro) 208 | thr.append(threshold) 209 | 210 | # Return (FPR/PRO) pairs in increasing FPR order. 211 | fprs = fprs[::-1] 212 | pros = pros[::-1] 213 | thr = thr[::-1] 214 | 215 | return fprs, pros, thr 216 | 217 | 218 | def calculate_au_pro(gts, predictions, integration_limit = [0.3, 0.1, 0.05, 0.01], num_thresholds = 99): 219 | """ 220 | Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images. 221 | Args: 222 | gts: List of tensors that contain the ground truth images for a single dataset object. 223 | predictions: List of tensors containing anomaly images for each ground truth image. 224 | integration_limit: Integration limit to use when computing the area under the PRO curve. 225 | num_thresholds: Number of thresholds to use to sample the area under the PRO curve. 226 | 227 | Returns: 228 | au_pro: Area under the PRO curve computed up to the given integration limit. 229 | pro_curve: PRO curve values for localization (fpr,pro). 230 | """ 231 | # Compute the PRO curve. 232 | pro_curve = compute_pro(anomaly_maps = predictions, ground_truth_maps = gts, num_thresholds = num_thresholds) 233 | 234 | au_pros = [] 235 | 236 | # Compute the area under the PRO curve. 237 | for int_lim in integration_limit: 238 | au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max = int_lim) 239 | au_pro /= int_lim 240 | 241 | au_pros.append(au_pro) 242 | 243 | # Return the evaluation metrics. 244 | return au_pros, pro_curve 245 | 246 | 247 | def calculate_au_prc(gts, predictions): 248 | """ 249 | Compute the area under the PRC curve for a set of ground truth images and corresponding anomaly images. 250 | Args: 251 | gts: List of tensors that contain the ground truth images for a single dataset object. 252 | predictions: List of tensors containing anomaly images for each ground truth image. 253 | Returns: 254 | au_prc: Area under the PRC curve. 255 | """ 256 | # Compute the PRC curve. 257 | fpr, tpr, _ = sklearn.metrics.roc_curve(gts, predictions) 258 | au_prc = sklearn.metrics.auc(fpr, tpr) 259 | 260 | return au_prc -------------------------------------------------------------------------------- /utils/mvtec3d_utils.py: -------------------------------------------------------------------------------- 1 | import tifffile as tiff 2 | import torch 3 | 4 | 5 | def organized_pc_to_unorganized_pc(organized_pc): 6 | return organized_pc.reshape(organized_pc.shape[0] * organized_pc.shape[1], organized_pc.shape[2]) 7 | 8 | 9 | def read_tiff_organized_pc(path): 10 | tiff_img = tiff.imread(path) 11 | return tiff_img 12 | 13 | 14 | def resize_organized_pc(organized_pc, target_height=224, target_width=224, tensor_out=True): 15 | torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous() 16 | torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=(target_height, target_width), 17 | mode='nearest') 18 | if tensor_out: 19 | return torch_resized_organized_pc.squeeze(dim=0).contiguous() 20 | else: 21 | return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy() 22 | 23 | 24 | def organized_pc_to_depth_map(organized_pc): 25 | return organized_pc[:, :, 2] -------------------------------------------------------------------------------- /utils/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | 20 | def square_distance(src, dst): 21 | """ 22 | Calculate Euclid distance between each two points. 23 | src^T * dst = xn * xm + yn * ym + zn * zm; 24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 28 | Input: 29 | src: source points, [B, N, C] 30 | dst: target points, [B, M, C] 31 | Output: 32 | dist: per-point square distance, [B, N, M] 33 | """ 34 | B, N, _ = src.shape 35 | _, M, _ = dst.shape 36 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 37 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 38 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 39 | return dist 40 | 41 | 42 | def index_points(points, idx): 43 | """ 44 | Input: 45 | points: input points data, [B, N, C] 46 | idx: sample index data, [B, S] 47 | Return: 48 | new_points:, indexed points data, [B, S, C] 49 | """ 50 | 51 | device = points.device 52 | B = points.shape[0] 53 | view_shape = list(idx.shape) 54 | view_shape[1:] = [1] * (len(view_shape) - 1) 55 | repeat_shape = list(idx.shape) 56 | repeat_shape[0] = 1 57 | batch_indices = torch.arange(B, dtype=torch.long, device=device).view(view_shape).repeat(repeat_shape) 58 | new_points = points[batch_indices, idx, :] 59 | 60 | return new_points 61 | 62 | def farthest_point_sample(xyz, npoint): 63 | """ 64 | Input: 65 | xyz: pointcloud data, [B, N, 3] 66 | npoint: number of samples 67 | Return: 68 | centroids: sampled pointcloud index, [B, npoint] 69 | """ 70 | device = xyz.device 71 | B, N, C = xyz.shape 72 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 73 | distance = torch.ones(B, N).to(device) * 1e10 74 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 75 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 76 | for i in range(npoint): 77 | centroids[:, i] = farthest 78 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 79 | dist = torch.sum((xyz - centroid) ** 2, -1) 80 | mask = dist < distance 81 | distance[mask] = dist[mask] 82 | farthest = torch.max(distance, -1)[1] 83 | return centroids 84 | 85 | 86 | def query_ball_point(radius, nsample, xyz, new_xyz): 87 | """ 88 | Input: 89 | radius: local region radius 90 | nsample: max sample number in local region 91 | xyz: all points, [B, N, 3] 92 | new_xyz: query points, [B, S, 3] 93 | Return: 94 | group_idx: grouped points index, [B, S, nsample] 95 | """ 96 | device = xyz.device 97 | B, N, C = xyz.shape 98 | _, S, _ = new_xyz.shape 99 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 100 | sqrdists = square_distance(new_xyz, xyz) 101 | group_idx[sqrdists > radius ** 2] = N 102 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 103 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 104 | mask = group_idx == N 105 | group_idx[mask] = group_first[mask] 106 | return group_idx 107 | 108 | 109 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 110 | """ 111 | Input: 112 | npoint: 113 | radius: 114 | nsample: 115 | xyz: input points position data, [B, N, 3] 116 | points: input points data, [B, N, D] 117 | Return: 118 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 119 | new_points: sampled points data, [B, npoint, nsample, 3+D] 120 | """ 121 | B, N, C = xyz.shape 122 | S = npoint 123 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 124 | new_xyz = index_points(xyz, fps_idx) 125 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 126 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 127 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 128 | 129 | if points is not None: 130 | grouped_points = index_points(points, idx) 131 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 132 | else: 133 | new_points = grouped_xyz_norm 134 | if returnfps: 135 | return new_xyz, new_points, grouped_xyz, fps_idx 136 | else: 137 | return new_xyz, new_points 138 | 139 | 140 | def sample_and_group_all(xyz, points): 141 | """ 142 | Input: 143 | xyz: input points position data, [B, N, 3] 144 | points: input points data, [B, N, D] 145 | Return: 146 | new_xyz: sampled points position data, [B, 1, 3] 147 | new_points: sampled points data, [B, 1, N, 3+D] 148 | """ 149 | device = xyz.device 150 | B, N, C = xyz.shape 151 | new_xyz = torch.zeros(B, 1, C).to(device) 152 | grouped_xyz = xyz.view(B, 1, N, C) 153 | if points is not None: 154 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 155 | else: 156 | new_points = grouped_xyz 157 | return new_xyz, new_points 158 | 159 | 160 | def interpolating_points(xyz1, xyz2, points2): 161 | """ 162 | Input: 163 | xyz1: input points position data, [B, C, N] 164 | xyz2: sampled input points position data, [B, C, S] 165 | points2: input points data, [B, D, S] 166 | Return: 167 | new_points: upsampled points data, [B, D', N] 168 | """ 169 | xyz1 = xyz1.permute(0, 2, 1) 170 | xyz2 = xyz2.permute(0, 2, 1) 171 | 172 | points2 = points2.permute(0, 2, 1) 173 | B, N, C = xyz1.shape 174 | _, S, _ = xyz2.shape 175 | 176 | if S == 1: 177 | interpolated_points = points2.repeat(1, N, 1) 178 | else: 179 | dists = square_distance(xyz1, xyz2) 180 | dists, idx = dists.sort(dim=-1) 181 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 182 | dist_recip = 1.0 / (dists + 1e-8) 183 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 184 | weight = dist_recip / norm 185 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 186 | interpolated_points = interpolated_points.permute(0, 2, 1) 187 | 188 | return interpolated_points --------------------------------------------------------------------------------