├── .idea
├── .gitignore
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── main.iml
├── train.sh
├── utils
├── lr_sched.py
├── utils.py
├── mvtec3d_util.py
├── preprocessing1.py
├── preprocess_eyecandies.py
├── preprocessing.py
├── au_pro_util.py
└── misc.py
├── requirement.txt
├── feature_extractors
├── change_ex.py
└── features.py
├── engine_fusion_pretrain.py
├── .gitignore
├── README.md
├── models
├── pointnet2_utils.py
├── feature_fusion.py
└── models.py
├── m3dm_runner1.py
├── fusion_pretrain.py
├── main.py
├── LICENSE
├── dataset2.py
└── dataset.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python3 main.py \
2 | --method_name DINO+Point_MAE+Fusion \
3 | --use_uff \
4 | --memory_bank multiple \
5 | --rgb_backbone_name vit_base_patch8_224_dino \
6 | --xyz_backbone_name Point_MAE \
7 | --fusion_module_path checkpoints/checkpoint-0.pth
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/main.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/utils/lr_sched.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | def adjust_learning_rate(optimizer, epoch, args):
4 | """Decay the learning rate with half-cycle cosine after warmup"""
5 | if epoch < args.warmup_epochs:
6 | lr = args.lr * epoch / args.warmup_epochs
7 | else:
8 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
9 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
10 | for param_group in optimizer.param_groups:
11 | if "lr_scale" in param_group:
12 | param_group["lr"] = lr * param_group["lr_scale"]
13 | else:
14 | param_group["lr"] = lr
15 | return lr
16 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | from torchvision import transforms
5 | from PIL import ImageFilter
6 |
7 | def set_seeds(seed: int = 0) -> None:
8 | np.random.seed(seed)
9 | random.seed(seed)
10 | torch.manual_seed(seed)
11 |
12 | class KNNGaussianBlur(torch.nn.Module):
13 | def __init__(self, radius : int = 4):
14 | super().__init__()
15 | self.radius = radius
16 | self.unload = transforms.ToPILImage()
17 | self.load = transforms.ToTensor()
18 | self.blur_kernel = ImageFilter.GaussianBlur(radius=4)
19 |
20 | def __call__(self, img):
21 | map_max = img.max()
22 | final_map = self.load(self.unload(img[0] / map_max).filter(self.blur_kernel)) * map_max
23 | return final_map
24 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | * python>=3.9
2 | * torch>=2.5.1 (preferably with CUDA support, e.g., `torch>=2.5.1+cu121`)
3 | * torchvision>=0.20.1
4 | * torchaudio>=2.5.1
5 | * numpy>=1.26.3
6 | * opencv-python>=4.11.0
7 | * Pillow>=11.0.0 (PIL fork)
8 | * matplotlib>=3.4.2 (for plotting)
9 | * scikit-learn>=1.6.1
10 | * scikit-image>=0.24.0
11 | * einops>=0.8.1
12 | * timm>=1.0.15 (PyTorch Image Models)
13 |
14 | ### Utilities & Configuration:
15 | * pyyaml>=5.4.1 (for YAML configuration files)
16 | * tqdm>=4.67.1 (for progress bars)
17 | * easydict>=1.13 (or addict>=2.4.0, for easy dict access)
18 | * pandas>=2.2.3 (for data manipulation, if applicable)
19 | * h5py>=3.13.0 (if using HDF5 files)
20 |
21 | ### Experiment Tracking & Model Hubs (if used):
22 | * wandb>=0.19.10 (Weights & Biases)
23 | * huggingface-hub>=0.29.3
24 |
25 | ### Potentially 3D-Specific (include if your project uses 3D data):
26 | * open3d>=0.19.0
27 | * pointnet2-ops>=3.0.0
28 | * chamferdist>=1.0.3
29 |
--------------------------------------------------------------------------------
/utils/mvtec3d_util.py:
--------------------------------------------------------------------------------
1 | import tifffile as tiff
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def organized_pc_to_unorganized_pc(organized_pc):
7 | print(f"organized_pc shape: {organized_pc.shape}")
8 | return organized_pc.reshape(organized_pc.shape[0] * organized_pc.shape[1], organized_pc.shape[2])
9 |
10 |
11 | def read_tiff_organized_pc(path):
12 | tiff_img = tiff.imread(path)
13 | # nan_count = np.isnan(tiff_img).sum()
14 | # nan_percentage = nan_count / tiff_img.size * 100
15 | # print(f"NaN percentage: {nan_percentage:.2f}%")
16 | # 计算均值,忽略 NaN 值
17 | # mean_min = np.nanmin(tiff_img)
18 |
19 | # # 将 NaN 值替换为均值
20 | # tiff_img[np.isnan(tiff_img)] = mean_min
21 | # 将NaN值替换为0
22 | # tiff_img[np.isnan(tiff_img)] = 0
23 | #tiff_img = np.resize(tiff_img, (800, 800))
24 | # 获取第三通道非nan的最小值
25 | # min_z = np.nanmin(tiff_img[:, :, 2])
26 |
27 | # # 如果第三通道最小值小于等于0
28 | # if min_z <= 0:
29 | # # 第三通道非nan的值减去最小值再加1
30 | # mask = ~np.isnan(tiff_img[:, :, 2])
31 | # tiff_img[mask, 2] = tiff_img[mask, 2] - min_z + 1
32 | # # 将NaN值置0
33 | #tiff_img[np.isnan(tiff_img)] = 0
34 | return tiff_img
35 |
36 |
37 | def resize_organized_pc(organized_pc, target_height=224, target_width=224, tensor_out=True):
38 | torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous()
39 | torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=(target_height, target_width),
40 | mode='nearest')
41 | if tensor_out:
42 | return torch_resized_organized_pc.squeeze(dim=0).contiguous()
43 | else:
44 | return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy()
45 |
46 |
47 | def organized_pc_to_depth_map(organized_pc):
48 | return organized_pc[:, :, 2]
49 |
--------------------------------------------------------------------------------
/feature_extractors/change_ex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class ChannelExchange(nn.Module):
5 | def __init__(self, p=2):
6 | super().__init__()
7 | self.p = p # 1/p of the features will be exchanged.
8 |
9 | def forward(self, x0, x1):
10 | # x0, x1: the bi-temporal feature maps.
11 | N, C, H, W = x0.shape
12 | exchange_mask = torch.arange(C) % self.p == 0
13 | exchange_mask = exchange_mask.unsqueeze(0).expand((N, -1))
14 |
15 | out_x0 = torch.zeros_like(x0)
16 | out_x1 = torch.zeros_like(x1)
17 |
18 | out_x0[~exchange_mask] = x0[~exchange_mask]
19 | out_x1[~exchange_mask] = x1[~exchange_mask]
20 | out_x0[exchange_mask] = x1[exchange_mask]
21 | out_x1[exchange_mask] = x0[exchange_mask]
22 |
23 | return out_x0, out_x1
24 |
25 | class SpatialExchange(nn.Module):
26 | def __init__(self, p=2):
27 | super().__init__()
28 | self.p = p # 1/p of the features will be exchanged.
29 |
30 | def forward(self, x0, x1):
31 | # x0, x1: the bi-temporal feature maps.
32 | N, C, H, W = x0.shape
33 | # Create a mask based on width dimension
34 | exchange_mask = torch.arange(W, device=x0.device) % self.p == 0
35 | # Expand mask to match feature dimensions
36 | exchange_mask = exchange_mask.view(1, 1, 1, W).expand(N, C, H, -1)
37 |
38 | out_x0 = x0.clone()
39 | out_x1 = x1.clone()
40 |
41 | # Perform column-wise exchange
42 | out_x0[..., exchange_mask] = x1[..., exchange_mask]
43 | out_x1[..., exchange_mask] = x0[..., exchange_mask]
44 |
45 | return out_x0, out_x1
46 |
47 | class CombinedExchange(nn.Module):
48 | def __init__(self, p=2):
49 | super().__init__()
50 | self.channel_exchange = ChannelExchange(p=p)
51 | self.spatial_exchange = SpatialExchange(p=p)
52 |
53 | def forward(self, x0, x1):
54 | # First perform channel exchange
55 | x0, x1 = self.channel_exchange(x0, x1)
56 | # Then perform spatial exchange
57 | x0, x1 = self.spatial_exchange(x0, x1)
58 | return x0, x1
59 |
--------------------------------------------------------------------------------
/engine_fusion_pretrain.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 |
5 | import torch
6 |
7 | import utils.misc as misc
8 | import utils.lr_sched as lr_sched
9 |
10 | def check_max_gradient(model):
11 | max_grad = 0.0
12 | max_grad_param_name = ""
13 |
14 | # 遍历模型所有参数,找到具有最大梯度的参数
15 | for name, param in model.named_parameters():
16 | if param.grad is not None:
17 | grad_norm = param.grad.norm(2).item()
18 | if grad_norm > max_grad:
19 | max_grad = grad_norm
20 | max_grad_param_name = name
21 |
22 | print(f"Maximum gradient norm: {max_grad} (Parameter: {max_grad_param_name})")
23 |
24 | # 可以根据 max_grad 值判断是否有梯度爆炸的迹象
25 | if max_grad > 20000: # 可根据任务的实际情况设置阈值
26 | print("Warning: Potential gradient explosion detected!")
27 |
28 |
29 |
30 | def train_one_epoch(model: torch.nn.Module,
31 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
32 | device: torch.device, epoch: int, loss_scaler,
33 | log_writer=None,
34 | args=None):
35 | model.train(True)
36 | metric_logger = misc.MetricLogger(delimiter=" ")
37 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
38 | header = 'Epoch: [{}]'.format(epoch)
39 | print_freq = 20
40 |
41 | accum_iter = args.accum_iter
42 |
43 | optimizer.zero_grad()
44 |
45 | if log_writer is not None:
46 | print('log_dir: {}'.format(log_writer.log_dir))
47 |
48 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
49 |
50 | # we use a per iteration (instead of per epoch) lr scheduler
51 | if data_iter_step % accum_iter == 0:
52 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
53 |
54 |
55 | xyz_samples = samples[:,:,:1152].to(device, non_blocking=True)
56 | rgb_samples = samples[:,:,1152:].to(device, non_blocking=True)
57 |
58 | with torch.cuda.amp.autocast():
59 | # print('model device:', model.device)
60 | # print('data device:', xyz_samples.device)
61 | loss = model(xyz_samples, rgb_samples)
62 |
63 | check_max_gradient(model)
64 | loss_value = loss.item()
65 |
66 | if not math.isfinite(loss_value):
67 | print("Loss is {}, stopping training".format(loss_value))
68 | sys.exit(1)
69 |
70 | loss /= accum_iter
71 | loss_scaler(loss, optimizer, parameters=model.parameters(),
72 | update_grad=(data_iter_step + 1) % accum_iter == 0)
73 | if (data_iter_step + 1) % accum_iter == 0:
74 | optimizer.zero_grad()
75 |
76 | torch.cuda.synchronize()
77 |
78 | metric_logger.update(loss=loss_value)
79 |
80 | lr = optimizer.param_groups[0]["lr"]
81 | metric_logger.update(lr=lr)
82 |
83 |
84 | loss_value_reduce = misc.all_reduce_mean(loss_value)
85 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
86 | """ We use epoch_1000x as the x-axis in tensorboard.
87 | This calibrates different curves when batch size changes.
88 | """
89 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
90 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
91 | log_writer.add_scalar('lr', lr, epoch_1000x)
92 |
93 |
94 | # gather the stats from all processes
95 | metric_logger.synchronize_between_processes()
96 | print("Averaged stats:", metric_logger)
97 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | D³M is a comprehensive framework for industrial anomaly detection, localization, and classification, leveraging a variety of deep learning-based feature extractors for 2D (RGB) and 3D (Point Cloud) data, along with flexible fusion strategies. This repository provides an implementation of the D3M framework and tools for evaluating its performance on standard and custom datasets.
2 |
3 | The project is designed to explore and benchmark different uni-modal and multi-modal approaches for anomaly detection in industrial settings, inspired by the challenges of real-world defect identification.
4 |
5 | ## 📄 Citation
6 |
7 | If you find this work useful in your research, please consider citing our associated paper:
8 |
9 | ```bibtex
10 | @article{zhu2025real,
11 | title={Real-IAD D3: A Real-World 2D/Pseudo-3D/3D Dataset for Industrial Anomaly Detection},
12 | author={Zhu, Wenbing and Wang, Lidong and Zhou, Ziqing and Wang, Chengjie and Pan, Yurui and Zhang, Ruoyi and Chen, Zhuhao and Cheng, Linjie and Gao, Bin-Bin and Zhang, Jiangning and others},
13 | journal={arXiv preprint arXiv:2504.14221},
14 | year={2025}
15 | }
16 | ```
17 |
18 | ---
19 |
20 | ## ✨ Features
21 |
22 | * **Versatile Anomaly Detection Framework:** Supports a wide range of feature extraction and fusion methodologies.
23 | * **Multi-Modal Support:** Handles both 2D RGB images and 3D point cloud data.
24 | * **RGB Feature Extractors:** Utilizes pre-trained models like DINO (e.g., `vit_base_patch8_224_dino`).
25 | * **Point Cloud Feature Extractors:** Incorporates models like PointMAE, PointBert, and traditional FPFH.
26 | * **Flexible Fusion Strategies:** Implements various early and late fusion techniques for combining multi-modal features (e.g., simple addition, dedicated fusion modules).
27 | * **Multiple Method Configurations:** Easily experiment with different combinations of backbones and fusion approaches via command-line arguments (e.g., `DINO`, `Point_MAE`, `DINO+Point_MAE`, `DINO+Point_MAE+Fusion`, and custom "ours" variants).
28 | * **Dataset Compatibility:**
29 | * Works with standard 3D anomaly detection datasets like MVTec 3D-AD and Eyecandies.
30 | * Easily adaptable to custom datasets (`test_3d` option).
31 | * **Comprehensive Evaluation:**
32 | * Calculates image-level ROCAUC.
33 | * Calculates pixel/point-level ROCAUC for localization.
34 | * Calculates AU-PRO scores.
35 | * **Memory Bank & Coreset Subsampling:** Implements memory bank concepts and coreset subsampling for efficient training and inference, particularly with large feature sets.
36 | * **Extensible:** Designed to be easily extended with new feature extractors, fusion modules, or datasets.
37 | * **Result Visualization:** Option to save prediction maps for qualitative analysis.
38 |
39 | ---
40 |
41 | ## 🚀 Getting Started
42 |
43 | ### Prerequisites
44 |
45 | * Python >= 3.9
46 | * PyTorch >= 2.5.1 (preferably with CUDA support for GPU acceleration)
47 | * Other dependencies as listed in `requirements.txt`
48 |
49 | ### Installation
50 |
51 | 1. **Clone the repository:**
52 | ```bash
53 | git clone [https://github.com/your-username/d3m.git](https://github.com/your-username/d3m.git)
54 | cd d3m
55 | ```
56 |
57 | 2. **Create a virtual environment (recommended):**
58 | ```bash
59 | python -m venv venv
60 | source venv/bin/activate # On Windows use `venv\Scripts\activate`
61 | ```
62 | Or using Conda:
63 | ```bash
64 | conda create -n d3m python=3.9
65 | conda activate d3m
66 | ```
67 |
68 | 3. **Install dependencies:**
69 | A `requirements.txt` file is recommended. Create one with the following content or adapt it to your precise needs:
70 | ```text
71 | # requirements.txt
72 |
73 | # Core Deep Learning & Computer Vision
74 | torch>=2.5.1 # Install with specific CUDA version if needed, e.g., torch>=2.5.1+cu121
75 | torchvision>=0.20.1
76 | torchaudio>=2.5.1
77 | numpy>=1.26.3
78 | opencv-python>=4.11.0
79 | Pillow>=11.0.0
80 | matplotlib>=3.4.2
81 | scikit-learn>=1.6.1
82 | scikit-image>=0.24.0
83 | einops>=0.8.1
84 | timm>=1.0.15
85 |
86 | # Utilities & Configuration
87 | pyyaml>=5.4.1
88 | tqdm>=4.67.1
89 | easydict>=1.13 # or addict>=2.4.0
90 | pandas>=2.2.3
91 | h5py>=3.13.0
92 |
93 | # Experiment Tracking & Model Hubs (Optional, uncomment if used)
94 | # wandb>=0.19.10
95 | # huggingface-hub>=0.29.3
96 |
97 | # Potentially 3D-Specific (Uncomment if your project uses these)
98 | # open3d>=0.19.0
99 | # pointnet2-ops>=3.0.0 # May require custom compilation
100 | # chamferdist>=1.0.3
101 | ```
102 | Then install using:
103 | ```bash
104 | pip install -r requirements.txt
105 | ```
106 | **Note on PyTorch:** Ensure you install the PyTorch version compatible with your CUDA toolkit. Refer to the [official PyTorch website](https://pytorch.org/) for installation instructions.
107 |
108 | ---
109 |
110 | ## 💻 Usage
111 |
112 | The main script for running experiments is likely named `your_main_script_name.py` (the one containing the `run_3d_ads` function and `argparse` definitions). You can configure experiments using command-line arguments.
113 |
114 | ### Example Command
115 |
116 | ```bash
117 | python your_main_script_name.py \
118 | --dataset_type mvtec3d \
119 | --dataset_path /path/to/your/mvtec3d_anomaly_detection \
120 | --method_name DINO+Point_MAE+Fusion \
121 | --rgb_backbone_name vit_base_patch8_224_dino \
122 | --xyz_backbone_name Point_MAE \
123 | --fusion_module_path /path/to/your/fusion_checkpoint.pth \
124 | --img_size 224 \
125 | --max_sample 400 \
126 | --coreset_eps 0.9 \
127 | --save_preds
128 | # Add other arguments as needed
129 | ```
130 | Refer to the `argparse` section in the main script for a full list of available arguments and their descriptions.
131 |
132 | ---
133 |
134 | ## 📊 Output
135 |
136 | The script will print tables summarizing the performance metrics for each class and method:
137 |
138 | * **Image ROCAUC Results**
139 | * **Pixel ROCAUC Results**
140 | * **AU-PRO Results**
141 |
142 | If `--save_preds` is enabled, anomaly maps will be saved to the specified directory (default or `./pred_maps`).
143 |
144 | ---
145 |
146 | ## 🙏 Acknowledgements
147 |
148 | * This framework builds upon concepts from various SOTA anomaly detection and feature extraction literature.
149 | * Utilizes awesome libraries like PyTorch, TIMM, Open3D, etc.
150 |
151 | ---
152 |
153 | ## 📞 Contact
154 |
155 | For questions, issues, or suggestions, please open an issue in this repository.
156 |
157 | ```
158 |
--------------------------------------------------------------------------------
/models/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 | def timeit(tag, t):
8 | print("{}: {}s".format(tag, time() - t))
9 | return time()
10 |
11 | def pc_normalize(pc):
12 | l = pc.shape[0]
13 | centroid = np.mean(pc, axis=0)
14 | pc = pc - centroid
15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
16 | pc = pc / m
17 | return pc
18 |
19 | def square_distance(src, dst):
20 | """
21 | Calculate Euclid distance between each two points.
22 | src^T * dst = xn * xm + yn * ym + zn * zm;
23 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
24 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
25 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
26 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
27 | Input:
28 | src: source points, [B, N, C]
29 | dst: target points, [B, M, C]
30 | Output:
31 | dist: per-point square distance, [B, N, M]
32 | """
33 | B, N, _ = src.shape
34 | _, M, _ = dst.shape
35 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
36 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
37 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
38 | return dist
39 |
40 |
41 | def index_points(points, idx):
42 | """
43 | Input:
44 | points: input points data, [B, N, C]
45 | idx: sample index data, [B, S]
46 | Return:
47 | new_points:, indexed points data, [B, S, C]
48 | """
49 | device = points.device
50 | B = points.shape[0]
51 | view_shape = list(idx.shape)
52 | view_shape[1:] = [1] * (len(view_shape) - 1)
53 | repeat_shape = list(idx.shape)
54 | repeat_shape[0] = 1
55 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
56 | # print("points device:", points.device)
57 | # print("idx device:", idx.device)
58 | points = points.to(idx.device)
59 | new_points = points[batch_indices, idx, :]
60 | return new_points
61 |
62 |
63 | def farthest_point_sample(xyz, npoint):
64 | """
65 | Input:
66 | xyz: pointcloud data, [B, N, 3]
67 | npoint: number of samples
68 | Return:
69 | centroids: sampled pointcloud index, [B, npoint]
70 | """
71 | device = xyz.device
72 | B, N, C = xyz.shape
73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
74 | distance = torch.ones(B, N).to(device) * 1e10
75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
76 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
77 | for i in range(npoint):
78 | centroids[:, i] = farthest
79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
80 | dist = torch.sum((xyz - centroid) ** 2, -1)
81 | mask = dist < distance
82 | distance[mask] = dist[mask]
83 | farthest = torch.max(distance, -1)[1]
84 | return centroids
85 |
86 |
87 | def query_ball_point(radius, nsample, xyz, new_xyz):
88 | """
89 | Input:
90 | radius: local region radius
91 | nsample: max sample number in local region
92 | xyz: all points, [B, N, 3]
93 | new_xyz: query points, [B, S, 3]
94 | Return:
95 | group_idx: grouped points index, [B, S, nsample]
96 | """
97 | device = xyz.device
98 | B, N, C = xyz.shape
99 | _, S, _ = new_xyz.shape
100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
101 | sqrdists = square_distance(new_xyz, xyz)
102 | group_idx[sqrdists > radius ** 2] = N
103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
105 | mask = group_idx == N
106 | group_idx[mask] = group_first[mask]
107 | return group_idx
108 |
109 |
110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
111 | """
112 | Input:
113 | npoint:
114 | radius:
115 | nsample:
116 | xyz: input points position data, [B, N, 3]
117 | points: input points data, [B, N, D]
118 | Return:
119 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
120 | new_points: sampled points data, [B, npoint, nsample, 3+D]
121 | """
122 | B, N, C = xyz.shape
123 | S = npoint
124 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
125 | new_xyz = index_points(xyz, fps_idx)
126 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
127 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
128 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
129 |
130 | if points is not None:
131 | grouped_points = index_points(points, idx)
132 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
133 | else:
134 | new_points = grouped_xyz_norm
135 | if returnfps:
136 | return new_xyz, new_points, grouped_xyz, fps_idx
137 | else:
138 | return new_xyz, new_points
139 |
140 |
141 | def sample_and_group_all(xyz, points):
142 | """
143 | Input:
144 | xyz: input points position data, [B, N, 3]
145 | points: input points data, [B, N, D]
146 | Return:
147 | new_xyz: sampled points position data, [B, 1, 3]
148 | new_points: sampled points data, [B, 1, N, 3+D]
149 | """
150 | device = xyz.device
151 | B, N, C = xyz.shape
152 | new_xyz = torch.zeros(B, 1, C).to(device)
153 | grouped_xyz = xyz.view(B, 1, N, C)
154 | if points is not None:
155 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
156 | else:
157 | new_points = grouped_xyz
158 | return new_xyz, new_points
159 |
160 | def interpolating_points(xyz1, xyz2, points2):
161 | """
162 | Input:
163 | xyz1: input points position data, [B, C, N]
164 | xyz2: sampled input points position data, [B, C, S]
165 | points2: input points data, [B, D, S]
166 | Return:
167 | new_points: upsampled points data, [B, D', N]
168 | """
169 | xyz1 = xyz1.permute(0, 2, 1)
170 | xyz2 = xyz2.permute(0, 2, 1)
171 |
172 | points2 = points2.permute(0, 2, 1)
173 | B, N, C = xyz1.shape
174 | _, S, _ = xyz2.shape
175 |
176 | if S == 1:
177 | interpolated_points = points2.repeat(1, N, 1)
178 | else:
179 | dists = square_distance(xyz1, xyz2)
180 | dists, idx = dists.sort(dim=-1)
181 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
182 |
183 | dist_recip = 1.0 / (dists + 1e-8)
184 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
185 | weight = dist_recip / norm
186 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
187 |
188 | interpolated_points = interpolated_points.permute(0, 2, 1)
189 | return interpolated_points
--------------------------------------------------------------------------------
/m3dm_runner1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import os
4 | import numpy as np
5 |
6 | from feature_extractors import multiple_features_clean as multiple_features
7 |
8 | from dataset import get_data_loader
9 |
10 | class M3DM():
11 | def __init__(self, args):
12 | self.args = args
13 | self.image_size = args.img_size
14 | self.count = args.max_sample
15 | if args.method_name == 'DINO':
16 | self.methods = {
17 | "DINO": multiple_features.RGBFeatures(args),
18 | }
19 | elif args.method_name == 'Point_MAE':
20 | self.methods = {
21 | "Point_MAE": multiple_features.PointFeatures(args),
22 | }
23 | elif args.method_name == 'Fusion':
24 | self.methods = {
25 | "Fusion": multiple_features.FusionFeatures(args),
26 | }
27 | elif args.method_name == 'DINO+Point_MAE':
28 | self.methods = {
29 | "DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures(args),
30 | }
31 | elif args.method_name == 'DINO+Point_MAE+add':
32 | self.methods = {
33 | "DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures_add(args),
34 | }
35 | elif args.method_name == 'DINO+Point_MAE+Fusion':
36 | self.methods = {
37 | "DINO+Point_MAE+Fusion": multiple_features.TripleFeatures(args),
38 | }
39 | elif args.method_name == 'DINO+FPFH':
40 | self.methods = {
41 | "DINO+FPFH": multiple_features.DoubleRGBPointFeatures(args),
42 | }
43 | elif args.method_name == 'DINO+FPFH+Fusion':
44 | self.methods = {
45 | "DINO+FPFH+Fusion": multiple_features.TripleFeatures(args),
46 | }
47 | elif args.method_name == 'DINO+FPFH+Fusion+ps':
48 | self.methods = {
49 | "DINO+FPFH+Fusion+ps": multiple_features.TripleFeatures_PS(args),
50 | }
51 | elif args.method_name == 'DINO+Point_MAE+Fusion+ps':
52 | self.methods = {
53 | "DINO+Point_MAE+Fusion+ps": multiple_features.TripleFeatures_PS(args),
54 | }
55 | elif args.method_name == 'DINO+Point_MAE+ps':
56 | self.methods = {
57 | "DINO+Point_MAE+ps": multiple_features.DoubleRGB_PS_Features(args),
58 | }
59 | elif args.method_name == 'DINO+FPFH+ps':
60 | self.methods = {
61 | "DINO+FPFH+ps": multiple_features.DoubleRGB_PS_Features(args),
62 | }
63 | elif args.method_name == 'ours':
64 | self.methods = {
65 | "ours": multiple_features.PSRGBPointFeatures_add(args),
66 | }
67 | elif args.method_name == 'ours2':
68 | self.methods = {
69 | "ours2": multiple_features.TripleFeatures_PS2(args),
70 | }
71 | elif args.method_name == 'ours3':
72 | self.methods = {
73 | "ours3": multiple_features.FourFeatures(args),
74 | }
75 | elif args.method_name == 'ours_final':
76 | self.methods = {
77 | "ours_final": multiple_features.TripleFeatures_PS_EX(args),
78 | }
79 | elif args.method_name == 'ours_final1':
80 | self.methods = {
81 | "ours_final1": multiple_features.PSRGBPointFeatures_add_EX(args),
82 | }
83 | elif args.method_name == 'm3dm_uninterpolate':
84 | self.methods = {
85 | "m3dm_uninterpolate": multiple_features.DoubleRGBPointFeatures_uninter_full(args),
86 | }
87 |
88 | def fit(self, class_name):
89 | train_loader = get_data_loader("train", class_name=class_name, img_size=self.image_size, args=self.args)
90 |
91 | flag = 0
92 | for sample, _ in tqdm(train_loader, desc=f'Extracting train features for class {class_name}'):
93 | for method in self.methods.values():
94 | if self.args.save_feature:
95 | method.add_sample_to_mem_bank(sample, class_name=class_name)
96 | else:
97 | method.add_sample_to_mem_bank(sample)
98 | flag += 1
99 | if flag > self.count:
100 | flag = 0
101 | break
102 |
103 | for method_name, method in self.methods.items():
104 | print(f'\n\nRunning coreset for {method_name} on class {class_name}...')
105 | method.run_coreset()
106 |
107 |
108 | if self.args.memory_bank == 'multiple':
109 | flag = 0
110 | for sample, _ in tqdm(train_loader, desc=f'Running late fusion for {method_name} on class {class_name}..'):
111 | for method_name, method in self.methods.items():
112 | method.add_sample_to_late_fusion_mem_bank(sample)
113 | flag += 1
114 | if flag > self.count:
115 | flag = 0
116 | break
117 |
118 | for method_name, method in self.methods.items():
119 | print(f'\n\nTraining Dicision Layer Fusion for {method_name} on class {class_name}...')
120 | method.run_late_fusion()
121 |
122 | def evaluate(self, class_name):
123 | image_rocaucs = dict()
124 | pixel_rocaucs = dict()
125 | au_pros = dict()
126 | valid_dir = os.path.join(self.args.dataset_path, class_name, 'validation')
127 | defect_names = os.listdir(valid_dir)
128 | # print(defect_names)
129 | path_list = []
130 | for defect_name in defect_names:
131 | if defect_name == 'GOOD':
132 | continue
133 | test_loader = get_data_loader("validation", class_name=class_name, img_size=self.image_size, args=self.args, defect_name=defect_name)
134 | with torch.no_grad():
135 | print(class_name, defect_name)
136 | print(f'len of loader:{len(test_loader)}')
137 | for sample, mask, label, rgb_path in tqdm(test_loader, desc=f'Extracting test features for class {class_name}'):
138 | for method in self.methods.values():
139 | method.predict(sample, mask, label)
140 | path_list.append(rgb_path)
141 |
142 |
143 | for method_name, method in self.methods.items():
144 | method.calculate_metrics()
145 | image_rocauc = method.image_rocauc
146 | pixel_rocauc = method.pixel_rocauc
147 | au_pro = method.au_pro
148 |
149 | print(f"Debug - Raw values: Image ROCAUC: {image_rocauc}, Pixel ROCAUC: {pixel_rocauc}, AU-PRO: {au_pro}")
150 |
151 | if np.isnan(image_rocauc) or np.isnan(pixel_rocauc) or np.isnan(au_pro):
152 | print(f"Warning: NaN detected for {method_name}")
153 | # 可以在这里添加更多的调试信息
154 | # 可以在这里添加更多的调试信息
155 | image_rocaucs[f'{method_name}_{defect_name}'] = round(method.image_rocauc, 3)
156 | pixel_rocaucs[f'{method_name}_{defect_name}'] = round(method.pixel_rocauc, 3)
157 | au_pros[f'{method_name}_{defect_name}'] = round(method.au_pro, 3)
158 | print(
159 | f'Debug - Class: {class_name}, {method_name}, defect_name:{defect_name} Image ROCAUC: {method.image_rocauc:.3f}, {method_name} Pixel ROCAUC: {method.pixel_rocauc:.3f}, {method_name} AU-PRO: {method.au_pro:.3f}')
160 | if self.args.save_preds:
161 | method.save_prediction_maps('./pred_maps', path_list)
162 | return image_rocaucs, pixel_rocaucs, au_pros
163 |
--------------------------------------------------------------------------------
/utils/preprocessing1.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tifffile as tiff
4 | import open3d as o3d
5 | from pathlib import Path
6 | from PIL import Image
7 | import math
8 | import mvtec3d_util as mvt_util
9 | import argparse
10 |
11 |
12 | def get_edges_of_pc(organized_pc):
13 | unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2])
14 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0)
15 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0)
16 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0)
17 | unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:]
18 | return unorganized_edges_pc
19 |
20 | def get_plane_eq(unorganized_pc,ransac_n_pts=50):
21 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc))
22 | plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000)
23 | return plane_model
24 |
25 | def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005):
26 | # PREP PC
27 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean)
28 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
29 | clean_planeless_unorganized_pc = unorganized_pc.copy()
30 | planeless_unorganized_rgb = unorganized_rgb.copy()
31 |
32 | # REMOVE PLANE
33 | plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean))
34 | distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T))
35 | plane_indices = np.argwhere(distances < distance_threshold)
36 |
37 | planeless_unorganized_rgb[plane_indices] = 0
38 | clean_planeless_unorganized_pc[plane_indices] = 0
39 | clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0],
40 | organized_pc_clean.shape[1],
41 | organized_pc_clean.shape[2])
42 | planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0],
43 | organized_rgb.shape[1],
44 | organized_rgb.shape[2])
45 | return clean_planeless_organized_pc, planeless_organized_rgb
46 |
47 |
48 |
49 | def connected_components_cleaning(organized_pc, organized_rgb, image_path):
50 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc)
51 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
52 |
53 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
54 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :]
55 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros))
56 | labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False))
57 |
58 |
59 | unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True)
60 | max_label = labels.max()
61 | if max_label>0:
62 | print("##########################################################################")
63 | print(f"Point cloud file {image_path} has {max_label + 1} clusters")
64 | print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}")
65 | print("##########################################################################\n\n")
66 |
67 | largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)]
68 | outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id)
69 | outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array]
70 | unorganized_pc[outlier_indices_original_pc_array] = 0
71 | unorganized_rgb[outlier_indices_original_pc_array] = 0
72 | organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0],
73 | organized_pc.shape[1],
74 | organized_pc.shape[2])
75 | organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0],
76 | organized_rgb.shape[1],
77 | organized_rgb.shape[2])
78 | return organized_clustered_pc, organized_clustered_rgb
79 |
80 | def roundup_next_100(x):
81 | return int(math.ceil(x / 100.0)) * 100
82 |
83 | def pad_cropped_pc(cropped_pc, single_channel=False):
84 | orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1]
85 | round_orig_h = roundup_next_100(orig_h)
86 | round_orig_w = roundup_next_100(orig_w)
87 | large_side = max(round_orig_h, round_orig_w)
88 |
89 | a = (large_side - orig_h) // 2
90 | aa = large_side - a - orig_h
91 |
92 | b = (large_side - orig_w) // 2
93 | bb = large_side - b - orig_w
94 | if single_channel:
95 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant')
96 | else:
97 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant')
98 |
99 | def preprocess_pc(tiff_path):
100 | # READ FILES
101 | organized_pc = mvt_util.read_tiff_organized_pc(tiff_path)
102 | rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "png")
103 | gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png")
104 | organized_rgb = np.array(Image.open(rgb_path))
105 |
106 | organized_gt = None
107 | gt_exists = os.path.isfile(gt_path)
108 | if gt_exists:
109 | organized_gt = np.array(Image.open(gt_path))
110 |
111 | # REMOVE PLANE
112 | planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb)
113 |
114 |
115 | # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE)
116 | padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False)
117 | padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False)
118 | if gt_exists:
119 | padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True)
120 |
121 | organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path)
122 | # SAVE PREPROCESSED FILES
123 | tiff.imsave(tiff_path, organized_clustered_pc)
124 | Image.fromarray(organized_clustered_rgb).save(rgb_path)
125 | if gt_exists:
126 | Image.fromarray(padded_organized_gt).save(gt_path)
127 |
128 |
129 |
130 | if __name__ == '__main__':
131 | parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD')
132 | parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)')
133 | args = parser.parse_args()
134 |
135 |
136 | root_path = args.dataset_path
137 | paths = Path(root_path).rglob('*.tiff')
138 | print(f"Found {len(list(paths))} tiff files in {root_path}")
139 | processed_files = 0
140 | for path in Path(root_path).rglob('*.tiff'):
141 | preprocess_pc(path)
142 | processed_files += 1
143 | if processed_files % 50 == 0:
144 | print(f"Processed {processed_files} tiff files...")
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
--------------------------------------------------------------------------------
/fusion_pretrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import numpy as np
5 | import os
6 | import time
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torch.utils.tensorboard import SummaryWriter
12 | import torchvision.transforms as transforms
13 |
14 | import timm
15 |
16 | import timm.optim.optim_factory as optim_factory
17 |
18 | import utils.misc as misc
19 | from utils.misc import NativeScalerWithGradNormCount as NativeScaler
20 |
21 |
22 | from engine_fusion_pretrain import train_one_epoch
23 |
24 | import dataset
25 |
26 | import torch
27 | from models.feature_fusion import FeatureFusionBlock
28 |
29 |
30 |
31 |
32 | def get_args_parser():
33 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
34 | parser.add_argument('--batch_size', default=64, type=int,
35 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
36 | parser.add_argument('--epochs', default=3, type=int)
37 | parser.add_argument('--accum_iter', default=1, type=int,
38 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
39 |
40 | # Model parameters
41 |
42 | parser.add_argument('--input_size', default=224, type=int,
43 | help='images input size')
44 |
45 |
46 | # Optimizer parameters
47 | parser.add_argument('--clip_grad', type=float, default=None,
48 | help='gradient clipping norm (default: None)')
49 |
50 | parser.add_argument('--weight_decay', type=float, default=1.5e-6,
51 | help='weight decay (default: 0.05)')
52 |
53 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
54 | help='learning rate (absolute lr)')
55 | parser.add_argument('--blr', type=float, default=0.002, metavar='LR',
56 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
57 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
58 | help='lower lr bound for cyclic schedulers that hit 0')
59 |
60 | parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N',
61 | help='epochs to warmup LR')
62 |
63 | # Dataset parameters
64 | parser.add_argument('--data_path', default='', type=str,
65 | help='dataset path')
66 |
67 | parser.add_argument('--output_dir', default='./output_dir',
68 | help='path where to save, empty for no saving')
69 | parser.add_argument('--log_dir', default='./output_dir',
70 | help='path where to tensorboard log')
71 | parser.add_argument('--device', default='cuda',
72 | help='device to use for training / testing')
73 | parser.add_argument('--seed', default=0, type=int)
74 | parser.add_argument('--resume', default='',
75 | help='resume from checkpoint')
76 |
77 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
78 | help='start epoch')
79 | parser.add_argument('--num_workers', default=10, type=int)
80 | parser.add_argument('--pin_mem', action='store_true',
81 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
82 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
83 | parser.set_defaults(pin_mem=True)
84 |
85 | # distributed training parameters
86 | parser.add_argument('--world_size', default=1, type=int,
87 | help='number of distributed processes')
88 | parser.add_argument('--local_rank', default=-1, type=int)
89 | parser.add_argument('--dist_on_itp', action='store_true')
90 | parser.add_argument('--dist_url', default='env://',
91 | help='url used to set up distributed training')
92 |
93 | return parser
94 |
95 |
96 |
97 |
98 | def main(args):
99 | misc.init_distributed_mode(args)
100 | # args.gpu = 1
101 | # args.device = 'cuda:' + str(args.gpu)
102 |
103 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
104 | print("{}".format(args).replace(', ', ',\n'))
105 |
106 | device = torch.device(args.device)
107 |
108 | # fix the seed for reproducibility
109 | seed = args.seed + misc.get_rank()
110 | torch.manual_seed(seed)
111 | np.random.seed(seed)
112 |
113 | cudnn.benchmark = True
114 |
115 |
116 | dataset_train = dataset.PreTrainTensorDataset(args.data_path)
117 |
118 | print(dataset_train)
119 |
120 |
121 | if True: # args.distributed:
122 | num_tasks = misc.get_world_size()
123 | global_rank = misc.get_rank()
124 | sampler_train = torch.utils.data.DistributedSampler(
125 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
126 | )
127 | print("Sampler_train = %s" % str(sampler_train))
128 | else:
129 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
130 |
131 | if global_rank == 0 and args.log_dir is not None:
132 | os.makedirs(args.log_dir, exist_ok=True)
133 | log_writer = SummaryWriter(log_dir=args.log_dir)
134 | else:
135 | log_writer = None
136 |
137 | data_loader_train = torch.utils.data.DataLoader(
138 | dataset_train, sampler=sampler_train,
139 | batch_size=args.batch_size,
140 | num_workers=args.num_workers,
141 | pin_memory=args.pin_mem,
142 | drop_last=True,
143 | )
144 |
145 |
146 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
147 |
148 | if args.lr is None: # only base_lr is specified
149 | args.lr = args.blr * eff_batch_size / 256
150 |
151 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
152 | print("actual lr: %.2e" % args.lr)
153 |
154 | print("accumulate grad iterations: %d" % args.accum_iter)
155 | print("effective batch size: %d" % eff_batch_size)
156 |
157 | model = FeatureFusionBlock(1152, 768)
158 |
159 | model.to(device)
160 |
161 | if args.distributed:
162 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
163 | model_without_ddp = model.module
164 | print('gpu using:', args.gpu)
165 | # following timm: set wd as 0 for bias and norm layers
166 | optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr, betas=(0.9, 0.95))
167 | print(optimizer)
168 | print('clip_grad:', args.clip_grad)
169 | loss_scaler = NativeScaler()
170 |
171 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
172 |
173 | print(f"Start training for {args.epochs} epochs")
174 | start_time = time.time()
175 | for epoch in range(args.start_epoch, args.epochs):
176 | if args.distributed:
177 | data_loader_train.sampler.set_epoch(epoch)
178 | train_stats = train_one_epoch(
179 | model, data_loader_train,
180 | optimizer, device, epoch, loss_scaler,
181 | log_writer=log_writer,
182 | args=args
183 | )
184 | if args.output_dir and (epoch % 1 == 0 or epoch + 1 == args.epochs):
185 | misc.save_model(
186 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
187 | loss_scaler=loss_scaler, epoch=epoch)
188 |
189 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
190 | 'epoch': epoch,}
191 |
192 | if args.output_dir and misc.is_main_process():
193 | if log_writer is not None:
194 | log_writer.flush()
195 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
196 | f.write(json.dumps(log_stats) + "\n")
197 |
198 | total_time = time.time() - start_time
199 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
200 | print('Training time {}'.format(total_time_str))
201 |
202 |
203 | if __name__ == '__main__':
204 | args = get_args_parser()
205 | args = args.parse_args()
206 | if args.output_dir:
207 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
208 | main(args)
209 |
--------------------------------------------------------------------------------
/models/feature_fusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 | def initialize_weights(model):
6 | for layer in model.modules():
7 | if isinstance(layer, (nn.Conv2d, nn.Linear)):
8 | nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
9 | if layer.bias is not None:
10 | nn.init.constant_(layer.bias, 0)
11 |
12 | class Mlp(nn.Module):
13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
14 | super().__init__()
15 | out_features = out_features or in_features
16 | hidden_features = hidden_features or in_features
17 | self.fc1 = nn.Linear(in_features, hidden_features)
18 | self.act = act_layer()
19 | self.fc2 = nn.Linear(hidden_features, out_features)
20 | self.drop = nn.Dropout(drop)
21 |
22 | def forward(self, x):
23 | x = self.fc1(x)
24 | x = self.act(x)
25 | x = self.drop(x)
26 | x = self.fc2(x)
27 | x = self.drop(x)
28 | return x
29 | # class Mlp(nn.Module):
30 | # def __init__(self, in_features, hidden_features=None, out_features=None,
31 | # act_layer=nn.GELU, drop=0.2, use_ln=True):
32 | # super().__init__()
33 | # out_features = out_features or in_features
34 | # hidden_features = hidden_features or in_features
35 |
36 | # self.use_ln = use_ln
37 | # self.scale = min(1.0, (in_features ** -0.5))
38 |
39 | # # 第一个残差块
40 | # self.res_block1 = nn.Sequential(
41 | # nn.Linear(in_features, hidden_features),
42 | # nn.LayerNorm(hidden_features) if use_ln else nn.Identity(),
43 | # act_layer(),
44 | # nn.Dropout(drop),
45 | # nn.Linear(hidden_features, in_features), # 注意输出维度要和输入相同
46 | # nn.LayerNorm(in_features) if use_ln else nn.Identity(),
47 | # nn.Dropout(drop)
48 | # )
49 |
50 | # # 第二个残差块
51 | # self.res_block2 = nn.Sequential(
52 | # nn.Linear(in_features, hidden_features),
53 | # nn.LayerNorm(hidden_features) if use_ln else nn.Identity(),
54 | # act_layer(),
55 | # nn.Dropout(drop),
56 | # nn.Linear(hidden_features, out_features),
57 | # nn.LayerNorm(out_features) if use_ln else nn.Identity(),
58 | # nn.Dropout(drop)
59 | # )
60 |
61 | # # 如果输入输出维度不同,需要一个投影层
62 | # self.proj = None
63 | # if in_features != out_features:
64 | # self.proj = nn.Linear(in_features, out_features)
65 |
66 | # self._init_weights()
67 |
68 | # def _init_weights(self):
69 | # gain = 0.001
70 | # # 初始化第一个残差块
71 | # for m in self.res_block1.modules():
72 | # if isinstance(m, nn.Linear):
73 | # nn.init.xavier_uniform_(m.weight, gain=gain)
74 | # if m.bias is not None:
75 | # nn.init.zeros_(m.bias)
76 | # m.weight.data *= 0.1
77 |
78 | # # 初始化第二个残差块
79 | # for m in self.res_block2.modules():
80 | # if isinstance(m, nn.Linear):
81 | # nn.init.xavier_uniform_(m.weight, gain=gain)
82 | # if m.bias is not None:
83 | # nn.init.zeros_(m.bias)
84 | # m.weight.data *= 0.1
85 |
86 | # # 初始化投影层
87 | # if self.proj is not None:
88 | # nn.init.xavier_uniform_(self.proj.weight, gain=gain)
89 | # nn.init.zeros_(self.proj.bias)
90 | # self.proj.weight.data *= 0.1
91 |
92 | # def forward(self, x):
93 | # # 第一个残差块
94 | # identity = x
95 | # x = self.res_block1(x * self.scale)
96 | # x = x * self.scale + identity
97 |
98 | # # 第二个残差块
99 | # identity = x if self.proj is None else self.proj(x)
100 | # x = self.res_block2(x * self.scale)
101 | # x = x * self.scale + identity
102 |
103 | # return x
104 | # class Mlp(nn.Module):
105 | # def __init__(self, in_features, hidden_features=None, out_features=None,
106 | # act_layer=nn.GELU, drop=0.1, use_ln=True):
107 | # super().__init__()
108 | # out_features = out_features or in_features
109 | # hidden_features = hidden_features or in_features
110 |
111 | # self.use_ln = use_ln
112 |
113 | # # 第一层
114 | # self.fc1 = nn.Linear(in_features, hidden_features)
115 | # if use_ln:
116 | # self.ln1 = nn.LayerNorm(hidden_features)
117 | # self.act = act_layer()
118 | # self.drop1 = nn.Dropout(drop)
119 |
120 | # # 第二层
121 | # self.fc2 = nn.Linear(hidden_features, out_features)
122 | # if use_ln:
123 | # self.ln2 = nn.LayerNorm(out_features)
124 | # self.drop2 = nn.Dropout(drop)
125 |
126 | # # 初始化参数
127 | # self._init_weights()
128 |
129 | # def _init_weights(self):
130 | # # 使用较小的初始值来防止梯度爆炸
131 | # nn.init.xavier_uniform_(self.fc1.weight, gain=0.01)
132 | # nn.init.zeros_(self.fc1.bias)
133 | # nn.init.xavier_uniform_(self.fc2.weight, gain=0.01)
134 | # nn.init.zeros_(self.fc2.bias)
135 |
136 | # def forward(self, x):
137 | # x = self.fc1(x)
138 | # if self.use_ln:
139 | # x = self.ln1(x)
140 | # x = self.act(x)
141 | # x = self.drop1(x)
142 |
143 | # x = self.fc2(x)
144 | # if self.use_ln:
145 | # x = self.ln2(x)
146 | # x = self.drop2(x)
147 |
148 | # return x
149 | class FeatureFusionBlock(nn.Module):
150 | def __init__(self, xyz_dim, rgb_dim, mlp_ratio=4.):
151 | super().__init__()
152 |
153 | self.xyz_dim = xyz_dim
154 | self.rgb_dim = rgb_dim
155 |
156 | self.xyz_norm = nn.LayerNorm(xyz_dim)
157 | self.xyz_mlp = Mlp(in_features=xyz_dim, hidden_features=int(xyz_dim * mlp_ratio), act_layer=nn.GELU, drop=0.)
158 | initialize_weights(self.xyz_mlp)
159 |
160 | self.rgb_norm = nn.LayerNorm(rgb_dim)
161 | self.rgb_mlp = Mlp(in_features=rgb_dim, hidden_features=int(rgb_dim * mlp_ratio), act_layer=nn.GELU, drop=0.)
162 | initialize_weights(self.rgb_mlp)
163 |
164 | self.rgb_head = nn.Linear(rgb_dim, 256)
165 | self.xyz_head = nn.Linear(xyz_dim, 256)
166 | initialize_weights(self.rgb_head)
167 | initialize_weights(self.xyz_head)
168 |
169 |
170 | self.T = 1
171 |
172 | def feature_fusion(self, xyz_feature, rgb_feature):
173 |
174 | xyz_feature = self.xyz_mlp(self.xyz_norm(xyz_feature))
175 | rgb_feature = self.rgb_mlp(self.rgb_norm(rgb_feature))
176 |
177 | feature = torch.cat([xyz_feature, rgb_feature], dim=2)
178 |
179 | return feature
180 |
181 | def contrastive_loss(self, q, k):
182 | # normalize
183 | q = nn.functional.normalize(q, dim=1)
184 | k = nn.functional.normalize(k, dim=1)
185 | # gather all targets
186 | # Einstein sum is more intuitive
187 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
188 | N = logits.shape[0] # batch size per GPU
189 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
190 | return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)
191 |
192 | def reparameterize(self, mu, logvar):
193 | """
194 | Will a single z be enough ti compute the expectation
195 | for the loss??
196 | :param mu: (Tensor) Mean of the latent Gaussian
197 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
198 | :return:
199 | """
200 | std = torch.exp(0.5 * logvar)
201 | eps = torch.randn_like(std)
202 | return eps * std + mu
203 |
204 |
205 | def forward(self, xyz_feature, rgb_feature):
206 |
207 |
208 | feature = self.feature_fusion(xyz_feature, rgb_feature)
209 |
210 | feature_xyz = feature[:,:, :self.xyz_dim]
211 | feature_rgb = feature[:,:, self.xyz_dim:]
212 |
213 | q = self.rgb_head(feature_rgb.view(-1, feature_rgb.shape[2]))
214 | k = self.xyz_head(feature_xyz.view(-1, feature_xyz.shape[2]))
215 |
216 | xyz_feature = xyz_feature.view(-1, xyz_feature.shape[2])
217 | rgb_feature = rgb_feature.view(-1, rgb_feature.shape[2])
218 |
219 | patch_no_zeros_indices = torch.nonzero(torch.all(xyz_feature != 0, dim=1))
220 |
221 | loss = self.contrastive_loss(q[patch_no_zeros_indices,:].squeeze(), k[patch_no_zeros_indices,:].squeeze())
222 |
223 | return loss
224 |
225 |
--------------------------------------------------------------------------------
/utils/preprocess_eyecandies.py:
--------------------------------------------------------------------------------
1 | import os
2 | from shutil import copyfile
3 | import cv2
4 | import numpy as np
5 | import tifffile
6 | import yaml
7 | import imageio.v3 as iio
8 | import math
9 | import argparse
10 |
11 | # The same camera has been used for all the images
12 | FOCAL_LENGTH = 711.11
13 |
14 | def load_and_convert_depth(depth_img, info_depth):
15 | with open(info_depth) as f:
16 | data = yaml.safe_load(f)
17 | mind, maxd = data["normalization"]["min"], data["normalization"]["max"]
18 |
19 | dimg = iio.imread(depth_img)
20 | dimg = dimg.astype(np.float32)
21 | dimg = dimg / 65535.0 * (maxd - mind) + mind
22 | return dimg
23 |
24 | def depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length):
25 | # input depth map (in meters) --- cfr previous section
26 | depth_mt = load_and_convert_depth(depth_img, info_depth)
27 |
28 | # input pose
29 | pose = np.loadtxt(pose_txt)
30 |
31 | # camera intrinsics
32 | height, width = depth_mt.shape[:2]
33 | intrinsics_4x4 = np.array([
34 | [focal_length, 0, width / 2, 0],
35 | [0, focal_length, height / 2, 0],
36 | [0, 0, 1, 0],
37 | [0, 0, 0, 1]]
38 | )
39 |
40 | # build the camera projection matrix
41 | camera_proj = intrinsics_4x4 @ pose
42 |
43 | # build the (u, v, 1, 1/depth) vectors (non optimized version)
44 | camera_vectors = np.zeros((width * height, 4))
45 | count=0
46 | for j in range(height):
47 | for i in range(width):
48 | camera_vectors[count, :] = np.array([i, j, 1, 1/depth_mt[j, i]])
49 | count += 1
50 |
51 | # invert and apply to each 4-vector
52 | hom_3d_pts= np.linalg.inv(camera_proj) @ camera_vectors.T
53 | # print(hom_3d_pts.shape)
54 | # remove the homogeneous coordinate
55 | pcd = depth_mt.reshape(-1, 1) * hom_3d_pts.T
56 | return pcd[:, :3]
57 |
58 | def remove_point_cloud_background(pc):
59 |
60 | # The second dim is z
61 | dz = pc[256,1] - pc[-256,1]
62 | dy = pc[256,2] - pc[-256,2]
63 |
64 | norm = math.sqrt(dz**2 + dy**2)
65 | start_points = np.array([0, pc[-256, 1], pc[-256, 2]])
66 | cos_theta = dy / norm
67 | sin_theta = dz / norm
68 |
69 | # Transform and rotation
70 | rotation_matrix = np.array([[1, 0, 0], [0, cos_theta, -sin_theta],[0, sin_theta, cos_theta]])
71 | processed_pc = (rotation_matrix @ (pc - start_points).T).T
72 |
73 | # Remove background point
74 | for i in range(processed_pc.shape[0]):
75 | if processed_pc[i,1] > -0.02:
76 | processed_pc[i, :] = -start_points
77 | if processed_pc[i,2] > 1.8:
78 | processed_pc[i, :] = -start_points
79 | elif processed_pc[i,0] > 1 or processed_pc[i,0] < -1:
80 | processed_pc[i, :] = -start_points
81 |
82 | processed_pc = (rotation_matrix.T @ processed_pc.T).T + start_points
83 |
84 | index = [0, 2, 1]
85 | processed_pc = processed_pc[:,index]
86 | return processed_pc*[0.1, -0.1, 0.1]
87 |
88 |
89 | if __name__ == '__main__':
90 |
91 | parser = argparse.ArgumentParser(description='Process some integers.')
92 | parser.add_argument('--dataset_path', default='datasets/eyecandies', type=str, help="Original Eyecandies dataset path.")
93 | parser.add_argument('--target_dir', default='datasets/eyecandies_preprocessed', type=str, help="Processed Eyecandies dataset path")
94 | args = parser.parse_args()
95 |
96 | os.mkdir(args.target_dir)
97 | categories_list = os.listdir(args.dataset_path)
98 |
99 | for category_dir in categories_list:
100 | category_root_path = os.path.join(args.dataset_path, category_dir)
101 |
102 | category_train_path = os.path.join(category_root_path, '/train/data')
103 | category_test_path = os.path.join(category_root_path, '/test_public/data')
104 |
105 | category_target_path = os.path.join(args.target_dir, category_dir)
106 | os.mkdir(category_target_path)
107 |
108 | os.mkdir(os.path.join(category_target_path, 'train'))
109 | category_target_train_good_path = os.path.join(category_target_path, 'train/good')
110 | category_target_train_good_rgb_path = os.path.join(category_target_train_good_path, 'rgb')
111 | category_target_train_good_xyz_path = os.path.join(category_target_train_good_path, 'xyz')
112 | os.mkdir(category_target_train_good_path)
113 | os.mkdir(category_target_train_good_rgb_path)
114 | os.mkdir(category_target_train_good_xyz_path)
115 |
116 | os.mkdir(os.path.join(category_target_path, 'test'))
117 | category_target_test_good_path = os.path.join(category_target_path, 'test/good')
118 | category_target_test_good_rgb_path = os.path.join(category_target_test_good_path, 'rgb')
119 | category_target_test_good_xyz_path = os.path.join(category_target_test_good_path, 'xyz')
120 | category_target_test_good_gt_path = os.path.join(category_target_test_good_path, 'gt')
121 | os.mkdir(category_target_test_good_path)
122 | os.mkdir(category_target_test_good_rgb_path)
123 | os.mkdir(category_target_test_good_xyz_path)
124 | os.mkdir(category_target_test_good_gt_path)
125 | category_target_test_bad_path = os.path.join(category_target_path, 'test/bad')
126 | category_target_test_bad_rgb_path = os.path.join(category_target_test_bad_path, 'rgb')
127 | category_target_test_bad_xyz_path = os.path.join(category_target_test_bad_path, 'xyz')
128 | category_target_test_bad_gt_path = os.path.join(category_target_test_bad_path, 'gt')
129 | os.mkdir(category_target_test_bad_path)
130 | os.mkdir(category_target_test_bad_rgb_path)
131 | os.mkdir(category_target_test_bad_xyz_path)
132 | os.mkdir(category_target_test_bad_gt_path)
133 |
134 | category_train_files = os.listdir(category_train_path)
135 | num_train_files = len(category_train_files)//17
136 | for i in range(0, num_train_files):
137 | pc = depth_to_pointcloud(
138 | os.path.join(category_train_path,str(i).zfill(3)+'_depth.png'),
139 | os.path.join(category_train_path,str(i).zfill(3)+'_info_depth.yaml'),
140 | os.path.join(category_train_path,str(i).zfill(3)+'_pose.txt'),
141 | FOCAL_LENGTH,
142 | )
143 | pc = remove_point_cloud_background(pc)
144 | pc = pc.reshape(512,512,3)
145 | tifffile.imwrite(os.path.join(category_target_train_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)
146 | copyfile(os.path.join(category_train_path,str(i).zfill(3)+'_image_4.png'),os.path.join(category_target_train_good_rgb_path, str(i).zfill(3)+'.png'))
147 |
148 |
149 | category_test_files = os.listdir(category_test_path)
150 | num_test_files = len(category_test_files)//17
151 | for i in range(0, num_test_files):
152 | mask = cv2.imread(os.path.join(category_test_path,str(i).zfill(2)+'_mask.png'))
153 | if np.any(mask):
154 | pc = depth_to_pointcloud(
155 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),
156 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),
157 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),
158 | FOCAL_LENGTH,
159 | )
160 | pc = remove_point_cloud_background(pc)
161 | pc = pc.reshape(512,512,3)
162 | tifffile.imwrite(os.path.join(category_target_test_bad_xyz_path, str(i).zfill(3)+'.tiff'), pc)
163 | cv2.imwrite(os.path.join(category_target_test_bad_gt_path, str(i).zfill(3)+'.png'), mask)
164 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_bad_rgb_path, str(i).zfill(3)+'.png'))
165 | else:
166 | pc = depth_to_pointcloud(
167 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),
168 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),
169 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),
170 | FOCAL_LENGTH,
171 | )
172 | pc = remove_point_cloud_background(pc)
173 | pc = pc.reshape(512,512,3)
174 | tifffile.imwrite(os.path.join(category_target_test_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)
175 | cv2.imwrite(os.path.join(category_target_test_good_gt_path, str(i).zfill(3)+'.png'), mask)
176 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_good_rgb_path, str(i).zfill(3)+'.png'))
177 |
--------------------------------------------------------------------------------
/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tifffile as tiff
4 | import open3d as o3d
5 | from pathlib import Path
6 | from PIL import Image
7 | import math
8 | import mvtec3d_util as mvt_util
9 | import argparse
10 |
11 |
12 | def get_edges_of_pc(organized_pc):
13 | unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2])
14 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0)
15 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0)
16 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0)
17 | unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:]
18 | return unorganized_edges_pc
19 |
20 | def get_plane_eq(unorganized_pc,ransac_n_pts=50):
21 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc))
22 | plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000)
23 | return plane_model
24 |
25 | def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005):
26 | # PREP PC
27 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean)
28 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
29 | clean_planeless_unorganized_pc = unorganized_pc.copy()
30 | planeless_unorganized_rgb = unorganized_rgb.copy()
31 |
32 | # REMOVE PLANE
33 | plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean))
34 | distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T))
35 | plane_indices = np.argwhere(distances < distance_threshold)
36 |
37 | planeless_unorganized_rgb[plane_indices] = 0
38 | clean_planeless_unorganized_pc[plane_indices] = 0
39 | clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0],
40 | organized_pc_clean.shape[1],
41 | organized_pc_clean.shape[2])
42 | planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0],
43 | organized_rgb.shape[1],
44 | organized_rgb.shape[2])
45 | return clean_planeless_organized_pc, planeless_organized_rgb
46 |
47 |
48 |
49 | def connected_components_cleaning(organized_pc, organized_rgb, image_path):
50 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc)
51 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
52 |
53 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
54 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :]
55 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros))
56 | labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False))
57 |
58 |
59 | unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True)
60 | max_label = labels.max()
61 | if max_label>0:
62 | print("##########################################################################")
63 | print(f"Point cloud file {image_path} has {max_label + 1} clusters")
64 | print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}")
65 | print("##########################################################################\n\n")
66 |
67 | largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)]
68 | outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id)
69 | outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array]
70 | unorganized_pc[outlier_indices_original_pc_array] = 0
71 | unorganized_rgb[outlier_indices_original_pc_array] = 0
72 | organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0],
73 | organized_pc.shape[1],
74 | organized_pc.shape[2])
75 | organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0],
76 | organized_rgb.shape[1],
77 | organized_rgb.shape[2])
78 | return organized_clustered_pc, organized_clustered_rgb
79 |
80 | def roundup_next_100(x):
81 | return int(math.ceil(x / 100.0)) * 100
82 |
83 | def pad_cropped_pc(cropped_pc, single_channel=False):
84 | orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1]
85 | round_orig_h = roundup_next_100(orig_h)
86 | round_orig_w = roundup_next_100(orig_w)
87 | large_side = max(round_orig_h, round_orig_w)
88 |
89 | a = (large_side - orig_h) // 2
90 | aa = large_side - a - orig_h
91 |
92 | b = (large_side - orig_w) // 2
93 | bb = large_side - b - orig_w
94 | if single_channel:
95 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant')
96 | else:
97 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant')
98 |
99 | def preprocess_pc(tiff_path):
100 | # READ FILES
101 | organized_pc = mvt_util.read_tiff_organized_pc(tiff_path)
102 |
103 | # rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "jpg")
104 | # gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png")
105 |
106 | # 修改获取原图和gt的方式
107 | rgb_path = ''
108 | gt_path = ''
109 | # 获取 tiff 文件名
110 | tiff_file_name = os.path.basename(tiff_path)
111 |
112 | # 获取 tiff 文件所在的目录,并生成对应的 RGB 和 GT 目录
113 | tiff_dir = os.path.dirname(tiff_path)
114 | rgb_dir = tiff_dir.replace("xyz", "rgb") # 将 tiff 路径中的 xyz 替换为 rgb
115 | gt_dir = tiff_dir.replace("xyz", "gt") # 将 tiff 路径中的 xyz 替换为 gt
116 |
117 | # 获取 tiff 文件名中的匹配前缀(下划线前的部分)
118 | if "_" in tiff_file_name:
119 | tiff_prefix = tiff_file_name.split('_', 1)[0] # 获取 tiff 文件名中第一个下划线前的部分
120 | else:
121 | tiff_prefix = tiff_file_name # 没有下划线时,直接使用完整文件名作为前缀
122 |
123 | # 遍历 RGB 目录,找到与 tiff 前缀匹配的文件
124 | for rgb_file in os.listdir(rgb_dir):
125 | if "_" in rgb_file:
126 | rgb_prefix = rgb_file.split('_', 1)[0] # 获取 RGB 文件名中第一个下划线前的部分
127 | else:
128 | rgb_prefix = rgb_file # 没有下划线时,直接使用文件名
129 |
130 | # 如果前缀匹配,则生成新的 RGB 路径,并使用原始 RGB 文件名
131 | if tiff_prefix == rgb_prefix:
132 | rgb_path = os.path.join(rgb_dir, rgb_file) # 保留原始 RGB 文件名
133 | print(rgb_path)
134 |
135 | # 遍历 GT 目录,找到与 tiff 前缀匹配的文件
136 | for gt_file in os.listdir(gt_dir):
137 | if "_" in gt_file:
138 | gt_prefix = gt_file.split('_', 1)[0] # 获取 GT 文件名中第一个下划线前的部分
139 | else:
140 | gt_prefix = gt_file.split('.', 1)[0] # 没有下划线时,直接使用文件名
141 |
142 | # 如果前缀匹配,则生成新的 GT 路径,并使用原始 GT 文件名
143 | if tiff_prefix == gt_prefix:
144 | gt_path = os.path.join(gt_dir, gt_file) # 保留原始 GT 文件名
145 | print(gt_path)
146 |
147 |
148 |
149 | organized_rgb = np.array(Image.open(rgb_path))
150 |
151 | organized_gt = None
152 | gt_exists = os.path.isfile(gt_path)
153 | if gt_exists:
154 | organized_gt = np.array(Image.open(gt_path))
155 |
156 | # REMOVE PLANE
157 | planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb)
158 |
159 |
160 | # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE)
161 | padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False)
162 | padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False)
163 | if gt_exists:
164 | padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True)
165 |
166 | organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path)
167 | # SAVE PREPROCESSED FILES
168 | tiff.imsave(tiff_path, organized_clustered_pc)
169 | Image.fromarray(organized_clustered_rgb).save(rgb_path)
170 | if gt_exists:
171 | Image.fromarray(padded_organized_gt).save(gt_path)
172 |
173 |
174 |
175 | if __name__ == '__main__':
176 | parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD')
177 | parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)')
178 | args = parser.parse_args()
179 |
180 |
181 | root_path = args.dataset_path
182 | paths = Path(root_path).rglob('*.tiff')
183 | print(f"Found {len(list(paths))} tiff files in {root_path}")
184 | processed_files = 0
185 | for path in Path(root_path).rglob('*.tiff'):
186 | preprocess_pc(path)
187 | processed_files += 1
188 | if processed_files % 50 == 0:
189 | print(f"Processed {processed_files} tiff files...")
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from m3dm_runner import M3DM
3 | from dataset import eyecandies_classes, mvtec3d_classes, test_3d_classes
4 | import os
5 | import pandas as pd
6 | import torch
7 | #os.environ['CUDA_VISIBLE_DEVICES'] = '7'
8 |
9 |
10 | def run_3d_ads(args):
11 | if args.dataset_type=='eyecandies':
12 | classes = eyecandies_classes()
13 | elif args.dataset_type=='mvtec3d':
14 | classes = mvtec3d_classes()
15 | elif args.dataset_type=='test_3d':
16 | classes = test_3d_classes()
17 |
18 | METHOD_NAMES = [args.method_name]
19 |
20 | image_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])
21 | pixel_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])
22 | au_pros_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])
23 | for cls in classes:
24 | model = M3DM(args)
25 | model.fit(cls)
26 | image_rocaucs, pixel_rocaucs, au_pros = model.evaluate(cls)
27 |
28 | image_rocaucs_df[cls.title()] = image_rocaucs_df['Method'].map(image_rocaucs)
29 | pixel_rocaucs_df[cls.title()] = pixel_rocaucs_df['Method'].map(pixel_rocaucs)
30 | au_pros_df[cls.title()] = au_pros_df['Method'].map(au_pros)
31 |
32 | print(f"\nFinished running on class {cls}")
33 | print("################################################################################\n\n")
34 |
35 | # image_rocaucs_df['Mean'] = round(image_rocaucs_df.iloc[:, 1:].mean(axis=1),3)
36 | # pixel_rocaucs_df['Mean'] = round(pixel_rocaucs_df.iloc[:, 1:].mean(axis=1),3)
37 | # au_pros_df['Mean'] = round(au_pros_df.iloc[:, 1:].mean(axis=1),3)
38 |
39 | # print("\n\n################################################################################")
40 | # print("############################# Image ROCAUC Results #############################")
41 | # print("################################################################################\n")
42 | # print(image_rocaucs_df.to_markdown(index=False))
43 |
44 | # print("\n\n################################################################################")
45 | # print("############################# Pixel ROCAUC Results #############################")
46 | # print("################################################################################\n")
47 | # print(pixel_rocaucs_df.to_markdown(index=False))
48 |
49 | # print("\n\n##########################################################################")
50 | # print("############################# AU PRO Results #############################")
51 | # print("##########################################################################\n")
52 | # print(au_pros_df.to_markdown(index=False))
53 | print("\n\n################################################################################")
54 | print("############################# Image ROCAUC Results #############################")
55 | print("################################################################################\n")
56 | print(image_rocaucs_df.to_string(index=False))
57 |
58 | print("\n\n################################################################################")
59 | print("############################# Pixel ROCAUC Results #############################")
60 | print("################################################################################\n")
61 | print(pixel_rocaucs_df.to_string(index=False))
62 |
63 | print("\n\n##########################################################################")
64 | print("############################# AU PRO Results #############################")
65 | print("##########################################################################\n")
66 | print(au_pros_df.to_string(index=False))
67 |
68 |
69 |
70 | # with open("results/image_rocauc_results.md", "a") as tf:
71 | # tf.write(image_rocaucs_df.to_markdown(index=False))
72 | # with open("results/pixel_rocauc_results.md", "a") as tf:
73 | # tf.write(pixel_rocaucs_df.to_markdown(index=False))
74 | # with open("results/aupro_results.md", "a") as tf:
75 | # tf.write(au_pros_df.to_markdown(index=False))
76 | # with open("results/image_rocauc_results.txt", "a") as tf:
77 | # tf.write(image_rocaucs_df.to_string(index=False))
78 | # with open("results/pixel_rocauc_results.txt", "a") as tf:
79 | # tf.write(pixel_rocaucs_df.to_string(index=False))
80 | # with open("results/aupro_results.txt", "a") as tf:
81 | # tf.write(au_pros_df.to_string(index=False))
82 |
83 |
84 | if __name__ == '__main__':
85 | #torch.cuda.set_device(7)
86 | parser = argparse.ArgumentParser(description='Process some integers.')
87 |
88 | parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str,
89 | choices=['DINO','Point_MAE','Fusion','DINO+Point_MAE','DINO+Point_MAE+Fusion','DINO+Point_MAE+add','DINO+FPFH','DINO+FPFH+Fusion',
90 | 'DINO+FPFH+Fusion+ps','DINO+Point_MAE+Fusion+ps','DINO+Point_MAE+ps','DINO+FPFH+ps','ours','ours2','ours3','ours_final','ours_final1'
91 | ,'ours_final1_VS', 'm3dm_uninterpolate', 'ours_PS','m3dm_VS','Point_MAE_VS','DINO_VS','patchcore_VS','PS_VS','OURS_EX_VS','NEW_OURS_EX_VS','patchcore_uninterpolate','shape'],
92 | help='Anomaly detection modal name.')
93 | parser.add_argument('--max_sample', default=400, type=int,
94 | help='Max sample number.')
95 | parser.add_argument('--memory_bank', default='multiple', type=str,
96 | choices=["multiple", "single"],
97 | help='memory bank mode: "multiple", "single".')
98 | parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str,
99 | choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'],
100 | help='Timm checkpoints name of RGB backbone.')
101 | parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert','FPFH'],
102 | help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert, FPFH].')
103 | parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str,
104 | help='Checkpoints for fusion module.')
105 | parser.add_argument('--save_feature', default=False, action='store_true',
106 | help='Save feature for training fusion block.')
107 | parser.add_argument('--use_uff', default=False, action='store_true',
108 | help='Use UFF module.')
109 | parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str,
110 | help='Save feature for training fusion block.')
111 | parser.add_argument('--save_preds', default=False, action='store_true',
112 | help='Save predicts results.')
113 | parser.add_argument('--group_size', default=128, type=int,
114 | help='Point group size of Point Transformer.')
115 | parser.add_argument('--num_group', default=1024, type=int,
116 | help='Point groups number of Point Transformer.')
117 | parser.add_argument('--random_state', default=None, type=int,
118 | help='random_state for random project')
119 | parser.add_argument('--dataset_type', default='test_3d', type=str, choices=['mvtec3d', 'eyecandies','test_3d'],
120 | help='Dataset type for training or testing')
121 | parser.add_argument('--dataset_path', default='/fuxi_team14_intern/D3', type=str,
122 | help='Dataset store path')
123 | parser.add_argument('--img_size', default=224, type=int,
124 | help='Images size for model')
125 | parser.add_argument('--xyz_s_lambda', default=1.0, type=float,
126 | help='xyz_s_lambda')
127 | parser.add_argument('--xyz_smap_lambda', default=1.0, type=float,
128 | help='xyz_smap_lambda')
129 | parser.add_argument('--rgb_s_lambda', default=0.1, type=float,
130 | help='rgb_s_lambda')
131 | parser.add_argument('--rgb_smap_lambda', default=0.1, type=float,
132 | help='rgb_smap_lambda')
133 | parser.add_argument('--ps_s_lambda', default=0.1, type=float,
134 | help='rgb_s_lambda')
135 | parser.add_argument('--ps_smap_lambda', default=0.1, type=float,
136 | help='rgb_smap_lambda')
137 | parser.add_argument('--fusion_s_lambda', default=1.0, type=float,
138 | help='fusion_s_lambda')
139 | parser.add_argument('--fusion_smap_lambda', default=1.0, type=float,
140 | help='fusion_smap_lambda')
141 | parser.add_argument('--coreset_eps', default=0.9, type=float,
142 | help='eps for sparse project')
143 | parser.add_argument('--f_coreset', default=0.1, type=float,
144 | help='eps for sparse project')
145 | parser.add_argument('--asy_memory_bank', default=None, type=int,
146 | help='build an asymmetric memory bank for point clouds')
147 | parser.add_argument('--ocsvm_nu', default=0.5, type=float,
148 | help='ocsvm nu')
149 | parser.add_argument('--ocsvm_maxiter', default=1000, type=int,
150 | help='ocsvm maxiter')
151 | parser.add_argument('--rm_zero_for_project', default=False, action='store_true',
152 | help='Save predicts results.')
153 | parser.add_argument('--downsampling', default=1, type=int,
154 | help='downsampling factor')
155 | parser.add_argument('--rotate_angle', default=0, type=float,
156 | help='rotate angle')
157 | parser.add_argument('--small', default=False, action='store_true',
158 | help='small predict')
159 | parser.add_argument('--split', default=False, action='store_true',
160 | help='split_predict')
161 | parser.add_argument('--ex', default=1, type=int,
162 | help='ex_factor')
163 |
164 | args = parser.parse_args()
165 | run_3d_ads(args)
166 |
--------------------------------------------------------------------------------
/utils/au_pro_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Code based on the official MVTec 3D-AD evaluation code found at
3 | https://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz
4 |
5 | Utility functions that compute a PRO curve and its definite integral, given
6 | pairs of anomaly and ground truth maps.
7 |
8 | The PRO curve can also be integrated up to a constant integration limit.
9 | """
10 | import numpy as np
11 | from scipy.ndimage.measurements import label
12 | from bisect import bisect
13 |
14 |
15 | class GroundTruthComponent:
16 | """
17 | Stores sorted anomaly scores of a single ground truth component.
18 | Used to efficiently compute the region overlap for many increasing thresholds.
19 | """
20 |
21 | def __init__(self, anomaly_scores):
22 | """
23 | Initialize the module.
24 |
25 | Args:
26 | anomaly_scores: List of all anomaly scores within the ground truth
27 | component as numpy array.
28 | """
29 | # Keep a sorted list of all anomaly scores within the component.
30 | self.anomaly_scores = anomaly_scores.copy()
31 | self.anomaly_scores.sort()
32 |
33 | # Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels.
34 | self.index = 0
35 |
36 | # The last evaluated threshold.
37 | self.last_threshold = None
38 |
39 | def compute_overlap(self, threshold):
40 | """
41 | Compute the region overlap for a specific threshold.
42 | Thresholds must be passed in increasing order.
43 |
44 | Args:
45 | threshold: Threshold to compute the region overlap.
46 |
47 | Returns:
48 | Region overlap for the specified threshold.
49 | """
50 | if self.last_threshold is not None:
51 | assert self.last_threshold <= threshold
52 |
53 | # Increase the index until it points to an anomaly score that is just above the specified threshold.
54 | while (self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold):
55 | self.index += 1
56 |
57 | # Compute the fraction of component pixels that are correctly segmented as anomalous.
58 | return 1.0 - self.index / len(self.anomaly_scores)
59 |
60 |
61 | def trapezoid(x, y, x_max=None):
62 | """
63 | This function calculates the definit integral of a curve given by x- and corresponding y-values.
64 | In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by
65 | setting a value x_max.
66 |
67 | Points that do not have a finite x or y value will be ignored with a warning.
68 |
69 | Args:
70 | x: Samples from the domain of the function to integrate need to be sorted in ascending order. May contain
71 | the same value multiple times. In that case, the order of the corresponding y values will affect the
72 | integration with the trapezoidal rule.
73 | y: Values of the function corresponding to x values.
74 | x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its
75 | neighbors. Must not lie outside of the range of x.
76 |
77 | Returns:
78 | Area under the curve.
79 | """
80 |
81 | x = np.array(x)
82 | y = np.array(y)
83 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y))
84 | if not finite_mask.all():
85 | print(
86 | """WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.""")
87 | x = x[finite_mask]
88 | y = y[finite_mask]
89 |
90 | # Introduce a correction term if max_x is not an element of x.
91 | correction = 0.
92 | if x_max is not None:
93 | if x_max not in x:
94 | # Get the insertion index that would keep x sorted after np.insert(x, ins, x_max).
95 | ins = bisect(x, x_max)
96 | # x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x).
97 | assert 0 < ins < len(x)
98 |
99 | # Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not
100 | # know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1].
101 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1]))
102 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1])
103 |
104 | # Cut off at x_max.
105 | mask = x <= x_max
106 | x = x[mask]
107 | y = y[mask]
108 |
109 | # Return area under the curve using the trapezoidal rule.
110 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction
111 |
112 |
113 | def collect_anomaly_scores(anomaly_maps, ground_truth_maps):
114 | """
115 | Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false
116 | positive pixel from anomaly maps.
117 |
118 | Args:
119 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
120 |
121 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
122 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
123 | an anomaly.
124 |
125 | Returns:
126 | ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each
127 | component, a sorted list of its anomaly scores is stored.
128 |
129 | anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list
130 | can be used to quickly select thresholds that fix a certain false positive rate.
131 | """
132 | # Make sure an anomaly map is present for each ground truth map.
133 | assert len(anomaly_maps) == len(ground_truth_maps)
134 |
135 | # Initialize ground truth components and scores of potential fp pixels.
136 | ground_truth_components = []
137 | anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size)
138 |
139 | # Structuring element for computing connected components.
140 | structure = np.ones((3, 3), dtype=int)
141 |
142 | # Collect anomaly scores within each ground truth region and for all potential fp pixels.
143 | ok_index = 0
144 | for gt_map, prediction in zip(ground_truth_maps, anomaly_maps):
145 |
146 | # Compute the connected components in the ground truth map.
147 | labeled, n_components = label(gt_map, structure)
148 |
149 | # Store all potential fp scores.
150 | num_ok_pixels = len(prediction[labeled == 0])
151 | anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy()
152 | ok_index += num_ok_pixels
153 |
154 | # Fetch anomaly scores within each GT component.
155 | for k in range(n_components):
156 | component_scores = prediction[labeled == (k + 1)]
157 | ground_truth_components.append(GroundTruthComponent(component_scores))
158 |
159 | # Sort all potential false positive scores.
160 | anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index)
161 | anomaly_scores_ok_pixels.sort()
162 |
163 | return ground_truth_components, anomaly_scores_ok_pixels
164 |
165 |
166 | def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds):
167 | """
168 | Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground
169 | truth maps. The number of interpolation points can be set manually.
170 |
171 | Args:
172 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
173 |
174 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
175 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
176 | an anomaly.
177 |
178 | num_thresholds: Number of thresholds to compute the PRO curve.
179 | Returns:
180 | fprs: List of false positive rates.
181 | pros: List of correspoding PRO values.
182 | """
183 | # Fetch sorted anomaly scores.
184 | ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps)
185 |
186 | # Select equidistant thresholds.
187 | threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int)
188 |
189 | fprs = [1.0]
190 | pros = [1.0]
191 | for pos in threshold_positions:
192 | threshold = anomaly_scores_ok_pixels[pos]
193 |
194 | # Compute the false positive rate for this threshold.
195 | fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels)
196 |
197 | # Compute the PRO value for this threshold.
198 | pro = 0.0
199 | for component in ground_truth_components:
200 | pro += component.compute_overlap(threshold)
201 | pro /= len(ground_truth_components)
202 |
203 | fprs.append(fpr)
204 | pros.append(pro)
205 |
206 | # Return (FPR/PRO) pairs in increasing FPR order.
207 | fprs = fprs[::-1]
208 | pros = pros[::-1]
209 |
210 | return fprs, pros
211 |
212 |
213 | def calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresholds=100):
214 | """
215 | Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images.
216 | Args:
217 | gts: List of tensors that contain the ground truth images for a single dataset object.
218 | predictions: List of tensors containing anomaly images for each ground truth image.
219 | integration_limit: Integration limit to use when computing the area under the PRO curve.
220 | num_thresholds: Number of thresholds to use to sample the area under the PRO curve.
221 |
222 | Returns:
223 | au_pro: Area under the PRO curve computed up to the given integration limit.
224 | pro_curve: PRO curve values for localization (fpr,pro).
225 | """
226 | # Compute the PRO curve.
227 | pro_curve = compute_pro(anomaly_maps=predictions, ground_truth_maps=gts, num_thresholds=num_thresholds)
228 |
229 | # Compute the area under the PRO curve.
230 | au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max=integration_limit)
231 | au_pro /= integration_limit
232 |
233 | # Return the evaluation metrics.
234 | return au_pro, pro_curve
235 |
--------------------------------------------------------------------------------
/feature_extractors/features.py:
--------------------------------------------------------------------------------
1 | """
2 | PatchCore logic based on https://github.com/rvorias/ind_knn_ad
3 | """
4 | import torch
5 | import numpy as np
6 | import os
7 | from tqdm import tqdm
8 | from matplotlib import pyplot as plt
9 |
10 | from sklearn import random_projection
11 | from sklearn import linear_model
12 | from sklearn.svm import OneClassSVM
13 | from sklearn.ensemble import IsolationForest
14 | from sklearn.metrics import roc_auc_score
15 |
16 | from timm.models.layers import DropPath, trunc_normal_
17 | from pointnet2_ops import pointnet2_utils
18 | from knn_cuda import KNN
19 |
20 | from utils.utils import KNNGaussianBlur
21 | from utils.utils import set_seeds
22 | from utils.au_pro_util import calculate_au_pro
23 |
24 | from models.pointnet2_utils import interpolating_points
25 | from models.feature_fusion import FeatureFusionBlock
26 | from models.models import Model
27 |
28 | class Features(torch.nn.Module):
29 |
30 | def __init__(self, args, image_size=224, f_coreset=0.1, coreset_eps=0.9):
31 | super().__init__()
32 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
33 | self.deep_feature_extractor = Model(
34 | device=self.device,
35 | rgb_backbone_name=args.rgb_backbone_name,
36 | xyz_backbone_name=args.xyz_backbone_name,
37 | group_size = args.group_size,
38 | num_group=args.num_group
39 | )
40 | self.deep_feature_extractor.to(self.device)
41 |
42 | self.args = args
43 | self.image_size = args.img_size
44 | self.f_coreset = args.f_coreset
45 | self.coreset_eps = args.coreset_eps
46 | self.ex_factor =args.ex
47 | self.blur = KNNGaussianBlur(4)
48 | self.n_reweight = 3
49 | set_seeds(0)
50 | self.patch_xyz_lib = []
51 | self.patch_rgb_lib = []
52 | self.patch_ps_lib = []
53 | self.patch_fusion_lib = []
54 | self.patch_lib = []
55 | self.random_state = args.random_state
56 |
57 | self.xyz_dim = 0
58 | self.rgb_dim = 0
59 |
60 | self.xyz_mean=0
61 | self.xyz_std=0
62 | self.rgb_mean=0
63 | self.rgb_std=0
64 | self.fusion_mean=0
65 | self.fusion_std=0
66 | self.ps_mean=0
67 | self.ps_std=0
68 |
69 | self.average = torch.nn.AvgPool2d(3, stride=1) # torch.nn.AvgPool2d(1, stride=1) #
70 | self.resize = torch.nn.AdaptiveAvgPool2d((56, 56))
71 | self.resize2 = torch.nn.AdaptiveAvgPool2d((56, 56))
72 |
73 | self.image_preds = list()
74 | self.image_labels = list()
75 | self.pixel_preds = list()
76 | self.pixel_labels = list()
77 | self.gts = []
78 | self.predictions = []
79 | self.image_rocauc = 0
80 | self.pixel_rocauc = 0
81 | self.au_pro = 0
82 | self.ins_id = 0
83 | self.rgb_layernorm = torch.nn.LayerNorm(768, elementwise_affine=False)
84 | if self.args.use_uff:
85 | if self.args.xyz_backbone_name == 'FPFH' :
86 | self.fusion = FeatureFusionBlock(33, 768, mlp_ratio=4.)
87 | else :
88 | self.fusion = FeatureFusionBlock(1152, 768, mlp_ratio=4.)
89 | #应为没找到预训练的所以先注释掉
90 | ckpt = torch.load('/fuxi_team14_intern/m3dm/checkpoints/uff_pretrain.pth')['model']
91 | #ckpt = torch.load('/workspace/data3/code/M3DM-main/checkpoints5/checkpoint-99.pth')['model']
92 | #ckpt = torch.load('/workspace/data3/code/M3DM-main/checkpoints6/checkpoint-400.pth')['model']
93 |
94 | incompatible = self.fusion.load_state_dict(ckpt, strict=False)
95 |
96 | print('[Fusion Block]', incompatible)
97 |
98 | self.detect_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter)
99 | self.seg_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter)
100 |
101 | self.s_lib = []
102 | self.s_map_lib = []
103 |
104 | def __call__(self, rgb, xyz):
105 | # Extract the desired feature maps using the backbone model.
106 | rgb = rgb.to(self.device)
107 | xyz = xyz.to(self.device)
108 | with torch.no_grad():
109 | rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz)
110 |
111 | interpolate = True
112 | if interpolate:
113 | interpolated_feature_maps = interpolating_points(xyz, center.permute(0,2,1), xyz_feature_maps).to("cpu")
114 |
115 | xyz_feature_maps = [fmap.to("cpu") for fmap in [xyz_feature_maps]]
116 | rgb_feature_maps = [fmap.to("cpu") for fmap in [rgb_feature_maps]]
117 |
118 | if interpolate:
119 | return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx, interpolated_feature_maps
120 | else:
121 | return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx
122 |
123 | def add_sample_to_mem_bank(self, sample):
124 | raise NotImplementedError
125 |
126 | def predict(self, sample, mask, label):
127 | raise NotImplementedError
128 |
129 | def add_sample_to_late_fusion_mem_bank(self, sample):
130 | raise NotImplementedError
131 |
132 | def interpolate_points(self, rgb, xyz):
133 | with torch.no_grad():
134 | rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz)
135 | return xyz_feature_maps, center, xyz
136 |
137 | def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):
138 | raise NotImplementedError
139 |
140 | def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):
141 | raise NotImplementedError
142 |
143 | def run_coreset(self):
144 | raise NotImplementedError
145 |
146 | def calculate_metrics(self):
147 | self.image_preds = np.stack(self.image_preds)
148 | self.image_labels = np.stack(self.image_labels)
149 | self.pixel_preds = np.array(self.pixel_preds)
150 | # 在计算metrics之前检查和处理NaN值
151 | if np.any(np.isnan(self.image_preds)):
152 | print(f"Warning: Found {np.sum(np.isnan(self.image_preds))} NaN values in image_preds")
153 | # 将NaN替换为0或该特征的平均值
154 | self.image_preds = np.nan_to_num(self.image_preds, nan=0.0)
155 |
156 | if np.any(np.isnan(self.pixel_preds)):
157 | print(f"Warning: Found {np.sum(np.isnan(self.pixel_preds))} NaN values in pixel_preds")
158 | self.pixel_preds = np.nan_to_num(self.pixel_preds, nan=0.0)
159 |
160 | try:
161 | self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds)
162 | except Exception as e:
163 | print(f"Error calculating image_rocauc: {e}")
164 | print(f"image_labels shape: {self.image_labels.shape}")
165 | print(f"image_preds shape: {self.image_preds.shape}")
166 | print(f"image_labels unique values: {np.unique(self.image_labels)}")
167 | print(f"image_preds range: [{np.min(self.image_preds)}, {np.max(self.image_preds)}]")
168 | self.image_rocauc = 0.0
169 | #self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds)
170 | self.pixel_rocauc = roc_auc_score(self.pixel_labels, self.pixel_preds)
171 | self.au_pro, _ = calculate_au_pro(self.gts, self.predictions)
172 |
173 | self.image_preds = list()
174 | self.image_labels = list()
175 | self.pixel_preds = list()
176 | self.pixel_labels = list()
177 | self.gts = []
178 | self.predictions = []
179 |
180 | def save_prediction_maps(self, output_path, rgb_path, save_num=5):
181 | for i in range(max(save_num, len(self.predictions))):
182 | # fig = plt.figure(dpi=300)
183 | fig = plt.figure()
184 |
185 | ax3 = fig.add_subplot(1,3,1)
186 | gt = plt.imread(rgb_path[i][0])
187 | ax3.imshow(gt)
188 |
189 | ax2 = fig.add_subplot(1,3,2)
190 | im2 = ax2.imshow(self.gts[i], cmap=plt.cm.gray)
191 |
192 | ax = fig.add_subplot(1,3,3)
193 | im = ax.imshow(self.predictions[i], cmap=plt.cm.jet)
194 |
195 | class_dir = os.path.join(output_path, rgb_path[i][0].split('/')[-5])
196 | if not os.path.exists(class_dir):
197 | os.mkdir(class_dir)
198 |
199 | ad_dir = os.path.join(class_dir, rgb_path[i][0].split('/')[-3])
200 | if not os.path.exists(ad_dir):
201 | os.mkdir(ad_dir)
202 |
203 | plt.savefig(os.path.join(ad_dir, str(self.image_preds[i]) + '_pred_' + rgb_path[i][0].split('/')[-1] + '.jpg'))
204 |
205 | def run_late_fusion(self):
206 | self.s_lib = torch.cat(self.s_lib, 0)
207 | self.s_map_lib = torch.cat(self.s_map_lib, 0)
208 |
209 | # # 检查并处理缺失值
210 | # if torch.isnan(self.s_lib).any():
211 | # print("警告: self.s_lib 中包含缺失值,正在进行处理。")
212 | # self.s_lib = self.s_lib[~torch.isnan(self.s_lib).any(dim=1)] # 删除包含缺失值的样本
213 |
214 | # if torch.isnan(self.s_map_lib).any():
215 | # print("警告: self.s_map_lib 中包含缺失值,正在进行处理。")
216 | # self.s_map_lib = self.s_map_lib[~torch.isnan(self.s_map_lib).any(dim=1)] # 删除包含缺失值的样本
217 | self.detect_fuser.fit(self.s_lib)
218 | self.seg_fuser.fit(self.s_map_lib)
219 |
220 | def get_coreset_idx_randomp(self, z_lib, n=1000, eps=0.90, float16=True, force_cpu=False):
221 |
222 | print(f" Fitting random projections. Start dim = {z_lib.shape}.")
223 | try:
224 | transformer = random_projection.SparseRandomProjection(eps=eps, random_state=self.random_state)
225 | z_lib = torch.tensor(transformer.fit_transform(z_lib))
226 |
227 | print(f" DONE. Transformed dim = {z_lib.shape}.")
228 | except ValueError:
229 | print(" Error: could not project vectors. Please increase `eps`.")
230 |
231 | select_idx = 0
232 | last_item = z_lib[select_idx:select_idx + 1]
233 | coreset_idx = [torch.tensor(select_idx)]
234 | min_distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True)
235 |
236 | if float16:
237 | last_item = last_item.half()
238 | z_lib = z_lib.half()
239 | min_distances = min_distances.half()
240 | if torch.cuda.is_available() and not force_cpu:
241 | last_item = last_item.to("cuda")
242 | z_lib = z_lib.to("cuda")
243 | min_distances = min_distances.to("cuda")
244 |
245 | for _ in tqdm(range(n - 1)):
246 | distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True) # broadcasting step
247 | min_distances = torch.minimum(distances, min_distances) # iterative step
248 | select_idx = torch.argmax(min_distances) # selection step
249 |
250 | # bookkeeping
251 | last_item = z_lib[select_idx:select_idx + 1]
252 | min_distances[select_idx] = 0
253 | coreset_idx.append(select_idx.to("cpu"))
254 | return torch.stack(coreset_idx)
255 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import datetime
3 | import os
4 | import time
5 | from collections import defaultdict, deque
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.distributed as dist
10 | # from torch._six import inf
11 | inf = float('inf')
12 |
13 | class SmoothedValue(object):
14 | """Track a series of values and provide access to smoothed values over a
15 | window or the global series average.
16 | """
17 |
18 | def __init__(self, window_size=20, fmt=None):
19 | if fmt is None:
20 | fmt = "{median:.4f} ({global_avg:.4f})"
21 | self.deque = deque(maxlen=window_size)
22 | self.total = 0.0
23 | self.count = 0
24 | self.fmt = fmt
25 |
26 | def update(self, value, n=1):
27 | self.deque.append(value)
28 | self.count += n
29 | self.total += value * n
30 |
31 | def synchronize_between_processes(self):
32 | """
33 | Warning: does not synchronize the deque!
34 | """
35 | if not is_dist_avail_and_initialized():
36 | return
37 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
38 | dist.barrier()
39 | dist.all_reduce(t)
40 | t = t.tolist()
41 | self.count = int(t[0])
42 | self.total = t[1]
43 |
44 | @property
45 | def median(self):
46 | d = torch.tensor(list(self.deque))
47 | return d.median().item()
48 |
49 | @property
50 | def avg(self):
51 | d = torch.tensor(list(self.deque), dtype=torch.float32)
52 | return d.mean().item()
53 |
54 | @property
55 | def global_avg(self):
56 | return self.total / self.count
57 |
58 | @property
59 | def max(self):
60 | return max(self.deque)
61 |
62 | @property
63 | def value(self):
64 | return self.deque[-1]
65 |
66 | def __str__(self):
67 | return self.fmt.format(
68 | median=self.median,
69 | avg=self.avg,
70 | global_avg=self.global_avg,
71 | max=self.max,
72 | value=self.value)
73 |
74 |
75 | class MetricLogger(object):
76 | def __init__(self, delimiter="\t"):
77 | self.meters = defaultdict(SmoothedValue)
78 | self.delimiter = delimiter
79 |
80 | def update(self, **kwargs):
81 | for k, v in kwargs.items():
82 | if v is None:
83 | continue
84 | if isinstance(v, torch.Tensor):
85 | v = v.item()
86 | assert isinstance(v, (float, int))
87 | self.meters[k].update(v)
88 |
89 | def __getattr__(self, attr):
90 | if attr in self.meters:
91 | return self.meters[attr]
92 | if attr in self.__dict__:
93 | return self.__dict__[attr]
94 | raise AttributeError("'{}' object has no attribute '{}'".format(
95 | type(self).__name__, attr))
96 |
97 | def __str__(self):
98 | loss_str = []
99 | for name, meter in self.meters.items():
100 | loss_str.append(
101 | "{}: {}".format(name, str(meter))
102 | )
103 | return self.delimiter.join(loss_str)
104 |
105 | def synchronize_between_processes(self):
106 | for meter in self.meters.values():
107 | meter.synchronize_between_processes()
108 |
109 | def add_meter(self, name, meter):
110 | self.meters[name] = meter
111 |
112 | def log_every(self, iterable, print_freq, header=None):
113 | i = 0
114 | if not header:
115 | header = ''
116 | start_time = time.time()
117 | end = time.time()
118 | iter_time = SmoothedValue(fmt='{avg:.4f}')
119 | data_time = SmoothedValue(fmt='{avg:.4f}')
120 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
121 | log_msg = [
122 | header,
123 | '[{0' + space_fmt + '}/{1}]',
124 | 'eta: {eta}',
125 | '{meters}',
126 | 'time: {time}',
127 | 'data: {data}'
128 | ]
129 | if torch.cuda.is_available():
130 | log_msg.append('max mem: {memory:.0f}')
131 | log_msg = self.delimiter.join(log_msg)
132 | MB = 1024.0 * 1024.0
133 | for obj in iterable:
134 | data_time.update(time.time() - end)
135 | yield obj
136 | iter_time.update(time.time() - end)
137 | if i % print_freq == 0 or i == len(iterable) - 1:
138 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
139 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
140 | if torch.cuda.is_available():
141 | print(log_msg.format(
142 | i, len(iterable), eta=eta_string,
143 | meters=str(self),
144 | time=str(iter_time), data=str(data_time),
145 | memory=torch.cuda.max_memory_allocated() / MB))
146 | else:
147 | print(log_msg.format(
148 | i, len(iterable), eta=eta_string,
149 | meters=str(self),
150 | time=str(iter_time), data=str(data_time)))
151 | i += 1
152 | end = time.time()
153 | total_time = time.time() - start_time
154 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
155 | print('{} Total time: {} ({:.4f} s / it)'.format(
156 | header, total_time_str, total_time / len(iterable)))
157 |
158 |
159 | def setup_for_distributed(is_master):
160 | """
161 | This function disables printing when not in master process
162 | """
163 | builtin_print = builtins.print
164 |
165 | def print(*args, **kwargs):
166 | force = kwargs.pop('force', False)
167 | force = force or (get_world_size() > 8)
168 | if is_master or force:
169 | now = datetime.datetime.now().time()
170 | builtin_print('[{}] '.format(now), end='') # print with time stamp
171 | builtin_print(*args, **kwargs)
172 |
173 | builtins.print = print
174 |
175 |
176 | def is_dist_avail_and_initialized():
177 | if not dist.is_available():
178 | return False
179 | if not dist.is_initialized():
180 | return False
181 | return True
182 |
183 |
184 | def get_world_size():
185 | if not is_dist_avail_and_initialized():
186 | return 1
187 | return dist.get_world_size()
188 |
189 |
190 | def get_rank():
191 | if not is_dist_avail_and_initialized():
192 | return 0
193 | return dist.get_rank()
194 |
195 |
196 | def is_main_process():
197 | return get_rank() == 0
198 |
199 |
200 | def save_on_master(*args, **kwargs):
201 | if is_main_process():
202 | torch.save(*args, **kwargs)
203 |
204 |
205 | def init_distributed_mode(args):
206 | if args.dist_on_itp:
207 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
208 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
209 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
210 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
211 | os.environ['LOCAL_RANK'] = str(args.gpu)
212 | os.environ['RANK'] = str(args.rank)
213 | os.environ['WORLD_SIZE'] = str(args.world_size)
214 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
215 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
216 | args.rank = int(os.environ["RANK"])
217 | args.world_size = int(os.environ['WORLD_SIZE'])
218 | args.gpu = int(os.environ['LOCAL_RANK'])
219 | elif 'SLURM_PROCID' in os.environ:
220 | args.rank = int(os.environ['SLURM_PROCID'])
221 | args.gpu = args.rank % torch.cuda.device_count()
222 | else:
223 | print('Not using distributed mode')
224 | setup_for_distributed(is_master=True) # hack
225 | args.distributed = False
226 | return
227 |
228 | args.distributed = True
229 |
230 | torch.cuda.set_device(args.gpu)
231 | args.dist_backend = 'nccl'
232 | print('| distributed init (rank {}): {}, gpu {}'.format(
233 | args.rank, args.dist_url, args.gpu), flush=True)
234 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
235 | world_size=args.world_size, rank=args.rank)
236 | torch.distributed.barrier()
237 | setup_for_distributed(args.rank == 0)
238 |
239 |
240 | class NativeScalerWithGradNormCount:
241 | state_dict_key = "amp_scaler"
242 |
243 | def __init__(self):
244 | self._scaler = torch.cuda.amp.GradScaler()
245 |
246 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
247 | self._scaler.scale(loss).backward(create_graph=create_graph)
248 | if update_grad:
249 | if clip_grad is not None:
250 | assert parameters is not None
251 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
252 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
253 | else:
254 | self._scaler.unscale_(optimizer)
255 | norm = get_grad_norm_(parameters)
256 | self._scaler.step(optimizer)
257 | self._scaler.update()
258 | else:
259 | norm = None
260 | return norm
261 |
262 | def state_dict(self):
263 | return self._scaler.state_dict()
264 |
265 | def load_state_dict(self, state_dict):
266 | self._scaler.load_state_dict(state_dict)
267 |
268 |
269 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
270 | if isinstance(parameters, torch.Tensor):
271 | parameters = [parameters]
272 | parameters = [p for p in parameters if p.grad is not None]
273 | norm_type = float(norm_type)
274 | if len(parameters) == 0:
275 | return torch.tensor(0.)
276 | device = parameters[0].grad.device
277 | if norm_type == inf:
278 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
279 | else:
280 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
281 | return total_norm
282 |
283 |
284 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
285 | output_dir = Path(args.output_dir)
286 | epoch_name = str(epoch)
287 | if loss_scaler is not None:
288 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
289 | for checkpoint_path in checkpoint_paths:
290 | to_save = {
291 | 'model': model_without_ddp.state_dict(),
292 | 'optimizer': optimizer.state_dict(),
293 | 'epoch': epoch,
294 | 'scaler': loss_scaler.state_dict(),
295 | 'args': args,
296 | }
297 |
298 | save_on_master(to_save, checkpoint_path)
299 | else:
300 | client_state = {'epoch': epoch}
301 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
302 |
303 | def save_model_gan(args, epoch, model, discriminator, model_without_ddp, discriminator_without_ddp,
304 | optimizer_g, optimizer_d, loss_scaler):
305 | output_dir = Path(args.output_dir)
306 | epoch_name = str(epoch)
307 | if loss_scaler is not None:
308 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
309 | for checkpoint_path in checkpoint_paths:
310 | to_save = {
311 | 'model': model_without_ddp.state_dict(),
312 | 'discriminator_without_ddp': discriminator_without_ddp.state_dict(),
313 | 'optimizer_g': optimizer_g.state_dict(),
314 | 'optimizer_d': optimizer_d.state_dict(),
315 | 'epoch': epoch,
316 | 'scaler': loss_scaler.state_dict(),
317 | 'args': args,
318 | }
319 |
320 | save_on_master(to_save, checkpoint_path)
321 | else:
322 | client_state = {'epoch': epoch}
323 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
324 | discriminator.save_checkpoint(save_dir=args.output_dir, tag="checkpoint_d-%s" % epoch_name, client_state=client_state)
325 |
326 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
327 | if args.resume:
328 | if args.resume.startswith('https'):
329 | checkpoint = torch.hub.load_state_dict_from_url(
330 | args.resume, map_location='cpu', check_hash=True)
331 | else:
332 | checkpoint = torch.load(args.resume, map_location='cpu')
333 | model_without_ddp.load_state_dict(checkpoint['model'])
334 | print("Resume checkpoint %s" % args.resume)
335 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
336 | optimizer.load_state_dict(checkpoint['optimizer'])
337 | args.start_epoch = checkpoint['epoch'] + 1
338 | if 'scaler' in checkpoint:
339 | loss_scaler.load_state_dict(checkpoint['scaler'])
340 | print("With optim & sched!")
341 |
342 | def load_model_gan(args, model_without_ddp, discriminator_without_ddp,
343 | optimizer_g, optimizer_d, loss_scaler):
344 | if args.resume:
345 | if args.resume.startswith('https'):
346 | checkpoint = torch.hub.load_state_dict_from_url(
347 | args.resume, map_location='cpu', check_hash=True)
348 | else:
349 | checkpoint = torch.load(args.resume, map_location='cpu')
350 | model_without_ddp.load_state_dict(checkpoint['model'])
351 | discriminator_without_ddp.load_state_dict(checkpoint['discriminator'])
352 | print("Resume checkpoint %s" % args.resume)
353 | if 'optimizer_d' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
354 | optimizer_d.load_state_dict(checkpoint['optimizer_d'])
355 | if 'optimizer_g' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
356 | optimizer_g.load_state_dict(checkpoint['optimizer_g'])
357 | args.start_epoch = checkpoint['epoch'] + 1
358 | if 'scaler' in checkpoint:
359 | loss_scaler.load_state_dict(checkpoint['scaler'])
360 | print("With optim & sched!")
361 |
362 |
363 |
364 | def all_reduce_mean(x):
365 | world_size = get_world_size()
366 | if world_size > 1:
367 | x_reduce = torch.tensor(x).cuda()
368 | dist.all_reduce(x_reduce)
369 | x_reduce /= world_size
370 | return x_reduce.item()
371 | else:
372 | return x
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | from timm.models.layers import DropPath, trunc_normal_
5 | from pointnet2_ops import pointnet2_utils
6 | from knn_cuda import KNN
7 |
8 | class Model(torch.nn.Module):
9 |
10 | def __init__(self, device, rgb_backbone_name='vit_base_patch8_224_dino', out_indices=None,
11 | checkpoint_path='/fuxi_team14_intern/m3dm/checkpoints/dino_vitbase8_pretrain.pth',
12 | pool_last=False, xyz_backbone_name='Point_MAE', group_size=128, num_group=1024):
13 | super().__init__()
14 | # 'vit_base_patch8_224_dino'
15 | # Determine if to output features.
16 | self.device = device
17 |
18 | kwargs = {'features_only': True if out_indices else False}
19 | if out_indices:
20 | kwargs.update({'out_indices': out_indices})
21 |
22 | ## RGB backbone
23 | self.rgb_backbone = timm.create_model(model_name=rgb_backbone_name, pretrained=False,
24 | checkpoint_path=checkpoint_path,
25 | **kwargs)
26 |
27 | ## XYZ backbone
28 |
29 | if xyz_backbone_name == 'Point_MAE':
30 | self.xyz_backbone = PointTransformer(group_size=group_size, num_group=num_group)
31 | self.xyz_backbone.load_model_from_ckpt("/fuxi_team14_intern/m3dm/checkpoints/pointmae_pretrain.pth")
32 | elif xyz_backbone_name == 'Point-BERT':
33 | self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group, encoder_dims=256)
34 | self.xyz_backbone.load_model_from_pb_ckpt("/fuxi_team14_intern/m3dm/checkpoints/Point-BERT.pth")
35 | elif xyz_backbone_name == 'FPFH':
36 | self.xyz_backbone=FPFH(group_size=group_size, num_group=num_group,voxel_size=0.05)
37 | #self.xyz_backbone.load_model_from_pb_ckpt("/workspace/data2/checkpoints/Point-BERT.pth")
38 |
39 |
40 |
41 |
42 | def forward_rgb_features(self, x):
43 | x = self.rgb_backbone.patch_embed(x)
44 | x = self.rgb_backbone._pos_embed(x)
45 | x = self.rgb_backbone.norm_pre(x)
46 | if self.rgb_backbone.grad_checkpointing and not torch.jit.is_scripting():
47 | x = checkpoint_seq(self.blocks, x)
48 | else:
49 | x = self.rgb_backbone.blocks(x)
50 | x = self.rgb_backbone.norm(x)
51 |
52 | feat = x[:,1:].permute(0, 2, 1).view(1, -1, 28, 28)
53 | return feat
54 |
55 |
56 | def forward(self, rgb, xyz):
57 |
58 | rgb_features = self.forward_rgb_features(rgb)
59 |
60 | xyz_features, center, ori_idx, center_idx = self.xyz_backbone(xyz)
61 |
62 | xyz_features.permute(0, 2, 1)
63 |
64 | return rgb_features, xyz_features, center, ori_idx, center_idx
65 |
66 |
67 |
68 | def fps(data, number):
69 | '''
70 | data B N 3
71 | number int
72 | '''
73 | fps_idx = pointnet2_utils.furthest_point_sample(data, number)
74 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
75 | return fps_data, fps_idx
76 |
77 | class Group(nn.Module):
78 | def __init__(self, num_group, group_size):
79 | super().__init__()
80 | self.num_group = num_group
81 | self.group_size = group_size
82 | self.knn = KNN(k=self.group_size, transpose_mode=True)
83 |
84 | def forward(self, xyz):
85 | '''
86 | input: B N 3
87 | ---------------------------
88 | output: B G M 3
89 | center : B G 3
90 | '''
91 | batch_size, num_points, _ = xyz.shape
92 | # fps the centers out
93 | center, center_idx = fps(xyz.contiguous(), self.num_group) # B G 3
94 | # knn to get the neighborhood
95 | _, idx = self.knn(xyz, center) # B G M
96 | assert idx.size(1) == self.num_group
97 | assert idx.size(2) == self.group_size
98 | ori_idx = idx
99 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
100 | idx = idx + idx_base
101 | idx = idx.view(-1)
102 | neighborhood = xyz.reshape(batch_size * num_points, -1)[idx, :]
103 | neighborhood = neighborhood.reshape(batch_size, self.num_group, self.group_size, 3).contiguous()
104 | # normalize
105 | neighborhood = neighborhood - center.unsqueeze(2)
106 | return neighborhood, center, ori_idx, center_idx
107 |
108 |
109 | class Encoder(nn.Module):
110 | def __init__(self, encoder_channel):
111 | super().__init__()
112 | self.encoder_channel = encoder_channel
113 | self.first_conv = nn.Sequential(
114 | nn.Conv1d(3, 128, 1),
115 | nn.BatchNorm1d(128),
116 | nn.ReLU(inplace=True),
117 | nn.Conv1d(128, 256, 1)
118 | )
119 | self.second_conv = nn.Sequential(
120 | nn.Conv1d(512, 512, 1),
121 | nn.BatchNorm1d(512),
122 | nn.ReLU(inplace=True),
123 | nn.Conv1d(512, self.encoder_channel, 1)
124 | )
125 |
126 | def forward(self, point_groups):
127 | '''
128 | point_groups : B G N 3
129 | -----------------
130 | feature_global : B G C
131 | '''
132 | bs, g, n, _ = point_groups.shape
133 | point_groups = point_groups.reshape(bs * g, n, 3)
134 | # encoder
135 | feature = self.first_conv(point_groups.transpose(2, 1))
136 | feature_global = torch.max(feature, dim=2, keepdim=True)[0]
137 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)
138 | feature = self.second_conv(feature)
139 | feature_global = torch.max(feature, dim=2, keepdim=False)[0]
140 | return feature_global.reshape(bs, g, self.encoder_channel)
141 |
142 |
143 | class Mlp(nn.Module):
144 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
145 | super().__init__()
146 | out_features = out_features or in_features
147 | hidden_features = hidden_features or in_features
148 | self.fc1 = nn.Linear(in_features, hidden_features)
149 | self.act = act_layer()
150 | self.fc2 = nn.Linear(hidden_features, out_features)
151 | self.drop = nn.Dropout(drop)
152 |
153 | def forward(self, x):
154 | x = self.fc1(x)
155 | x = self.act(x)
156 | x = self.drop(x)
157 | x = self.fc2(x)
158 | x = self.drop(x)
159 | return x
160 |
161 |
162 | class Attention(nn.Module):
163 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
164 | super().__init__()
165 | self.num_heads = num_heads
166 | head_dim = dim // num_heads
167 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
168 | self.scale = qk_scale or head_dim ** -0.5
169 |
170 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
171 | self.attn_drop = nn.Dropout(attn_drop)
172 | self.proj = nn.Linear(dim, dim)
173 | self.proj_drop = nn.Dropout(proj_drop)
174 |
175 | def forward(self, x):
176 | B, N, C = x.shape
177 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
178 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
179 |
180 | attn = (q * self.scale) @ k.transpose(-2, -1)
181 | attn = attn.softmax(dim=-1)
182 | attn = self.attn_drop(attn)
183 |
184 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
185 | x = self.proj(x)
186 | x = self.proj_drop(x)
187 | return x
188 |
189 |
190 | class Block(nn.Module):
191 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
192 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
193 | super().__init__()
194 | self.norm1 = norm_layer(dim)
195 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
196 | self.norm2 = norm_layer(dim)
197 | mlp_hidden_dim = int(dim * mlp_ratio)
198 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
199 |
200 | self.attn = Attention(
201 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
202 |
203 | def forward(self, x):
204 | x = x + self.drop_path(self.attn(self.norm1(x)))
205 | x = x + self.drop_path(self.mlp(self.norm2(x)))
206 | return x
207 |
208 |
209 | class TransformerEncoder(nn.Module):
210 | """ Transformer Encoder without hierarchical structure
211 | """
212 |
213 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
214 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
215 | super().__init__()
216 |
217 | self.blocks = nn.ModuleList([
218 | Block(
219 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
220 | drop=drop_rate, attn_drop=attn_drop_rate,
221 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
222 | )
223 | for i in range(depth)])
224 |
225 | def forward(self, x, pos):
226 | feature_list = []
227 | fetch_idx = [3, 7, 11]
228 | for i, block in enumerate(self.blocks):
229 | x = block(x + pos)
230 | if i in fetch_idx:
231 | feature_list.append(x)
232 | return feature_list
233 |
234 |
235 | class PointTransformer(nn.Module):
236 | def __init__(self, group_size=128, num_group=1024, encoder_dims=384):
237 | super().__init__()
238 |
239 | self.trans_dim = 384
240 | self.depth = 12
241 | self.drop_path_rate = 0.1
242 | self.num_heads = 6
243 |
244 | self.group_size = group_size
245 | self.num_group = num_group
246 | # grouper
247 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
248 | # define the encoder
249 | self.encoder_dims = encoder_dims
250 | if self.encoder_dims != self.trans_dim:
251 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
252 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
253 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim)
254 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
255 | # bridge encoder and transformer
256 |
257 | self.pos_embed = nn.Sequential(
258 | nn.Linear(3, 128),
259 | nn.GELU(),
260 | nn.Linear(128, self.trans_dim)
261 | )
262 |
263 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
264 | self.blocks = TransformerEncoder(
265 | embed_dim=self.trans_dim,
266 | depth=self.depth,
267 | drop_path_rate=dpr,
268 | num_heads=self.num_heads
269 | )
270 |
271 | self.norm = nn.LayerNorm(self.trans_dim)
272 |
273 | def load_model_from_ckpt(self, bert_ckpt_path):
274 | if bert_ckpt_path is not None:
275 | ckpt = torch.load(bert_ckpt_path)
276 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
277 |
278 | for k in list(base_ckpt.keys()):
279 | if k.startswith('MAE_encoder'):
280 | base_ckpt[k[len('MAE_encoder.'):]] = base_ckpt[k]
281 | del base_ckpt[k]
282 | elif k.startswith('base_model'):
283 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
284 | del base_ckpt[k]
285 |
286 | incompatible = self.load_state_dict(base_ckpt, strict=False)
287 |
288 | #if incompatible.missing_keys:
289 | # print('missing_keys')
290 | # print(
291 | # incompatible.missing_keys
292 | # )
293 | #if incompatible.unexpected_keys:
294 | # print('unexpected_keys')
295 | # print(
296 | # incompatible.unexpected_keys
297 |
298 | # )
299 |
300 | # print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')
301 |
302 | def load_model_from_pb_ckpt(self, bert_ckpt_path):
303 | ckpt = torch.load(bert_ckpt_path)
304 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
305 | for k in list(base_ckpt.keys()):
306 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'):
307 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k]
308 | elif k.startswith('base_model'):
309 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
310 | del base_ckpt[k]
311 |
312 | incompatible = self.load_state_dict(base_ckpt, strict=False)
313 |
314 | if incompatible.missing_keys:
315 | print('missing_keys')
316 | print(
317 | incompatible.missing_keys
318 | )
319 | if incompatible.unexpected_keys:
320 | print('unexpected_keys')
321 | print(
322 | incompatible.unexpected_keys
323 |
324 | )
325 |
326 | print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')
327 |
328 |
329 | def forward(self, pts):
330 | if self.encoder_dims != self.trans_dim:
331 | B,C,N = pts.shape
332 | pts = pts.transpose(-1, -2) # B N 3
333 | # divide the point clo ud in the same form. This is important
334 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts)
335 | # # generate mask
336 | # bool_masked_pos = self._mask_center(center, no_mask = False) # B G
337 | # encoder the input cloud blocks
338 | group_input_tokens = self.encoder(neighborhood) # B G N
339 | group_input_tokens = self.reduce_dim(group_input_tokens)
340 | # prepare cls
341 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
342 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
343 | # add pos embedding
344 | pos = self.pos_embed(center)
345 | # final input
346 | x = torch.cat((cls_tokens, group_input_tokens), dim=1)
347 | pos = torch.cat((cls_pos, pos), dim=1)
348 | # transformer
349 | feature_list = self.blocks(x, pos)
350 | feature_list = [self.norm(x)[:,1:].transpose(-1, -2).contiguous() for x in feature_list]
351 | x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152
352 | return x, center, ori_idx, center_idx
353 | else:
354 | B, C, N = pts.shape
355 | pts = pts.transpose(-1, -2) # B N 3
356 | # divide the point clo ud in the same form. This is important
357 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts)
358 |
359 | group_input_tokens = self.encoder(neighborhood) # B G N
360 |
361 | pos = self.pos_embed(center)
362 | # final input
363 | x = group_input_tokens
364 | # transformer
365 | feature_list = self.blocks(x, pos)
366 | feature_list = [self.norm(x).transpose(-1, -2).contiguous() for x in feature_list]
367 | x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152
368 | return x, center, ori_idx, center_idx
369 |
370 | # class FPFH(nn.Module):
371 | # def __init__(self, group_size=32, num_group=512, voxel_size=0.05):
372 | # super(FPFH, self).__init__()
373 | # self.group_size = group_size
374 | # self.num_group = num_group
375 | # self.voxel_size = voxel_size
376 | # self.resize = nn.AdaptiveAvgPool2d((28, 28))
377 | # self.average = nn.AvgPool2d(2, 2)
378 |
379 | # def organized_pc_to_unorganized_pc(self, organized_pc):
380 | # return organized_pc.reshape(-1, 3)
381 |
382 | # def get_fpfh_features(self, organized_pc):
383 | # organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
384 | # unorganized_pc = self.organized_pc_to_unorganized_pc(organized_pc_np)
385 |
386 | # nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
387 | # unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :]
388 |
389 | # o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros))
390 |
391 | # radius_normal = self.voxel_size * 2
392 | # o3d_pc.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))
393 |
394 | # radius_feature = self.voxel_size * 5
395 | # pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
396 | # o3d_pc,
397 | # o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100)
398 | # )
399 | # fpfh = pcd_fpfh.data.T
400 |
401 | # full_fpfh = np.zeros((unorganized_pc.shape[0], fpfh.shape[1]), dtype=fpfh.dtype)
402 | # full_fpfh[nonzero_indices, :] = fpfh
403 | # full_fpfh_reshaped = full_fpfh.reshape((organized_pc_np.shape[0], organized_pc_np.shape[1], fpfh.shape[1]))
404 | # full_fpfh_tensor = torch.tensor(full_fpfh_reshaped).permute(2, 0, 1).unsqueeze(dim=0)
405 |
406 | # return full_fpfh_tensor
407 |
408 | # def forward(self, xyz):
409 | # batch_size, _, height, width = xyz.shape
410 |
411 | # # Compute FPFH features
412 | # xyz_features = self.get_fpfh_features(xyz)
413 |
414 | # # Resize and average
415 | # xyz_features_resized = self.resize(self.average(xyz_features))
416 |
417 | # # Randomly sample center points
418 | # center_idx = torch.randperm(height * width)[:self.num_group]
419 | # center = xyz.view(batch_size, 3, -1).permute(0, 2, 1)[:, center_idx, :]
420 |
421 | # # Create original indices
422 | # ori_idx = torch.arange(height * width).view(1, height, width).expand(batch_size, -1, -1)
423 |
424 | # return xyz_features_resized, center, ori_idx, center_idx
425 |
426 | # def add_sample_to_mem_bank(self, sample):
427 | # fpfh_feature_maps = self.get_fpfh_features(sample[1])
428 | # fpfh_feature_maps_resized = self.resize(self.average(fpfh_feature_maps))
429 | # fpfh_patch = fpfh_feature_maps_resized.reshape(fpfh_feature_maps_resized.shape[1], -1).T
430 | # return fpfh_patch
431 |
432 | # def predict(self, sample):
433 | # depth_feature_maps = self.get_fpfh_features(sample[1])
434 | # depth_feature_maps_resized = self.resize(self.average(depth_feature_maps))
435 | # patch = depth_feature_maps_resized.reshape(depth_feature_maps_resized.shape[1], -1).T
436 | # return patch, depth_feature_maps_resized.shape[-2:]
437 | import numpy as np
438 | import open3d as o3d
439 | class FPFH(nn.Module):
440 | def __init__(self, group_size=32, num_group=512, voxel_size=0.05):
441 | super(FPFH, self).__init__()
442 | self.group_size = group_size
443 | self.num_group = num_group
444 | self.voxel_size = voxel_size
445 |
446 | def get_fpfh_features(self, unorganized_pc):
447 | # 确保 unorganized_pc 是 CPU 上的 numpy 数组
448 | if isinstance(unorganized_pc, torch.Tensor):
449 | unorganized_pc = unorganized_pc.cpu().numpy()
450 |
451 | # 确保形状是 (N, 3)
452 | if unorganized_pc.shape[0] == 3:
453 | unorganized_pc = unorganized_pc.T
454 |
455 | # 移除零点
456 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
457 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :]
458 |
459 | # 确保数据类型是 float64
460 | unorganized_pc_no_zeros = unorganized_pc_no_zeros.astype(np.float64)
461 |
462 | o3d_pc = o3d.geometry.PointCloud()
463 | o3d_pc.points = o3d.utility.Vector3dVector(unorganized_pc_no_zeros)
464 |
465 | radius_normal = self.voxel_size * 2
466 | o3d_pc.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))
467 |
468 | radius_feature = self.voxel_size * 5
469 | pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
470 | o3d_pc,
471 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100)
472 | )
473 | fpfh = pcd_fpfh.data # 形状为 (33, M),M 是非零点的数量
474 |
475 | # 将 FPFH 特征转换为 torch.Tensor
476 | fpfh_tensor = torch.tensor(fpfh, dtype=torch.float32)
477 |
478 | return fpfh_tensor
479 |
480 | def forward(self, xyz):
481 | # 假设 xyz 是形状为 (B, N, 3) 的 torch.Tensor,其中 B 是批量大小,N 是点的数量
482 | xyz = xyz.permute(0, 2, 1)
483 | batch_size,num_points, _ = xyz.shape
484 |
485 | # 计算 FPFH 特征
486 | fpfh_features = []
487 | for i in range(batch_size):
488 | fpfh = self.get_fpfh_features(xyz[i])
489 | fpfh_features.append(fpfh)
490 |
491 | fpfh_features = torch.stack(fpfh_features)
492 |
493 | # 随机采样中心点
494 | center_idx = torch.randperm(num_points)[:self.num_group]
495 | center = xyz[:, center_idx, :]
496 |
497 | ori_idx = torch.arange(num_points)
498 |
499 | return fpfh_features, center, ori_idx,center_idx
500 |
501 | def add_sample_to_mem_bank(self, sample):
502 | print(sample.shape)
503 | # 假设 sample 是形状为 (N, 3) 的 torch.Tensor
504 | fpfh_features = self.get_fpfh_features(sample)
505 | return fpfh_features
506 |
507 | def predict(self, sample):
508 | # 假设 sample 是形状为 (N, 3) 的 torch.Tensor
509 | fpfh_features = self.get_fpfh_features(sample)
510 | return fpfh_features, None # 返回 None 作为形状,因为无序点云没有固定的空间维度
511 |
--------------------------------------------------------------------------------
/dataset2.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torchvision import transforms
4 | import glob
5 | from torch.utils.data import Dataset
6 | from utils.mvtec3d_util import *
7 | from torch.utils.data import DataLoader
8 | import numpy as np
9 | import math
10 | import cv2
11 | import argparse
12 | from matplotlib import pyplot as plt
13 |
14 | def eyecandies_classes():
15 | return [
16 | 'CandyCane',
17 | 'ChocolateCookie',
18 | 'ChocolatePraline',
19 | 'Confetto',
20 | 'GummyBear',
21 | 'HazelnutTruffle',
22 | 'LicoriceSandwich',
23 | 'Lollipop',
24 | 'Marshmallow',
25 | 'PeppermintCandy',
26 | ]
27 |
28 | def mvtec3d_classes():
29 | return [
30 | "bagel",
31 | "cable_gland",
32 | "carrot",
33 | "cookie",
34 | "dowel",
35 | "foam",
36 | "peach",
37 | "potato",
38 | "rope",
39 | "tire",
40 | ]
41 | def test_3d_classes():
42 | return [
43 | 'audio_jack_socket',
44 | 'common_mode_filter',
45 | 'connector_housing-female',
46 | 'crimp_st_cable_mount_box',
47 | 'dc_power_connector',
48 | 'fork_crimp_terminal',
49 | 'headphone_jack_socket',
50 | 'miniature_lifting_motor',
51 | 'purple-clay-pot',
52 | 'power_jack',
53 |
54 | 'ethernet_connector',
55 | 'ferrite_bead',
56 | 'fuse_holder',
57 | 'humidity_sensor',
58 | 'knob-cap',
59 | 'lattice_block_plug',
60 | 'lego_pin_connector_plate',
61 | 'lego_propeller',
62 | 'limit-switch',
63 | 'telephone_spring_switch',
64 | # "bagel",
65 | # "cable_gland",
66 | # "carrot",
67 | # "cookie",
68 | # "dowel",
69 | # "foam",
70 | # "peach",
71 | # "potato",
72 | # "rope",
73 | # "tire",
74 | ]
75 |
76 | RGB_SIZE = 224
77 |
78 | class BaseAnomalyDetectionDataset(Dataset):
79 |
80 | def __init__(self, split, class_name, img_size,downsampling, angle, small,dataset_path='datasets/eyecandies_preprocessed'):
81 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
82 | self.IMAGENET_STD = [0.229, 0.224, 0.225]
83 | self.cls = class_name
84 | self.size = img_size
85 | self.img_path = os.path.join(dataset_path, self.cls, split)
86 | self.downsampling = downsampling
87 | self.angle = angle
88 | self.small = small
89 | self.rgb_transform = transforms.Compose(
90 | [transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
91 | transforms.ToTensor(),
92 | transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)])
93 | def analyze_depth_importance(self,organized_pc):
94 | """分析深度图的重要性,返回行列的重要性标记"""
95 | depth = organized_pc[:,:,2] # 获取深度通道
96 |
97 | # 计算每行的非零点占比
98 | row_importance = np.mean(depth != 0, axis=1)
99 | col_importance = np.mean(depth != 0, axis=0)
100 |
101 | # 使用中位数作为分界线
102 | row_median = np.median(row_importance[row_importance > 0])
103 | col_median = np.median(col_importance[col_importance > 0])
104 |
105 | # 标记重要性(True表示重要行/列)
106 | important_rows = row_importance >= row_median
107 | important_cols = col_importance >= col_median
108 |
109 | return important_rows, important_cols
110 |
111 | def smart_downsample(self, organized_pc, target_factor):
112 | """智能降采样,重要区域降采样更多,保持总体降采样因子不变"""
113 | important_rows, important_cols = self.analyze_depth_importance(organized_pc)
114 |
115 | # 获取重要和不重要的行列索引
116 | important_row_indices = np.where(important_rows)[0]
117 | unimportant_row_indices = np.where(~important_rows)[0]
118 | important_col_indices = np.where(important_cols)[0]
119 | unimportant_col_indices = np.where(~important_cols)[0]
120 |
121 | # 计算原始的重要和不重要区域的比例
122 | total_rows = len(important_row_indices) + len(unimportant_row_indices)
123 | total_cols = len(important_col_indices) + len(unimportant_col_indices)
124 | factor = int(math.sqrt(target_factor))
125 | # 目标总行列数
126 | target_total_rows = total_rows // factor
127 | target_total_cols = total_cols // factor
128 |
129 | # 设置重要区域的更高降采样率(比如降采样率为1/4)
130 | important_factor = factor*2
131 |
132 | # 计算重要区域的目标数量
133 | n_important_rows = len(important_row_indices) // important_factor
134 | n_important_cols = len(important_col_indices) // important_factor
135 |
136 | # 计算不重要区域需要保留的数量(确保总数符合目标)
137 | n_unimportant_rows = target_total_rows - n_important_rows
138 | n_unimportant_cols = target_total_cols - n_important_cols
139 |
140 | # 确保不重要区域的数量不会超过原始数量
141 | n_unimportant_rows = min(n_unimportant_rows, len(unimportant_row_indices))
142 | n_unimportant_cols = min(n_unimportant_cols, len(unimportant_col_indices))
143 |
144 | # 选择行
145 | selected_important_rows = np.linspace(0, len(important_row_indices)-1, n_important_rows, dtype=int)
146 | selected_important_rows = important_row_indices[selected_important_rows]
147 |
148 | selected_unimportant_rows = np.linspace(0, len(unimportant_row_indices)-1, n_unimportant_rows, dtype=int)
149 | selected_unimportant_rows = unimportant_row_indices[selected_unimportant_rows]
150 |
151 | # 选择列
152 | selected_important_cols = np.linspace(0, len(important_col_indices)-1, n_important_cols, dtype=int)
153 | selected_important_cols = important_col_indices[selected_important_cols]
154 |
155 | selected_unimportant_cols = np.linspace(0, len(unimportant_col_indices)-1, n_unimportant_cols, dtype=int)
156 | selected_unimportant_cols = unimportant_col_indices[selected_unimportant_cols]
157 |
158 | # 合并选择的行和列
159 | selected_rows = np.sort(np.concatenate([selected_important_rows, selected_unimportant_rows]))
160 | selected_cols = np.sort(np.concatenate([selected_important_cols, selected_unimportant_cols]))
161 |
162 | # 打印降采样信息
163 | print(f"原始大小: {organized_pc.shape[:2]}")
164 | print(f"重要行: {len(important_row_indices)} -> {len(selected_important_rows)} (1/{important_factor})")
165 | print(f"不重要行: {len(unimportant_row_indices)} -> {len(selected_unimportant_rows)}")
166 | print(f"重要列: {len(important_col_indices)} -> {len(selected_important_cols)} (1/{important_factor})")
167 | print(f"不重要列: {len(unimportant_col_indices)} -> {len(selected_unimportant_cols)}")
168 | print(f"最终大小: {len(selected_rows)}x{len(selected_cols)}")
169 | print(f"实际降采样因子: {(total_rows*total_cols)/(len(selected_rows)*len(selected_cols)):.2f}")
170 |
171 | return organized_pc[selected_rows][:, selected_cols]
172 |
173 | def get_matrix(self, image, angle):
174 | image = self.pillow_to_opencv(image)
175 | (h, w) = image.shape[:2]
176 | src_points = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
177 | dst_points = self.calculate_destination_points(0, w, 0, h, angle)
178 | M = cv2.getPerspectiveTransform(src_points, dst_points)
179 | return M
180 |
181 | def calculate_destination_points(self, left, right, top, bottom, angle):
182 | # 计算中心点
183 | center_x = (left + right) / 2
184 | center_y = (top + bottom) / 2
185 |
186 | # 计算角度的弧度值
187 | angle_rad = math.radians(angle)
188 |
189 | # 计算目标点
190 | dst_points = []
191 | for x, y in [(left, top), (right, top), (right, bottom), (left, bottom)]:
192 | new_x = center_x + (x - center_x) * math.cos(angle_rad) - (y - center_y) * math.sin(angle_rad)
193 | new_y = center_y + (x - center_x) * math.sin(angle_rad) + (y - center_y) * math.cos(angle_rad)
194 | dst_points.append([new_x, new_y])
195 |
196 | return np.float32(dst_points)
197 |
198 | def perspective_transform(self, image, matrix):
199 | """
200 | :param image_path: 输入图像路径
201 | :param angle: 旋转角度
202 | :param save_type: 保存的图片类型
203 | :return: 输出图像
204 | """
205 | # 读取图像
206 | image = self.pillow_to_opencv(image)
207 | (h, w) = image.shape[:2]
208 | # 执行透视变换
209 | transformed = cv2.warpPerspective(image, matrix, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=image[0, 0].tolist())
210 | transformed = self.opencv_to_pillow(transformed)
211 | return transformed
212 |
213 | def opencv_to_pillow(self, opencv_image):
214 | """
215 | 将 OpenCV 图像转换为 Pillow 图像。
216 |
217 | :param opencv_image: OpenCV 图像对象(BGR 格式)
218 | :return: Pillow 图像对象
219 | """
220 | # 检查是否为彩色图像并转换通道顺序
221 | if len(opencv_image.shape) == 3: # 彩色图像
222 | opencv_image_rgb = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
223 | return Image.fromarray(opencv_image_rgb)
224 | else: # 灰度图像
225 | return Image.fromarray(opencv_image)
226 |
227 | def pillow_to_opencv(self, pil_image):
228 | """
229 | 将 Pillow 图像转换为 OpenCV 图像。
230 |
231 | :param pil_image: Pillow 图像对象
232 | :return: OpenCV 图像对象(BGR 格式)
233 | """
234 | # 将 Pillow 图像转换为 numpy 数组
235 | opencv_image = np.array(pil_image)
236 |
237 | # 如果是 RGB 图像,转换为 BGR 格式
238 | if opencv_image.ndim == 3 and opencv_image.shape[2] == 3: # 检查是否为 RGB 图像
239 | opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
240 |
241 | return opencv_image
242 |
243 | class PreTrainTensorDataset(Dataset):
244 | def __init__(self, root_path):
245 | super().__init__()
246 | self.root_path = root_path
247 | self.tensor_paths = os.listdir(self.root_path)
248 |
249 |
250 | def __len__(self):
251 | return len(self.tensor_paths)
252 |
253 | def __getitem__(self, idx):
254 | tensor_path = self.tensor_paths[idx]
255 |
256 | tensor = torch.load(os.path.join(self.root_path, tensor_path))
257 |
258 | label = 0
259 |
260 | return tensor, label
261 |
262 | class TrainDataset(BaseAnomalyDetectionDataset):
263 | def __init__(self, class_name, img_size,downsampling,angle, small, dataset_path='datasets/eyecandies_preprocessed'):
264 | super().__init__(split="train", class_name=class_name, img_size=img_size,downsampling=downsampling, angle=angle, small=small, dataset_path=dataset_path)
265 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
266 |
267 | def load_dataset(self):
268 | img_tot_paths = []
269 | tot_labels = []
270 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
271 | rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb', '*', '*L05*RGB*.jpg'))
272 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
273 | ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg")
274 | rgb_paths.sort()
275 | tiff_paths.sort()
276 | ps_paths.sort()
277 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths))
278 | # sample_paths = list(zip(rgb_paths, tiff_paths))
279 | img_tot_paths.extend(sample_paths)
280 | tot_labels.extend([0] * len(sample_paths))
281 | # img_tot_paths = img_tot_paths[0:10]
282 | # tot_labels = tot_labels[0:10]
283 | return img_tot_paths, tot_labels
284 |
285 | def __len__(self):
286 | return len(self.img_paths)
287 |
288 | def __getitem__(self, idx):
289 | img_path, label = self.img_paths[idx], self.labels[idx]
290 | rgb_path = img_path[0]
291 | tiff_path = img_path[1]
292 | ps_path = img_path[2]
293 | img = Image.open(rgb_path).convert('RGB')
294 | # add rotation
295 | # matrix = self.get_matrix(img, self.angle)
296 | # img = self.perspective_transform(img, matrix)
297 |
298 | img = self.rgb_transform(img)
299 | ps= Image.open(ps_path).convert('RGB')
300 | # # add rotation
301 | # ps = self.perspective_transform(ps, matrix)
302 |
303 | ps = self.rgb_transform(ps)
304 | if self.downsampling > 1:
305 | organized_pc = read_tiff_organized_pc(tiff_path)
306 | factor1 = int(math.floor(math.sqrt(self.downsampling)))
307 | factor2 = int(math.ceil(self.downsampling / factor1))
308 | organized_pc = organized_pc[::factor1, ::factor2]
309 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling)
310 | else:
311 | organized_pc = read_tiff_organized_pc(tiff_path)
312 |
313 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
314 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
315 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
316 | resized_organized_pc = resized_organized_pc.clone().detach().float()
317 |
318 |
319 | return (img, resized_organized_pc, resized_depth_map_3channel,ps), label
320 |
321 |
322 | class TestDataset(BaseAnomalyDetectionDataset):
323 | def __init__(self, class_name, img_size,downsampling,angle,small,dataset_path='datasets/eyecandies_preprocessed'):
324 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path)
325 | self.gt_transform = transforms.Compose([
326 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
327 | transforms.ToTensor()])
328 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
329 |
330 | def load_dataset(self):
331 | img_tot_paths = []
332 | gt_tot_paths = []
333 | tot_labels = []
334 | ps_tot_paths = []
335 | defect_types = os.listdir(self.img_path)
336 | print(defect_types)
337 | # 如果types不为NONE,只保留GOOD和指定的types
338 | for defect_type in defect_types:
339 | if defect_type == 'good':
340 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
341 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
342 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
343 | #gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png")
344 | ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg")
345 | rgb_paths.sort()
346 | tiff_paths.sort()
347 | ps_paths.sort()
348 | # 只保留前5个样本
349 | if self.small:
350 | rgb_paths = rgb_paths[:5]
351 | tiff_paths = tiff_paths[:5]
352 | ps_paths = ps_paths[:5]
353 | sample_paths = list(zip(rgb_paths, tiff_paths,ps_paths))
354 | img_tot_paths.extend(sample_paths)
355 | gt_tot_paths.extend([0] * len(sample_paths))
356 | tot_labels.extend([0] * len(sample_paths))
357 | else:
358 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
359 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
360 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
361 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png")
362 | ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.png")
363 | rgb_paths.sort()
364 | tiff_paths.sort()
365 | gt_paths.sort()
366 | ps_paths.sort()
367 | # if self.small:
368 | # # 检查每个gt mask中缺陷占比
369 | # valid_indices = []
370 | # for i, gt_path in enumerate(gt_paths):
371 | # gt_mask = np.array(Image.open(gt_path))
372 | # total_pixels = gt_mask.shape[0] * gt_mask.shape[1]
373 | # threshold = int(total_pixels * 0.005) # 计算1%像素数量
374 | # defect_pixels = np.sum(gt_mask > 0)
375 | # if defect_pixels <= threshold: # 直接比较像素数量
376 | # valid_indices.append(i)
377 |
378 | # # 只保留缺陷占比<=1%的样本
379 | # rgb_paths = [rgb_paths[i] for i in valid_indices]
380 | # tiff_paths = [tiff_paths[i] for i in valid_indices]
381 | # gt_paths = [gt_paths[i] for i in valid_indices]
382 | # # ps_paths = [ps_paths[i] for i in valid_indices]
383 | sample_paths = list(zip(rgb_paths, tiff_paths,ps_paths))
384 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}, ps_paths: {len(ps_paths)}")
385 | img_tot_paths.extend(sample_paths)
386 | gt_tot_paths.extend(gt_paths)
387 | tot_labels.extend([1] * len(sample_paths))
388 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
389 |
390 | return img_tot_paths, gt_tot_paths, tot_labels
391 |
392 | def __len__(self):
393 | return len(self.img_paths)
394 |
395 | def __getitem__(self, idx):
396 |
397 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5))
398 |
399 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
400 | rgb_path = img_path[0]
401 | tiff_path = img_path[1]
402 | ps_path = img_path[2]
403 | img_original = Image.open(rgb_path).convert('RGB')
404 | # matrix = self.get_matrix(img_original, self.angle)
405 | # img_original = self.perspective_transform(img_original, matrix)
406 |
407 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB))
408 |
409 | img = self.rgb_transform(img_original)
410 | ps_original = Image.open(ps_path).convert('RGB')
411 | # ps_original = self.perspective_transform(ps_original, matrix)
412 | ps = self.rgb_transform(ps_original)
413 | # organized_pc = read_tiff_organized_pc(tiff_path)
414 | if self.downsampling > 1:
415 | organized_pc = read_tiff_organized_pc(tiff_path)
416 | factor1 = int(math.floor(math.sqrt(self.downsampling)))
417 | factor2 = int(math.ceil(self.downsampling / factor1))
418 | organized_pc = organized_pc[::factor1, ::factor2]
419 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling)
420 | else:
421 | organized_pc = read_tiff_organized_pc(tiff_path)
422 |
423 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
424 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
425 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
426 | resized_organized_pc = resized_organized_pc.clone().detach().float()
427 |
428 |
429 |
430 |
431 | if gt == 0:
432 | gt = torch.zeros(
433 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])
434 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
435 | else:
436 | gt = Image.open(gt).convert('L')
437 | # gt = self.perspective_transform(gt, matrix)
438 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
439 | gt = self.gt_transform(gt)
440 | gt = torch.where(gt > 0.5, 1., .0)
441 |
442 | # fig.show()
443 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight')
444 |
445 | return (img, resized_depth_map_3channel,ps), gt[:1], label, rgb_path
446 |
447 | class ValidDataset(BaseAnomalyDetectionDataset):
448 | def __init__(self, class_name, img_size,downsampling,angle,small,defect_name,dataset_path='datasets/eyecandies_preprocessed'):
449 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path)
450 | self.gt_transform = transforms.Compose([
451 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
452 | transforms.ToTensor()])
453 | self.defect_name = defect_name
454 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
455 |
456 | def load_dataset(self):
457 | img_tot_paths = []
458 | gt_tot_paths = []
459 | tot_labels = []
460 | #ps_tot_paths = []
461 | defect_types = os.listdir(self.img_path)
462 | # print(defect_types)
463 | # 如果types不为NONE,只保留GOOD和指定的types
464 | for defect_type in defect_types:
465 | # print(defect_type)
466 | if defect_type == 'GOOD':
467 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
468 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
469 | ps_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'ps') + "/*.jpg")
470 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
471 | # print(rgb_paths)
472 | rgb_paths.sort()
473 | tiff_paths.sort()
474 | ps_paths.sort()
475 | # 只保留前5个样本
476 | if self.small:
477 | rgb_paths = rgb_paths[:5]
478 | tiff_paths = tiff_paths[:5]
479 | ps_paths = ps_paths[:5]
480 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths))
481 | sample_paths = sample_paths[:5]
482 | img_tot_paths.extend(sample_paths)
483 | gt_tot_paths.extend([0] * len(sample_paths))
484 | tot_labels.extend([0] * len(sample_paths))
485 | elif defect_type in self.defect_name:
486 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
487 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
488 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt','rgb') + "/*.png")
489 | ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.jpg")
490 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
491 | # print(rgb_paths)
492 | rgb_paths.sort()
493 | tiff_paths.sort()
494 | gt_paths.sort()
495 | ps_paths.sort()
496 | if self.small:
497 | # 检查每个gt mask中缺陷占比
498 | valid_indices = []
499 | for i, gt_path in enumerate(gt_paths):
500 | gt_mask = np.array(Image.open(gt_path))
501 | total_pixels = gt_mask.shape[0] * gt_mask.shape[1]
502 | threshold = int(total_pixels * 0.005) # 计算1%像素数量
503 | defect_pixels = np.sum(gt_mask > 0)
504 | if defect_pixels <= threshold: # 直接比较像素数量
505 | valid_indices.append(i)
506 |
507 | # 只保留缺陷占比<=1%的样本
508 | rgb_paths = [rgb_paths[i] for i in valid_indices]
509 | tiff_paths = [tiff_paths[i] for i in valid_indices]
510 | gt_paths = [gt_paths[i] for i in valid_indices]
511 | ps_paths = [ps_paths[i] for i in valid_indices]
512 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths))
513 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}, ps_paths: {len(ps_paths)}")
514 | img_tot_paths.extend(sample_paths)
515 | gt_tot_paths.extend(gt_paths)
516 | tot_labels.extend([1] * len(sample_paths))
517 |
518 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
519 |
520 | return img_tot_paths, gt_tot_paths, tot_labels
521 |
522 | def __len__(self):
523 | return len(self.img_paths)
524 |
525 | def __getitem__(self, idx):
526 |
527 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5))
528 |
529 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
530 | rgb_path = img_path[0]
531 | tiff_path = img_path[1]
532 | ps_path = img_path[2]
533 | img_original = Image.open(rgb_path).convert('RGB')
534 | matrix = self.get_matrix(img_original, self.angle)
535 | img_original = self.perspective_transform(img_original, matrix)
536 |
537 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB))
538 |
539 | img = self.rgb_transform(img_original)
540 | ps_original = Image.open(ps_path).convert('RGB')
541 | ps_original = self.perspective_transform(ps_original, matrix)
542 | ps = self.rgb_transform(ps_original)
543 | # organized_pc = read_tiff_organized_pc(tiff_path)
544 | if self.downsampling > 1:
545 | organized_pc = read_tiff_organized_pc(tiff_path)
546 | factor1 = int(math.floor(math.sqrt(self.downsampling)))
547 | factor2 = int(math.ceil(self.downsampling / factor1))
548 | organized_pc = organized_pc[::factor1, ::factor2]
549 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling)
550 | else:
551 | organized_pc = read_tiff_organized_pc(tiff_path)
552 |
553 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
554 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
555 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
556 | resized_organized_pc = resized_organized_pc.clone().detach().float()
557 |
558 |
559 |
560 |
561 | if gt == 0:
562 | gt = torch.zeros(
563 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])
564 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
565 | else:
566 | gt = Image.open(gt).convert('L')
567 | gt = self.perspective_transform(gt, matrix)
568 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
569 | gt = self.gt_transform(gt)
570 | gt = torch.where(gt > 0.5, 1., .0)
571 |
572 | # fig.show()
573 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight')
574 |
575 | return (img, resized_depth_map_3channel, ps), gt[:1], label, rgb_path
576 |
577 | from torch.utils.data import Subset
578 | def redistribute_dataset(dataset, chunk_size=50):
579 | """
580 | 将数据集每50个分成一组,并重新分配包含GOOD的组
581 | :param dataset: 原始数据集
582 | :param chunk_size: 每组的大小
583 | :return: 重新分配后的数据集列表
584 | """
585 | total_size = len(dataset)
586 | num_chunks = (total_size + chunk_size - 1) // chunk_size # 向上取整
587 |
588 | # 初始分组
589 | chunks = []
590 | for i in range(num_chunks):
591 | start_idx = i * chunk_size
592 | end_idx = min(start_idx + chunk_size, total_size)
593 | chunk_indices = list(range(start_idx, end_idx))
594 | chunks.append(Subset(dataset, chunk_indices))
595 |
596 | # 找出包含GOOD的组
597 | good_chunks = []
598 | normal_chunks = []
599 |
600 | for i, chunk in enumerate(chunks):
601 | has_good = False
602 | # 检查这个chunk中是否包含GOOD
603 | for idx in chunk.indices:
604 | if 'GOOD' in dataset.img_paths[idx][0]:
605 | has_good = True
606 | break
607 |
608 | if has_good:
609 | good_chunks.append(i)
610 | else:
611 | normal_chunks.append(i)
612 |
613 | # 如果没有GOOD组或没有正常组,直接返回原始分组
614 | if not good_chunks or not normal_chunks:
615 | return chunks
616 |
617 | # 重新分配GOOD组的数据
618 | for good_chunk_idx in good_chunks:
619 | chunk = chunks[good_chunk_idx]
620 | good_indices = []
621 | normal_indices = []
622 |
623 | # 分离GOOD和非GOOD数据
624 | for idx in chunk.indices:
625 | if 'GOOD' in dataset.img_paths[idx][0]:
626 | good_indices.append(idx)
627 | else:
628 | normal_indices.append(idx)
629 |
630 | # 将GOOD数据平均分配给其他组
631 | num_good = len(good_indices)
632 | num_normal_chunks = len(normal_chunks)
633 |
634 | if num_normal_chunks > 0:
635 | # 计算每个正常组应该获得多少GOOD数据
636 | indices_per_chunk = num_good // num_normal_chunks
637 | remainder = num_good % num_normal_chunks
638 |
639 | # 分配GOOD数据
640 | current_good_idx = 0
641 | for i, normal_chunk_idx in enumerate(normal_chunks):
642 | extra = 1 if i < remainder else 0
643 | num_to_add = indices_per_chunk + extra
644 |
645 | # 添加GOOD数据到正常组
646 | chunks[normal_chunk_idx] = Subset(dataset,
647 | list(chunks[normal_chunk_idx].indices) +
648 | good_indices[current_good_idx:current_good_idx + num_to_add]
649 | )
650 | current_good_idx += num_to_add
651 |
652 | # 更新原始GOOD组,只保留非GOOD数据
653 | chunks[good_chunk_idx] = Subset(dataset, normal_indices)
654 |
655 | return chunks
656 | def get_data_loader(split, class_name, img_size, args, defect_name = None):
657 | if split in ['train']:
658 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small)
659 | elif split in ['test']:
660 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small)
661 | elif split in ['validation']:
662 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name)
663 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False,
664 | pin_memory=True)
665 | return data_loader
666 |
667 | def get_data_set(split, class_name, img_size, args, defect_name = None):
668 | if split in ['train']:
669 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle)
670 | elif split in ['test']:
671 | print('test')
672 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle)
673 | elif split in ['validation']:
674 | print('validation')
675 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name)
676 |
677 | return dataset
678 |
679 | # if __name__ == "__main__":
680 | # parser = argparse.ArgumentParser(description='Process some integers.')
681 | #
682 | # parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str,
683 | # choices=['DINO','Point_MAE','Fusion','DINO+Point_MAE','DINO+Point_MAE+Fusion','DINO+Point_MAE+add','DINO+FPFH','DINO+FPFH+Fusion',
684 | # 'DINO+FPFH+Fusion+ps','DINO+Point_MAE+Fusion+ps','DINO+Point_MAE+ps','DINO+FPFH+ps','ours','ours2','ours3','ours_final','ours_final1'
685 | # ,'ours_final1_VS'],
686 | # help='Anomaly detection modal name.')
687 | # parser.add_argument('--max_sample', default=400, type=int,
688 | # help='Max sample number.')
689 | # parser.add_argument('--memory_bank', default='multiple', type=str,
690 | # choices=["multiple", "single"],
691 | # help='memory bank mode: "multiple", "single".')
692 | # parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str,
693 | # choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'],
694 | # help='Timm checkpoints name of RGB backbone.')
695 | # parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert','FPFH'],
696 | # help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert, FPFH].')
697 | # parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str,
698 | # help='Checkpoints for fusion module.')
699 | # parser.add_argument('--save_feature', default=False, action='store_true',
700 | # help='Save feature for training fusion block.')
701 | # parser.add_argument('--use_uff', default=False, action='store_true',
702 | # help='Use UFF module.')
703 | # parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str,
704 | # help='Save feature for training fusion block.')
705 | # parser.add_argument('--save_preds', default=False, action='store_true',
706 | # help='Save predicts results.')
707 | # parser.add_argument('--group_size', default=128, type=int,
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torchvision import transforms
4 | import glob
5 | from torch.utils.data import Dataset
6 | from utils.mvtec3d_util import *
7 | from torch.utils.data import DataLoader
8 | import numpy as np
9 | import math
10 | import cv2
11 | import argparse
12 | from matplotlib import pyplot as plt
13 |
14 | def eyecandies_classes():
15 | return [
16 | 'CandyCane',
17 | 'ChocolateCookie',
18 | 'ChocolatePraline',
19 | 'Confetto',
20 | 'GummyBear',
21 | 'HazelnutTruffle',
22 | 'LicoriceSandwich',
23 | 'Lollipop',
24 | 'Marshmallow',
25 | 'PeppermintCandy',
26 | ]
27 |
28 | def mvtec3d_classes():
29 | return [
30 | "bagel",
31 | "cable_gland",
32 | "carrot",
33 | "cookie",
34 | "dowel",
35 | "foam",
36 | "peach",
37 | "potato",
38 | "rope",
39 | "tire",
40 | ]
41 | def test_3d_classes():
42 | return [
43 | # 'audio_jack_socket',
44 | # 'common_mode_filter',
45 | # 'connector_housing-female',
46 | # 'crimp_st_cable_mount_box',
47 | # 'dc_power_connector',
48 | # 'fork_crimp_terminal',
49 | # 'headphone_jack_socket',
50 | # 'miniature_lifting_motor',
51 | # 'purple-clay-pot',
52 | # 'power_jack',
53 | # 'ethernet_connector',
54 | # 'ferrite_bead',
55 | # 'fuse_holder',
56 | # 'humidity_sensor',
57 | # 'knob-cap',
58 | # 'lattice_block_plug',
59 | 'lego_pin_connector_plate',
60 | 'lego_propeller',
61 | 'limit-switch',
62 | 'telephone_spring_switch',
63 | ]
64 |
65 | RGB_SIZE = 224
66 |
67 | class BaseAnomalyDetectionDataset(Dataset):
68 |
69 | def __init__(self, split, class_name, img_size,downsampling, angle, small,dataset_path='datasets/eyecandies_preprocessed'):
70 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
71 | self.IMAGENET_STD = [0.229, 0.224, 0.225]
72 | self.cls = class_name
73 | self.size = img_size
74 | self.img_path = os.path.join(dataset_path, self.cls, split)
75 | self.downsampling = downsampling
76 | self.angle = angle
77 | self.small = small
78 | self.rgb_transform = transforms.Compose(
79 | [transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
80 | transforms.ToTensor(),
81 | transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)])
82 | def analyze_depth_importance(self,organized_pc):
83 | """分析深度图的重要性,返回行列的重要性标记"""
84 | depth = organized_pc[:,:,2] # 获取深度通道
85 |
86 | # 计算每行的非零点占比
87 | row_importance = np.mean(depth != 0, axis=1)
88 | col_importance = np.mean(depth != 0, axis=0)
89 |
90 | # 使用中位数作为分界线
91 | row_median = np.median(row_importance[row_importance > 0])
92 | col_median = np.median(col_importance[col_importance > 0])
93 |
94 | # 标记重要性(True表示重要行/列)
95 | important_rows = row_importance >= row_median
96 | important_cols = col_importance >= col_median
97 |
98 | return important_rows, important_cols
99 |
100 | def smart_downsample(self, organized_pc, target_factor):
101 | """智能降采样,重要区域降采样更多,保持总体降采样因子不变"""
102 | important_rows, important_cols = self.analyze_depth_importance(organized_pc)
103 |
104 | # 获取重要和不重要的行列索引
105 | important_row_indices = np.where(important_rows)[0]
106 | unimportant_row_indices = np.where(~important_rows)[0]
107 | important_col_indices = np.where(important_cols)[0]
108 | unimportant_col_indices = np.where(~important_cols)[0]
109 |
110 | # 计算原始的重要和不重要区域的比例
111 | total_rows = len(important_row_indices) + len(unimportant_row_indices)
112 | total_cols = len(important_col_indices) + len(unimportant_col_indices)
113 | factor = int(math.sqrt(target_factor))
114 | # 目标总行列数
115 | target_total_rows = total_rows // factor
116 | target_total_cols = total_cols // factor
117 |
118 | # 设置重要区域的更高降采样率(比如降采样率为1/4)
119 | important_factor = factor*2
120 |
121 | # 计算重要区域的目标数量
122 | n_important_rows = len(important_row_indices) // important_factor
123 | n_important_cols = len(important_col_indices) // important_factor
124 |
125 | # 计算不重要区域需要保留的数量(确保总数符合目标)
126 | n_unimportant_rows = target_total_rows - n_important_rows
127 | n_unimportant_cols = target_total_cols - n_important_cols
128 |
129 | # 确保不重要区域的数量不会超过原始数量
130 | n_unimportant_rows = min(n_unimportant_rows, len(unimportant_row_indices))
131 | n_unimportant_cols = min(n_unimportant_cols, len(unimportant_col_indices))
132 |
133 | # 选择行
134 | selected_important_rows = np.linspace(0, len(important_row_indices)-1, n_important_rows, dtype=int)
135 | selected_important_rows = important_row_indices[selected_important_rows]
136 |
137 | selected_unimportant_rows = np.linspace(0, len(unimportant_row_indices)-1, n_unimportant_rows, dtype=int)
138 | selected_unimportant_rows = unimportant_row_indices[selected_unimportant_rows]
139 |
140 | # 选择列
141 | selected_important_cols = np.linspace(0, len(important_col_indices)-1, n_important_cols, dtype=int)
142 | selected_important_cols = important_col_indices[selected_important_cols]
143 |
144 | selected_unimportant_cols = np.linspace(0, len(unimportant_col_indices)-1, n_unimportant_cols, dtype=int)
145 | selected_unimportant_cols = unimportant_col_indices[selected_unimportant_cols]
146 |
147 | # 合并选择的行和列
148 | selected_rows = np.sort(np.concatenate([selected_important_rows, selected_unimportant_rows]))
149 | selected_cols = np.sort(np.concatenate([selected_important_cols, selected_unimportant_cols]))
150 |
151 | # 打印降采样信息
152 | print(f"原始大小: {organized_pc.shape[:2]}")
153 | print(f"重要行: {len(important_row_indices)} -> {len(selected_important_rows)} (1/{important_factor})")
154 | print(f"不重要行: {len(unimportant_row_indices)} -> {len(selected_unimportant_rows)}")
155 | print(f"重要列: {len(important_col_indices)} -> {len(selected_important_cols)} (1/{important_factor})")
156 | print(f"不重要列: {len(unimportant_col_indices)} -> {len(selected_unimportant_cols)}")
157 | print(f"最终大小: {len(selected_rows)}x{len(selected_cols)}")
158 | print(f"实际降采样因子: {(total_rows*total_cols)/(len(selected_rows)*len(selected_cols)):.2f}")
159 |
160 | return organized_pc[selected_rows][:, selected_cols]
161 |
162 | def get_matrix(self, image, angle):
163 | image = self.pillow_to_opencv(image)
164 | (h, w) = image.shape[:2]
165 | src_points = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
166 | dst_points = self.calculate_destination_points(0, w, 0, h, angle)
167 | M = cv2.getPerspectiveTransform(src_points, dst_points)
168 | return M
169 |
170 | def calculate_destination_points(self, left, right, top, bottom, angle):
171 | # 计算中心点
172 | center_x = (left + right) / 2
173 | center_y = (top + bottom) / 2
174 |
175 | # 计算角度的弧度值
176 | angle_rad = math.radians(angle)
177 |
178 | # 计算目标点
179 | dst_points = []
180 | for x, y in [(left, top), (right, top), (right, bottom), (left, bottom)]:
181 | new_x = center_x + (x - center_x) * math.cos(angle_rad) - (y - center_y) * math.sin(angle_rad)
182 | new_y = center_y + (x - center_x) * math.sin(angle_rad) + (y - center_y) * math.cos(angle_rad)
183 | dst_points.append([new_x, new_y])
184 |
185 | return np.float32(dst_points)
186 |
187 | def perspective_transform(self, image, matrix):
188 | """
189 | :param image_path: 输入图像路径
190 | :param angle: 旋转角度
191 | :param save_type: 保存的图片类型
192 | :return: 输出图像
193 | """
194 | # 读取图像
195 | image = self.pillow_to_opencv(image)
196 | (h, w) = image.shape[:2]
197 | # 执行透视变换
198 | transformed = cv2.warpPerspective(image, matrix, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=image[0, 0].tolist())
199 | transformed = self.opencv_to_pillow(transformed)
200 | return transformed
201 |
202 | def opencv_to_pillow(self, opencv_image):
203 | """
204 | 将 OpenCV 图像转换为 Pillow 图像。
205 |
206 | :param opencv_image: OpenCV 图像对象(BGR 格式)
207 | :return: Pillow 图像对象
208 | """
209 | # 检查是否为彩色图像并转换通道顺序
210 | if len(opencv_image.shape) == 3: # 彩色图像
211 | opencv_image_rgb = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
212 | return Image.fromarray(opencv_image_rgb)
213 | else: # 灰度图像
214 | return Image.fromarray(opencv_image)
215 |
216 | def pillow_to_opencv(self, pil_image):
217 | """
218 | 将 Pillow 图像转换为 OpenCV 图像。
219 |
220 | :param pil_image: Pillow 图像对象
221 | :return: OpenCV 图像对象(BGR 格式)
222 | """
223 | # 将 Pillow 图像转换为 numpy 数组
224 | opencv_image = np.array(pil_image)
225 |
226 | # 如果是 RGB 图像,转换为 BGR 格式
227 | if opencv_image.ndim == 3 and opencv_image.shape[2] == 3: # 检查是否为 RGB 图像
228 | opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
229 |
230 | return opencv_image
231 |
232 | class PreTrainTensorDataset(Dataset):
233 | def __init__(self, root_path):
234 | super().__init__()
235 | self.root_path = root_path
236 | self.tensor_paths = os.listdir(self.root_path)
237 |
238 |
239 | def __len__(self):
240 | return len(self.tensor_paths)
241 |
242 | def __getitem__(self, idx):
243 | tensor_path = self.tensor_paths[idx]
244 |
245 | tensor = torch.load(os.path.join(self.root_path, tensor_path))
246 |
247 | label = 0
248 |
249 | return tensor, label
250 |
251 | class TrainDataset(BaseAnomalyDetectionDataset):
252 | def __init__(self, class_name, img_size,downsampling,angle, small, dataset_path='datasets/eyecandies_preprocessed'):
253 | super().__init__(split="train", class_name=class_name, img_size=img_size,downsampling=downsampling, angle=angle, small=small, dataset_path=dataset_path)
254 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
255 |
256 | def load_dataset(self):
257 | img_tot_paths = []
258 | tot_labels = []
259 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb') + "/*.png")
260 | rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb', '*', '*L05*RGB*.jpg'))
261 | tiff_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'xyz') + "/*.tiff")
262 | # ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg")
263 | rgb_paths.sort()
264 | tiff_paths.sort()
265 |
266 | # print(len(rgb_paths))
267 | # ps_paths.sort()
268 | # sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths))
269 | sample_paths = list(zip(rgb_paths, tiff_paths))
270 | img_tot_paths.extend(sample_paths)
271 | tot_labels.extend([0] * len(sample_paths))
272 | # img_tot_paths = img_tot_paths[0:10]
273 | # tot_labels = tot_labels[0:10]
274 | return img_tot_paths, tot_labels
275 |
276 | def __len__(self):
277 | return len(self.img_paths)
278 |
279 | def __getitem__(self, idx):
280 | img_path, label = self.img_paths[idx], self.labels[idx]
281 | rgb_path = img_path[0]
282 | tiff_path = img_path[1]
283 | # ps_path = img_path[2]
284 | img = Image.open(rgb_path).convert('RGB')
285 | # add rotation
286 | # matrix = self.get_matrix(img, self.angle)
287 | # img = self.perspective_transform(img, matrix)
288 |
289 | img = self.rgb_transform(img)
290 | # ps= Image.open(ps_path).convert('RGB')
291 | # # add rotation
292 | # ps = self.perspective_transform(ps, matrix)
293 |
294 | # ps = self.rgb_transform(ps)
295 | if self.downsampling > 1:
296 | organized_pc = read_tiff_organized_pc(tiff_path)
297 | factor1 = int(math.floor(math.sqrt(self.downsampling)))
298 | factor2 = int(math.ceil(self.downsampling / factor1))
299 | organized_pc = organized_pc[::factor1, ::factor2]
300 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling)
301 | else:
302 | organized_pc = read_tiff_organized_pc(tiff_path)
303 |
304 | for x in range(organized_pc.shape[0]):
305 | for y in range(organized_pc.shape[1]):
306 | organized_pc[x, y, 0] = x
307 | organized_pc[x, y, 1] = y
308 |
309 | z_channel = organized_pc[:, :, 2]
310 | if np.isnan(z_channel).any():
311 | min_z = np.nanmin(z_channel) # 计算 g 通道的最小值,忽略 NaN
312 | z_channel[np.isnan(z_channel)] = min_z
313 | organized_pc[:, :, 2] = z_channel.astype(np.float32)
314 | print("train z_channel nan filled")
315 |
316 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
317 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
318 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
319 | resized_organized_pc = resized_organized_pc.clone().detach().float()
320 |
321 |
322 | return (img, resized_organized_pc, resized_depth_map_3channel), label
323 |
324 |
325 | class TestDataset(BaseAnomalyDetectionDataset):
326 | def __init__(self, class_name, img_size,downsampling,angle,small,dataset_path='datasets/eyecandies_preprocessed'):
327 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path)
328 | self.gt_transform = transforms.Compose([
329 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
330 | transforms.ToTensor()])
331 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
332 |
333 | def load_dataset(self):
334 | img_tot_paths = []
335 | gt_tot_paths = []
336 | tot_labels = []
337 | #ps_tot_paths = []
338 | defect_types = os.listdir(self.img_path)
339 | print(defect_types)
340 | # 如果types不为NONE,只保留GOOD和指定的types
341 | for defect_type in defect_types:
342 | if defect_type == 'GOOD':
343 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb') + "/*.png")
344 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
345 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
346 | #gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png")
347 | # ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg")
348 | rgb_paths.sort()
349 | tiff_paths.sort()
350 | # ps_paths.sort()
351 | # 只保留前5个样本
352 | if self.small:
353 | rgb_paths = rgb_paths[:5]
354 | tiff_paths = tiff_paths[:5]
355 | # ps_paths = ps_paths[:5]
356 | sample_paths = list(zip(rgb_paths, tiff_paths))
357 | img_tot_paths.extend(sample_paths)
358 | gt_tot_paths.extend([0] * len(sample_paths))
359 | tot_labels.extend([0] * len(sample_paths))
360 | else:
361 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
362 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
363 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
364 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt', 'rgb') + "/*.png")
365 | # ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.png")
366 | rgb_paths.sort()
367 | tiff_paths.sort()
368 | gt_paths.sort()
369 | # ps_paths.sort()
370 | # if self.small:
371 | # # 检查每个gt mask中缺陷占比
372 | # valid_indices = []
373 | # for i, gt_path in enumerate(gt_paths):
374 | # gt_mask = np.array(Image.open(gt_path))
375 | # total_pixels = gt_mask.shape[0] * gt_mask.shape[1]
376 | # threshold = int(total_pixels * 0.005) # 计算1%像素数量
377 | # defect_pixels = np.sum(gt_mask > 0)
378 | # if defect_pixels <= threshold: # 直接比较像素数量
379 | # valid_indices.append(i)
380 |
381 | # # 只保留缺陷占比<=1%的样本
382 | # rgb_paths = [rgb_paths[i] for i in valid_indices]
383 | # tiff_paths = [tiff_paths[i] for i in valid_indices]
384 | # gt_paths = [gt_paths[i] for i in valid_indices]
385 | # # ps_paths = [ps_paths[i] for i in valid_indices]
386 | sample_paths = list(zip(rgb_paths, tiff_paths))
387 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}")
388 | img_tot_paths.extend(sample_paths)
389 | gt_tot_paths.extend(gt_paths)
390 | tot_labels.extend([1] * len(sample_paths))
391 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
392 |
393 | return img_tot_paths, gt_tot_paths, tot_labels
394 |
395 | def __len__(self):
396 | return len(self.img_paths)
397 |
398 | def __getitem__(self, idx):
399 |
400 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5))
401 |
402 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
403 | rgb_path = img_path[0]
404 | tiff_path = img_path[1]
405 | # ps_path = img_path[2]
406 | img_original = Image.open(rgb_path).convert('RGB')
407 | # matrix = self.get_matrix(img_original, self.angle)
408 | # img_original = self.perspective_transform(img_original, matrix)
409 |
410 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB))
411 |
412 | img = self.rgb_transform(img_original)
413 | # ps_original = Image.open(ps_path).convert('RGB')
414 | # ps_original = self.perspective_transform(ps_original, matrix)
415 | # ps = self.rgb_transform(ps_original)
416 | # organized_pc = read_tiff_organized_pc(tiff_path)
417 | if self.downsampling > 1:
418 | organized_pc = read_tiff_organized_pc(tiff_path)
419 | factor1 = int(math.floor(math.sqrt(self.downsampling)))
420 | factor2 = int(math.ceil(self.downsampling / factor1))
421 | organized_pc = organized_pc[::factor1, ::factor2]
422 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling)
423 | else:
424 | organized_pc = read_tiff_organized_pc(tiff_path)
425 |
426 |
427 | for x in range(organized_pc.shape[0]):
428 | for y in range(organized_pc.shape[1]):
429 | organized_pc[x, y, 0] = x
430 | organized_pc[x, y, 1] = y
431 |
432 | z_channel = organized_pc[:, :, 2]
433 | if np.isnan(z_channel).any():
434 | min_z = np.nanmin(z_channel) # 计算 g 通道的最小值,忽略 NaN
435 | z_channel[np.isnan(z_channel)] = min_z
436 | organized_pc[:, :, 2] = z_channel.astype(np.float32)
437 | print("test z_channel nan filled")
438 |
439 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
440 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
441 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
442 | resized_organized_pc = resized_organized_pc.clone().detach().float()
443 |
444 |
445 |
446 |
447 | if gt == 0:
448 | gt = torch.zeros(
449 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])
450 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
451 | else:
452 | gt = Image.open(gt).convert('L')
453 | # gt = self.perspective_transform(gt, matrix)
454 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
455 | gt = self.gt_transform(gt)
456 | gt = torch.where(gt > 0.5, 1., .0)
457 |
458 | # fig.show()
459 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight')
460 |
461 | return (img, resized_depth_map_3channel), gt[:1], label, rgb_path
462 |
463 | class ValidDataset(BaseAnomalyDetectionDataset):
464 | def __init__(self, class_name, img_size,downsampling,angle,small,defect_name,dataset_path='datasets/eyecandies_preprocessed'):
465 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path)
466 | self.gt_transform = transforms.Compose([
467 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
468 | transforms.ToTensor()])
469 | self.defect_name = defect_name
470 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
471 |
472 | def load_dataset(self):
473 | img_tot_paths = []
474 | gt_tot_paths = []
475 | tot_labels = []
476 | #ps_tot_paths = []
477 | defect_types = os.listdir(self.img_path)
478 | # print(defect_types)
479 | # 如果types不为NONE,只保留GOOD和指定的types
480 | for defect_type in defect_types:
481 | # print(defect_type)
482 | if defect_type == 'GOOD':
483 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
484 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
485 | ps_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'ps') + "/*.jpg")
486 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
487 | # print(rgb_paths)
488 | rgb_paths.sort()
489 | tiff_paths.sort()
490 | ps_paths.sort()
491 | # 只保留前5个样本
492 | if self.small:
493 | rgb_paths = rgb_paths[:5]
494 | tiff_paths = tiff_paths[:5]
495 | ps_paths = ps_paths[:5]
496 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths))
497 | sample_paths = sample_paths[:5]
498 | img_tot_paths.extend(sample_paths)
499 | gt_tot_paths.extend([0] * len(sample_paths))
500 | tot_labels.extend([0] * len(sample_paths))
501 | elif defect_type in self.defect_name:
502 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
503 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
504 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt','rgb') + "/*.png")
505 | ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.jpg")
506 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg'))
507 | # print(rgb_paths)
508 | rgb_paths.sort()
509 | tiff_paths.sort()
510 | gt_paths.sort()
511 | ps_paths.sort()
512 | if self.small:
513 | # 检查每个gt mask中缺陷占比
514 | valid_indices = []
515 | for i, gt_path in enumerate(gt_paths):
516 | gt_mask = np.array(Image.open(gt_path))
517 | total_pixels = gt_mask.shape[0] * gt_mask.shape[1]
518 | threshold = int(total_pixels * 0.005) # 计算1%像素数量
519 | defect_pixels = np.sum(gt_mask > 0)
520 | if defect_pixels <= threshold: # 直接比较像素数量
521 | valid_indices.append(i)
522 |
523 | # 只保留缺陷占比<=1%的样本
524 | rgb_paths = [rgb_paths[i] for i in valid_indices]
525 | tiff_paths = [tiff_paths[i] for i in valid_indices]
526 | gt_paths = [gt_paths[i] for i in valid_indices]
527 | ps_paths = [ps_paths[i] for i in valid_indices]
528 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths))
529 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}, ps_paths: {len(ps_paths)}")
530 | img_tot_paths.extend(sample_paths)
531 | gt_tot_paths.extend(gt_paths)
532 | tot_labels.extend([1] * len(sample_paths))
533 |
534 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
535 |
536 | return img_tot_paths, gt_tot_paths, tot_labels
537 |
538 | def __len__(self):
539 | return len(self.img_paths)
540 |
541 | def __getitem__(self, idx):
542 |
543 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5))
544 |
545 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
546 | rgb_path = img_path[0]
547 | tiff_path = img_path[1]
548 | ps_path = img_path[2]
549 | img_original = Image.open(rgb_path).convert('RGB')
550 | matrix = self.get_matrix(img_original, self.angle)
551 | img_original = self.perspective_transform(img_original, matrix)
552 |
553 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB))
554 |
555 | img = self.rgb_transform(img_original)
556 | ps_original = Image.open(ps_path).convert('RGB')
557 | ps_original = self.perspective_transform(ps_original, matrix)
558 | ps = self.rgb_transform(ps_original)
559 | # organized_pc = read_tiff_organized_pc(tiff_path)
560 | if self.downsampling > 1:
561 | organized_pc = read_tiff_organized_pc(tiff_path)
562 | factor1 = int(math.floor(math.sqrt(self.downsampling)))
563 | factor2 = int(math.ceil(self.downsampling / factor1))
564 | organized_pc = organized_pc[::factor1, ::factor2]
565 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling)
566 | else:
567 | organized_pc = read_tiff_organized_pc(tiff_path)
568 |
569 | for x in range(organized_pc.shape[0]):
570 | for y in range(organized_pc.shape[1]):
571 | organized_pc[x, y, 0] = x
572 | organized_pc[x, y, 1] = y
573 |
574 | z_channel = organized_pc[:, :, 2]
575 | if np.isnan(z_channel).any():
576 | min_z = np.nanmin(z_channel) # 计算 g 通道的最小值,忽略 NaN
577 | z_channel[np.isnan(z_channel)] = min_z
578 | organized_pc[:, :, 2] = z_channel.astype(np.float32)
579 | print("val z_channel nan filled")
580 |
581 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
582 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
583 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
584 | resized_organized_pc = resized_organized_pc.clone().detach().float()
585 |
586 |
587 |
588 |
589 | if gt == 0:
590 | gt = torch.zeros(
591 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])
592 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
593 | else:
594 | gt = Image.open(gt).convert('L')
595 | gt = self.perspective_transform(gt, matrix)
596 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray')
597 | gt = self.gt_transform(gt)
598 | gt = torch.where(gt > 0.5, 1., .0)
599 |
600 | # fig.show()
601 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight')
602 |
603 | return (img, resized_depth_map_3channel, ps), gt[:1], label, rgb_path
604 |
605 | from torch.utils.data import Subset
606 | def redistribute_dataset(dataset, chunk_size=50):
607 | """
608 | 将数据集每50个分成一组,并重新分配包含GOOD的组
609 | :param dataset: 原始数据集
610 | :param chunk_size: 每组的大小
611 | :return: 重新分配后的数据集列表
612 | """
613 | total_size = len(dataset)
614 | num_chunks = (total_size + chunk_size - 1) // chunk_size # 向上取整
615 |
616 | # 初始分组
617 | chunks = []
618 | for i in range(num_chunks):
619 | start_idx = i * chunk_size
620 | end_idx = min(start_idx + chunk_size, total_size)
621 | chunk_indices = list(range(start_idx, end_idx))
622 | chunks.append(Subset(dataset, chunk_indices))
623 |
624 | # 找出包含GOOD的组
625 | good_chunks = []
626 | normal_chunks = []
627 |
628 | for i, chunk in enumerate(chunks):
629 | has_good = False
630 | # 检查这个chunk中是否包含GOOD
631 | for idx in chunk.indices:
632 | if 'GOOD' in dataset.img_paths[idx][0]:
633 | has_good = True
634 | break
635 |
636 | if has_good:
637 | good_chunks.append(i)
638 | else:
639 | normal_chunks.append(i)
640 |
641 | # 如果没有GOOD组或没有正常组,直接返回原始分组
642 | if not good_chunks or not normal_chunks:
643 | return chunks
644 |
645 | # 重新分配GOOD组的数据
646 | for good_chunk_idx in good_chunks:
647 | chunk = chunks[good_chunk_idx]
648 | good_indices = []
649 | normal_indices = []
650 |
651 | # 分离GOOD和非GOOD数据
652 | for idx in chunk.indices:
653 | if 'GOOD' in dataset.img_paths[idx][0]:
654 | good_indices.append(idx)
655 | else:
656 | normal_indices.append(idx)
657 |
658 | # 将GOOD数据平均分配给其他组
659 | num_good = len(good_indices)
660 | num_normal_chunks = len(normal_chunks)
661 |
662 | if num_normal_chunks > 0:
663 | # 计算每个正常组应该获得多少GOOD数据
664 | indices_per_chunk = num_good // num_normal_chunks
665 | remainder = num_good % num_normal_chunks
666 |
667 | # 分配GOOD数据
668 | current_good_idx = 0
669 | for i, normal_chunk_idx in enumerate(normal_chunks):
670 | extra = 1 if i < remainder else 0
671 | num_to_add = indices_per_chunk + extra
672 |
673 | # 添加GOOD数据到正常组
674 | chunks[normal_chunk_idx] = Subset(dataset,
675 | list(chunks[normal_chunk_idx].indices) +
676 | good_indices[current_good_idx:current_good_idx + num_to_add]
677 | )
678 | current_good_idx += num_to_add
679 |
680 | # 更新原始GOOD组,只保留非GOOD数据
681 | chunks[good_chunk_idx] = Subset(dataset, normal_indices)
682 |
683 | return chunks
684 | def get_data_loader(split, class_name, img_size, args, defect_name = None):
685 | if split in ['train']:
686 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small)
687 | elif split in ['test']:
688 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small)
689 | elif split in ['validation']:
690 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name)
691 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False,
692 | pin_memory=True)
693 | return data_loader
694 |
695 | def get_data_set(split, class_name, img_size, args, defect_name = None):
696 | if split in ['train']:
697 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle)
698 | elif split in ['test']:
699 | print('test')
700 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle)
701 | elif split in ['validation']:
702 | print('validation')
703 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name)
704 |
705 | return dataset
706 |
707 | # if __name__ == "__main__":
708 | # parser = argparse.ArgumentParser(description='Process some integers.')
709 | #
710 | # parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str,
711 | # choices=['DINO','Point_MAE','Fusion','DINO+Point_MAE','DINO+Point_MAE+Fusion','DINO+Point_MAE+add','DINO+FPFH','DINO+FPFH+Fusion',
712 | # 'DINO+FPFH+Fusion+ps','DINO+Point_MAE+Fusion+ps','DINO+Point_MAE+ps','DINO+FPFH+ps','ours','ours2','ours3','ours_final','ours_final1'
713 | # ,'ours_final1_VS'],
714 | # help='Anomaly detection modal name.')
715 | # parser.add_argument('--max_sample', default=400, type=int,
716 | # help='Max sample number.')
717 | # parser.add_argument('--memory_bank', default='multiple', type=str,
718 | # choices=["multiple", "single"],
719 | # help='memory bank mode: "multiple", "single".')
720 | # parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str,
721 | # choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'],
722 | # help='Timm checkpoints name of RGB backbone.')
723 | # parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert','FPFH'],
724 | # help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert, FPFH].')
725 | # parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str,
726 | # help='Checkpoints for fusion module.')
727 | # parser.add_argument('--save_feature', default=False, action='store_true',
728 | # help='Save feature for training fusion block.')
729 | # parser.add_argument('--use_uff', default=False, action='store_true',
730 | # help='Use UFF module.')
731 | # parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str,
732 | # help='Save feature for training fusion block.')
733 | # parser.add_argument('--save_preds', default=False, action='store_true',
734 | # help='Save predicts results.')
735 | # parser.add_argument('--group_size', default=128, type=int,
736 |
--------------------------------------------------------------------------------