├── .gitignore
├── 01_train_mvtec.sh
├── 02_eval_mvtec.sh
├── 03_train_eyecandies.sh
├── 04_eval_eyecandies.sh
├── LICENSE.txt
├── README.md
├── cfm_inference.py
├── cfm_training.py
├── images
└── architecture.jpg
├── models
├── dataset.py
├── feature_transfer_nets.py
├── features.py
└── full_models.py
├── processing
├── aggregate_results.py
├── preprocess_eyecandies.py
└── preprocess_mvtec.py
├── requirements.txt
└── utils
├── general_utils.py
├── metrics_utils.py
├── mvtec3d_utils.py
└── pointnet2_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | models/__pycache__
3 | utils/__pycache__
4 | processing/__pycache__
5 | datasets
6 | checkpoints
7 | wandb
8 | results
--------------------------------------------------------------------------------
/01_train_mvtec.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | epochs=50
4 | batch_size=4
5 |
6 | class_names=("bagel" "cable_gland" "carrot" "cookie" "dowel" "foam" "peach" "potato" "rope" "tire")
7 |
8 | for class_name in "${class_names[@]}"
9 | do
10 | python cfm_training.py --class_name $class_name --epochs_no $epochs --batch_size $batch_size
11 | done
--------------------------------------------------------------------------------
/02_eval_mvtec.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | epochs=250
4 | batch_size=4
5 |
6 | class_names=("bagel" "cable_gland" "carrot" "cookie" "dowel" "foam" "peach" "potato" "rope" "tire")
7 |
8 | for class_name in "${class_names[@]}"
9 | do
10 | python cfm_inference.py --class_name $class_name --epochs_no $epochs --batch_size $batch_size
11 | done
--------------------------------------------------------------------------------
/03_train_eyecandies.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | dataset_path=datasets/eyecandies
4 | checkpoint_savepath=models/checkpoints_CFM_eyecandies
5 | epochs=50
6 | batch_size=2
7 |
8 | class_names=("CandyCane" "ChocolateCookie" "ChocolatePraline" "Confetto" "GummyBear" "HazelnutTruffle" "LicoriceSandwich" "Lollipop" "Marshmallow" "PeppermintCandy")
9 |
10 | for class_name in "${class_names[@]}"
11 | do
12 | python cfm_training.py --class_name $class_name --model_type $model --epochs_no $epochs --batch_size $batch_size --dataset_path $dataset_path --checkpoint_savepath $checkpoint_savepath
13 | done
--------------------------------------------------------------------------------
/04_eval_eyecandies.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | dataset_path=datasets/eyecandies
4 | checkpoint_savepath=models/checkpoints_CFM_eyecandies
5 | epochs=50
6 | batch_size=2
7 |
8 | quantitative_folder=results/quantitatives_eyecandies
9 |
10 | class_names=("CandyCane" "ChocolateCookie" "ChocolatePraline" "Confetto" "GummyBear" "HazelnutTruffle" "LicoriceSandwich" "Lollipop" "Marshmallow" "PeppermintCandy")
11 |
12 | for class_name in "${class_names[@]}"
13 | do
14 | python cfm_inference.py --class_name $class_name --model_type $model --epochs_no $epochs --batch_size $batch_size --dataset_path $dataset_path --checkpoint_folder $checkpoint_savepath --quantitative_folder $quantitative_folder --produce_qualitatives
15 | done
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | This Software is licensed under the terms of the following license
2 | which allows for non-commercial use only. For any other use of the software not
3 | covered by the terms of this license, please contact Pierluigi Zama Ramirez (pierluigi.zama@unibo.it)
4 |
5 | Article 1 – Definitions.
6 | 1. The following terms, as used herein, shall have the following meanings:
7 | - "Licensee" shall mean the person or organization agreeing to use the Software
8 | - "Licensor" shall mean Alma Mater Studiorum – Università di Bologna, organized and existing under the laws of Italy, whose principal place of business is via Zamboni 33 – 40126 Bologna (Italy).
9 | - "Software" shall mean software Crossomodal Feature Mappings (hereinafter referred to as “CFM”) for neural networks trained to detect and segment anomalies on multimodal data. This SW is the implementation of the approach described in “A. Costanzino, P. Zama Ramirez, G. Lisanti, L. Di Stefano. Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping. The IEEE/CVF Conference on Computer Vision and Pattern Recognition 2024. as uploaded by Licensor to the github repository at https://github.com/CVLAB-Unibo/crossmodal-feature-mapping on May 22 2024, in source code or object code form and any accompanying documentation as well as any modifications or additions uploaded to the same github repository by Licensor.
10 |
11 | Article 2 - Grant and scope.
12 | 1. Licensor grants Licensee a non-exclusive license authorizing Licensee to reproduce and use the Software for research and teaching activities only, with exclusion of any directly or indirectly commercial activities or purposes, without territorial limitation
13 | (hereinafter referred to as "License").
14 | 2. The License extends to the source code of software CFM.
15 | 3. The License is non-assignable even if Licensee were to cease all activities.
16 | 4. Except as specifically set forth in this Agreement, Licensee acknowledges that this Agreement does not grant Licensee any use or rights to the Software, including, but not limited to, any author’s economic rights on the Software.
17 | 6. Licensee shall be solely responsible towards any user, specifically with regard to the installation, maintenance and documentation relative to the Software as well as the training of its users.
18 | 7. Licensor will keep the faculty to develop new versions of Software with no obligations to provide these new versions to Licensee, as well to use Software for research activities.
19 | 8. Licensor is under no obligation to create any upgrades or enhancements to the Software.
20 |
21 | Article 3 – No Fee.
22 | 1. The License is free of charge and no payment will be due by Licensee to Licensor for research purposes.
23 |
24 | Article 4 – Markings.
25 | 1. Licensee shall place the following text in all documentation and information it publishes or distributes in connection with the Software:
26 | "Software developed by Alma Mater Studiorum - University of Bologna (Italy) and SACMI IMOLA S.C. (Italy). Alma Mater Studiorum - University of Bologna (Italy) and SACMI IMOLA S.C. (Italy), makes no warranties of any kind on the software and shall in no event be liable for damages of any kind in connection with the use and exploitation of the software.”
27 |
28 | Article 5 - Limited liability.
29 | 1. To the extent permitted by applicable law, Licensor shall not be held liable for any damages resulting from the use of the Software by Licensee.
30 | 4. Licensee shall be solely responsible for any claims of third parties in connection with use of Software.
31 | 5. Licensee shall indemnify, defend and hold Licensor harmless against all claims by third parties arising out of any damage caused by such use.
32 | 6. The License extends to the Software "as is" and as transferred by Licensor no later than on the date of signature of this Agreement and without warranties as to the proper functioning of the Software or its functions, its fitness for a particular purpose or its merchantability.
33 | 7. To the extent permitted by applicable law, Licensor makes no warranty that the practice of the License does not infringe any trademark, patent, copyright, or similar rights of any third party.
34 |
35 | Article 6 - Claims by third parties and Infringements.
36 | 1. Licensor shall not be obligated to take measures to defend the Software against any claims made by third parties nor in case of infringements of the Software by third parties.
37 | 2. Should claims by third parties result in a limitation of the rights granted under the License, then the parties shall negotiate to adapt the provisions of this Agreement accordingly, whereby Licensee hereby waives any further claim or remedy towards Licensor.
38 | 3. Should claims by third parties result in the prohibition to use the Software, then Licensor shall be entitled to terminate this Agreement without indemnification.
39 |
40 | Article 7 - Duration and termination.
41 | 1. This Agreement is effective as long as it is available for download from the repository.
42 | 2. Upon breach or default of this Agreement by either party, a minimum two months' notice will be given by the other party to cure such breach or default. If at the end of such period the breach or default has not been cured, the other party will be entitled to terminate this Agreement by giving a one-month's notice.
43 | 3. At the end of this Agreement, Licensee shall cease any use of the Software and delete any copy of it, unless the copyrights to the Software have expired.
44 |
45 | Article 8 – Warranties and maintenance.
46 | 1. Licensor does not warrant that the Software will operate error-free.
47 | 2. No debugging service or any maintenance assistance will be provided by Licensor to Licensee for the Software.
48 | 3. Licensor will not be responsible for maintaining Licensee-modified portions of the Software.
49 |
50 |
51 | Article 10 - Nature of the Agreement.
52 | 1. Nothing herein shall be deemed to constitute either party as the agent or representative of the other party, or both parties as joint venturers or partners for any purposes.
53 | 2. Neither party shall be responsible for the acts or omissions of the other party, and neither party will have authority to speak for, represent or obligate the other party in any way without prior written authority from the other party.
54 |
55 | Article 11 - Applicable law and Jurisdiction.
56 | 1. This Agreement is governed by Italian law.
57 | 2. Any claims or disputes arising in connection with this Agreement, whether before or after it has expired, shall be brought before the court of Bologna, Italy.
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping (CVPR 2024)
3 |
4 |
5 |
6 |
7 | :rotating_light: This repository contains download links to the datasets, code snippets, and checkpoints of our work "**Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping**", [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024)
8 |
9 | by [Alex Costanzino*](https://alex-costanzino.github.io/), [Pierluigi Zama Ramirez*](https://pierlui92.github.io/), [Giuseppe Lisanti](https://www.unibo.it/sitoweb/giuseppe.lisanti), and [Luigi Di Stefano](https://www.unibo.it/sitoweb/luigi.distefano). \* _Equal Contribution_
10 |
11 | University of Bologna
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | [Project Page](https://cvlab-unibo.github.io/CrossmodalFeatureMapping/) | [Paper](https://arxiv.org/abs/2312.04521)
20 |
21 |
22 |
23 | ## :bookmark_tabs: Table of Contents
24 |
25 | 1. [Introduction](#clapper-introduction)
26 | 2. [Datasets](#file_cabinet)
27 | 3. [Checkpoints](#inbox_tray)
28 | 4. [Code](#memo-code)
29 | 6. [Contacts](#envelope-contacts)
30 |
31 |
32 |
33 | ## :clapper: Introduction
34 | Recent advancements have shown the potential of leveraging both point clouds and images to localize anomalies.
35 | Nevertheless, their applicability in industrial manufacturing is often constrained by significant drawbacks, such as the use of memory banks, which lead to a substantial increase in terms of memory footprint and inference time.
36 | We propose a novel light and fast framework that learns to map features from one modality to the other on nominal samples and detect anomalies by pinpointing inconsistencies between observed and mapped features.
37 | Extensive experiments show that our approach achieves state-of-the-art detection and segmentation performance, in both the standard and few-shot settings, on the MVTec 3D-AD dataset while achieving faster inference and occupying less memory than previous multimodal AD methods.
38 | Furthermore, we propose a layer pruning technique to improve memory and time efficiency with a marginal sacrifice in performance.
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | :fountain_pen: If you find this code useful in your research, please cite:
47 |
48 | ```bibtex
49 | @inproceedings{costanzino2024cross,
50 | title = {Multimodal Industrial Anomaly Detection by Crossmodal Feature Mapping},
51 | author = {Costanzino, Alex and Zama Ramirez, Pierluigi and Lisanti, Giuseppe and Di Stefano, Luigi},
52 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
53 | note = {CVPR},
54 | year = {2024},
55 | }
56 | ```
57 |
58 | :file_cabinet: Datasets
59 |
60 | In our experiments, we employed two datasets featuring rgb images and point clouds: [MVTec 3D-AD](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad) and [Eyecandies](https://eyecan-ai.github.io/eyecandies/). You can preprocess them with the scripts contained in `processing`.
61 |
62 |
63 | :inbox_tray: Checkpoints
64 |
65 | Here, you can download the weights of **CFMs** employed in the results of Table 1 and Table 2 of our paper.
66 |
67 | To use these weights, please follow these steps:
68 |
69 | 1. Create a folder named `checkpoints/checkpoints_CFM_mvtec` in the project directory;
70 | 2. Download the weights [[Download]](https://t.ly/DZ-o1);
71 | 3. Copy the downloaded weights into the `checkpoints_CFM_mvtec` folder.
72 |
73 |
74 | ## :memo: Code
75 |
76 |
77 |
78 | **Warning**:
79 | - The code utilizes `wandb` during training to log results. Please be sure to have a wandb account. Otherwise, if you prefer to not use `wandb`, disable it in `cfm_training.py` with the `flag mode = 'disabled'`.
80 |
81 |
82 |
83 |
84 | ### :hammer_and_wrench: Setup Instructions
85 |
86 | **Dependencies**: Ensure that you have installed all the necessary dependencies. The list of dependencies can be found in the `./requirements.txt` file.
87 |
88 |
89 | ### :rocket: Inference CFMs
90 |
91 | The `cfm_inference.py` script test the CFMs. It can be used to generate anomaly maps.
92 |
93 | You can specify the following options:
94 | - `--dataset_path`: Path to the root directory of the dataset.
95 | - `--checkpoint_folder`: Path to the directory of the checkpoints, i.e., `checkpoints/checkpoints_CFM_mvtec`.
96 | - `--class_name`: Class on which the CFMs was trained.
97 | - `--epochs_no`: Number of epochs used in CFMs optimization.
98 | - `--batch_size`: Number of samples per batch employed for CFMs optimization.
99 | - `--qualitative_folder`: Folder on which the anomaly maps are saved.
100 | - `--quantitative_folder`: Folder on which the metrics are saved.
101 | - `--visualize_plot`: Flag to visualize qualitatived during inference.
102 | - `--produce_qualitatives`: Flag to save qualitatived during inference.
103 |
104 | You can reproduce the results of Table 1 and Table 2 of the paper by running `02_eval_mvtec.sh`.
105 |
106 | If you haven't downloaded the checkpoints yet, you can find the download links in the **Checkpoints** section above.
107 |
108 |
109 | ### :rocket: Train CFMs
110 |
111 | To train CFMs refer to the example in `01_train_mvtec.sh` and `03_train_eyecandies.sh`.
112 |
113 | The `cfm_training.py` script train the CFMs.
114 |
115 | You can specify the following options:
116 | - `--dataset_path`: Path to the root directory of the dataset.
117 | - `--checkpoint_savepath`: Path to the directory on which checkpoints will be saved, i.e., `checkpoints/checkpoints_CFM_mvtec`.
118 | - `--class_name`: Class on which the CFMs are trained.
119 | - `--epochs_no`: Number of epochs for CFMs optimization.
120 | - `--batch_size`: Number of samples per batch for CFMs optimization.
121 |
122 |
123 | ## :envelope: Contacts
124 |
125 | For questions, please send an email to alex.costanzino@unibo.it or pierluigi.zama@unibo.it.
126 |
127 |
128 | ## :pray: Acknowledgements
129 |
130 | We would like to extend our sincere appreciation to the authors of the following projects for making their code available, which we have utilized in our work:
131 |
132 | - We would like to thank the authors of [M3DM](https://github.com/nomewang/M3DM), [3D-ADS](https://github.com/eliahuhorwitz/3D-ADS) and [AST](https://github.com/marco-rudolph/AST) for providing their code, which has been instrumental in our experiments.
--------------------------------------------------------------------------------
/cfm_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torchvision import transforms
5 | import numpy as np
6 |
7 | from tqdm import tqdm
8 | import matplotlib.pyplot as plt
9 |
10 | from models.features import MultimodalFeatures
11 | from models.dataset import get_data_loader
12 | from models.feature_transfer_nets import FeatureProjectionMLP, FeatureProjectionMLP_big
13 |
14 | from utils.metrics_utils import calculate_au_pro
15 | from sklearn.metrics import roc_auc_score
16 |
17 |
18 | def set_seeds(sid=42):
19 | np.random.seed(sid)
20 |
21 | torch.manual_seed(sid)
22 | if torch.cuda.is_available():
23 | torch.cuda.manual_seed(sid)
24 | torch.cuda.manual_seed_all(sid)
25 |
26 |
27 | def infer_CFM(args):
28 |
29 | set_seeds()
30 | device = "cuda" if torch.cuda.is_available() else "cpu"
31 |
32 | # Dataloaders.
33 | test_loader = get_data_loader("test", class_name = args.class_name, img_size = 224, dataset_path = args.dataset_path)
34 |
35 | # Feature extractors.
36 | feature_extractor = MultimodalFeatures()
37 |
38 | # Model instantiation.
39 | CFM_2Dto3D = FeatureProjectionMLP(in_features = 768, out_features = 1152)
40 | CFM_3Dto2D = FeatureProjectionMLP(in_features = 1152, out_features = 768)
41 |
42 | CFM_2Dto3D_path = rf'{args.checkpoint_folder}/{args.class_name}/CFM_2Dto3D_{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs.pth'
43 | CFM_3Dto2D_path = rf'{args.checkpoint_folder}/{args.class_name}/CFM_3Dto2D_{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs.pth'
44 |
45 | CFM_2Dto3D.load_state_dict(torch.load(CFM_2Dto3D_path))
46 | CFM_3Dto2D.load_state_dict(torch.load(CFM_3Dto2D_path))
47 |
48 | CFM_2Dto3D.to(device), CFM_3Dto2D.to(device)
49 |
50 | # Make CFMs non-trainable.
51 | CFM_2Dto3D.eval(), CFM_3Dto2D.eval()
52 |
53 | # Use box filters to approximate gaussian blur (https://www.peterkovesi.com/papers/FastGaussianSmoothing.pdf).
54 | w_l, w_u = 5, 7
55 | pad_l, pad_u = 2, 3
56 | weight_l = torch.ones(1, 1, w_l, w_l, device = device)/(w_l**2)
57 | weight_u = torch.ones(1, 1, w_u, w_u, device = device)/(w_u**2)
58 |
59 | predictions, gts = [], []
60 | image_labels, pixel_labels = [], []
61 | image_preds, pixel_preds = [], []
62 |
63 | # ------------ [Testing Loop] ------------ #
64 |
65 | # * Return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path
66 | for (rgb, pc, depth), gt, label, rgb_path in tqdm(test_loader, desc = f'Extracting feature from class: {args.class_name}.'):
67 |
68 | rgb, pc, depth = rgb.to(device), pc.to(device), depth.to(device)
69 |
70 | with torch.no_grad():
71 | rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb, pc)
72 |
73 | rgb_feat_pred = CFM_3Dto2D(xyz_patch)
74 | xyz_feat_pred = CFM_2Dto3D(rgb_patch)
75 |
76 | xyz_mask = (xyz_patch.sum(axis = -1) == 0) # Mask only the feature vectors that are 0 everywhere.
77 |
78 | cos_3d = (torch.nn.functional.normalize(xyz_feat_pred, dim = 1) - torch.nn.functional.normalize(xyz_patch, dim = 1)).pow(2).sum(1).sqrt()
79 | cos_3d[xyz_mask] = 0.
80 | cos_3d = cos_3d.reshape(224,224)
81 |
82 | cos_2d = (torch.nn.functional.normalize(rgb_feat_pred, dim = 1) - torch.nn.functional.normalize(rgb_patch, dim = 1)).pow(2).sum(1).sqrt()
83 | cos_2d[xyz_mask] = 0.
84 | cos_2d = cos_2d.reshape(224,224)
85 |
86 | cos_comb = (cos_2d * cos_3d)
87 | cos_comb.reshape(-1)[xyz_mask] = 0.
88 |
89 | # Repeated box filters to approximate a Gaussian blur.
90 | cos_comb = cos_comb.reshape(1, 1, 224, 224)
91 |
92 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l)
93 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l)
94 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l)
95 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l)
96 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_l, weight = weight_l)
97 |
98 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_u, weight = weight_u)
99 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_u, weight = weight_u)
100 | cos_comb = torch.nn.functional.conv2d(input = cos_comb, padding = pad_u, weight = weight_u)
101 |
102 | cos_comb = cos_comb.reshape(224,224)
103 |
104 | # Prediction and ground-truth accumulation.
105 | gts.append(gt.squeeze().cpu().detach().numpy()) # * (224,224)
106 | predictions.append((cos_comb / (cos_comb[cos_comb!=0].mean())).cpu().detach().numpy()) # * (224,224)
107 |
108 | # GTs.
109 | image_labels.append(label) # * (1,)
110 | pixel_labels.extend(gt.flatten().cpu().detach().numpy()) # * (50176,)
111 |
112 | # Predictions.
113 | image_preds.append((cos_comb / torch.sqrt(cos_comb[cos_comb!=0].mean())).cpu().detach().numpy().max()) # * number
114 | pixel_preds.extend((cos_comb / torch.sqrt(cos_comb.mean())).flatten().cpu().detach().numpy()) # * (224,224)
115 |
116 | if args.produce_qualitatives:
117 |
118 | defect_class_str = rgb_path[0].split('/')[-3]
119 | image_name_str = rgb_path[0].split('/')[-1]
120 |
121 | save_path = f'{args.qualitative_folder}/{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs/{defect_class_str}'
122 |
123 | if not os.path.exists(save_path):
124 | os.makedirs(save_path)
125 |
126 | fig, axs = plt.subplots(2,3, figsize = (7,7))
127 |
128 | denormalize = transforms.Compose([
129 | transforms.Normalize(mean = [0., 0., 0.], std = [1/0.229, 1/0.224, 1/0.225]),
130 | transforms.Normalize(mean = [-0.485, -0.456, -0.406], std = [1., 1., 1.]),
131 | ])
132 |
133 | rgb = denormalize(rgb)
134 |
135 | os.path.join(save_path, image_name_str)
136 |
137 | axs[0, 0].imshow(rgb.squeeze().permute(1,2,0).cpu().detach().numpy())
138 | axs[0, 0].set_title('RGB')
139 |
140 | axs[0, 1].imshow(gt.squeeze().cpu().detach().numpy())
141 | axs[0, 1].set_title('Ground-truth')
142 |
143 | axs[0, 2].imshow(depth.squeeze().permute(1,2,0).mean(axis=-1).cpu().detach().numpy())
144 | axs[0, 2].set_title('Depth')
145 |
146 | axs[1, 0].imshow(cos_3d.cpu().detach().numpy(), cmap=plt.cm.jet)
147 | axs[1, 0].set_title('3D Cosine Similarity')
148 |
149 | axs[1, 1].imshow(cos_2d.cpu().detach().numpy(), cmap=plt.cm.jet)
150 | axs[1, 1].set_title('2D Cosine Similarity')
151 |
152 | axs[1, 2].imshow(cos_comb.cpu().detach().numpy(), cmap=plt.cm.jet)
153 | axs[1, 2].set_title('Combined Cosine Similarity')
154 |
155 | # Remove ticks and labels from all subplots
156 | for ax in axs.flat:
157 | ax.set_xticks([])
158 | ax.set_yticks([])
159 | ax.set_xticklabels([])
160 | ax.set_yticklabels([])
161 |
162 | # Adjust the layout and spacing
163 | plt.tight_layout()
164 |
165 | plt.savefig(os.path.join(save_path, image_name_str), dpi = 256)
166 |
167 | if args.visualize_plot:
168 | plt.show()
169 |
170 | # Calculate AD&S metrics.
171 | au_pros, _ = calculate_au_pro(gts, predictions)
172 | pixel_rocauc = roc_auc_score(np.stack(pixel_labels), np.stack(pixel_preds))
173 | image_rocauc = roc_auc_score(np.stack(image_labels), np.stack(image_preds))
174 |
175 | result_file_name = f'{args.quantitative_folder}/{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs.md'
176 |
177 | title_string = f'Metrics for class {args.class_name} with {args.epochs_no}ep_{args.batch_size}bs'
178 | header_string = 'AUPRO@30% & AUPRO@10% & AUPRO@5% & AUPRO@1% & P-AUROC & I-AUROC'
179 | results_string = f'{au_pros[0]:.3f} & {au_pros[1]:.3f} & {au_pros[2]:.3f} & {au_pros[3]:.3f} & {pixel_rocauc:.3f} & {image_rocauc:.3f}'
180 |
181 | if not os.path.exists(args.quantitative_folder):
182 | os.makedirs(args.quantitative_folder)
183 |
184 | with open(result_file_name, "w") as markdown_file:
185 | markdown_file.write(title_string + '\n' + header_string + '\n' + results_string)
186 |
187 | # Print AD&S metrics.
188 | print(title_string)
189 | print("AUPRO@30% | AUPRO@10% | AUPRO@5% | AUPRO@1% | P-AUROC | I-AUROC")
190 | print(f' {au_pros[0]:.3f} | {au_pros[1]:.3f} | {au_pros[2]:.3f} | {au_pros[3]:.3f} | {pixel_rocauc:.3f} | {image_rocauc:.3f}', end = '\n')
191 |
192 | if __name__ == '__main__':
193 | parser = argparse.ArgumentParser(description = 'Make inference with Crossmodal Feature Networks (CFMs) on a dataset.')
194 |
195 | parser.add_argument('--dataset_path', default = './datasets/mvtec3d', type = str,
196 | help = 'Dataset path.')
197 |
198 | parser.add_argument('--class_name', default = None, type = str, choices = ["bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire",
199 | 'CandyCane', 'ChocolateCookie', 'ChocolatePraline', 'Confetto', 'GummyBear', 'HazelnutTruffle', 'LicoriceSandwich', 'Lollipop', 'Marshmallow', 'PeppermintCandy'],
200 | help = 'Category name.')
201 |
202 | parser.add_argument('--checkpoint_folder', default = './checkpoints/checkpoints_CFM_mvtec', type = str,
203 | help = 'Path to the folder containing CFMs checkpoints.')
204 |
205 | parser.add_argument('--qualitative_folder', default = './results/qualitatives_mvtec', type = str,
206 | help = 'Path to the folder in which to save the qualitatives.')
207 |
208 | parser.add_argument('--quantitative_folder', default = './results/quantitatives_mvtec', type = str,
209 | help = 'Path to the folder in which to save the quantitatives.')
210 |
211 | parser.add_argument('--epochs_no', default = 50, type = int,
212 | help = 'Number of epochs to train the CFMs.')
213 |
214 | parser.add_argument('--batch_size', default = 4, type = int,
215 | help = 'Batch dimension. Usually 16 is around the max.')
216 |
217 | parser.add_argument('--visualize_plot', default = False, action = 'store_true',
218 | help = 'Whether to show plot or not.')
219 |
220 | parser.add_argument('--produce_qualitatives', default = False, action = 'store_true',
221 | help = 'Whether to produce qualitatives or not.')
222 |
223 | args = parser.parse_args()
224 |
225 | infer_CFM(args)
--------------------------------------------------------------------------------
/cfm_training.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import os
4 | import torch
5 | import wandb
6 |
7 | import numpy as np
8 | from itertools import chain
9 |
10 | from tqdm import tqdm, trange
11 |
12 | from models.features import MultimodalFeatures
13 | from models.dataset import get_data_loader
14 | from models.feature_transfer_nets import FeatureProjectionMLP, FeatureProjectionMLP_big
15 |
16 |
17 | def set_seeds(sid=115):
18 | np.random.seed(sid)
19 |
20 | torch.manual_seed(sid)
21 | if torch.cuda.is_available():
22 | torch.cuda.manual_seed(sid)
23 | torch.cuda.manual_seed_all(sid)
24 |
25 |
26 | def train_CFM(args):
27 |
28 | set_seeds()
29 | device = "cuda" if torch.cuda.is_available() else "cpu"
30 |
31 | model_name = f'{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs'
32 |
33 | wandb.init(
34 | project = 'crossmodal-feature-mappings',
35 | name = model_name
36 | )
37 |
38 | # Dataloader.
39 | train_loader = get_data_loader("train", class_name = args.class_name, img_size = 224, dataset_path = args.dataset_path, batch_size = args.batch_size, shuffle = True)
40 |
41 | # Feature extractors.
42 | feature_extractor = MultimodalFeatures()
43 |
44 | # Model instantiation.
45 | CFM_2Dto3D = FeatureProjectionMLP(in_features = 768, out_features = 1152)
46 | CFM_3Dto2D = FeatureProjectionMLP(in_features = 1152, out_features = 768)
47 |
48 | optimizer = torch.optim.Adam(params = chain(CFM_2Dto3D.parameters(), CFM_3Dto2D.parameters()))
49 |
50 | CFM_2Dto3D.to(device), CFM_3Dto2D.to(device)
51 |
52 | metric = torch.nn.CosineSimilarity(dim = -1, eps = 1e-06)
53 |
54 | for epoch in trange(args.epochs_no, desc = f'Training Feature Transfer Net.'):
55 |
56 | epoch_cos_sim_3Dto2D, epoch_cos_sim_2Dto3D = [], []
57 |
58 | # ------------ [Trainig Loop] ------------ #
59 | # * Return (rgb_img, organized_pc, depth_map_3channel), globl_label
60 | for (rgb, pc, _), _ in tqdm(train_loader, desc = f'Extracting feature from class: {args.class_name}.'):
61 | rgb, pc = rgb.to(device), pc.to(device)
62 |
63 | # Make CFMs trainable.
64 | CFM_2Dto3D.train(), CFM_3Dto2D.train()
65 |
66 | if args.batch_size == 1:
67 | rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb, pc)
68 | else:
69 | rgb_patches = []
70 | xyz_patches = []
71 |
72 | for i in range(rgb.shape[0]):
73 | rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb[i].unsqueeze(dim=0), pc[i].unsqueeze(dim=0))
74 |
75 | rgb_patches.append(rgb_patch)
76 | xyz_patches.append(xyz_patch)
77 |
78 | rgb_patch = torch.stack(rgb_patches, dim=0)
79 | xyz_patch = torch.stack(xyz_patches, dim=0)
80 |
81 | # Predictions.
82 | rgb_feat_pred = CFM_3Dto2D(xyz_patch)
83 | xyz_feat_pred = CFM_2Dto3D(rgb_patch)
84 |
85 | # Losses.
86 | xyz_mask = (xyz_patch.sum(axis = -1) == 0) # Mask only the feature vectors that are 0 everywhere.
87 |
88 | loss_3Dto2D = 1 - metric(xyz_feat_pred[~xyz_mask], xyz_patch[~xyz_mask]).mean()
89 | loss_2Dto3D = 1 - metric(rgb_feat_pred[~xyz_mask], rgb_patch[~xyz_mask]).mean()
90 |
91 | cos_sim_3Dto2D, cos_sim_2Dto3D = 1 - loss_3Dto2D.cpu(), 1 - loss_2Dto3D.cpu()
92 |
93 | epoch_cos_sim_3Dto2D.append(cos_sim_3Dto2D), epoch_cos_sim_2Dto3D.append(cos_sim_2Dto3D)
94 |
95 | # Logging.
96 | wandb.log({
97 | "train/loss_3Dto2D" : loss_3Dto2D,
98 | "train/loss_2Dto3D" : loss_2Dto3D,
99 | "train/cosine_similarity_3Dto2D" : cos_sim_3Dto2D,
100 | "train/cosine_similarity_2Dto3D" : cos_sim_2Dto3D,
101 | })
102 |
103 | if torch.isnan(loss_3Dto2D) or torch.isinf(loss_3Dto2D) or torch.isnan(loss_2Dto3D) or torch.isinf(loss_2Dto3D):
104 | exit()
105 |
106 | # Optimization.
107 | if not torch.isnan(loss_3Dto2D) and not torch.isinf(loss_3Dto2D) and not torch.isnan(loss_2Dto3D) and not torch.isinf(loss_2Dto3D):
108 |
109 | optimizer.zero_grad()
110 |
111 | loss_3Dto2D.backward(), loss_2Dto3D.backward()
112 |
113 | optimizer.step()
114 |
115 | # Global logging.
116 | wandb.log({
117 | "global_train/cos_sim_3Dto2D" : torch.Tensor(epoch_cos_sim_3Dto2D, device = 'cpu').mean(),
118 | "global_train/cos_sim_2Dto3D" : torch.Tensor(epoch_cos_sim_2Dto3D, device = 'cpu').mean()
119 | })
120 |
121 | # Model saving.
122 | directory = f'{args.checkpoint_savepath}/{args.class_name}'
123 |
124 | if not os.path.exists(directory):
125 | os.makedirs(directory)
126 |
127 | torch.save(CFM_2Dto3D.state_dict(), os.path.join(directory, 'CFM_2Dto3D_' + model_name + '.pth'))
128 | torch.save(CFM_3Dto2D.state_dict(), os.path.join(directory, 'CFM_3Dto2D_' + model_name + '.pth'))
129 |
130 |
131 | if __name__ == '__main__':
132 | parser = argparse.ArgumentParser(description = 'Train Crossmodal Feature Networks (CFMs) on a dataset.')
133 |
134 | parser.add_argument('--dataset_path', default = './datasets/mvtec3d', type = str,
135 | help = 'Dataset path.')
136 |
137 | parser.add_argument('--checkpoint_savepath', default = './checkpoints/checkpoints_CFM_mvtec', type = str,
138 | help = 'Where to save the model checkpoints.')
139 |
140 | parser.add_argument('--class_name', default = None, type = str, choices = ["bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire",
141 | 'CandyCane', 'ChocolateCookie', 'ChocolatePraline', 'Confetto', 'GummyBear', 'HazelnutTruffle', 'LicoriceSandwich', 'Lollipop', 'Marshmallow', 'PeppermintCandy'],
142 | help = 'Category name.')
143 |
144 | parser.add_argument('--epochs_no', default = 50, type = int,
145 | help = 'Number of epochs to train the CFMs.')
146 |
147 | parser.add_argument('--batch_size', default = 4, type = int,
148 | help = 'Batch dimension. Usually 16 is around the max.')
149 |
150 | args = parser.parse_args()
151 | train_CFM(args)
--------------------------------------------------------------------------------
/images/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVLAB-Unibo/crossmodal-feature-mapping/10198716d65919dcc5257323af512907e63a4003/images/architecture.jpg
--------------------------------------------------------------------------------
/models/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torchvision import transforms
4 | import glob
5 | from torch.utils.data import Dataset
6 | from utils.mvtec3d_utils import *
7 | from torch.utils.data import DataLoader
8 | import numpy as np
9 | from utils.general_utils import SquarePad
10 |
11 | def eyecandies_classes():
12 | return [
13 | 'CandyCane',
14 | 'ChocolateCookie',
15 | 'ChocolatePraline',
16 | 'Confetto',
17 | 'GummyBear',
18 | 'HazelnutTruffle',
19 | 'LicoriceSandwich',
20 | 'Lollipop',
21 | 'Marshmallow',
22 | 'PeppermintCandy',
23 | ]
24 |
25 | def mvtec3d_classes():
26 | return [
27 | "bagel",
28 | "cable_gland",
29 | "carrot",
30 | "cookie",
31 | "dowel",
32 | "foam",
33 | "peach",
34 | "potato",
35 | "rope",
36 | "tire",
37 | ]
38 |
39 | RGB_SIZE = 224
40 |
41 | class BaseAnomalyDetectionDataset(Dataset):
42 | def __init__(self, split, class_name, img_size, dataset_path):
43 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
44 | self.IMAGENET_STD = [0.229, 0.224, 0.225]
45 |
46 | self.cls = class_name
47 | self.size = img_size
48 | self.img_path = os.path.join(dataset_path, self.cls, split)
49 |
50 | self.rgb_transform = transforms.Compose([
51 | SquarePad(),
52 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation = transforms.InterpolationMode.BICUBIC),
53 | transforms.ToTensor(),
54 | transforms.Normalize(mean = self.IMAGENET_MEAN, std = self.IMAGENET_STD)
55 | ])
56 |
57 |
58 | class TrainValDataset(BaseAnomalyDetectionDataset):
59 | def __init__(self, split, class_name, img_size, dataset_path):
60 | super().__init__(split = split, class_name = class_name, img_size = img_size, dataset_path = dataset_path)
61 |
62 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
63 |
64 | def load_dataset(self):
65 | img_tot_paths = []
66 | tot_labels = []
67 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
68 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
69 | rgb_paths.sort()
70 | tiff_paths.sort()
71 | sample_paths = list(zip(rgb_paths, tiff_paths))
72 | img_tot_paths.extend(sample_paths)
73 | tot_labels.extend([0] * len(sample_paths))
74 | return img_tot_paths, tot_labels
75 |
76 | def __len__(self):
77 | return len(self.img_paths)
78 |
79 | def __getitem__(self, idx):
80 | img_path, label = self.img_paths[idx], self.labels[idx]
81 | rgb_path = img_path[0]
82 | tiff_path = img_path[1]
83 | img = Image.open(rgb_path).convert('RGB')
84 |
85 | img = self.rgb_transform(img)
86 | organized_pc = read_tiff_organized_pc(tiff_path)
87 |
88 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis = 2)
89 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
90 | resized_organized_pc = resize_organized_pc(organized_pc, target_height = self.size, target_width = self.size)
91 | resized_organized_pc = resized_organized_pc.clone().detach().float()
92 |
93 | return (img, resized_organized_pc, resized_depth_map_3channel), label
94 |
95 |
96 | class TestDataset(BaseAnomalyDetectionDataset):
97 | def __init__(self, class_name, img_size, dataset_path):
98 | super().__init__(split = "test", class_name = class_name, img_size = img_size, dataset_path = dataset_path)
99 |
100 | self.gt_transform = transforms.Compose([
101 | SquarePad(),
102 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
103 | transforms.ToTensor()])
104 |
105 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
106 |
107 | def load_dataset(self):
108 | img_tot_paths = []
109 | gt_tot_paths = []
110 | tot_labels = []
111 | defect_types = os.listdir(self.img_path)
112 |
113 | for defect_type in defect_types:
114 | if defect_type == 'good':
115 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
116 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
117 | rgb_paths.sort()
118 | tiff_paths.sort()
119 | sample_paths = list(zip(rgb_paths, tiff_paths))
120 | img_tot_paths.extend(sample_paths)
121 | gt_tot_paths.extend([0] * len(sample_paths))
122 | tot_labels.extend([0] * len(sample_paths))
123 | else:
124 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
125 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
126 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png")
127 | rgb_paths.sort()
128 | tiff_paths.sort()
129 | gt_paths.sort()
130 | sample_paths = list(zip(rgb_paths, tiff_paths))
131 |
132 | img_tot_paths.extend(sample_paths)
133 | gt_tot_paths.extend(gt_paths)
134 | tot_labels.extend([1] * len(sample_paths))
135 |
136 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
137 |
138 | return img_tot_paths, gt_tot_paths, tot_labels
139 |
140 | def __len__(self):
141 | return len(self.img_paths)
142 |
143 | def __getitem__(self, idx):
144 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
145 | rgb_path = img_path[0]
146 | tiff_path = img_path[1]
147 | img_original = Image.open(rgb_path).convert('RGB')
148 | img = self.rgb_transform(img_original)
149 |
150 | organized_pc = read_tiff_organized_pc(tiff_path)
151 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
152 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
153 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
154 | resized_organized_pc = resized_organized_pc.clone().detach().float()
155 |
156 | if gt == 0:
157 | gt = torch.zeros(
158 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])
159 | else:
160 | gt = Image.open(gt).convert('L')
161 | gt = self.gt_transform(gt)
162 | gt = torch.where(gt > 0.5, 1., .0)
163 |
164 | return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path
165 |
166 |
167 | def get_data_loader(split, class_name, dataset_path, img_size = 224, batch_size = 1, shuffle = False):
168 | if split in ['train']:
169 | dataset = TrainValDataset(split = "train", class_name = class_name, img_size = img_size, dataset_path = dataset_path)
170 | elif split in ['validation']:
171 | dataset = TrainValDataset(split = "validation", class_name = class_name, img_size = img_size, dataset_path = dataset_path)
172 | elif split in ['test']:
173 | dataset = TestDataset(class_name = class_name, img_size = img_size, dataset_path = dataset_path)
174 |
175 | data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = shuffle,
176 | num_workers = 1, drop_last = False, pin_memory = True)
177 |
178 | return data_loader
179 |
--------------------------------------------------------------------------------
/models/feature_transfer_nets.py:
--------------------------------------------------------------------------------
1 | # Alex Costanzino, CVLab
2 | # July 2023
3 |
4 | import torch
5 |
6 | class FeatureProjectionMLP(torch.nn.Module):
7 | def __init__(self, in_features = None, out_features = None, act_layer = torch.nn.GELU):
8 | super().__init__()
9 |
10 | self.act_fcn = act_layer()
11 |
12 | self.input = torch.nn.Linear(in_features, (in_features + out_features) // 2)
13 | self.projection = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2)
14 | self.output = torch.nn.Linear((in_features + out_features) // 2, out_features)
15 |
16 | def forward(self, x):
17 | x = self.input(x)
18 | x = self.act_fcn(x)
19 |
20 | x = self.projection(x)
21 | x = self.act_fcn(x)
22 |
23 | x = self.output(x)
24 |
25 | return x
26 |
27 | class FeatureProjectionMLP_big(torch.nn.Module):
28 | def __init__(self, in_features = None, out_features = None, act_layer = torch.nn.GELU):
29 | super().__init__()
30 |
31 | self.act_fcn = act_layer()
32 |
33 | self.input = torch.nn.Linear(in_features, (in_features + out_features) // 2)
34 |
35 | self.projection_a = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2)
36 | self.projection_b = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2)
37 | self.projection_c = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2)
38 | self.projection_d = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2)
39 | self.projection_e = torch.nn.Linear((in_features + out_features) // 2, (in_features + out_features) // 2)
40 |
41 | self.output = torch.nn.Linear((in_features + out_features) // 2, out_features)
42 |
43 | def forward(self, x):
44 | x = self.input(x)
45 | x = self.act_fcn(x)
46 |
47 | x = self.projection_a(x)
48 | x = self.act_fcn(x)
49 | x = self.projection_b(x)
50 | x = self.act_fcn(x)
51 | x = self.projection_c(x)
52 | x = self.act_fcn(x)
53 | x = self.projection_d(x)
54 | x = self.act_fcn(x)
55 | x = self.projection_e(x)
56 | x = self.act_fcn(x)
57 |
58 | x = self.output(x)
59 |
60 | return x
--------------------------------------------------------------------------------
/models/features.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from sklearn.metrics import roc_auc_score
5 | from utils.metrics_utils import calculate_au_pro
6 | from utils.pointnet2_utils import interpolating_points
7 | from models.full_models import FeatureExtractors
8 |
9 |
10 | dino_backbone_name = 'vit_base_patch8_224.dino' # 224/8 -> 28 patches.
11 | group_size = 128
12 | num_group = 1024
13 |
14 | class MultimodalFeatures(torch.nn.Module):
15 | def __init__(self, image_size = 224):
16 | super().__init__()
17 |
18 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
19 |
20 | self.deep_feature_extractor = FeatureExtractors(device = self.device,
21 | rgb_backbone_name = dino_backbone_name,
22 | group_size = group_size, num_group = num_group)
23 |
24 | self.deep_feature_extractor.to(self.device)
25 |
26 | self.image_size = image_size
27 |
28 | # * Applies a 2D adaptive average pooling over an input signal composed of several input planes.
29 | # * The output is of size H x W, for any input size. The number of output features is equal to the number of input planes.
30 | self.resize = torch.nn.AdaptiveAvgPool2d((224, 224))
31 |
32 | self.average = torch.nn.AvgPool2d(kernel_size = 3, stride = 1)
33 |
34 | def __call__(self, rgb, xyz):
35 | rgb = rgb.to(self.device)
36 | xyz = xyz.to(self.device)
37 |
38 | with torch.no_grad():
39 | rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz)
40 |
41 |
42 | interpolated_feature_maps = interpolating_points(xyz, center.permute(0,2,1), xyz_feature_maps)
43 |
44 | xyz_feature_maps = [fmap for fmap in [xyz_feature_maps]]
45 | rgb_feature_maps = [fmap for fmap in [rgb_feature_maps]]
46 |
47 | return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx, interpolated_feature_maps
48 |
49 | def calculate_metrics(self):
50 | self.image_preds = np.stack(self.image_preds)
51 | self.image_labels = np.stack(self.image_labels)
52 | self.pixel_preds = np.array(self.pixel_preds)
53 |
54 | self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds)
55 | self.pixel_rocauc = roc_auc_score(self.pixel_labels, self.pixel_preds)
56 | self.au_pro, _ = calculate_au_pro(self.gts, self.predictions)
57 |
58 | def get_features_maps(self, rgb, pc):
59 |
60 | unorganized_pc = pc.squeeze().permute(1, 2, 0).reshape(-1, pc.shape[1])
61 |
62 | # Find nonzero indices.
63 | nonzero_indices = torch.nonzero(torch.all(unorganized_pc != 0, dim=1)).squeeze(dim=1)
64 |
65 | # Select nonzero indices and discard the others.
66 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :].unsqueeze(dim=0).permute(0, 2, 1)
67 |
68 | rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(rgb, unorganized_pc_no_zeros.contiguous())
69 |
70 | # Interpolation to obtain a "full image" with point cloud features.
71 | xyz_patch = torch.cat(xyz_feature_maps, 1)
72 |
73 | xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size * self.image_size), dtype = xyz_patch.dtype, device = self.device)
74 | xyz_patch_full[..., nonzero_indices] = interpolated_pc
75 |
76 | xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
77 | xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
78 | xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
79 |
80 | rgb_patch = torch.cat(rgb_feature_maps, 1)
81 |
82 | upsample_shape = xyz_patch_full_resized.shape[-2:]
83 | rgb_patch_upsample = torch.nn.functional.interpolate(rgb_patch, size = upsample_shape, mode = 'bilinear', align_corners = False)
84 | rgb_patch_upsample = rgb_patch_upsample.reshape(rgb_patch.shape[1], -1).T
85 |
86 | return rgb_patch_upsample, xyz_patch
--------------------------------------------------------------------------------
/models/full_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | from timm.models.layers import DropPath
5 | from pointnet2_ops import pointnet2_utils
6 |
7 | class FeatureExtractors(torch.nn.Module):
8 | def __init__(self, device,
9 | rgb_backbone_name = 'vit_base_patch8_224_dino.dino', out_indices = None,
10 | group_size = 128, num_group = 1024):
11 |
12 | super().__init__()
13 |
14 | self.device = device
15 |
16 | kwargs = {'features_only': True if out_indices else False}
17 |
18 | if out_indices:
19 | kwargs.update({'out_indices': out_indices})
20 |
21 | layers_keep = 12
22 |
23 | ## RGB backbone
24 | self.rgb_backbone = timm.create_model(model_name = rgb_backbone_name, pretrained = True, **kwargs)
25 | # ! Use only the first k blocks.
26 | self.rgb_backbone.blocks = torch.nn.Sequential(*self.rgb_backbone.blocks[:layers_keep]) # Remove Block(s) from 5 to 11.
27 |
28 | ## XYZ backbone
29 | self.xyz_backbone = PointTransformer(group_size = group_size, num_group = num_group)
30 | self.xyz_backbone.load_model_from_ckpt("checkpoints/feature_extractors/pointmae_pretrain.pth")
31 | # ! Use only the first k blocks.
32 | self.xyz_backbone.blocks.blocks = torch.nn.Sequential(*self.xyz_backbone.blocks.blocks[:layers_keep]) # Remove Block(s) from 5 to 11.
33 |
34 |
35 | def forward_rgb_features(self, x):
36 | x = self.rgb_backbone.patch_embed(x)
37 | x = self.rgb_backbone._pos_embed(x)
38 | x = self.rgb_backbone.norm_pre(x)
39 | x = self.rgb_backbone.blocks(x)
40 | x = self.rgb_backbone.norm(x)
41 |
42 | feat = x[:,1:].permute(0, 2, 1).view(1, -1, 28, 28) # view(1, -1, 14, 14)
43 | return feat
44 |
45 |
46 | def forward(self, rgb, xyz):
47 | rgb_features = self.forward_rgb_features(rgb)
48 | xyz_features, center, ori_idx, center_idx = self.xyz_backbone(xyz)
49 |
50 | return rgb_features, xyz_features, center, ori_idx, center_idx
51 |
52 |
53 | def fps(data, number):
54 | '''
55 | data B N 3
56 | number int
57 | '''
58 | fps_idx = pointnet2_utils.furthest_point_sample(data, number)
59 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
60 | return fps_data, fps_idx
61 |
62 |
63 | class KNN(nn.Module):
64 | def __init__(self, k):
65 | super(KNN, self).__init__()
66 | self.k = k
67 |
68 | def forward(self, xyz, centers):
69 | assert xyz.size(0) == centers.size(0), "Batch size of xyz and centers should be the same"
70 |
71 | B, N_points, _ = xyz.size()
72 | K = centers.size(1)
73 |
74 | # Compute pairwise distances
75 | xyz = xyz.unsqueeze(2) # [B, N, 1, 3]
76 | centers = centers.unsqueeze(1) # [B, 1, K, 3]
77 | distances = torch.norm(xyz - centers, dim=-1) # [B, N, K]
78 |
79 | # Get the indices of the k nearest neighbors
80 | _, indices = torch.topk(distances, self.k, dim=1, largest=False, sorted=True)
81 | return indices
82 |
83 |
84 | class Group(nn.Module):
85 | def __init__(self, num_group, group_size):
86 | super().__init__()
87 | self.num_group = num_group
88 | self.group_size = group_size
89 | self.knn = KNN(k=self.group_size)
90 |
91 | def forward(self, xyz):
92 | '''
93 | input: B N 3
94 | ---------------------------
95 | output: B G M 3
96 | center : B G 3
97 | '''
98 |
99 | batch_size, num_points, _ = xyz.shape
100 | # fps the centers out
101 | center, center_idx = fps(xyz.contiguous(), self.num_group) # B G 3
102 |
103 | # knn to get the neighborhood
104 | # _, idx = self.knn(xyz, center) # B G M
105 | idx = self.knn(xyz, center).permute(0,2,1) # B G M
106 |
107 | assert idx.size(1) == self.num_group
108 | assert idx.size(2) == self.group_size
109 | ori_idx = idx
110 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
111 | idx = idx + idx_base
112 | idx = idx[-1]
113 | neighborhood = xyz.reshape(batch_size * num_points, -1)[idx, :]
114 | neighborhood = neighborhood.reshape(batch_size, self.num_group, self.group_size, 3).contiguous()
115 | # normalize
116 | neighborhood = neighborhood - center.unsqueeze(2)
117 | return neighborhood, center, ori_idx, center_idx
118 |
119 |
120 | class Encoder(nn.Module):
121 | def __init__(self, encoder_channel):
122 | super().__init__()
123 | self.encoder_channel = encoder_channel
124 | self.first_conv = nn.Sequential(
125 | nn.Conv1d(3, 128, 1),
126 | nn.BatchNorm1d(128),
127 | nn.ReLU(inplace=True),
128 | nn.Conv1d(128, 256, 1)
129 | )
130 | self.second_conv = nn.Sequential(
131 | nn.Conv1d(512, 512, 1),
132 | nn.BatchNorm1d(512),
133 | nn.ReLU(inplace=True),
134 | nn.Conv1d(512, self.encoder_channel, 1)
135 | )
136 |
137 | def forward(self, point_groups):
138 | '''
139 | point_groups : B G N 3
140 | -----------------
141 | feature_global : B G C
142 | '''
143 | bs, g, n, _ = point_groups.shape
144 | point_groups = point_groups.reshape(bs * g, n, 3)
145 | # encoder
146 | feature = self.first_conv(point_groups.transpose(2, 1))
147 | feature_global = torch.max(feature, dim=2, keepdim=True)[0]
148 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)
149 | feature = self.second_conv(feature)
150 | feature_global = torch.max(feature, dim=2, keepdim=False)[0]
151 | return feature_global.reshape(bs, g, self.encoder_channel)
152 |
153 |
154 | class MLP(nn.Module):
155 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
156 | super().__init__()
157 | out_features = out_features or in_features
158 | hidden_features = hidden_features or in_features
159 | self.fc1 = nn.Linear(in_features, hidden_features)
160 | self.act = act_layer()
161 | self.fc2 = nn.Linear(hidden_features, out_features)
162 | self.drop = nn.Dropout(drop)
163 |
164 | def forward(self, x):
165 | x = self.fc1(x)
166 | x = self.act(x)
167 | x = self.drop(x)
168 | x = self.fc2(x)
169 | x = self.drop(x)
170 | return x
171 |
172 |
173 | class Attention(nn.Module):
174 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
175 | super().__init__()
176 | self.num_heads = num_heads
177 | head_dim = dim // num_heads
178 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
179 | self.scale = qk_scale or head_dim ** -0.5
180 |
181 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
182 | self.attn_drop = nn.Dropout(attn_drop)
183 | self.proj = nn.Linear(dim, dim)
184 | self.proj_drop = nn.Dropout(proj_drop)
185 |
186 | def forward(self, x):
187 | B, N, C = x.shape
188 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
189 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
190 |
191 | attn = (q * self.scale) @ k.transpose(-2, -1)
192 | attn = attn.softmax(dim=-1)
193 | attn = self.attn_drop(attn)
194 |
195 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
196 | x = self.proj(x)
197 | x = self.proj_drop(x)
198 | return x
199 |
200 |
201 | class Block(nn.Module):
202 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
203 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
204 | super().__init__()
205 | self.norm1 = norm_layer(dim)
206 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
207 | self.norm2 = norm_layer(dim)
208 | mlp_hidden_dim = int(dim * mlp_ratio)
209 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
210 |
211 | self.attn = Attention(
212 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
213 |
214 | def forward(self, x):
215 | x = x + self.drop_path(self.attn(self.norm1(x)))
216 | x = x + self.drop_path(self.mlp(self.norm2(x)))
217 | return x
218 |
219 |
220 | class TransformerEncoder(nn.Module):
221 | """
222 | Transformer Encoder without hierarchical structure
223 | """
224 |
225 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
226 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
227 | super().__init__()
228 |
229 | self.blocks = nn.ModuleList([
230 | Block(
231 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
232 | drop=drop_rate, attn_drop=attn_drop_rate,
233 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
234 | )
235 | for i in range(depth)])
236 |
237 | def forward(self, x, pos):
238 | feature_list = []
239 | # fetch_idx = [3, 7, 11] ### ! PD
240 | for i, block in enumerate(self.blocks):
241 | x = block(x + pos)
242 | # if i in fetch_idx: ### ! PD
243 | feature_list.append(x)
244 | return feature_list
245 |
246 |
247 | class PointTransformer(nn.Module):
248 | def __init__(self, group_size = 128, num_group = 1024, encoder_dims = 384):
249 | super().__init__()
250 |
251 | self.trans_dim = 384
252 | self.depth = 12
253 | self.drop_path_rate = 0.1
254 | self.num_heads = 6
255 |
256 | self.group_size = group_size
257 | self.num_group = num_group
258 | # grouper
259 | self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
260 | # define the encoder
261 | self.encoder_dims = encoder_dims
262 | if self.encoder_dims != self.trans_dim:
263 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
264 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
265 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim)
266 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
267 | # bridge encoder and transformer
268 |
269 | self.pos_embed = nn.Sequential(
270 | nn.Linear(3, 128),
271 | nn.GELU(),
272 | nn.Linear(128, self.trans_dim)
273 | )
274 |
275 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
276 | self.blocks = TransformerEncoder(
277 | embed_dim=self.trans_dim,
278 | depth=self.depth,
279 | drop_path_rate=dpr,
280 | num_heads=self.num_heads
281 | )
282 |
283 | self.norm = nn.LayerNorm(self.trans_dim)
284 |
285 | def load_model_from_ckpt(self, bert_ckpt_path):
286 | if bert_ckpt_path is not None:
287 | device = "cuda" if torch.cuda.is_available() else "cpu"
288 | ckpt = torch.load(bert_ckpt_path, map_location=device)
289 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
290 |
291 | for k in list(base_ckpt.keys()):
292 | if k.startswith('MAE_encoder'):
293 | base_ckpt[k[len('MAE_encoder.'):]] = base_ckpt[k]
294 | del base_ckpt[k]
295 | elif k.startswith('base_model'):
296 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
297 | del base_ckpt[k]
298 |
299 | incompatible = self.load_state_dict(base_ckpt, strict=False)
300 |
301 | def load_model_from_pb_ckpt(self, bert_ckpt_path):
302 | ckpt = torch.load(bert_ckpt_path)
303 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
304 | for k in list(base_ckpt.keys()):
305 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'):
306 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k]
307 | elif k.startswith('base_model'):
308 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
309 | del base_ckpt[k]
310 |
311 | incompatible = self.load_state_dict(base_ckpt, strict=False)
312 |
313 | if incompatible.missing_keys:
314 | print('missing_keys')
315 | print(
316 | incompatible.missing_keys
317 | )
318 | if incompatible.unexpected_keys:
319 | print('unexpected_keys')
320 | print(
321 | incompatible.unexpected_keys
322 | )
323 |
324 | print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')
325 |
326 | def forward(self, pts):
327 | if self.encoder_dims != self.trans_dim:
328 | B,C,N = pts.shape
329 | pts = pts.transpose(-1, -2) # B N 3
330 | # divide the point clo ud in the same form. This is important
331 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts)
332 | # # generate mask
333 | # bool_masked_pos = self._mask_center(center, no_mask = False) # B G
334 | # encoder the input cloud blocks
335 | group_input_tokens = self.encoder(neighborhood) # B G N
336 | group_input_tokens = self.reduce_dim(group_input_tokens)
337 | # prepare cls
338 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
339 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
340 | # add pos embedding
341 | pos = self.pos_embed(center)
342 | # final input
343 | x = torch.cat((cls_tokens, group_input_tokens), dim=1)
344 | pos = torch.cat((cls_pos, pos), dim=1)
345 | # transformer
346 | feature_list = self.blocks(x, pos)
347 | feature_list = [self.norm(x)[:,1:].transpose(-1, -2).contiguous() for x in feature_list]
348 | x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152
349 | return x, center, ori_idx, center_idx
350 | else:
351 | B, C, N = pts.shape
352 | pts = pts.transpose(-1, -2) # B N 3
353 | # divide the point clo ud in the same form. This is important
354 |
355 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts)
356 | group_input_tokens = self.encoder(neighborhood) # B G N
357 |
358 | pos = self.pos_embed(center)
359 | # final input
360 | x = group_input_tokens
361 | # transformer
362 | feature_list = self.blocks(x, pos)
363 | feature_list = [self.norm(x).transpose(-1, -2).contiguous() for x in feature_list]
364 | if len(feature_list) == 12:
365 | x = torch.cat((feature_list[3],feature_list[7],feature_list[11]), dim=1)
366 | elif len(feature_list) == 8:
367 | x = torch.cat((feature_list[1],feature_list[4],feature_list[7]), dim=1)
368 | elif len(feature_list) == 4:
369 | x = torch.cat((feature_list[1],feature_list[2],feature_list[3]), dim=1)
370 | else:
371 | x = feature_list[-1]
372 | return x, center, ori_idx, center_idx
--------------------------------------------------------------------------------
/processing/aggregate_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import pandas as pd
5 |
6 | classes = ["bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire"]
7 | metrics = ["AUPRO@30\%", "AUPRO@10\%", "AUPRO@5\%", "AUPRO@1\%", "P-AUROC", "I-AUROC"]
8 |
9 | def read_md_files(directory):
10 | aggregated_results = []
11 | md_files = sorted([filename for filename in os.listdir(directory) if filename.endswith(".md")])
12 |
13 | for filename in md_files:
14 | filepath = os.path.join(directory, filename)
15 | with open(filepath, "r", encoding = "utf-8") as file:
16 | file_contents = file.read().split('\n')[-1].split('&') # Take only the last line (results) and split in string numbers.
17 | file_contents_num = [float(res) for res in file_contents] # Convert string numbers in float numbers.
18 | aggregated_results.append(file_contents_num)
19 | return np.array(aggregated_results) # Expected shape (10,6)
20 |
21 | def produce_table(args):
22 | data = read_md_files(args.quantitative_folder)
23 | results = pd.DataFrame(data, index=classes, columns=metrics).T
24 | results['mean'] = results.mean(axis=1)
25 |
26 | print(results.to_latex(float_format = "%.3f"))
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser(description = 'Create a LaTeX table with all the results.')
30 |
31 | parser.add_argument('--quantitative_folder', default = None, type = str,
32 | help = 'Path to the folder from which to fetch the quantitatives.')
33 |
34 | args = parser.parse_args()
35 |
36 | produce_table(args)
--------------------------------------------------------------------------------
/processing/preprocess_eyecandies.py:
--------------------------------------------------------------------------------
1 | import os
2 | from shutil import copyfile
3 | import cv2
4 | import numpy as np
5 | import tifffile
6 | import yaml
7 | import imageio.v3 as iio
8 | import math
9 | import argparse
10 |
11 | # The same camera has been used for all the images
12 | FOCAL_LENGTH = 711.11
13 |
14 | def load_and_convert_depth(depth_img, info_depth):
15 | with open(info_depth) as f:
16 | data = yaml.safe_load(f)
17 | mind, maxd = data["normalization"]["min"], data["normalization"]["max"]
18 |
19 | dimg = iio.imread(depth_img)
20 | dimg = dimg.astype(np.float32)
21 | dimg = dimg / 65535.0 * (maxd - mind) + mind
22 | return dimg
23 |
24 | def depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length):
25 | # input depth map (in meters) --- cfr previous section
26 | depth_mt = load_and_convert_depth(depth_img, info_depth)
27 |
28 | # input pose
29 | pose = np.loadtxt(pose_txt)
30 |
31 | # camera intrinsics
32 | height, width = depth_mt.shape[:2]
33 | intrinsics_4x4 = np.array([
34 | [focal_length, 0, width / 2, 0],
35 | [0, focal_length, height / 2, 0],
36 | [0, 0, 1, 0],
37 | [0, 0, 0, 1]]
38 | )
39 |
40 | # build the camera projection matrix
41 | camera_proj = intrinsics_4x4 @ pose
42 |
43 | # build the (u, v, 1, 1/depth) vectors (non optimized version)
44 | camera_vectors = np.zeros((width * height, 4))
45 | count=0
46 | for j in range(height):
47 | for i in range(width):
48 | camera_vectors[count, :] = np.array([i, j, 1, 1/depth_mt[j, i]])
49 | count += 1
50 |
51 | # invert and apply to each 4-vector
52 | hom_3d_pts= np.linalg.inv(camera_proj) @ camera_vectors.T
53 | # print(hom_3d_pts.shape)
54 | # remove the homogeneous coordinate
55 | pcd = depth_mt.reshape(-1, 1) * hom_3d_pts.T
56 | return pcd[:, :3]
57 |
58 | def remove_point_cloud_background(pc):
59 |
60 | # The second dim is z
61 | dz = pc[256,1] - pc[-256,1]
62 | dy = pc[256,2] - pc[-256,2]
63 |
64 | norm = math.sqrt(dz**2 + dy**2)
65 | start_points = np.array([0, pc[-256, 1], pc[-256, 2]])
66 | cos_theta = dy / norm
67 | sin_theta = dz / norm
68 |
69 | # Transform and rotation
70 | rotation_matrix = np.array([[1, 0, 0], [0, cos_theta, -sin_theta],[0, sin_theta, cos_theta]])
71 | processed_pc = (rotation_matrix @ (pc - start_points).T).T
72 |
73 | # Remove background point
74 | for i in range(processed_pc.shape[0]):
75 | if processed_pc[i,1] > -0.02:
76 | processed_pc[i, :] = -start_points
77 | if processed_pc[i,2] > 1.8:
78 | processed_pc[i, :] = -start_points
79 | elif processed_pc[i,0] > 1 or processed_pc[i,0] < -1:
80 | processed_pc[i, :] = -start_points
81 |
82 | processed_pc = (rotation_matrix.T @ processed_pc.T).T + start_points
83 |
84 | index = [0, 2, 1]
85 | processed_pc = processed_pc[:,index]
86 | return processed_pc*[0.1, -0.1, 0.1]
87 |
88 |
89 | if __name__ == '__main__':
90 |
91 | parser = argparse.ArgumentParser(description='Process some integers.')
92 | parser.add_argument('--dataset_path', default='datasets/eyecandies', type=str, help="Original Eyecandies dataset path.")
93 | parser.add_argument('--target_dir', default='datasets/eyecandies_preprocessed', type=str, help="Processed Eyecandies dataset path")
94 | args = parser.parse_args()
95 |
96 | os.mkdir(args.target_dir)
97 | categories_list = os.listdir(args.dataset_path)
98 |
99 | for category_dir in categories_list:
100 | category_root_path = os.path.join(args.dataset_path, category_dir)
101 |
102 | category_train_path = os.path.join(category_root_path, '/train/data')
103 | category_test_path = os.path.join(category_root_path, '/test_public/data')
104 |
105 | category_target_path = os.path.join(args.target_dir, category_dir)
106 | os.mkdir(category_target_path)
107 |
108 | os.mkdir(os.path.join(category_target_path, 'train'))
109 | category_target_train_good_path = os.path.join(category_target_path, 'train/good')
110 | category_target_train_good_rgb_path = os.path.join(category_target_train_good_path, 'rgb')
111 | category_target_train_good_xyz_path = os.path.join(category_target_train_good_path, 'xyz')
112 | os.mkdir(category_target_train_good_path)
113 | os.mkdir(category_target_train_good_rgb_path)
114 | os.mkdir(category_target_train_good_xyz_path)
115 |
116 | os.mkdir(os.path.join(category_target_path, 'test'))
117 | category_target_test_good_path = os.path.join(category_target_path, 'test/good')
118 | category_target_test_good_rgb_path = os.path.join(category_target_test_good_path, 'rgb')
119 | category_target_test_good_xyz_path = os.path.join(category_target_test_good_path, 'xyz')
120 | category_target_test_good_gt_path = os.path.join(category_target_test_good_path, 'gt')
121 | os.mkdir(category_target_test_good_path)
122 | os.mkdir(category_target_test_good_rgb_path)
123 | os.mkdir(category_target_test_good_xyz_path)
124 | os.mkdir(category_target_test_good_gt_path)
125 | category_target_test_bad_path = os.path.join(category_target_path, 'test/bad')
126 | category_target_test_bad_rgb_path = os.path.join(category_target_test_bad_path, 'rgb')
127 | category_target_test_bad_xyz_path = os.path.join(category_target_test_bad_path, 'xyz')
128 | category_target_test_bad_gt_path = os.path.join(category_target_test_bad_path, 'gt')
129 | os.mkdir(category_target_test_bad_path)
130 | os.mkdir(category_target_test_bad_rgb_path)
131 | os.mkdir(category_target_test_bad_xyz_path)
132 | os.mkdir(category_target_test_bad_gt_path)
133 |
134 | category_train_files = os.listdir(category_train_path)
135 | num_train_files = len(category_train_files)//17
136 | for i in range(0, num_train_files):
137 | pc = depth_to_pointcloud(
138 | os.path.join(category_train_path,str(i).zfill(3)+'_depth.png'),
139 | os.path.join(category_train_path,str(i).zfill(3)+'_info_depth.yaml'),
140 | os.path.join(category_train_path,str(i).zfill(3)+'_pose.txt'),
141 | FOCAL_LENGTH,
142 | )
143 | pc = remove_point_cloud_background(pc)
144 | pc = pc.reshape(512,512,3)
145 | tifffile.imwrite(os.path.join(category_target_train_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)
146 | copyfile(os.path.join(category_train_path,str(i).zfill(3)+'_image_4.png'),os.path.join(category_target_train_good_rgb_path, str(i).zfill(3)+'.png'))
147 |
148 |
149 | category_test_files = os.listdir(category_test_path)
150 | num_test_files = len(category_test_files)//17
151 | for i in range(0, num_test_files):
152 | mask = cv2.imread(os.path.join(category_test_path,str(i).zfill(2)+'_mask.png'))
153 | if np.any(mask):
154 | pc = depth_to_pointcloud(
155 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),
156 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),
157 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),
158 | FOCAL_LENGTH,
159 | )
160 | pc = remove_point_cloud_background(pc)
161 | pc = pc.reshape(512,512,3)
162 | tifffile.imwrite(os.path.join(category_target_test_bad_xyz_path, str(i).zfill(3)+'.tiff'), pc)
163 | cv2.imwrite(os.path.join(category_target_test_bad_gt_path, str(i).zfill(3)+'.png'), mask)
164 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_bad_rgb_path, str(i).zfill(3)+'.png'))
165 | else:
166 | pc = depth_to_pointcloud(
167 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),
168 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),
169 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),
170 | FOCAL_LENGTH,
171 | )
172 | pc = remove_point_cloud_background(pc)
173 | pc = pc.reshape(512,512,3)
174 | tifffile.imwrite(os.path.join(category_target_test_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)
175 | cv2.imwrite(os.path.join(category_target_test_good_gt_path, str(i).zfill(3)+'.png'), mask)
176 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_good_rgb_path, str(i).zfill(3)+'.png'))
177 |
--------------------------------------------------------------------------------
/processing/preprocess_mvtec.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tifffile as tiff
4 | import open3d as o3d
5 | from pathlib import Path
6 | from PIL import Image
7 | import math
8 | import utils.mvtec3d_utils as mvt_util
9 | import argparse
10 |
11 |
12 | def get_edges_of_pc(organized_pc):
13 | unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2])
14 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0)
15 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0)
16 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0)
17 | unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:]
18 | return unorganized_edges_pc
19 |
20 | def get_plane_eq(unorganized_pc,ransac_n_pts=50):
21 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc))
22 | plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000)
23 | return plane_model
24 |
25 | def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005):
26 | # PREP PC
27 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean)
28 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
29 | clean_planeless_unorganized_pc = unorganized_pc.copy()
30 | planeless_unorganized_rgb = unorganized_rgb.copy()
31 |
32 | # REMOVE PLANE
33 | plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean))
34 | distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T))
35 | plane_indices = np.argwhere(distances < distance_threshold)
36 |
37 | planeless_unorganized_rgb[plane_indices] = 0
38 | clean_planeless_unorganized_pc[plane_indices] = 0
39 | clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0],
40 | organized_pc_clean.shape[1],
41 | organized_pc_clean.shape[2])
42 | planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0],
43 | organized_rgb.shape[1],
44 | organized_rgb.shape[2])
45 | return clean_planeless_organized_pc, planeless_organized_rgb
46 |
47 |
48 |
49 | def connected_components_cleaning(organized_pc, organized_rgb, image_path):
50 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc)
51 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
52 |
53 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
54 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :]
55 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros))
56 | labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False))
57 |
58 |
59 | unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True)
60 | max_label = labels.max()
61 | if max_label>0:
62 | print("##########################################################################")
63 | print(f"Point cloud file {image_path} has {max_label + 1} clusters")
64 | print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}")
65 | print("##########################################################################\n\n")
66 |
67 | largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)]
68 | outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id)
69 | outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array]
70 | unorganized_pc[outlier_indices_original_pc_array] = 0
71 | unorganized_rgb[outlier_indices_original_pc_array] = 0
72 | organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0],
73 | organized_pc.shape[1],
74 | organized_pc.shape[2])
75 | organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0],
76 | organized_rgb.shape[1],
77 | organized_rgb.shape[2])
78 | return organized_clustered_pc, organized_clustered_rgb
79 |
80 | def roundup_next_100(x):
81 | return int(math.ceil(x / 100.0)) * 100
82 |
83 | def pad_cropped_pc(cropped_pc, single_channel=False):
84 | orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1]
85 | round_orig_h = roundup_next_100(orig_h)
86 | round_orig_w = roundup_next_100(orig_w)
87 | large_side = max(round_orig_h, round_orig_w)
88 |
89 | a = (large_side - orig_h) // 2
90 | aa = large_side - a - orig_h
91 |
92 | b = (large_side - orig_w) // 2
93 | bb = large_side - b - orig_w
94 | if single_channel:
95 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant')
96 | else:
97 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant')
98 |
99 | def preprocess_pc(tiff_path):
100 | # READ FILES
101 | organized_pc = mvt_util.read_tiff_organized_pc(tiff_path)
102 | rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "png")
103 | gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png")
104 | organized_rgb = np.array(Image.open(rgb_path))
105 |
106 | organized_gt = None
107 | gt_exists = os.path.isfile(gt_path)
108 | if gt_exists:
109 | organized_gt = np.array(Image.open(gt_path))
110 |
111 | # REMOVE PLANE
112 | planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb)
113 |
114 |
115 | # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE)
116 | padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False)
117 | padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False)
118 | #if gt_exists:
119 | # padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True)
120 |
121 | organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path)
122 | # SAVE PREPROCESSED FILES
123 | tiff.imsave(tiff_path, organized_clustered_pc)
124 | #Image.fromarray(organized_clustered_rgb).save(rgb_path)
125 | #if gt_exists:
126 | # Image.fromarray(padded_organized_gt).save(gt_path)
127 |
128 |
129 |
130 | if __name__ == '__main__':
131 | parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD')
132 | parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)')
133 | args = parser.parse_args()
134 |
135 |
136 | root_path = args.dataset_path
137 | paths = Path(root_path).rglob('*.tiff')
138 | print(f"Found {len(list(paths))} tiff files in {root_path}")
139 | processed_files = 0
140 | for path in Path(root_path).rglob('*.tiff'):
141 | preprocess_pc(path)
142 | processed_files += 1
143 | if processed_files % 50 == 0:
144 | print(f"Processed {processed_files} tiff files...")
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | imageio==2.26.0
2 | matplotlib==3.8.3
3 | numpy==1.23.1
4 | open3d==0.18.0
5 | opencv_contrib_python==4.9.0.80
6 | opencv_python==4.7.0.72
7 | opencv_python_headless==4.9.0.80
8 | pandas==1.5.3
9 | Pillow==9.4.0
10 | Pillow==10.3.0
11 | pointnet2_ops==3.0.0
12 | PyYAML==6.0.1
13 | PyYAML==6.0.1
14 | scikit_learn==1.2.1
15 | scipy==1.9.1
16 | tifffile==2023.4.12
17 | timm==1.0.3
18 | torch==1.13.1
19 | torchvision==0.14.1
20 | tqdm==4.65.0
21 | wandb==0.15.5
22 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | from torchvision import transforms
5 | from PIL import ImageFilter
6 |
7 | def set_seeds(seed: int = 0) -> None:
8 | np.random.seed(seed)
9 | random.seed(seed)
10 | torch.manual_seed(seed)
11 |
12 | class KNNGaussianBlur(torch.nn.Module):
13 | def __init__(self, radius : int = 4):
14 | super().__init__()
15 | self.radius = radius
16 | self.unload = transforms.ToPILImage()
17 | self.load = transforms.ToTensor()
18 | self.blur_kernel = ImageFilter.GaussianBlur(radius=self.radius)
19 |
20 | def __call__(self, img):
21 | img = img.unsqueeze(dim=0)
22 | map_max = img.max()
23 | final_map = self.load(self.unload(img[0] / map_max).filter(self.blur_kernel)).to('cuda') * map_max
24 | return final_map.squeeze()
25 |
26 | class SquarePad:
27 | def __call__(self, image):
28 | max_wh = max(image.size)
29 | p_left, p_top = [(max_wh - s) // 2 for s in image.size]
30 | p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
31 | padding = (p_left, p_top, p_right, p_bottom)
32 | return transforms.functional.pad(image, padding, padding_mode = 'edge')
--------------------------------------------------------------------------------
/utils/metrics_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Code based on the official MVTec 3D-AD evaluation code found at
3 | https://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz
4 |
5 | Utility functions that compute a PRO curve and its definite integral, given
6 | pairs of anomaly and ground truth maps.
7 |
8 | The PRO curve can also be integrated up to a constant integration limit.
9 | """
10 | import numpy as np
11 | import sklearn
12 | from scipy.ndimage.measurements import label
13 | from bisect import bisect
14 |
15 |
16 | class GroundTruthComponent:
17 | """
18 | Stores sorted anomaly scores of a single ground truth component.
19 | Used to efficiently compute the region overlap for many increasing thresholds.
20 | """
21 |
22 | def __init__(self, anomaly_scores):
23 | """
24 | Initialize the module.
25 |
26 | Args:
27 | anomaly_scores: List of all anomaly scores within the ground truth
28 | component as numpy array.
29 | """
30 | # Keep a sorted list of all anomaly scores within the component.
31 | self.anomaly_scores = anomaly_scores.copy()
32 | self.anomaly_scores.sort()
33 |
34 | # Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels.
35 | self.index = 0
36 |
37 | # The last evaluated threshold.
38 | self.last_threshold = None
39 |
40 | def compute_overlap(self, threshold):
41 | """
42 | Compute the region overlap for a specific threshold.
43 | Thresholds must be passed in increasing order.
44 |
45 | Args:
46 | threshold: Threshold to compute the region overlap.
47 |
48 | Returns:
49 | Region overlap for the specified threshold.
50 | """
51 | if self.last_threshold is not None:
52 | assert self.last_threshold <= threshold
53 |
54 | # Increase the index until it points to an anomaly score that is just above the specified threshold.
55 | while (self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold):
56 | self.index += 1
57 |
58 | # Compute the fraction of component pixels that are correctly segmented as anomalous.
59 | return 1.0 - self.index / len(self.anomaly_scores)
60 |
61 |
62 | def trapezoid(x, y, x_max=None):
63 | """
64 | This function calculates the definit integral of a curve given by x- and corresponding y-values.
65 | In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by
66 | setting a value x_max.
67 |
68 | Points that do not have a finite x or y value will be ignored with a warning.
69 |
70 | Args:
71 | x: Samples from the domain of the function to integrate need to be sorted in ascending order. May contain
72 | the same value multiple times. In that case, the order of the corresponding y values will affect the
73 | integration with the trapezoidal rule.
74 | y: Values of the function corresponding to x values.
75 | x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its
76 | neighbors. Must not lie outside of the range of x.
77 |
78 | Returns:
79 | Area under the curve.
80 | """
81 |
82 | x = np.array(x)
83 | y = np.array(y)
84 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y))
85 | if not finite_mask.all():
86 | print(
87 | """WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.""")
88 | x = x[finite_mask]
89 | y = y[finite_mask]
90 |
91 | # Introduce a correction term if max_x is not an element of x.
92 | correction = 0.
93 | if x_max is not None:
94 | if x_max not in x:
95 | # Get the insertion index that would keep x sorted after np.insert(x, ins, x_max).
96 | ins = bisect(x, x_max)
97 | # x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x).
98 | assert 0 < ins < len(x)
99 |
100 | # Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not
101 | # know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1].
102 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1]))
103 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1])
104 |
105 | # Cut off at x_max.
106 | mask = x <= x_max
107 | x = x[mask]
108 | y = y[mask]
109 |
110 | # Return area under the curve using the trapezoidal rule.
111 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction
112 |
113 |
114 | def collect_anomaly_scores(anomaly_maps, ground_truth_maps):
115 | """
116 | Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false
117 | positive pixel from anomaly maps.
118 |
119 | Args:
120 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
121 |
122 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
123 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
124 | an anomaly.
125 |
126 | Returns:
127 | ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each
128 | component, a sorted list of its anomaly scores is stored.
129 |
130 | anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list
131 | can be used to quickly select thresholds that fix a certain false positive rate.
132 | """
133 | # Make sure an anomaly map is present for each ground truth map.
134 | assert len(anomaly_maps) == len(ground_truth_maps)
135 |
136 | # Initialize ground truth components and scores of potential fp pixels.
137 | ground_truth_components = []
138 | anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size)
139 |
140 | # Structuring element for computing connected components.
141 | structure = np.ones((3, 3), dtype=int)
142 |
143 | # Collect anomaly scores within each ground truth region and for all potential fp pixels.
144 | ok_index = 0
145 | for gt_map, prediction in zip(ground_truth_maps, anomaly_maps):
146 |
147 | # Compute the connected components in the ground truth map.
148 | labeled, n_components = label(gt_map, structure)
149 |
150 | # Store all potential fp scores.
151 | num_ok_pixels = len(prediction[labeled == 0])
152 | anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy()
153 | ok_index += num_ok_pixels
154 |
155 | # Fetch anomaly scores within each GT component.
156 | for k in range(n_components):
157 | component_scores = prediction[labeled == (k + 1)]
158 | ground_truth_components.append(GroundTruthComponent(component_scores))
159 |
160 | # Sort all potential false positive scores.
161 | anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index)
162 | anomaly_scores_ok_pixels.sort()
163 |
164 | return ground_truth_components, anomaly_scores_ok_pixels
165 |
166 |
167 | def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds):
168 | """
169 | Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground
170 | truth maps. The number of interpolation points can be set manually.
171 |
172 | Args:
173 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
174 |
175 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
176 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
177 | an anomaly.
178 |
179 | num_thresholds: Number of thresholds to compute the PRO curve.
180 | Returns:
181 | fprs: List of false positive rates.
182 | pros: List of correspoding PRO values.
183 | """
184 | # Fetch sorted anomaly scores.
185 | ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps)
186 |
187 | # Select equidistant thresholds.
188 | threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int)
189 |
190 | fprs = [1.0]
191 | pros = [1.0]
192 | thr = [0.0]
193 |
194 | for pos in threshold_positions:
195 | threshold = anomaly_scores_ok_pixels[pos]
196 |
197 | # Compute the false positive rate for this threshold.
198 | fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels)
199 |
200 | # Compute the PRO value for this threshold.
201 | pro = 0.0
202 | for component in ground_truth_components:
203 | pro += component.compute_overlap(threshold)
204 | pro /= len(ground_truth_components)
205 |
206 | fprs.append(fpr)
207 | pros.append(pro)
208 | thr.append(threshold)
209 |
210 | # Return (FPR/PRO) pairs in increasing FPR order.
211 | fprs = fprs[::-1]
212 | pros = pros[::-1]
213 | thr = thr[::-1]
214 |
215 | return fprs, pros, thr
216 |
217 |
218 | def calculate_au_pro(gts, predictions, integration_limit = [0.3, 0.1, 0.05, 0.01], num_thresholds = 99):
219 | """
220 | Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images.
221 | Args:
222 | gts: List of tensors that contain the ground truth images for a single dataset object.
223 | predictions: List of tensors containing anomaly images for each ground truth image.
224 | integration_limit: Integration limit to use when computing the area under the PRO curve.
225 | num_thresholds: Number of thresholds to use to sample the area under the PRO curve.
226 |
227 | Returns:
228 | au_pro: Area under the PRO curve computed up to the given integration limit.
229 | pro_curve: PRO curve values for localization (fpr,pro).
230 | """
231 | # Compute the PRO curve.
232 | pro_curve = compute_pro(anomaly_maps = predictions, ground_truth_maps = gts, num_thresholds = num_thresholds)
233 |
234 | au_pros = []
235 |
236 | # Compute the area under the PRO curve.
237 | for int_lim in integration_limit:
238 | au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max = int_lim)
239 | au_pro /= int_lim
240 |
241 | au_pros.append(au_pro)
242 |
243 | # Return the evaluation metrics.
244 | return au_pros, pro_curve
245 |
246 |
247 | def calculate_au_prc(gts, predictions):
248 | """
249 | Compute the area under the PRC curve for a set of ground truth images and corresponding anomaly images.
250 | Args:
251 | gts: List of tensors that contain the ground truth images for a single dataset object.
252 | predictions: List of tensors containing anomaly images for each ground truth image.
253 | Returns:
254 | au_prc: Area under the PRC curve.
255 | """
256 | # Compute the PRC curve.
257 | fpr, tpr, _ = sklearn.metrics.roc_curve(gts, predictions)
258 | au_prc = sklearn.metrics.auc(fpr, tpr)
259 |
260 | return au_prc
--------------------------------------------------------------------------------
/utils/mvtec3d_utils.py:
--------------------------------------------------------------------------------
1 | import tifffile as tiff
2 | import torch
3 |
4 |
5 | def organized_pc_to_unorganized_pc(organized_pc):
6 | return organized_pc.reshape(organized_pc.shape[0] * organized_pc.shape[1], organized_pc.shape[2])
7 |
8 |
9 | def read_tiff_organized_pc(path):
10 | tiff_img = tiff.imread(path)
11 | return tiff_img
12 |
13 |
14 | def resize_organized_pc(organized_pc, target_height=224, target_width=224, tensor_out=True):
15 | torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous()
16 | torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=(target_height, target_width),
17 | mode='nearest')
18 | if tensor_out:
19 | return torch_resized_organized_pc.squeeze(dim=0).contiguous()
20 | else:
21 | return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy()
22 |
23 |
24 | def organized_pc_to_depth_map(organized_pc):
25 | return organized_pc[:, :, 2]
--------------------------------------------------------------------------------
/utils/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 | def timeit(tag, t):
8 | print("{}: {}s".format(tag, time() - t))
9 | return time()
10 |
11 | def pc_normalize(pc):
12 | l = pc.shape[0]
13 | centroid = np.mean(pc, axis=0)
14 | pc = pc - centroid
15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
16 | pc = pc / m
17 | return pc
18 |
19 |
20 | def square_distance(src, dst):
21 | """
22 | Calculate Euclid distance between each two points.
23 | src^T * dst = xn * xm + yn * ym + zn * zm;
24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
28 | Input:
29 | src: source points, [B, N, C]
30 | dst: target points, [B, M, C]
31 | Output:
32 | dist: per-point square distance, [B, N, M]
33 | """
34 | B, N, _ = src.shape
35 | _, M, _ = dst.shape
36 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
37 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
38 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
39 | return dist
40 |
41 |
42 | def index_points(points, idx):
43 | """
44 | Input:
45 | points: input points data, [B, N, C]
46 | idx: sample index data, [B, S]
47 | Return:
48 | new_points:, indexed points data, [B, S, C]
49 | """
50 |
51 | device = points.device
52 | B = points.shape[0]
53 | view_shape = list(idx.shape)
54 | view_shape[1:] = [1] * (len(view_shape) - 1)
55 | repeat_shape = list(idx.shape)
56 | repeat_shape[0] = 1
57 | batch_indices = torch.arange(B, dtype=torch.long, device=device).view(view_shape).repeat(repeat_shape)
58 | new_points = points[batch_indices, idx, :]
59 |
60 | return new_points
61 |
62 | def farthest_point_sample(xyz, npoint):
63 | """
64 | Input:
65 | xyz: pointcloud data, [B, N, 3]
66 | npoint: number of samples
67 | Return:
68 | centroids: sampled pointcloud index, [B, npoint]
69 | """
70 | device = xyz.device
71 | B, N, C = xyz.shape
72 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
73 | distance = torch.ones(B, N).to(device) * 1e10
74 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
75 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
76 | for i in range(npoint):
77 | centroids[:, i] = farthest
78 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
79 | dist = torch.sum((xyz - centroid) ** 2, -1)
80 | mask = dist < distance
81 | distance[mask] = dist[mask]
82 | farthest = torch.max(distance, -1)[1]
83 | return centroids
84 |
85 |
86 | def query_ball_point(radius, nsample, xyz, new_xyz):
87 | """
88 | Input:
89 | radius: local region radius
90 | nsample: max sample number in local region
91 | xyz: all points, [B, N, 3]
92 | new_xyz: query points, [B, S, 3]
93 | Return:
94 | group_idx: grouped points index, [B, S, nsample]
95 | """
96 | device = xyz.device
97 | B, N, C = xyz.shape
98 | _, S, _ = new_xyz.shape
99 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
100 | sqrdists = square_distance(new_xyz, xyz)
101 | group_idx[sqrdists > radius ** 2] = N
102 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
103 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
104 | mask = group_idx == N
105 | group_idx[mask] = group_first[mask]
106 | return group_idx
107 |
108 |
109 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
110 | """
111 | Input:
112 | npoint:
113 | radius:
114 | nsample:
115 | xyz: input points position data, [B, N, 3]
116 | points: input points data, [B, N, D]
117 | Return:
118 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
119 | new_points: sampled points data, [B, npoint, nsample, 3+D]
120 | """
121 | B, N, C = xyz.shape
122 | S = npoint
123 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
124 | new_xyz = index_points(xyz, fps_idx)
125 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
126 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
127 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
128 |
129 | if points is not None:
130 | grouped_points = index_points(points, idx)
131 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
132 | else:
133 | new_points = grouped_xyz_norm
134 | if returnfps:
135 | return new_xyz, new_points, grouped_xyz, fps_idx
136 | else:
137 | return new_xyz, new_points
138 |
139 |
140 | def sample_and_group_all(xyz, points):
141 | """
142 | Input:
143 | xyz: input points position data, [B, N, 3]
144 | points: input points data, [B, N, D]
145 | Return:
146 | new_xyz: sampled points position data, [B, 1, 3]
147 | new_points: sampled points data, [B, 1, N, 3+D]
148 | """
149 | device = xyz.device
150 | B, N, C = xyz.shape
151 | new_xyz = torch.zeros(B, 1, C).to(device)
152 | grouped_xyz = xyz.view(B, 1, N, C)
153 | if points is not None:
154 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
155 | else:
156 | new_points = grouped_xyz
157 | return new_xyz, new_points
158 |
159 |
160 | def interpolating_points(xyz1, xyz2, points2):
161 | """
162 | Input:
163 | xyz1: input points position data, [B, C, N]
164 | xyz2: sampled input points position data, [B, C, S]
165 | points2: input points data, [B, D, S]
166 | Return:
167 | new_points: upsampled points data, [B, D', N]
168 | """
169 | xyz1 = xyz1.permute(0, 2, 1)
170 | xyz2 = xyz2.permute(0, 2, 1)
171 |
172 | points2 = points2.permute(0, 2, 1)
173 | B, N, C = xyz1.shape
174 | _, S, _ = xyz2.shape
175 |
176 | if S == 1:
177 | interpolated_points = points2.repeat(1, N, 1)
178 | else:
179 | dists = square_distance(xyz1, xyz2)
180 | dists, idx = dists.sort(dim=-1)
181 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
182 | dist_recip = 1.0 / (dists + 1e-8)
183 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
184 | weight = dist_recip / norm
185 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
186 | interpolated_points = interpolated_points.permute(0, 2, 1)
187 |
188 | return interpolated_points
--------------------------------------------------------------------------------