├── .idea ├── .gitignore ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── main.iml ├── train.sh ├── utils ├── lr_sched.py ├── utils.py ├── mvtec3d_util.py ├── preprocessing1.py ├── preprocess_eyecandies.py ├── preprocessing.py ├── au_pro_util.py └── misc.py ├── requirement.txt ├── feature_extractors ├── change_ex.py └── features.py ├── engine_fusion_pretrain.py ├── .gitignore ├── README.md ├── models ├── pointnet2_utils.py ├── feature_fusion.py └── models.py ├── m3dm_runner1.py ├── fusion_pretrain.py ├── main.py ├── LICENSE ├── dataset2.py └── dataset.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python3 main.py \ 2 | --method_name DINO+Point_MAE+Fusion \ 3 | --use_uff \ 4 | --memory_bank multiple \ 5 | --rgb_backbone_name vit_base_patch8_224_dino \ 6 | --xyz_backbone_name Point_MAE \ 7 | --fusion_module_path checkpoints/checkpoint-0.pth 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/main.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /utils/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def adjust_learning_rate(optimizer, epoch, args): 4 | """Decay the learning rate with half-cycle cosine after warmup""" 5 | if epoch < args.warmup_epochs: 6 | lr = args.lr * epoch / args.warmup_epochs 7 | else: 8 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 9 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 10 | for param_group in optimizer.param_groups: 11 | if "lr_scale" in param_group: 12 | param_group["lr"] = lr * param_group["lr_scale"] 13 | else: 14 | param_group["lr"] = lr 15 | return lr 16 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from torchvision import transforms 5 | from PIL import ImageFilter 6 | 7 | def set_seeds(seed: int = 0) -> None: 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | torch.manual_seed(seed) 11 | 12 | class KNNGaussianBlur(torch.nn.Module): 13 | def __init__(self, radius : int = 4): 14 | super().__init__() 15 | self.radius = radius 16 | self.unload = transforms.ToPILImage() 17 | self.load = transforms.ToTensor() 18 | self.blur_kernel = ImageFilter.GaussianBlur(radius=4) 19 | 20 | def __call__(self, img): 21 | map_max = img.max() 22 | final_map = self.load(self.unload(img[0] / map_max).filter(self.blur_kernel)) * map_max 23 | return final_map 24 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | * python>=3.9 2 | * torch>=2.5.1 (preferably with CUDA support, e.g., `torch>=2.5.1+cu121`) 3 | * torchvision>=0.20.1 4 | * torchaudio>=2.5.1 5 | * numpy>=1.26.3 6 | * opencv-python>=4.11.0 7 | * Pillow>=11.0.0 (PIL fork) 8 | * matplotlib>=3.4.2 (for plotting) 9 | * scikit-learn>=1.6.1 10 | * scikit-image>=0.24.0 11 | * einops>=0.8.1 12 | * timm>=1.0.15 (PyTorch Image Models) 13 | 14 | ### Utilities & Configuration: 15 | * pyyaml>=5.4.1 (for YAML configuration files) 16 | * tqdm>=4.67.1 (for progress bars) 17 | * easydict>=1.13 (or addict>=2.4.0, for easy dict access) 18 | * pandas>=2.2.3 (for data manipulation, if applicable) 19 | * h5py>=3.13.0 (if using HDF5 files) 20 | 21 | ### Experiment Tracking & Model Hubs (if used): 22 | * wandb>=0.19.10 (Weights & Biases) 23 | * huggingface-hub>=0.29.3 24 | 25 | ### Potentially 3D-Specific (include if your project uses 3D data): 26 | * open3d>=0.19.0 27 | * pointnet2-ops>=3.0.0 28 | * chamferdist>=1.0.3 29 | -------------------------------------------------------------------------------- /utils/mvtec3d_util.py: -------------------------------------------------------------------------------- 1 | import tifffile as tiff 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def organized_pc_to_unorganized_pc(organized_pc): 7 | print(f"organized_pc shape: {organized_pc.shape}") 8 | return organized_pc.reshape(organized_pc.shape[0] * organized_pc.shape[1], organized_pc.shape[2]) 9 | 10 | 11 | def read_tiff_organized_pc(path): 12 | tiff_img = tiff.imread(path) 13 | # nan_count = np.isnan(tiff_img).sum() 14 | # nan_percentage = nan_count / tiff_img.size * 100 15 | # print(f"NaN percentage: {nan_percentage:.2f}%") 16 | # 计算均值,忽略 NaN 值 17 | # mean_min = np.nanmin(tiff_img) 18 | 19 | # # 将 NaN 值替换为均值 20 | # tiff_img[np.isnan(tiff_img)] = mean_min 21 | # 将NaN值替换为0 22 | # tiff_img[np.isnan(tiff_img)] = 0 23 | #tiff_img = np.resize(tiff_img, (800, 800)) 24 | # 获取第三通道非nan的最小值 25 | # min_z = np.nanmin(tiff_img[:, :, 2]) 26 | 27 | # # 如果第三通道最小值小于等于0 28 | # if min_z <= 0: 29 | # # 第三通道非nan的值减去最小值再加1 30 | # mask = ~np.isnan(tiff_img[:, :, 2]) 31 | # tiff_img[mask, 2] = tiff_img[mask, 2] - min_z + 1 32 | # # 将NaN值置0 33 | #tiff_img[np.isnan(tiff_img)] = 0 34 | return tiff_img 35 | 36 | 37 | def resize_organized_pc(organized_pc, target_height=224, target_width=224, tensor_out=True): 38 | torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous() 39 | torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=(target_height, target_width), 40 | mode='nearest') 41 | if tensor_out: 42 | return torch_resized_organized_pc.squeeze(dim=0).contiguous() 43 | else: 44 | return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy() 45 | 46 | 47 | def organized_pc_to_depth_map(organized_pc): 48 | return organized_pc[:, :, 2] 49 | -------------------------------------------------------------------------------- /feature_extractors/change_ex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ChannelExchange(nn.Module): 5 | def __init__(self, p=2): 6 | super().__init__() 7 | self.p = p # 1/p of the features will be exchanged. 8 | 9 | def forward(self, x0, x1): 10 | # x0, x1: the bi-temporal feature maps. 11 | N, C, H, W = x0.shape 12 | exchange_mask = torch.arange(C) % self.p == 0 13 | exchange_mask = exchange_mask.unsqueeze(0).expand((N, -1)) 14 | 15 | out_x0 = torch.zeros_like(x0) 16 | out_x1 = torch.zeros_like(x1) 17 | 18 | out_x0[~exchange_mask] = x0[~exchange_mask] 19 | out_x1[~exchange_mask] = x1[~exchange_mask] 20 | out_x0[exchange_mask] = x1[exchange_mask] 21 | out_x1[exchange_mask] = x0[exchange_mask] 22 | 23 | return out_x0, out_x1 24 | 25 | class SpatialExchange(nn.Module): 26 | def __init__(self, p=2): 27 | super().__init__() 28 | self.p = p # 1/p of the features will be exchanged. 29 | 30 | def forward(self, x0, x1): 31 | # x0, x1: the bi-temporal feature maps. 32 | N, C, H, W = x0.shape 33 | # Create a mask based on width dimension 34 | exchange_mask = torch.arange(W, device=x0.device) % self.p == 0 35 | # Expand mask to match feature dimensions 36 | exchange_mask = exchange_mask.view(1, 1, 1, W).expand(N, C, H, -1) 37 | 38 | out_x0 = x0.clone() 39 | out_x1 = x1.clone() 40 | 41 | # Perform column-wise exchange 42 | out_x0[..., exchange_mask] = x1[..., exchange_mask] 43 | out_x1[..., exchange_mask] = x0[..., exchange_mask] 44 | 45 | return out_x0, out_x1 46 | 47 | class CombinedExchange(nn.Module): 48 | def __init__(self, p=2): 49 | super().__init__() 50 | self.channel_exchange = ChannelExchange(p=p) 51 | self.spatial_exchange = SpatialExchange(p=p) 52 | 53 | def forward(self, x0, x1): 54 | # First perform channel exchange 55 | x0, x1 = self.channel_exchange(x0, x1) 56 | # Then perform spatial exchange 57 | x0, x1 = self.spatial_exchange(x0, x1) 58 | return x0, x1 59 | -------------------------------------------------------------------------------- /engine_fusion_pretrain.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | 5 | import torch 6 | 7 | import utils.misc as misc 8 | import utils.lr_sched as lr_sched 9 | 10 | def check_max_gradient(model): 11 | max_grad = 0.0 12 | max_grad_param_name = "" 13 | 14 | # 遍历模型所有参数,找到具有最大梯度的参数 15 | for name, param in model.named_parameters(): 16 | if param.grad is not None: 17 | grad_norm = param.grad.norm(2).item() 18 | if grad_norm > max_grad: 19 | max_grad = grad_norm 20 | max_grad_param_name = name 21 | 22 | print(f"Maximum gradient norm: {max_grad} (Parameter: {max_grad_param_name})") 23 | 24 | # 可以根据 max_grad 值判断是否有梯度爆炸的迹象 25 | if max_grad > 20000: # 可根据任务的实际情况设置阈值 26 | print("Warning: Potential gradient explosion detected!") 27 | 28 | 29 | 30 | def train_one_epoch(model: torch.nn.Module, 31 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 32 | device: torch.device, epoch: int, loss_scaler, 33 | log_writer=None, 34 | args=None): 35 | model.train(True) 36 | metric_logger = misc.MetricLogger(delimiter=" ") 37 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 38 | header = 'Epoch: [{}]'.format(epoch) 39 | print_freq = 20 40 | 41 | accum_iter = args.accum_iter 42 | 43 | optimizer.zero_grad() 44 | 45 | if log_writer is not None: 46 | print('log_dir: {}'.format(log_writer.log_dir)) 47 | 48 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 49 | 50 | # we use a per iteration (instead of per epoch) lr scheduler 51 | if data_iter_step % accum_iter == 0: 52 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 53 | 54 | 55 | xyz_samples = samples[:,:,:1152].to(device, non_blocking=True) 56 | rgb_samples = samples[:,:,1152:].to(device, non_blocking=True) 57 | 58 | with torch.cuda.amp.autocast(): 59 | # print('model device:', model.device) 60 | # print('data device:', xyz_samples.device) 61 | loss = model(xyz_samples, rgb_samples) 62 | 63 | check_max_gradient(model) 64 | loss_value = loss.item() 65 | 66 | if not math.isfinite(loss_value): 67 | print("Loss is {}, stopping training".format(loss_value)) 68 | sys.exit(1) 69 | 70 | loss /= accum_iter 71 | loss_scaler(loss, optimizer, parameters=model.parameters(), 72 | update_grad=(data_iter_step + 1) % accum_iter == 0) 73 | if (data_iter_step + 1) % accum_iter == 0: 74 | optimizer.zero_grad() 75 | 76 | torch.cuda.synchronize() 77 | 78 | metric_logger.update(loss=loss_value) 79 | 80 | lr = optimizer.param_groups[0]["lr"] 81 | metric_logger.update(lr=lr) 82 | 83 | 84 | loss_value_reduce = misc.all_reduce_mean(loss_value) 85 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 86 | """ We use epoch_1000x as the x-axis in tensorboard. 87 | This calibrates different curves when batch size changes. 88 | """ 89 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 90 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 91 | log_writer.add_scalar('lr', lr, epoch_1000x) 92 | 93 | 94 | # gather the stats from all processes 95 | metric_logger.synchronize_between_processes() 96 | print("Averaged stats:", metric_logger) 97 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | D³M is a comprehensive framework for industrial anomaly detection, localization, and classification, leveraging a variety of deep learning-based feature extractors for 2D (RGB) and 3D (Point Cloud) data, along with flexible fusion strategies. This repository provides an implementation of the D3M framework and tools for evaluating its performance on standard and custom datasets. 2 | 3 | The project is designed to explore and benchmark different uni-modal and multi-modal approaches for anomaly detection in industrial settings, inspired by the challenges of real-world defect identification. 4 | 5 | ## 📄 Citation 6 | 7 | If you find this work useful in your research, please consider citing our associated paper: 8 | 9 | ```bibtex 10 | @article{zhu2025real, 11 | title={Real-IAD D3: A Real-World 2D/Pseudo-3D/3D Dataset for Industrial Anomaly Detection}, 12 | author={Zhu, Wenbing and Wang, Lidong and Zhou, Ziqing and Wang, Chengjie and Pan, Yurui and Zhang, Ruoyi and Chen, Zhuhao and Cheng, Linjie and Gao, Bin-Bin and Zhang, Jiangning and others}, 13 | journal={arXiv preprint arXiv:2504.14221}, 14 | year={2025} 15 | } 16 | ``` 17 | 18 | --- 19 | 20 | ## ✨ Features 21 | 22 | * **Versatile Anomaly Detection Framework:** Supports a wide range of feature extraction and fusion methodologies. 23 | * **Multi-Modal Support:** Handles both 2D RGB images and 3D point cloud data. 24 | * **RGB Feature Extractors:** Utilizes pre-trained models like DINO (e.g., `vit_base_patch8_224_dino`). 25 | * **Point Cloud Feature Extractors:** Incorporates models like PointMAE, PointBert, and traditional FPFH. 26 | * **Flexible Fusion Strategies:** Implements various early and late fusion techniques for combining multi-modal features (e.g., simple addition, dedicated fusion modules). 27 | * **Multiple Method Configurations:** Easily experiment with different combinations of backbones and fusion approaches via command-line arguments (e.g., `DINO`, `Point_MAE`, `DINO+Point_MAE`, `DINO+Point_MAE+Fusion`, and custom "ours" variants). 28 | * **Dataset Compatibility:** 29 | * Works with standard 3D anomaly detection datasets like MVTec 3D-AD and Eyecandies. 30 | * Easily adaptable to custom datasets (`test_3d` option). 31 | * **Comprehensive Evaluation:** 32 | * Calculates image-level ROCAUC. 33 | * Calculates pixel/point-level ROCAUC for localization. 34 | * Calculates AU-PRO scores. 35 | * **Memory Bank & Coreset Subsampling:** Implements memory bank concepts and coreset subsampling for efficient training and inference, particularly with large feature sets. 36 | * **Extensible:** Designed to be easily extended with new feature extractors, fusion modules, or datasets. 37 | * **Result Visualization:** Option to save prediction maps for qualitative analysis. 38 | 39 | --- 40 | 41 | ## 🚀 Getting Started 42 | 43 | ### Prerequisites 44 | 45 | * Python >= 3.9 46 | * PyTorch >= 2.5.1 (preferably with CUDA support for GPU acceleration) 47 | * Other dependencies as listed in `requirements.txt` 48 | 49 | ### Installation 50 | 51 | 1. **Clone the repository:** 52 | ```bash 53 | git clone [https://github.com/your-username/d3m.git](https://github.com/your-username/d3m.git) 54 | cd d3m 55 | ``` 56 | 57 | 2. **Create a virtual environment (recommended):** 58 | ```bash 59 | python -m venv venv 60 | source venv/bin/activate # On Windows use `venv\Scripts\activate` 61 | ``` 62 | Or using Conda: 63 | ```bash 64 | conda create -n d3m python=3.9 65 | conda activate d3m 66 | ``` 67 | 68 | 3. **Install dependencies:** 69 | A `requirements.txt` file is recommended. Create one with the following content or adapt it to your precise needs: 70 | ```text 71 | # requirements.txt 72 | 73 | # Core Deep Learning & Computer Vision 74 | torch>=2.5.1 # Install with specific CUDA version if needed, e.g., torch>=2.5.1+cu121 75 | torchvision>=0.20.1 76 | torchaudio>=2.5.1 77 | numpy>=1.26.3 78 | opencv-python>=4.11.0 79 | Pillow>=11.0.0 80 | matplotlib>=3.4.2 81 | scikit-learn>=1.6.1 82 | scikit-image>=0.24.0 83 | einops>=0.8.1 84 | timm>=1.0.15 85 | 86 | # Utilities & Configuration 87 | pyyaml>=5.4.1 88 | tqdm>=4.67.1 89 | easydict>=1.13 # or addict>=2.4.0 90 | pandas>=2.2.3 91 | h5py>=3.13.0 92 | 93 | # Experiment Tracking & Model Hubs (Optional, uncomment if used) 94 | # wandb>=0.19.10 95 | # huggingface-hub>=0.29.3 96 | 97 | # Potentially 3D-Specific (Uncomment if your project uses these) 98 | # open3d>=0.19.0 99 | # pointnet2-ops>=3.0.0 # May require custom compilation 100 | # chamferdist>=1.0.3 101 | ``` 102 | Then install using: 103 | ```bash 104 | pip install -r requirements.txt 105 | ``` 106 | **Note on PyTorch:** Ensure you install the PyTorch version compatible with your CUDA toolkit. Refer to the [official PyTorch website](https://pytorch.org/) for installation instructions. 107 | 108 | --- 109 | 110 | ## 💻 Usage 111 | 112 | The main script for running experiments is likely named `your_main_script_name.py` (the one containing the `run_3d_ads` function and `argparse` definitions). You can configure experiments using command-line arguments. 113 | 114 | ### Example Command 115 | 116 | ```bash 117 | python your_main_script_name.py \ 118 | --dataset_type mvtec3d \ 119 | --dataset_path /path/to/your/mvtec3d_anomaly_detection \ 120 | --method_name DINO+Point_MAE+Fusion \ 121 | --rgb_backbone_name vit_base_patch8_224_dino \ 122 | --xyz_backbone_name Point_MAE \ 123 | --fusion_module_path /path/to/your/fusion_checkpoint.pth \ 124 | --img_size 224 \ 125 | --max_sample 400 \ 126 | --coreset_eps 0.9 \ 127 | --save_preds 128 | # Add other arguments as needed 129 | ``` 130 | Refer to the `argparse` section in the main script for a full list of available arguments and their descriptions. 131 | 132 | --- 133 | 134 | ## 📊 Output 135 | 136 | The script will print tables summarizing the performance metrics for each class and method: 137 | 138 | * **Image ROCAUC Results** 139 | * **Pixel ROCAUC Results** 140 | * **AU-PRO Results** 141 | 142 | If `--save_preds` is enabled, anomaly maps will be saved to the specified directory (default or `./pred_maps`). 143 | 144 | --- 145 | 146 | ## 🙏 Acknowledgements 147 | 148 | * This framework builds upon concepts from various SOTA anomaly detection and feature extraction literature. 149 | * Utilizes awesome libraries like PyTorch, TIMM, Open3D, etc. 150 | 151 | --- 152 | 153 | ## 📞 Contact 154 | 155 | For questions, issues, or suggestions, please open an issue in this repository. 156 | 157 | ``` 158 | -------------------------------------------------------------------------------- /models/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | def square_distance(src, dst): 20 | """ 21 | Calculate Euclid distance between each two points. 22 | src^T * dst = xn * xm + yn * ym + zn * zm; 23 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 24 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 25 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 26 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 27 | Input: 28 | src: source points, [B, N, C] 29 | dst: target points, [B, M, C] 30 | Output: 31 | dist: per-point square distance, [B, N, M] 32 | """ 33 | B, N, _ = src.shape 34 | _, M, _ = dst.shape 35 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 36 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 37 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 38 | return dist 39 | 40 | 41 | def index_points(points, idx): 42 | """ 43 | Input: 44 | points: input points data, [B, N, C] 45 | idx: sample index data, [B, S] 46 | Return: 47 | new_points:, indexed points data, [B, S, C] 48 | """ 49 | device = points.device 50 | B = points.shape[0] 51 | view_shape = list(idx.shape) 52 | view_shape[1:] = [1] * (len(view_shape) - 1) 53 | repeat_shape = list(idx.shape) 54 | repeat_shape[0] = 1 55 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 56 | # print("points device:", points.device) 57 | # print("idx device:", idx.device) 58 | points = points.to(idx.device) 59 | new_points = points[batch_indices, idx, :] 60 | return new_points 61 | 62 | 63 | def farthest_point_sample(xyz, npoint): 64 | """ 65 | Input: 66 | xyz: pointcloud data, [B, N, 3] 67 | npoint: number of samples 68 | Return: 69 | centroids: sampled pointcloud index, [B, npoint] 70 | """ 71 | device = xyz.device 72 | B, N, C = xyz.shape 73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 74 | distance = torch.ones(B, N).to(device) * 1e10 75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 76 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 77 | for i in range(npoint): 78 | centroids[:, i] = farthest 79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 80 | dist = torch.sum((xyz - centroid) ** 2, -1) 81 | mask = dist < distance 82 | distance[mask] = dist[mask] 83 | farthest = torch.max(distance, -1)[1] 84 | return centroids 85 | 86 | 87 | def query_ball_point(radius, nsample, xyz, new_xyz): 88 | """ 89 | Input: 90 | radius: local region radius 91 | nsample: max sample number in local region 92 | xyz: all points, [B, N, 3] 93 | new_xyz: query points, [B, S, 3] 94 | Return: 95 | group_idx: grouped points index, [B, S, nsample] 96 | """ 97 | device = xyz.device 98 | B, N, C = xyz.shape 99 | _, S, _ = new_xyz.shape 100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 101 | sqrdists = square_distance(new_xyz, xyz) 102 | group_idx[sqrdists > radius ** 2] = N 103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 105 | mask = group_idx == N 106 | group_idx[mask] = group_first[mask] 107 | return group_idx 108 | 109 | 110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 111 | """ 112 | Input: 113 | npoint: 114 | radius: 115 | nsample: 116 | xyz: input points position data, [B, N, 3] 117 | points: input points data, [B, N, D] 118 | Return: 119 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 120 | new_points: sampled points data, [B, npoint, nsample, 3+D] 121 | """ 122 | B, N, C = xyz.shape 123 | S = npoint 124 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 125 | new_xyz = index_points(xyz, fps_idx) 126 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 127 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 128 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 129 | 130 | if points is not None: 131 | grouped_points = index_points(points, idx) 132 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 133 | else: 134 | new_points = grouped_xyz_norm 135 | if returnfps: 136 | return new_xyz, new_points, grouped_xyz, fps_idx 137 | else: 138 | return new_xyz, new_points 139 | 140 | 141 | def sample_and_group_all(xyz, points): 142 | """ 143 | Input: 144 | xyz: input points position data, [B, N, 3] 145 | points: input points data, [B, N, D] 146 | Return: 147 | new_xyz: sampled points position data, [B, 1, 3] 148 | new_points: sampled points data, [B, 1, N, 3+D] 149 | """ 150 | device = xyz.device 151 | B, N, C = xyz.shape 152 | new_xyz = torch.zeros(B, 1, C).to(device) 153 | grouped_xyz = xyz.view(B, 1, N, C) 154 | if points is not None: 155 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 156 | else: 157 | new_points = grouped_xyz 158 | return new_xyz, new_points 159 | 160 | def interpolating_points(xyz1, xyz2, points2): 161 | """ 162 | Input: 163 | xyz1: input points position data, [B, C, N] 164 | xyz2: sampled input points position data, [B, C, S] 165 | points2: input points data, [B, D, S] 166 | Return: 167 | new_points: upsampled points data, [B, D', N] 168 | """ 169 | xyz1 = xyz1.permute(0, 2, 1) 170 | xyz2 = xyz2.permute(0, 2, 1) 171 | 172 | points2 = points2.permute(0, 2, 1) 173 | B, N, C = xyz1.shape 174 | _, S, _ = xyz2.shape 175 | 176 | if S == 1: 177 | interpolated_points = points2.repeat(1, N, 1) 178 | else: 179 | dists = square_distance(xyz1, xyz2) 180 | dists, idx = dists.sort(dim=-1) 181 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 182 | 183 | dist_recip = 1.0 / (dists + 1e-8) 184 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 185 | weight = dist_recip / norm 186 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 187 | 188 | interpolated_points = interpolated_points.permute(0, 2, 1) 189 | return interpolated_points -------------------------------------------------------------------------------- /m3dm_runner1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | import numpy as np 5 | 6 | from feature_extractors import multiple_features_clean as multiple_features 7 | 8 | from dataset import get_data_loader 9 | 10 | class M3DM(): 11 | def __init__(self, args): 12 | self.args = args 13 | self.image_size = args.img_size 14 | self.count = args.max_sample 15 | if args.method_name == 'DINO': 16 | self.methods = { 17 | "DINO": multiple_features.RGBFeatures(args), 18 | } 19 | elif args.method_name == 'Point_MAE': 20 | self.methods = { 21 | "Point_MAE": multiple_features.PointFeatures(args), 22 | } 23 | elif args.method_name == 'Fusion': 24 | self.methods = { 25 | "Fusion": multiple_features.FusionFeatures(args), 26 | } 27 | elif args.method_name == 'DINO+Point_MAE': 28 | self.methods = { 29 | "DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures(args), 30 | } 31 | elif args.method_name == 'DINO+Point_MAE+add': 32 | self.methods = { 33 | "DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures_add(args), 34 | } 35 | elif args.method_name == 'DINO+Point_MAE+Fusion': 36 | self.methods = { 37 | "DINO+Point_MAE+Fusion": multiple_features.TripleFeatures(args), 38 | } 39 | elif args.method_name == 'DINO+FPFH': 40 | self.methods = { 41 | "DINO+FPFH": multiple_features.DoubleRGBPointFeatures(args), 42 | } 43 | elif args.method_name == 'DINO+FPFH+Fusion': 44 | self.methods = { 45 | "DINO+FPFH+Fusion": multiple_features.TripleFeatures(args), 46 | } 47 | elif args.method_name == 'DINO+FPFH+Fusion+ps': 48 | self.methods = { 49 | "DINO+FPFH+Fusion+ps": multiple_features.TripleFeatures_PS(args), 50 | } 51 | elif args.method_name == 'DINO+Point_MAE+Fusion+ps': 52 | self.methods = { 53 | "DINO+Point_MAE+Fusion+ps": multiple_features.TripleFeatures_PS(args), 54 | } 55 | elif args.method_name == 'DINO+Point_MAE+ps': 56 | self.methods = { 57 | "DINO+Point_MAE+ps": multiple_features.DoubleRGB_PS_Features(args), 58 | } 59 | elif args.method_name == 'DINO+FPFH+ps': 60 | self.methods = { 61 | "DINO+FPFH+ps": multiple_features.DoubleRGB_PS_Features(args), 62 | } 63 | elif args.method_name == 'ours': 64 | self.methods = { 65 | "ours": multiple_features.PSRGBPointFeatures_add(args), 66 | } 67 | elif args.method_name == 'ours2': 68 | self.methods = { 69 | "ours2": multiple_features.TripleFeatures_PS2(args), 70 | } 71 | elif args.method_name == 'ours3': 72 | self.methods = { 73 | "ours3": multiple_features.FourFeatures(args), 74 | } 75 | elif args.method_name == 'ours_final': 76 | self.methods = { 77 | "ours_final": multiple_features.TripleFeatures_PS_EX(args), 78 | } 79 | elif args.method_name == 'ours_final1': 80 | self.methods = { 81 | "ours_final1": multiple_features.PSRGBPointFeatures_add_EX(args), 82 | } 83 | elif args.method_name == 'm3dm_uninterpolate': 84 | self.methods = { 85 | "m3dm_uninterpolate": multiple_features.DoubleRGBPointFeatures_uninter_full(args), 86 | } 87 | 88 | def fit(self, class_name): 89 | train_loader = get_data_loader("train", class_name=class_name, img_size=self.image_size, args=self.args) 90 | 91 | flag = 0 92 | for sample, _ in tqdm(train_loader, desc=f'Extracting train features for class {class_name}'): 93 | for method in self.methods.values(): 94 | if self.args.save_feature: 95 | method.add_sample_to_mem_bank(sample, class_name=class_name) 96 | else: 97 | method.add_sample_to_mem_bank(sample) 98 | flag += 1 99 | if flag > self.count: 100 | flag = 0 101 | break 102 | 103 | for method_name, method in self.methods.items(): 104 | print(f'\n\nRunning coreset for {method_name} on class {class_name}...') 105 | method.run_coreset() 106 | 107 | 108 | if self.args.memory_bank == 'multiple': 109 | flag = 0 110 | for sample, _ in tqdm(train_loader, desc=f'Running late fusion for {method_name} on class {class_name}..'): 111 | for method_name, method in self.methods.items(): 112 | method.add_sample_to_late_fusion_mem_bank(sample) 113 | flag += 1 114 | if flag > self.count: 115 | flag = 0 116 | break 117 | 118 | for method_name, method in self.methods.items(): 119 | print(f'\n\nTraining Dicision Layer Fusion for {method_name} on class {class_name}...') 120 | method.run_late_fusion() 121 | 122 | def evaluate(self, class_name): 123 | image_rocaucs = dict() 124 | pixel_rocaucs = dict() 125 | au_pros = dict() 126 | valid_dir = os.path.join(self.args.dataset_path, class_name, 'validation') 127 | defect_names = os.listdir(valid_dir) 128 | # print(defect_names) 129 | path_list = [] 130 | for defect_name in defect_names: 131 | if defect_name == 'GOOD': 132 | continue 133 | test_loader = get_data_loader("validation", class_name=class_name, img_size=self.image_size, args=self.args, defect_name=defect_name) 134 | with torch.no_grad(): 135 | print(class_name, defect_name) 136 | print(f'len of loader:{len(test_loader)}') 137 | for sample, mask, label, rgb_path in tqdm(test_loader, desc=f'Extracting test features for class {class_name}'): 138 | for method in self.methods.values(): 139 | method.predict(sample, mask, label) 140 | path_list.append(rgb_path) 141 | 142 | 143 | for method_name, method in self.methods.items(): 144 | method.calculate_metrics() 145 | image_rocauc = method.image_rocauc 146 | pixel_rocauc = method.pixel_rocauc 147 | au_pro = method.au_pro 148 | 149 | print(f"Debug - Raw values: Image ROCAUC: {image_rocauc}, Pixel ROCAUC: {pixel_rocauc}, AU-PRO: {au_pro}") 150 | 151 | if np.isnan(image_rocauc) or np.isnan(pixel_rocauc) or np.isnan(au_pro): 152 | print(f"Warning: NaN detected for {method_name}") 153 | # 可以在这里添加更多的调试信息 154 | # 可以在这里添加更多的调试信息 155 | image_rocaucs[f'{method_name}_{defect_name}'] = round(method.image_rocauc, 3) 156 | pixel_rocaucs[f'{method_name}_{defect_name}'] = round(method.pixel_rocauc, 3) 157 | au_pros[f'{method_name}_{defect_name}'] = round(method.au_pro, 3) 158 | print( 159 | f'Debug - Class: {class_name}, {method_name}, defect_name:{defect_name} Image ROCAUC: {method.image_rocauc:.3f}, {method_name} Pixel ROCAUC: {method.pixel_rocauc:.3f}, {method_name} AU-PRO: {method.au_pro:.3f}') 160 | if self.args.save_preds: 161 | method.save_prediction_maps('./pred_maps', path_list) 162 | return image_rocaucs, pixel_rocaucs, au_pros 163 | -------------------------------------------------------------------------------- /utils/preprocessing1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tifffile as tiff 4 | import open3d as o3d 5 | from pathlib import Path 6 | from PIL import Image 7 | import math 8 | import mvtec3d_util as mvt_util 9 | import argparse 10 | 11 | 12 | def get_edges_of_pc(organized_pc): 13 | unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2]) 14 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0) 15 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0) 16 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0) 17 | unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:] 18 | return unorganized_edges_pc 19 | 20 | def get_plane_eq(unorganized_pc,ransac_n_pts=50): 21 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc)) 22 | plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000) 23 | return plane_model 24 | 25 | def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005): 26 | # PREP PC 27 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean) 28 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) 29 | clean_planeless_unorganized_pc = unorganized_pc.copy() 30 | planeless_unorganized_rgb = unorganized_rgb.copy() 31 | 32 | # REMOVE PLANE 33 | plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean)) 34 | distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T)) 35 | plane_indices = np.argwhere(distances < distance_threshold) 36 | 37 | planeless_unorganized_rgb[plane_indices] = 0 38 | clean_planeless_unorganized_pc[plane_indices] = 0 39 | clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0], 40 | organized_pc_clean.shape[1], 41 | organized_pc_clean.shape[2]) 42 | planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0], 43 | organized_rgb.shape[1], 44 | organized_rgb.shape[2]) 45 | return clean_planeless_organized_pc, planeless_organized_rgb 46 | 47 | 48 | 49 | def connected_components_cleaning(organized_pc, organized_rgb, image_path): 50 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc) 51 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) 52 | 53 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] 54 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :] 55 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros)) 56 | labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False)) 57 | 58 | 59 | unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True) 60 | max_label = labels.max() 61 | if max_label>0: 62 | print("##########################################################################") 63 | print(f"Point cloud file {image_path} has {max_label + 1} clusters") 64 | print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}") 65 | print("##########################################################################\n\n") 66 | 67 | largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)] 68 | outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id) 69 | outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array] 70 | unorganized_pc[outlier_indices_original_pc_array] = 0 71 | unorganized_rgb[outlier_indices_original_pc_array] = 0 72 | organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0], 73 | organized_pc.shape[1], 74 | organized_pc.shape[2]) 75 | organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0], 76 | organized_rgb.shape[1], 77 | organized_rgb.shape[2]) 78 | return organized_clustered_pc, organized_clustered_rgb 79 | 80 | def roundup_next_100(x): 81 | return int(math.ceil(x / 100.0)) * 100 82 | 83 | def pad_cropped_pc(cropped_pc, single_channel=False): 84 | orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1] 85 | round_orig_h = roundup_next_100(orig_h) 86 | round_orig_w = roundup_next_100(orig_w) 87 | large_side = max(round_orig_h, round_orig_w) 88 | 89 | a = (large_side - orig_h) // 2 90 | aa = large_side - a - orig_h 91 | 92 | b = (large_side - orig_w) // 2 93 | bb = large_side - b - orig_w 94 | if single_channel: 95 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant') 96 | else: 97 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant') 98 | 99 | def preprocess_pc(tiff_path): 100 | # READ FILES 101 | organized_pc = mvt_util.read_tiff_organized_pc(tiff_path) 102 | rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "png") 103 | gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png") 104 | organized_rgb = np.array(Image.open(rgb_path)) 105 | 106 | organized_gt = None 107 | gt_exists = os.path.isfile(gt_path) 108 | if gt_exists: 109 | organized_gt = np.array(Image.open(gt_path)) 110 | 111 | # REMOVE PLANE 112 | planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb) 113 | 114 | 115 | # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE) 116 | padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False) 117 | padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False) 118 | if gt_exists: 119 | padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True) 120 | 121 | organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path) 122 | # SAVE PREPROCESSED FILES 123 | tiff.imsave(tiff_path, organized_clustered_pc) 124 | Image.fromarray(organized_clustered_rgb).save(rgb_path) 125 | if gt_exists: 126 | Image.fromarray(padded_organized_gt).save(gt_path) 127 | 128 | 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD') 132 | parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)') 133 | args = parser.parse_args() 134 | 135 | 136 | root_path = args.dataset_path 137 | paths = Path(root_path).rglob('*.tiff') 138 | print(f"Found {len(list(paths))} tiff files in {root_path}") 139 | processed_files = 0 140 | for path in Path(root_path).rglob('*.tiff'): 141 | preprocess_pc(path) 142 | processed_files += 1 143 | if processed_files % 50 == 0: 144 | print(f"Processed {processed_files} tiff files...") 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /fusion_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torchvision.transforms as transforms 13 | 14 | import timm 15 | 16 | import timm.optim.optim_factory as optim_factory 17 | 18 | import utils.misc as misc 19 | from utils.misc import NativeScalerWithGradNormCount as NativeScaler 20 | 21 | 22 | from engine_fusion_pretrain import train_one_epoch 23 | 24 | import dataset 25 | 26 | import torch 27 | from models.feature_fusion import FeatureFusionBlock 28 | 29 | 30 | 31 | 32 | def get_args_parser(): 33 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 34 | parser.add_argument('--batch_size', default=64, type=int, 35 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 36 | parser.add_argument('--epochs', default=3, type=int) 37 | parser.add_argument('--accum_iter', default=1, type=int, 38 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 39 | 40 | # Model parameters 41 | 42 | parser.add_argument('--input_size', default=224, type=int, 43 | help='images input size') 44 | 45 | 46 | # Optimizer parameters 47 | parser.add_argument('--clip_grad', type=float, default=None, 48 | help='gradient clipping norm (default: None)') 49 | 50 | parser.add_argument('--weight_decay', type=float, default=1.5e-6, 51 | help='weight decay (default: 0.05)') 52 | 53 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 54 | help='learning rate (absolute lr)') 55 | parser.add_argument('--blr', type=float, default=0.002, metavar='LR', 56 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 57 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 58 | help='lower lr bound for cyclic schedulers that hit 0') 59 | 60 | parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N', 61 | help='epochs to warmup LR') 62 | 63 | # Dataset parameters 64 | parser.add_argument('--data_path', default='', type=str, 65 | help='dataset path') 66 | 67 | parser.add_argument('--output_dir', default='./output_dir', 68 | help='path where to save, empty for no saving') 69 | parser.add_argument('--log_dir', default='./output_dir', 70 | help='path where to tensorboard log') 71 | parser.add_argument('--device', default='cuda', 72 | help='device to use for training / testing') 73 | parser.add_argument('--seed', default=0, type=int) 74 | parser.add_argument('--resume', default='', 75 | help='resume from checkpoint') 76 | 77 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 78 | help='start epoch') 79 | parser.add_argument('--num_workers', default=10, type=int) 80 | parser.add_argument('--pin_mem', action='store_true', 81 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 82 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 83 | parser.set_defaults(pin_mem=True) 84 | 85 | # distributed training parameters 86 | parser.add_argument('--world_size', default=1, type=int, 87 | help='number of distributed processes') 88 | parser.add_argument('--local_rank', default=-1, type=int) 89 | parser.add_argument('--dist_on_itp', action='store_true') 90 | parser.add_argument('--dist_url', default='env://', 91 | help='url used to set up distributed training') 92 | 93 | return parser 94 | 95 | 96 | 97 | 98 | def main(args): 99 | misc.init_distributed_mode(args) 100 | # args.gpu = 1 101 | # args.device = 'cuda:' + str(args.gpu) 102 | 103 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 104 | print("{}".format(args).replace(', ', ',\n')) 105 | 106 | device = torch.device(args.device) 107 | 108 | # fix the seed for reproducibility 109 | seed = args.seed + misc.get_rank() 110 | torch.manual_seed(seed) 111 | np.random.seed(seed) 112 | 113 | cudnn.benchmark = True 114 | 115 | 116 | dataset_train = dataset.PreTrainTensorDataset(args.data_path) 117 | 118 | print(dataset_train) 119 | 120 | 121 | if True: # args.distributed: 122 | num_tasks = misc.get_world_size() 123 | global_rank = misc.get_rank() 124 | sampler_train = torch.utils.data.DistributedSampler( 125 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 126 | ) 127 | print("Sampler_train = %s" % str(sampler_train)) 128 | else: 129 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 130 | 131 | if global_rank == 0 and args.log_dir is not None: 132 | os.makedirs(args.log_dir, exist_ok=True) 133 | log_writer = SummaryWriter(log_dir=args.log_dir) 134 | else: 135 | log_writer = None 136 | 137 | data_loader_train = torch.utils.data.DataLoader( 138 | dataset_train, sampler=sampler_train, 139 | batch_size=args.batch_size, 140 | num_workers=args.num_workers, 141 | pin_memory=args.pin_mem, 142 | drop_last=True, 143 | ) 144 | 145 | 146 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 147 | 148 | if args.lr is None: # only base_lr is specified 149 | args.lr = args.blr * eff_batch_size / 256 150 | 151 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 152 | print("actual lr: %.2e" % args.lr) 153 | 154 | print("accumulate grad iterations: %d" % args.accum_iter) 155 | print("effective batch size: %d" % eff_batch_size) 156 | 157 | model = FeatureFusionBlock(1152, 768) 158 | 159 | model.to(device) 160 | 161 | if args.distributed: 162 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 163 | model_without_ddp = model.module 164 | print('gpu using:', args.gpu) 165 | # following timm: set wd as 0 for bias and norm layers 166 | optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr, betas=(0.9, 0.95)) 167 | print(optimizer) 168 | print('clip_grad:', args.clip_grad) 169 | loss_scaler = NativeScaler() 170 | 171 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 172 | 173 | print(f"Start training for {args.epochs} epochs") 174 | start_time = time.time() 175 | for epoch in range(args.start_epoch, args.epochs): 176 | if args.distributed: 177 | data_loader_train.sampler.set_epoch(epoch) 178 | train_stats = train_one_epoch( 179 | model, data_loader_train, 180 | optimizer, device, epoch, loss_scaler, 181 | log_writer=log_writer, 182 | args=args 183 | ) 184 | if args.output_dir and (epoch % 1 == 0 or epoch + 1 == args.epochs): 185 | misc.save_model( 186 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 187 | loss_scaler=loss_scaler, epoch=epoch) 188 | 189 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 190 | 'epoch': epoch,} 191 | 192 | if args.output_dir and misc.is_main_process(): 193 | if log_writer is not None: 194 | log_writer.flush() 195 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 196 | f.write(json.dumps(log_stats) + "\n") 197 | 198 | total_time = time.time() - start_time 199 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 200 | print('Training time {}'.format(total_time_str)) 201 | 202 | 203 | if __name__ == '__main__': 204 | args = get_args_parser() 205 | args = args.parse_args() 206 | if args.output_dir: 207 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 208 | main(args) 209 | -------------------------------------------------------------------------------- /models/feature_fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | def initialize_weights(model): 6 | for layer in model.modules(): 7 | if isinstance(layer, (nn.Conv2d, nn.Linear)): 8 | nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') 9 | if layer.bias is not None: 10 | nn.init.constant_(layer.bias, 0) 11 | 12 | class Mlp(nn.Module): 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.act = act_layer() 19 | self.fc2 = nn.Linear(hidden_features, out_features) 20 | self.drop = nn.Dropout(drop) 21 | 22 | def forward(self, x): 23 | x = self.fc1(x) 24 | x = self.act(x) 25 | x = self.drop(x) 26 | x = self.fc2(x) 27 | x = self.drop(x) 28 | return x 29 | # class Mlp(nn.Module): 30 | # def __init__(self, in_features, hidden_features=None, out_features=None, 31 | # act_layer=nn.GELU, drop=0.2, use_ln=True): 32 | # super().__init__() 33 | # out_features = out_features or in_features 34 | # hidden_features = hidden_features or in_features 35 | 36 | # self.use_ln = use_ln 37 | # self.scale = min(1.0, (in_features ** -0.5)) 38 | 39 | # # 第一个残差块 40 | # self.res_block1 = nn.Sequential( 41 | # nn.Linear(in_features, hidden_features), 42 | # nn.LayerNorm(hidden_features) if use_ln else nn.Identity(), 43 | # act_layer(), 44 | # nn.Dropout(drop), 45 | # nn.Linear(hidden_features, in_features), # 注意输出维度要和输入相同 46 | # nn.LayerNorm(in_features) if use_ln else nn.Identity(), 47 | # nn.Dropout(drop) 48 | # ) 49 | 50 | # # 第二个残差块 51 | # self.res_block2 = nn.Sequential( 52 | # nn.Linear(in_features, hidden_features), 53 | # nn.LayerNorm(hidden_features) if use_ln else nn.Identity(), 54 | # act_layer(), 55 | # nn.Dropout(drop), 56 | # nn.Linear(hidden_features, out_features), 57 | # nn.LayerNorm(out_features) if use_ln else nn.Identity(), 58 | # nn.Dropout(drop) 59 | # ) 60 | 61 | # # 如果输入输出维度不同,需要一个投影层 62 | # self.proj = None 63 | # if in_features != out_features: 64 | # self.proj = nn.Linear(in_features, out_features) 65 | 66 | # self._init_weights() 67 | 68 | # def _init_weights(self): 69 | # gain = 0.001 70 | # # 初始化第一个残差块 71 | # for m in self.res_block1.modules(): 72 | # if isinstance(m, nn.Linear): 73 | # nn.init.xavier_uniform_(m.weight, gain=gain) 74 | # if m.bias is not None: 75 | # nn.init.zeros_(m.bias) 76 | # m.weight.data *= 0.1 77 | 78 | # # 初始化第二个残差块 79 | # for m in self.res_block2.modules(): 80 | # if isinstance(m, nn.Linear): 81 | # nn.init.xavier_uniform_(m.weight, gain=gain) 82 | # if m.bias is not None: 83 | # nn.init.zeros_(m.bias) 84 | # m.weight.data *= 0.1 85 | 86 | # # 初始化投影层 87 | # if self.proj is not None: 88 | # nn.init.xavier_uniform_(self.proj.weight, gain=gain) 89 | # nn.init.zeros_(self.proj.bias) 90 | # self.proj.weight.data *= 0.1 91 | 92 | # def forward(self, x): 93 | # # 第一个残差块 94 | # identity = x 95 | # x = self.res_block1(x * self.scale) 96 | # x = x * self.scale + identity 97 | 98 | # # 第二个残差块 99 | # identity = x if self.proj is None else self.proj(x) 100 | # x = self.res_block2(x * self.scale) 101 | # x = x * self.scale + identity 102 | 103 | # return x 104 | # class Mlp(nn.Module): 105 | # def __init__(self, in_features, hidden_features=None, out_features=None, 106 | # act_layer=nn.GELU, drop=0.1, use_ln=True): 107 | # super().__init__() 108 | # out_features = out_features or in_features 109 | # hidden_features = hidden_features or in_features 110 | 111 | # self.use_ln = use_ln 112 | 113 | # # 第一层 114 | # self.fc1 = nn.Linear(in_features, hidden_features) 115 | # if use_ln: 116 | # self.ln1 = nn.LayerNorm(hidden_features) 117 | # self.act = act_layer() 118 | # self.drop1 = nn.Dropout(drop) 119 | 120 | # # 第二层 121 | # self.fc2 = nn.Linear(hidden_features, out_features) 122 | # if use_ln: 123 | # self.ln2 = nn.LayerNorm(out_features) 124 | # self.drop2 = nn.Dropout(drop) 125 | 126 | # # 初始化参数 127 | # self._init_weights() 128 | 129 | # def _init_weights(self): 130 | # # 使用较小的初始值来防止梯度爆炸 131 | # nn.init.xavier_uniform_(self.fc1.weight, gain=0.01) 132 | # nn.init.zeros_(self.fc1.bias) 133 | # nn.init.xavier_uniform_(self.fc2.weight, gain=0.01) 134 | # nn.init.zeros_(self.fc2.bias) 135 | 136 | # def forward(self, x): 137 | # x = self.fc1(x) 138 | # if self.use_ln: 139 | # x = self.ln1(x) 140 | # x = self.act(x) 141 | # x = self.drop1(x) 142 | 143 | # x = self.fc2(x) 144 | # if self.use_ln: 145 | # x = self.ln2(x) 146 | # x = self.drop2(x) 147 | 148 | # return x 149 | class FeatureFusionBlock(nn.Module): 150 | def __init__(self, xyz_dim, rgb_dim, mlp_ratio=4.): 151 | super().__init__() 152 | 153 | self.xyz_dim = xyz_dim 154 | self.rgb_dim = rgb_dim 155 | 156 | self.xyz_norm = nn.LayerNorm(xyz_dim) 157 | self.xyz_mlp = Mlp(in_features=xyz_dim, hidden_features=int(xyz_dim * mlp_ratio), act_layer=nn.GELU, drop=0.) 158 | initialize_weights(self.xyz_mlp) 159 | 160 | self.rgb_norm = nn.LayerNorm(rgb_dim) 161 | self.rgb_mlp = Mlp(in_features=rgb_dim, hidden_features=int(rgb_dim * mlp_ratio), act_layer=nn.GELU, drop=0.) 162 | initialize_weights(self.rgb_mlp) 163 | 164 | self.rgb_head = nn.Linear(rgb_dim, 256) 165 | self.xyz_head = nn.Linear(xyz_dim, 256) 166 | initialize_weights(self.rgb_head) 167 | initialize_weights(self.xyz_head) 168 | 169 | 170 | self.T = 1 171 | 172 | def feature_fusion(self, xyz_feature, rgb_feature): 173 | 174 | xyz_feature = self.xyz_mlp(self.xyz_norm(xyz_feature)) 175 | rgb_feature = self.rgb_mlp(self.rgb_norm(rgb_feature)) 176 | 177 | feature = torch.cat([xyz_feature, rgb_feature], dim=2) 178 | 179 | return feature 180 | 181 | def contrastive_loss(self, q, k): 182 | # normalize 183 | q = nn.functional.normalize(q, dim=1) 184 | k = nn.functional.normalize(k, dim=1) 185 | # gather all targets 186 | # Einstein sum is more intuitive 187 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T 188 | N = logits.shape[0] # batch size per GPU 189 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() 190 | return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) 191 | 192 | def reparameterize(self, mu, logvar): 193 | """ 194 | Will a single z be enough ti compute the expectation 195 | for the loss?? 196 | :param mu: (Tensor) Mean of the latent Gaussian 197 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 198 | :return: 199 | """ 200 | std = torch.exp(0.5 * logvar) 201 | eps = torch.randn_like(std) 202 | return eps * std + mu 203 | 204 | 205 | def forward(self, xyz_feature, rgb_feature): 206 | 207 | 208 | feature = self.feature_fusion(xyz_feature, rgb_feature) 209 | 210 | feature_xyz = feature[:,:, :self.xyz_dim] 211 | feature_rgb = feature[:,:, self.xyz_dim:] 212 | 213 | q = self.rgb_head(feature_rgb.view(-1, feature_rgb.shape[2])) 214 | k = self.xyz_head(feature_xyz.view(-1, feature_xyz.shape[2])) 215 | 216 | xyz_feature = xyz_feature.view(-1, xyz_feature.shape[2]) 217 | rgb_feature = rgb_feature.view(-1, rgb_feature.shape[2]) 218 | 219 | patch_no_zeros_indices = torch.nonzero(torch.all(xyz_feature != 0, dim=1)) 220 | 221 | loss = self.contrastive_loss(q[patch_no_zeros_indices,:].squeeze(), k[patch_no_zeros_indices,:].squeeze()) 222 | 223 | return loss 224 | 225 | -------------------------------------------------------------------------------- /utils/preprocess_eyecandies.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | import cv2 4 | import numpy as np 5 | import tifffile 6 | import yaml 7 | import imageio.v3 as iio 8 | import math 9 | import argparse 10 | 11 | # The same camera has been used for all the images 12 | FOCAL_LENGTH = 711.11 13 | 14 | def load_and_convert_depth(depth_img, info_depth): 15 | with open(info_depth) as f: 16 | data = yaml.safe_load(f) 17 | mind, maxd = data["normalization"]["min"], data["normalization"]["max"] 18 | 19 | dimg = iio.imread(depth_img) 20 | dimg = dimg.astype(np.float32) 21 | dimg = dimg / 65535.0 * (maxd - mind) + mind 22 | return dimg 23 | 24 | def depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length): 25 | # input depth map (in meters) --- cfr previous section 26 | depth_mt = load_and_convert_depth(depth_img, info_depth) 27 | 28 | # input pose 29 | pose = np.loadtxt(pose_txt) 30 | 31 | # camera intrinsics 32 | height, width = depth_mt.shape[:2] 33 | intrinsics_4x4 = np.array([ 34 | [focal_length, 0, width / 2, 0], 35 | [0, focal_length, height / 2, 0], 36 | [0, 0, 1, 0], 37 | [0, 0, 0, 1]] 38 | ) 39 | 40 | # build the camera projection matrix 41 | camera_proj = intrinsics_4x4 @ pose 42 | 43 | # build the (u, v, 1, 1/depth) vectors (non optimized version) 44 | camera_vectors = np.zeros((width * height, 4)) 45 | count=0 46 | for j in range(height): 47 | for i in range(width): 48 | camera_vectors[count, :] = np.array([i, j, 1, 1/depth_mt[j, i]]) 49 | count += 1 50 | 51 | # invert and apply to each 4-vector 52 | hom_3d_pts= np.linalg.inv(camera_proj) @ camera_vectors.T 53 | # print(hom_3d_pts.shape) 54 | # remove the homogeneous coordinate 55 | pcd = depth_mt.reshape(-1, 1) * hom_3d_pts.T 56 | return pcd[:, :3] 57 | 58 | def remove_point_cloud_background(pc): 59 | 60 | # The second dim is z 61 | dz = pc[256,1] - pc[-256,1] 62 | dy = pc[256,2] - pc[-256,2] 63 | 64 | norm = math.sqrt(dz**2 + dy**2) 65 | start_points = np.array([0, pc[-256, 1], pc[-256, 2]]) 66 | cos_theta = dy / norm 67 | sin_theta = dz / norm 68 | 69 | # Transform and rotation 70 | rotation_matrix = np.array([[1, 0, 0], [0, cos_theta, -sin_theta],[0, sin_theta, cos_theta]]) 71 | processed_pc = (rotation_matrix @ (pc - start_points).T).T 72 | 73 | # Remove background point 74 | for i in range(processed_pc.shape[0]): 75 | if processed_pc[i,1] > -0.02: 76 | processed_pc[i, :] = -start_points 77 | if processed_pc[i,2] > 1.8: 78 | processed_pc[i, :] = -start_points 79 | elif processed_pc[i,0] > 1 or processed_pc[i,0] < -1: 80 | processed_pc[i, :] = -start_points 81 | 82 | processed_pc = (rotation_matrix.T @ processed_pc.T).T + start_points 83 | 84 | index = [0, 2, 1] 85 | processed_pc = processed_pc[:,index] 86 | return processed_pc*[0.1, -0.1, 0.1] 87 | 88 | 89 | if __name__ == '__main__': 90 | 91 | parser = argparse.ArgumentParser(description='Process some integers.') 92 | parser.add_argument('--dataset_path', default='datasets/eyecandies', type=str, help="Original Eyecandies dataset path.") 93 | parser.add_argument('--target_dir', default='datasets/eyecandies_preprocessed', type=str, help="Processed Eyecandies dataset path") 94 | args = parser.parse_args() 95 | 96 | os.mkdir(args.target_dir) 97 | categories_list = os.listdir(args.dataset_path) 98 | 99 | for category_dir in categories_list: 100 | category_root_path = os.path.join(args.dataset_path, category_dir) 101 | 102 | category_train_path = os.path.join(category_root_path, '/train/data') 103 | category_test_path = os.path.join(category_root_path, '/test_public/data') 104 | 105 | category_target_path = os.path.join(args.target_dir, category_dir) 106 | os.mkdir(category_target_path) 107 | 108 | os.mkdir(os.path.join(category_target_path, 'train')) 109 | category_target_train_good_path = os.path.join(category_target_path, 'train/good') 110 | category_target_train_good_rgb_path = os.path.join(category_target_train_good_path, 'rgb') 111 | category_target_train_good_xyz_path = os.path.join(category_target_train_good_path, 'xyz') 112 | os.mkdir(category_target_train_good_path) 113 | os.mkdir(category_target_train_good_rgb_path) 114 | os.mkdir(category_target_train_good_xyz_path) 115 | 116 | os.mkdir(os.path.join(category_target_path, 'test')) 117 | category_target_test_good_path = os.path.join(category_target_path, 'test/good') 118 | category_target_test_good_rgb_path = os.path.join(category_target_test_good_path, 'rgb') 119 | category_target_test_good_xyz_path = os.path.join(category_target_test_good_path, 'xyz') 120 | category_target_test_good_gt_path = os.path.join(category_target_test_good_path, 'gt') 121 | os.mkdir(category_target_test_good_path) 122 | os.mkdir(category_target_test_good_rgb_path) 123 | os.mkdir(category_target_test_good_xyz_path) 124 | os.mkdir(category_target_test_good_gt_path) 125 | category_target_test_bad_path = os.path.join(category_target_path, 'test/bad') 126 | category_target_test_bad_rgb_path = os.path.join(category_target_test_bad_path, 'rgb') 127 | category_target_test_bad_xyz_path = os.path.join(category_target_test_bad_path, 'xyz') 128 | category_target_test_bad_gt_path = os.path.join(category_target_test_bad_path, 'gt') 129 | os.mkdir(category_target_test_bad_path) 130 | os.mkdir(category_target_test_bad_rgb_path) 131 | os.mkdir(category_target_test_bad_xyz_path) 132 | os.mkdir(category_target_test_bad_gt_path) 133 | 134 | category_train_files = os.listdir(category_train_path) 135 | num_train_files = len(category_train_files)//17 136 | for i in range(0, num_train_files): 137 | pc = depth_to_pointcloud( 138 | os.path.join(category_train_path,str(i).zfill(3)+'_depth.png'), 139 | os.path.join(category_train_path,str(i).zfill(3)+'_info_depth.yaml'), 140 | os.path.join(category_train_path,str(i).zfill(3)+'_pose.txt'), 141 | FOCAL_LENGTH, 142 | ) 143 | pc = remove_point_cloud_background(pc) 144 | pc = pc.reshape(512,512,3) 145 | tifffile.imwrite(os.path.join(category_target_train_good_xyz_path, str(i).zfill(3)+'.tiff'), pc) 146 | copyfile(os.path.join(category_train_path,str(i).zfill(3)+'_image_4.png'),os.path.join(category_target_train_good_rgb_path, str(i).zfill(3)+'.png')) 147 | 148 | 149 | category_test_files = os.listdir(category_test_path) 150 | num_test_files = len(category_test_files)//17 151 | for i in range(0, num_test_files): 152 | mask = cv2.imread(os.path.join(category_test_path,str(i).zfill(2)+'_mask.png')) 153 | if np.any(mask): 154 | pc = depth_to_pointcloud( 155 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'), 156 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'), 157 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'), 158 | FOCAL_LENGTH, 159 | ) 160 | pc = remove_point_cloud_background(pc) 161 | pc = pc.reshape(512,512,3) 162 | tifffile.imwrite(os.path.join(category_target_test_bad_xyz_path, str(i).zfill(3)+'.tiff'), pc) 163 | cv2.imwrite(os.path.join(category_target_test_bad_gt_path, str(i).zfill(3)+'.png'), mask) 164 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_bad_rgb_path, str(i).zfill(3)+'.png')) 165 | else: 166 | pc = depth_to_pointcloud( 167 | os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'), 168 | os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'), 169 | os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'), 170 | FOCAL_LENGTH, 171 | ) 172 | pc = remove_point_cloud_background(pc) 173 | pc = pc.reshape(512,512,3) 174 | tifffile.imwrite(os.path.join(category_target_test_good_xyz_path, str(i).zfill(3)+'.tiff'), pc) 175 | cv2.imwrite(os.path.join(category_target_test_good_gt_path, str(i).zfill(3)+'.png'), mask) 176 | copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_good_rgb_path, str(i).zfill(3)+'.png')) 177 | -------------------------------------------------------------------------------- /utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tifffile as tiff 4 | import open3d as o3d 5 | from pathlib import Path 6 | from PIL import Image 7 | import math 8 | import mvtec3d_util as mvt_util 9 | import argparse 10 | 11 | 12 | def get_edges_of_pc(organized_pc): 13 | unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2]) 14 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0) 15 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0) 16 | unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0) 17 | unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:] 18 | return unorganized_edges_pc 19 | 20 | def get_plane_eq(unorganized_pc,ransac_n_pts=50): 21 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc)) 22 | plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000) 23 | return plane_model 24 | 25 | def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005): 26 | # PREP PC 27 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean) 28 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) 29 | clean_planeless_unorganized_pc = unorganized_pc.copy() 30 | planeless_unorganized_rgb = unorganized_rgb.copy() 31 | 32 | # REMOVE PLANE 33 | plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean)) 34 | distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T)) 35 | plane_indices = np.argwhere(distances < distance_threshold) 36 | 37 | planeless_unorganized_rgb[plane_indices] = 0 38 | clean_planeless_unorganized_pc[plane_indices] = 0 39 | clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0], 40 | organized_pc_clean.shape[1], 41 | organized_pc_clean.shape[2]) 42 | planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0], 43 | organized_rgb.shape[1], 44 | organized_rgb.shape[2]) 45 | return clean_planeless_organized_pc, planeless_organized_rgb 46 | 47 | 48 | 49 | def connected_components_cleaning(organized_pc, organized_rgb, image_path): 50 | unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc) 51 | unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) 52 | 53 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] 54 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :] 55 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros)) 56 | labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False)) 57 | 58 | 59 | unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True) 60 | max_label = labels.max() 61 | if max_label>0: 62 | print("##########################################################################") 63 | print(f"Point cloud file {image_path} has {max_label + 1} clusters") 64 | print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}") 65 | print("##########################################################################\n\n") 66 | 67 | largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)] 68 | outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id) 69 | outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array] 70 | unorganized_pc[outlier_indices_original_pc_array] = 0 71 | unorganized_rgb[outlier_indices_original_pc_array] = 0 72 | organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0], 73 | organized_pc.shape[1], 74 | organized_pc.shape[2]) 75 | organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0], 76 | organized_rgb.shape[1], 77 | organized_rgb.shape[2]) 78 | return organized_clustered_pc, organized_clustered_rgb 79 | 80 | def roundup_next_100(x): 81 | return int(math.ceil(x / 100.0)) * 100 82 | 83 | def pad_cropped_pc(cropped_pc, single_channel=False): 84 | orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1] 85 | round_orig_h = roundup_next_100(orig_h) 86 | round_orig_w = roundup_next_100(orig_w) 87 | large_side = max(round_orig_h, round_orig_w) 88 | 89 | a = (large_side - orig_h) // 2 90 | aa = large_side - a - orig_h 91 | 92 | b = (large_side - orig_w) // 2 93 | bb = large_side - b - orig_w 94 | if single_channel: 95 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant') 96 | else: 97 | return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant') 98 | 99 | def preprocess_pc(tiff_path): 100 | # READ FILES 101 | organized_pc = mvt_util.read_tiff_organized_pc(tiff_path) 102 | 103 | # rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "jpg") 104 | # gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png") 105 | 106 | # 修改获取原图和gt的方式 107 | rgb_path = '' 108 | gt_path = '' 109 | # 获取 tiff 文件名 110 | tiff_file_name = os.path.basename(tiff_path) 111 | 112 | # 获取 tiff 文件所在的目录,并生成对应的 RGB 和 GT 目录 113 | tiff_dir = os.path.dirname(tiff_path) 114 | rgb_dir = tiff_dir.replace("xyz", "rgb") # 将 tiff 路径中的 xyz 替换为 rgb 115 | gt_dir = tiff_dir.replace("xyz", "gt") # 将 tiff 路径中的 xyz 替换为 gt 116 | 117 | # 获取 tiff 文件名中的匹配前缀(下划线前的部分) 118 | if "_" in tiff_file_name: 119 | tiff_prefix = tiff_file_name.split('_', 1)[0] # 获取 tiff 文件名中第一个下划线前的部分 120 | else: 121 | tiff_prefix = tiff_file_name # 没有下划线时,直接使用完整文件名作为前缀 122 | 123 | # 遍历 RGB 目录,找到与 tiff 前缀匹配的文件 124 | for rgb_file in os.listdir(rgb_dir): 125 | if "_" in rgb_file: 126 | rgb_prefix = rgb_file.split('_', 1)[0] # 获取 RGB 文件名中第一个下划线前的部分 127 | else: 128 | rgb_prefix = rgb_file # 没有下划线时,直接使用文件名 129 | 130 | # 如果前缀匹配,则生成新的 RGB 路径,并使用原始 RGB 文件名 131 | if tiff_prefix == rgb_prefix: 132 | rgb_path = os.path.join(rgb_dir, rgb_file) # 保留原始 RGB 文件名 133 | print(rgb_path) 134 | 135 | # 遍历 GT 目录,找到与 tiff 前缀匹配的文件 136 | for gt_file in os.listdir(gt_dir): 137 | if "_" in gt_file: 138 | gt_prefix = gt_file.split('_', 1)[0] # 获取 GT 文件名中第一个下划线前的部分 139 | else: 140 | gt_prefix = gt_file.split('.', 1)[0] # 没有下划线时,直接使用文件名 141 | 142 | # 如果前缀匹配,则生成新的 GT 路径,并使用原始 GT 文件名 143 | if tiff_prefix == gt_prefix: 144 | gt_path = os.path.join(gt_dir, gt_file) # 保留原始 GT 文件名 145 | print(gt_path) 146 | 147 | 148 | 149 | organized_rgb = np.array(Image.open(rgb_path)) 150 | 151 | organized_gt = None 152 | gt_exists = os.path.isfile(gt_path) 153 | if gt_exists: 154 | organized_gt = np.array(Image.open(gt_path)) 155 | 156 | # REMOVE PLANE 157 | planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb) 158 | 159 | 160 | # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE) 161 | padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False) 162 | padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False) 163 | if gt_exists: 164 | padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True) 165 | 166 | organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path) 167 | # SAVE PREPROCESSED FILES 168 | tiff.imsave(tiff_path, organized_clustered_pc) 169 | Image.fromarray(organized_clustered_rgb).save(rgb_path) 170 | if gt_exists: 171 | Image.fromarray(padded_organized_gt).save(gt_path) 172 | 173 | 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD') 177 | parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)') 178 | args = parser.parse_args() 179 | 180 | 181 | root_path = args.dataset_path 182 | paths = Path(root_path).rglob('*.tiff') 183 | print(f"Found {len(list(paths))} tiff files in {root_path}") 184 | processed_files = 0 185 | for path in Path(root_path).rglob('*.tiff'): 186 | preprocess_pc(path) 187 | processed_files += 1 188 | if processed_files % 50 == 0: 189 | print(f"Processed {processed_files} tiff files...") 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from m3dm_runner import M3DM 3 | from dataset import eyecandies_classes, mvtec3d_classes, test_3d_classes 4 | import os 5 | import pandas as pd 6 | import torch 7 | #os.environ['CUDA_VISIBLE_DEVICES'] = '7' 8 | 9 | 10 | def run_3d_ads(args): 11 | if args.dataset_type=='eyecandies': 12 | classes = eyecandies_classes() 13 | elif args.dataset_type=='mvtec3d': 14 | classes = mvtec3d_classes() 15 | elif args.dataset_type=='test_3d': 16 | classes = test_3d_classes() 17 | 18 | METHOD_NAMES = [args.method_name] 19 | 20 | image_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method']) 21 | pixel_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method']) 22 | au_pros_df = pd.DataFrame(METHOD_NAMES, columns=['Method']) 23 | for cls in classes: 24 | model = M3DM(args) 25 | model.fit(cls) 26 | image_rocaucs, pixel_rocaucs, au_pros = model.evaluate(cls) 27 | 28 | image_rocaucs_df[cls.title()] = image_rocaucs_df['Method'].map(image_rocaucs) 29 | pixel_rocaucs_df[cls.title()] = pixel_rocaucs_df['Method'].map(pixel_rocaucs) 30 | au_pros_df[cls.title()] = au_pros_df['Method'].map(au_pros) 31 | 32 | print(f"\nFinished running on class {cls}") 33 | print("################################################################################\n\n") 34 | 35 | # image_rocaucs_df['Mean'] = round(image_rocaucs_df.iloc[:, 1:].mean(axis=1),3) 36 | # pixel_rocaucs_df['Mean'] = round(pixel_rocaucs_df.iloc[:, 1:].mean(axis=1),3) 37 | # au_pros_df['Mean'] = round(au_pros_df.iloc[:, 1:].mean(axis=1),3) 38 | 39 | # print("\n\n################################################################################") 40 | # print("############################# Image ROCAUC Results #############################") 41 | # print("################################################################################\n") 42 | # print(image_rocaucs_df.to_markdown(index=False)) 43 | 44 | # print("\n\n################################################################################") 45 | # print("############################# Pixel ROCAUC Results #############################") 46 | # print("################################################################################\n") 47 | # print(pixel_rocaucs_df.to_markdown(index=False)) 48 | 49 | # print("\n\n##########################################################################") 50 | # print("############################# AU PRO Results #############################") 51 | # print("##########################################################################\n") 52 | # print(au_pros_df.to_markdown(index=False)) 53 | print("\n\n################################################################################") 54 | print("############################# Image ROCAUC Results #############################") 55 | print("################################################################################\n") 56 | print(image_rocaucs_df.to_string(index=False)) 57 | 58 | print("\n\n################################################################################") 59 | print("############################# Pixel ROCAUC Results #############################") 60 | print("################################################################################\n") 61 | print(pixel_rocaucs_df.to_string(index=False)) 62 | 63 | print("\n\n##########################################################################") 64 | print("############################# AU PRO Results #############################") 65 | print("##########################################################################\n") 66 | print(au_pros_df.to_string(index=False)) 67 | 68 | 69 | 70 | # with open("results/image_rocauc_results.md", "a") as tf: 71 | # tf.write(image_rocaucs_df.to_markdown(index=False)) 72 | # with open("results/pixel_rocauc_results.md", "a") as tf: 73 | # tf.write(pixel_rocaucs_df.to_markdown(index=False)) 74 | # with open("results/aupro_results.md", "a") as tf: 75 | # tf.write(au_pros_df.to_markdown(index=False)) 76 | # with open("results/image_rocauc_results.txt", "a") as tf: 77 | # tf.write(image_rocaucs_df.to_string(index=False)) 78 | # with open("results/pixel_rocauc_results.txt", "a") as tf: 79 | # tf.write(pixel_rocaucs_df.to_string(index=False)) 80 | # with open("results/aupro_results.txt", "a") as tf: 81 | # tf.write(au_pros_df.to_string(index=False)) 82 | 83 | 84 | if __name__ == '__main__': 85 | #torch.cuda.set_device(7) 86 | parser = argparse.ArgumentParser(description='Process some integers.') 87 | 88 | parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str, 89 | choices=['DINO','Point_MAE','Fusion','DINO+Point_MAE','DINO+Point_MAE+Fusion','DINO+Point_MAE+add','DINO+FPFH','DINO+FPFH+Fusion', 90 | 'DINO+FPFH+Fusion+ps','DINO+Point_MAE+Fusion+ps','DINO+Point_MAE+ps','DINO+FPFH+ps','ours','ours2','ours3','ours_final','ours_final1' 91 | ,'ours_final1_VS', 'm3dm_uninterpolate', 'ours_PS','m3dm_VS','Point_MAE_VS','DINO_VS','patchcore_VS','PS_VS','OURS_EX_VS','NEW_OURS_EX_VS','patchcore_uninterpolate','shape'], 92 | help='Anomaly detection modal name.') 93 | parser.add_argument('--max_sample', default=400, type=int, 94 | help='Max sample number.') 95 | parser.add_argument('--memory_bank', default='multiple', type=str, 96 | choices=["multiple", "single"], 97 | help='memory bank mode: "multiple", "single".') 98 | parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str, 99 | choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'], 100 | help='Timm checkpoints name of RGB backbone.') 101 | parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert','FPFH'], 102 | help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert, FPFH].') 103 | parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str, 104 | help='Checkpoints for fusion module.') 105 | parser.add_argument('--save_feature', default=False, action='store_true', 106 | help='Save feature for training fusion block.') 107 | parser.add_argument('--use_uff', default=False, action='store_true', 108 | help='Use UFF module.') 109 | parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str, 110 | help='Save feature for training fusion block.') 111 | parser.add_argument('--save_preds', default=False, action='store_true', 112 | help='Save predicts results.') 113 | parser.add_argument('--group_size', default=128, type=int, 114 | help='Point group size of Point Transformer.') 115 | parser.add_argument('--num_group', default=1024, type=int, 116 | help='Point groups number of Point Transformer.') 117 | parser.add_argument('--random_state', default=None, type=int, 118 | help='random_state for random project') 119 | parser.add_argument('--dataset_type', default='test_3d', type=str, choices=['mvtec3d', 'eyecandies','test_3d'], 120 | help='Dataset type for training or testing') 121 | parser.add_argument('--dataset_path', default='/fuxi_team14_intern/D3', type=str, 122 | help='Dataset store path') 123 | parser.add_argument('--img_size', default=224, type=int, 124 | help='Images size for model') 125 | parser.add_argument('--xyz_s_lambda', default=1.0, type=float, 126 | help='xyz_s_lambda') 127 | parser.add_argument('--xyz_smap_lambda', default=1.0, type=float, 128 | help='xyz_smap_lambda') 129 | parser.add_argument('--rgb_s_lambda', default=0.1, type=float, 130 | help='rgb_s_lambda') 131 | parser.add_argument('--rgb_smap_lambda', default=0.1, type=float, 132 | help='rgb_smap_lambda') 133 | parser.add_argument('--ps_s_lambda', default=0.1, type=float, 134 | help='rgb_s_lambda') 135 | parser.add_argument('--ps_smap_lambda', default=0.1, type=float, 136 | help='rgb_smap_lambda') 137 | parser.add_argument('--fusion_s_lambda', default=1.0, type=float, 138 | help='fusion_s_lambda') 139 | parser.add_argument('--fusion_smap_lambda', default=1.0, type=float, 140 | help='fusion_smap_lambda') 141 | parser.add_argument('--coreset_eps', default=0.9, type=float, 142 | help='eps for sparse project') 143 | parser.add_argument('--f_coreset', default=0.1, type=float, 144 | help='eps for sparse project') 145 | parser.add_argument('--asy_memory_bank', default=None, type=int, 146 | help='build an asymmetric memory bank for point clouds') 147 | parser.add_argument('--ocsvm_nu', default=0.5, type=float, 148 | help='ocsvm nu') 149 | parser.add_argument('--ocsvm_maxiter', default=1000, type=int, 150 | help='ocsvm maxiter') 151 | parser.add_argument('--rm_zero_for_project', default=False, action='store_true', 152 | help='Save predicts results.') 153 | parser.add_argument('--downsampling', default=1, type=int, 154 | help='downsampling factor') 155 | parser.add_argument('--rotate_angle', default=0, type=float, 156 | help='rotate angle') 157 | parser.add_argument('--small', default=False, action='store_true', 158 | help='small predict') 159 | parser.add_argument('--split', default=False, action='store_true', 160 | help='split_predict') 161 | parser.add_argument('--ex', default=1, type=int, 162 | help='ex_factor') 163 | 164 | args = parser.parse_args() 165 | run_3d_ads(args) 166 | -------------------------------------------------------------------------------- /utils/au_pro_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code based on the official MVTec 3D-AD evaluation code found at 3 | https://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz 4 | 5 | Utility functions that compute a PRO curve and its definite integral, given 6 | pairs of anomaly and ground truth maps. 7 | 8 | The PRO curve can also be integrated up to a constant integration limit. 9 | """ 10 | import numpy as np 11 | from scipy.ndimage.measurements import label 12 | from bisect import bisect 13 | 14 | 15 | class GroundTruthComponent: 16 | """ 17 | Stores sorted anomaly scores of a single ground truth component. 18 | Used to efficiently compute the region overlap for many increasing thresholds. 19 | """ 20 | 21 | def __init__(self, anomaly_scores): 22 | """ 23 | Initialize the module. 24 | 25 | Args: 26 | anomaly_scores: List of all anomaly scores within the ground truth 27 | component as numpy array. 28 | """ 29 | # Keep a sorted list of all anomaly scores within the component. 30 | self.anomaly_scores = anomaly_scores.copy() 31 | self.anomaly_scores.sort() 32 | 33 | # Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels. 34 | self.index = 0 35 | 36 | # The last evaluated threshold. 37 | self.last_threshold = None 38 | 39 | def compute_overlap(self, threshold): 40 | """ 41 | Compute the region overlap for a specific threshold. 42 | Thresholds must be passed in increasing order. 43 | 44 | Args: 45 | threshold: Threshold to compute the region overlap. 46 | 47 | Returns: 48 | Region overlap for the specified threshold. 49 | """ 50 | if self.last_threshold is not None: 51 | assert self.last_threshold <= threshold 52 | 53 | # Increase the index until it points to an anomaly score that is just above the specified threshold. 54 | while (self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold): 55 | self.index += 1 56 | 57 | # Compute the fraction of component pixels that are correctly segmented as anomalous. 58 | return 1.0 - self.index / len(self.anomaly_scores) 59 | 60 | 61 | def trapezoid(x, y, x_max=None): 62 | """ 63 | This function calculates the definit integral of a curve given by x- and corresponding y-values. 64 | In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by 65 | setting a value x_max. 66 | 67 | Points that do not have a finite x or y value will be ignored with a warning. 68 | 69 | Args: 70 | x: Samples from the domain of the function to integrate need to be sorted in ascending order. May contain 71 | the same value multiple times. In that case, the order of the corresponding y values will affect the 72 | integration with the trapezoidal rule. 73 | y: Values of the function corresponding to x values. 74 | x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its 75 | neighbors. Must not lie outside of the range of x. 76 | 77 | Returns: 78 | Area under the curve. 79 | """ 80 | 81 | x = np.array(x) 82 | y = np.array(y) 83 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y)) 84 | if not finite_mask.all(): 85 | print( 86 | """WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.""") 87 | x = x[finite_mask] 88 | y = y[finite_mask] 89 | 90 | # Introduce a correction term if max_x is not an element of x. 91 | correction = 0. 92 | if x_max is not None: 93 | if x_max not in x: 94 | # Get the insertion index that would keep x sorted after np.insert(x, ins, x_max). 95 | ins = bisect(x, x_max) 96 | # x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x). 97 | assert 0 < ins < len(x) 98 | 99 | # Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not 100 | # know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1]. 101 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1])) 102 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1]) 103 | 104 | # Cut off at x_max. 105 | mask = x <= x_max 106 | x = x[mask] 107 | y = y[mask] 108 | 109 | # Return area under the curve using the trapezoidal rule. 110 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction 111 | 112 | 113 | def collect_anomaly_scores(anomaly_maps, ground_truth_maps): 114 | """ 115 | Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false 116 | positive pixel from anomaly maps. 117 | 118 | Args: 119 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel. 120 | 121 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels 122 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains 123 | an anomaly. 124 | 125 | Returns: 126 | ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each 127 | component, a sorted list of its anomaly scores is stored. 128 | 129 | anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list 130 | can be used to quickly select thresholds that fix a certain false positive rate. 131 | """ 132 | # Make sure an anomaly map is present for each ground truth map. 133 | assert len(anomaly_maps) == len(ground_truth_maps) 134 | 135 | # Initialize ground truth components and scores of potential fp pixels. 136 | ground_truth_components = [] 137 | anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size) 138 | 139 | # Structuring element for computing connected components. 140 | structure = np.ones((3, 3), dtype=int) 141 | 142 | # Collect anomaly scores within each ground truth region and for all potential fp pixels. 143 | ok_index = 0 144 | for gt_map, prediction in zip(ground_truth_maps, anomaly_maps): 145 | 146 | # Compute the connected components in the ground truth map. 147 | labeled, n_components = label(gt_map, structure) 148 | 149 | # Store all potential fp scores. 150 | num_ok_pixels = len(prediction[labeled == 0]) 151 | anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy() 152 | ok_index += num_ok_pixels 153 | 154 | # Fetch anomaly scores within each GT component. 155 | for k in range(n_components): 156 | component_scores = prediction[labeled == (k + 1)] 157 | ground_truth_components.append(GroundTruthComponent(component_scores)) 158 | 159 | # Sort all potential false positive scores. 160 | anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index) 161 | anomaly_scores_ok_pixels.sort() 162 | 163 | return ground_truth_components, anomaly_scores_ok_pixels 164 | 165 | 166 | def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds): 167 | """ 168 | Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground 169 | truth maps. The number of interpolation points can be set manually. 170 | 171 | Args: 172 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel. 173 | 174 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels 175 | for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains 176 | an anomaly. 177 | 178 | num_thresholds: Number of thresholds to compute the PRO curve. 179 | Returns: 180 | fprs: List of false positive rates. 181 | pros: List of correspoding PRO values. 182 | """ 183 | # Fetch sorted anomaly scores. 184 | ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps) 185 | 186 | # Select equidistant thresholds. 187 | threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int) 188 | 189 | fprs = [1.0] 190 | pros = [1.0] 191 | for pos in threshold_positions: 192 | threshold = anomaly_scores_ok_pixels[pos] 193 | 194 | # Compute the false positive rate for this threshold. 195 | fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels) 196 | 197 | # Compute the PRO value for this threshold. 198 | pro = 0.0 199 | for component in ground_truth_components: 200 | pro += component.compute_overlap(threshold) 201 | pro /= len(ground_truth_components) 202 | 203 | fprs.append(fpr) 204 | pros.append(pro) 205 | 206 | # Return (FPR/PRO) pairs in increasing FPR order. 207 | fprs = fprs[::-1] 208 | pros = pros[::-1] 209 | 210 | return fprs, pros 211 | 212 | 213 | def calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresholds=100): 214 | """ 215 | Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images. 216 | Args: 217 | gts: List of tensors that contain the ground truth images for a single dataset object. 218 | predictions: List of tensors containing anomaly images for each ground truth image. 219 | integration_limit: Integration limit to use when computing the area under the PRO curve. 220 | num_thresholds: Number of thresholds to use to sample the area under the PRO curve. 221 | 222 | Returns: 223 | au_pro: Area under the PRO curve computed up to the given integration limit. 224 | pro_curve: PRO curve values for localization (fpr,pro). 225 | """ 226 | # Compute the PRO curve. 227 | pro_curve = compute_pro(anomaly_maps=predictions, ground_truth_maps=gts, num_thresholds=num_thresholds) 228 | 229 | # Compute the area under the PRO curve. 230 | au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max=integration_limit) 231 | au_pro /= integration_limit 232 | 233 | # Return the evaluation metrics. 234 | return au_pro, pro_curve 235 | -------------------------------------------------------------------------------- /feature_extractors/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | PatchCore logic based on https://github.com/rvorias/ind_knn_ad 3 | """ 4 | import torch 5 | import numpy as np 6 | import os 7 | from tqdm import tqdm 8 | from matplotlib import pyplot as plt 9 | 10 | from sklearn import random_projection 11 | from sklearn import linear_model 12 | from sklearn.svm import OneClassSVM 13 | from sklearn.ensemble import IsolationForest 14 | from sklearn.metrics import roc_auc_score 15 | 16 | from timm.models.layers import DropPath, trunc_normal_ 17 | from pointnet2_ops import pointnet2_utils 18 | from knn_cuda import KNN 19 | 20 | from utils.utils import KNNGaussianBlur 21 | from utils.utils import set_seeds 22 | from utils.au_pro_util import calculate_au_pro 23 | 24 | from models.pointnet2_utils import interpolating_points 25 | from models.feature_fusion import FeatureFusionBlock 26 | from models.models import Model 27 | 28 | class Features(torch.nn.Module): 29 | 30 | def __init__(self, args, image_size=224, f_coreset=0.1, coreset_eps=0.9): 31 | super().__init__() 32 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 33 | self.deep_feature_extractor = Model( 34 | device=self.device, 35 | rgb_backbone_name=args.rgb_backbone_name, 36 | xyz_backbone_name=args.xyz_backbone_name, 37 | group_size = args.group_size, 38 | num_group=args.num_group 39 | ) 40 | self.deep_feature_extractor.to(self.device) 41 | 42 | self.args = args 43 | self.image_size = args.img_size 44 | self.f_coreset = args.f_coreset 45 | self.coreset_eps = args.coreset_eps 46 | self.ex_factor =args.ex 47 | self.blur = KNNGaussianBlur(4) 48 | self.n_reweight = 3 49 | set_seeds(0) 50 | self.patch_xyz_lib = [] 51 | self.patch_rgb_lib = [] 52 | self.patch_ps_lib = [] 53 | self.patch_fusion_lib = [] 54 | self.patch_lib = [] 55 | self.random_state = args.random_state 56 | 57 | self.xyz_dim = 0 58 | self.rgb_dim = 0 59 | 60 | self.xyz_mean=0 61 | self.xyz_std=0 62 | self.rgb_mean=0 63 | self.rgb_std=0 64 | self.fusion_mean=0 65 | self.fusion_std=0 66 | self.ps_mean=0 67 | self.ps_std=0 68 | 69 | self.average = torch.nn.AvgPool2d(3, stride=1) # torch.nn.AvgPool2d(1, stride=1) # 70 | self.resize = torch.nn.AdaptiveAvgPool2d((56, 56)) 71 | self.resize2 = torch.nn.AdaptiveAvgPool2d((56, 56)) 72 | 73 | self.image_preds = list() 74 | self.image_labels = list() 75 | self.pixel_preds = list() 76 | self.pixel_labels = list() 77 | self.gts = [] 78 | self.predictions = [] 79 | self.image_rocauc = 0 80 | self.pixel_rocauc = 0 81 | self.au_pro = 0 82 | self.ins_id = 0 83 | self.rgb_layernorm = torch.nn.LayerNorm(768, elementwise_affine=False) 84 | if self.args.use_uff: 85 | if self.args.xyz_backbone_name == 'FPFH' : 86 | self.fusion = FeatureFusionBlock(33, 768, mlp_ratio=4.) 87 | else : 88 | self.fusion = FeatureFusionBlock(1152, 768, mlp_ratio=4.) 89 | #应为没找到预训练的所以先注释掉 90 | ckpt = torch.load('/fuxi_team14_intern/m3dm/checkpoints/uff_pretrain.pth')['model'] 91 | #ckpt = torch.load('/workspace/data3/code/M3DM-main/checkpoints5/checkpoint-99.pth')['model'] 92 | #ckpt = torch.load('/workspace/data3/code/M3DM-main/checkpoints6/checkpoint-400.pth')['model'] 93 | 94 | incompatible = self.fusion.load_state_dict(ckpt, strict=False) 95 | 96 | print('[Fusion Block]', incompatible) 97 | 98 | self.detect_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter) 99 | self.seg_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter) 100 | 101 | self.s_lib = [] 102 | self.s_map_lib = [] 103 | 104 | def __call__(self, rgb, xyz): 105 | # Extract the desired feature maps using the backbone model. 106 | rgb = rgb.to(self.device) 107 | xyz = xyz.to(self.device) 108 | with torch.no_grad(): 109 | rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz) 110 | 111 | interpolate = True 112 | if interpolate: 113 | interpolated_feature_maps = interpolating_points(xyz, center.permute(0,2,1), xyz_feature_maps).to("cpu") 114 | 115 | xyz_feature_maps = [fmap.to("cpu") for fmap in [xyz_feature_maps]] 116 | rgb_feature_maps = [fmap.to("cpu") for fmap in [rgb_feature_maps]] 117 | 118 | if interpolate: 119 | return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx, interpolated_feature_maps 120 | else: 121 | return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx 122 | 123 | def add_sample_to_mem_bank(self, sample): 124 | raise NotImplementedError 125 | 126 | def predict(self, sample, mask, label): 127 | raise NotImplementedError 128 | 129 | def add_sample_to_late_fusion_mem_bank(self, sample): 130 | raise NotImplementedError 131 | 132 | def interpolate_points(self, rgb, xyz): 133 | with torch.no_grad(): 134 | rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz) 135 | return xyz_feature_maps, center, xyz 136 | 137 | def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx): 138 | raise NotImplementedError 139 | 140 | def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'): 141 | raise NotImplementedError 142 | 143 | def run_coreset(self): 144 | raise NotImplementedError 145 | 146 | def calculate_metrics(self): 147 | self.image_preds = np.stack(self.image_preds) 148 | self.image_labels = np.stack(self.image_labels) 149 | self.pixel_preds = np.array(self.pixel_preds) 150 | # 在计算metrics之前检查和处理NaN值 151 | if np.any(np.isnan(self.image_preds)): 152 | print(f"Warning: Found {np.sum(np.isnan(self.image_preds))} NaN values in image_preds") 153 | # 将NaN替换为0或该特征的平均值 154 | self.image_preds = np.nan_to_num(self.image_preds, nan=0.0) 155 | 156 | if np.any(np.isnan(self.pixel_preds)): 157 | print(f"Warning: Found {np.sum(np.isnan(self.pixel_preds))} NaN values in pixel_preds") 158 | self.pixel_preds = np.nan_to_num(self.pixel_preds, nan=0.0) 159 | 160 | try: 161 | self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds) 162 | except Exception as e: 163 | print(f"Error calculating image_rocauc: {e}") 164 | print(f"image_labels shape: {self.image_labels.shape}") 165 | print(f"image_preds shape: {self.image_preds.shape}") 166 | print(f"image_labels unique values: {np.unique(self.image_labels)}") 167 | print(f"image_preds range: [{np.min(self.image_preds)}, {np.max(self.image_preds)}]") 168 | self.image_rocauc = 0.0 169 | #self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds) 170 | self.pixel_rocauc = roc_auc_score(self.pixel_labels, self.pixel_preds) 171 | self.au_pro, _ = calculate_au_pro(self.gts, self.predictions) 172 | 173 | self.image_preds = list() 174 | self.image_labels = list() 175 | self.pixel_preds = list() 176 | self.pixel_labels = list() 177 | self.gts = [] 178 | self.predictions = [] 179 | 180 | def save_prediction_maps(self, output_path, rgb_path, save_num=5): 181 | for i in range(max(save_num, len(self.predictions))): 182 | # fig = plt.figure(dpi=300) 183 | fig = plt.figure() 184 | 185 | ax3 = fig.add_subplot(1,3,1) 186 | gt = plt.imread(rgb_path[i][0]) 187 | ax3.imshow(gt) 188 | 189 | ax2 = fig.add_subplot(1,3,2) 190 | im2 = ax2.imshow(self.gts[i], cmap=plt.cm.gray) 191 | 192 | ax = fig.add_subplot(1,3,3) 193 | im = ax.imshow(self.predictions[i], cmap=plt.cm.jet) 194 | 195 | class_dir = os.path.join(output_path, rgb_path[i][0].split('/')[-5]) 196 | if not os.path.exists(class_dir): 197 | os.mkdir(class_dir) 198 | 199 | ad_dir = os.path.join(class_dir, rgb_path[i][0].split('/')[-3]) 200 | if not os.path.exists(ad_dir): 201 | os.mkdir(ad_dir) 202 | 203 | plt.savefig(os.path.join(ad_dir, str(self.image_preds[i]) + '_pred_' + rgb_path[i][0].split('/')[-1] + '.jpg')) 204 | 205 | def run_late_fusion(self): 206 | self.s_lib = torch.cat(self.s_lib, 0) 207 | self.s_map_lib = torch.cat(self.s_map_lib, 0) 208 | 209 | # # 检查并处理缺失值 210 | # if torch.isnan(self.s_lib).any(): 211 | # print("警告: self.s_lib 中包含缺失值,正在进行处理。") 212 | # self.s_lib = self.s_lib[~torch.isnan(self.s_lib).any(dim=1)] # 删除包含缺失值的样本 213 | 214 | # if torch.isnan(self.s_map_lib).any(): 215 | # print("警告: self.s_map_lib 中包含缺失值,正在进行处理。") 216 | # self.s_map_lib = self.s_map_lib[~torch.isnan(self.s_map_lib).any(dim=1)] # 删除包含缺失值的样本 217 | self.detect_fuser.fit(self.s_lib) 218 | self.seg_fuser.fit(self.s_map_lib) 219 | 220 | def get_coreset_idx_randomp(self, z_lib, n=1000, eps=0.90, float16=True, force_cpu=False): 221 | 222 | print(f" Fitting random projections. Start dim = {z_lib.shape}.") 223 | try: 224 | transformer = random_projection.SparseRandomProjection(eps=eps, random_state=self.random_state) 225 | z_lib = torch.tensor(transformer.fit_transform(z_lib)) 226 | 227 | print(f" DONE. Transformed dim = {z_lib.shape}.") 228 | except ValueError: 229 | print(" Error: could not project vectors. Please increase `eps`.") 230 | 231 | select_idx = 0 232 | last_item = z_lib[select_idx:select_idx + 1] 233 | coreset_idx = [torch.tensor(select_idx)] 234 | min_distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True) 235 | 236 | if float16: 237 | last_item = last_item.half() 238 | z_lib = z_lib.half() 239 | min_distances = min_distances.half() 240 | if torch.cuda.is_available() and not force_cpu: 241 | last_item = last_item.to("cuda") 242 | z_lib = z_lib.to("cuda") 243 | min_distances = min_distances.to("cuda") 244 | 245 | for _ in tqdm(range(n - 1)): 246 | distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True) # broadcasting step 247 | min_distances = torch.minimum(distances, min_distances) # iterative step 248 | select_idx = torch.argmax(min_distances) # selection step 249 | 250 | # bookkeeping 251 | last_item = z_lib[select_idx:select_idx + 1] 252 | min_distances[select_idx] = 0 253 | coreset_idx.append(select_idx.to("cpu")) 254 | return torch.stack(coreset_idx) 255 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import datetime 3 | import os 4 | import time 5 | from collections import defaultdict, deque 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.distributed as dist 10 | # from torch._six import inf 11 | inf = float('inf') 12 | 13 | class SmoothedValue(object): 14 | """Track a series of values and provide access to smoothed values over a 15 | window or the global series average. 16 | """ 17 | 18 | def __init__(self, window_size=20, fmt=None): 19 | if fmt is None: 20 | fmt = "{median:.4f} ({global_avg:.4f})" 21 | self.deque = deque(maxlen=window_size) 22 | self.total = 0.0 23 | self.count = 0 24 | self.fmt = fmt 25 | 26 | def update(self, value, n=1): 27 | self.deque.append(value) 28 | self.count += n 29 | self.total += value * n 30 | 31 | def synchronize_between_processes(self): 32 | """ 33 | Warning: does not synchronize the deque! 34 | """ 35 | if not is_dist_avail_and_initialized(): 36 | return 37 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 38 | dist.barrier() 39 | dist.all_reduce(t) 40 | t = t.tolist() 41 | self.count = int(t[0]) 42 | self.total = t[1] 43 | 44 | @property 45 | def median(self): 46 | d = torch.tensor(list(self.deque)) 47 | return d.median().item() 48 | 49 | @property 50 | def avg(self): 51 | d = torch.tensor(list(self.deque), dtype=torch.float32) 52 | return d.mean().item() 53 | 54 | @property 55 | def global_avg(self): 56 | return self.total / self.count 57 | 58 | @property 59 | def max(self): 60 | return max(self.deque) 61 | 62 | @property 63 | def value(self): 64 | return self.deque[-1] 65 | 66 | def __str__(self): 67 | return self.fmt.format( 68 | median=self.median, 69 | avg=self.avg, 70 | global_avg=self.global_avg, 71 | max=self.max, 72 | value=self.value) 73 | 74 | 75 | class MetricLogger(object): 76 | def __init__(self, delimiter="\t"): 77 | self.meters = defaultdict(SmoothedValue) 78 | self.delimiter = delimiter 79 | 80 | def update(self, **kwargs): 81 | for k, v in kwargs.items(): 82 | if v is None: 83 | continue 84 | if isinstance(v, torch.Tensor): 85 | v = v.item() 86 | assert isinstance(v, (float, int)) 87 | self.meters[k].update(v) 88 | 89 | def __getattr__(self, attr): 90 | if attr in self.meters: 91 | return self.meters[attr] 92 | if attr in self.__dict__: 93 | return self.__dict__[attr] 94 | raise AttributeError("'{}' object has no attribute '{}'".format( 95 | type(self).__name__, attr)) 96 | 97 | def __str__(self): 98 | loss_str = [] 99 | for name, meter in self.meters.items(): 100 | loss_str.append( 101 | "{}: {}".format(name, str(meter)) 102 | ) 103 | return self.delimiter.join(loss_str) 104 | 105 | def synchronize_between_processes(self): 106 | for meter in self.meters.values(): 107 | meter.synchronize_between_processes() 108 | 109 | def add_meter(self, name, meter): 110 | self.meters[name] = meter 111 | 112 | def log_every(self, iterable, print_freq, header=None): 113 | i = 0 114 | if not header: 115 | header = '' 116 | start_time = time.time() 117 | end = time.time() 118 | iter_time = SmoothedValue(fmt='{avg:.4f}') 119 | data_time = SmoothedValue(fmt='{avg:.4f}') 120 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 121 | log_msg = [ 122 | header, 123 | '[{0' + space_fmt + '}/{1}]', 124 | 'eta: {eta}', 125 | '{meters}', 126 | 'time: {time}', 127 | 'data: {data}' 128 | ] 129 | if torch.cuda.is_available(): 130 | log_msg.append('max mem: {memory:.0f}') 131 | log_msg = self.delimiter.join(log_msg) 132 | MB = 1024.0 * 1024.0 133 | for obj in iterable: 134 | data_time.update(time.time() - end) 135 | yield obj 136 | iter_time.update(time.time() - end) 137 | if i % print_freq == 0 or i == len(iterable) - 1: 138 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 139 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 140 | if torch.cuda.is_available(): 141 | print(log_msg.format( 142 | i, len(iterable), eta=eta_string, 143 | meters=str(self), 144 | time=str(iter_time), data=str(data_time), 145 | memory=torch.cuda.max_memory_allocated() / MB)) 146 | else: 147 | print(log_msg.format( 148 | i, len(iterable), eta=eta_string, 149 | meters=str(self), 150 | time=str(iter_time), data=str(data_time))) 151 | i += 1 152 | end = time.time() 153 | total_time = time.time() - start_time 154 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 155 | print('{} Total time: {} ({:.4f} s / it)'.format( 156 | header, total_time_str, total_time / len(iterable))) 157 | 158 | 159 | def setup_for_distributed(is_master): 160 | """ 161 | This function disables printing when not in master process 162 | """ 163 | builtin_print = builtins.print 164 | 165 | def print(*args, **kwargs): 166 | force = kwargs.pop('force', False) 167 | force = force or (get_world_size() > 8) 168 | if is_master or force: 169 | now = datetime.datetime.now().time() 170 | builtin_print('[{}] '.format(now), end='') # print with time stamp 171 | builtin_print(*args, **kwargs) 172 | 173 | builtins.print = print 174 | 175 | 176 | def is_dist_avail_and_initialized(): 177 | if not dist.is_available(): 178 | return False 179 | if not dist.is_initialized(): 180 | return False 181 | return True 182 | 183 | 184 | def get_world_size(): 185 | if not is_dist_avail_and_initialized(): 186 | return 1 187 | return dist.get_world_size() 188 | 189 | 190 | def get_rank(): 191 | if not is_dist_avail_and_initialized(): 192 | return 0 193 | return dist.get_rank() 194 | 195 | 196 | def is_main_process(): 197 | return get_rank() == 0 198 | 199 | 200 | def save_on_master(*args, **kwargs): 201 | if is_main_process(): 202 | torch.save(*args, **kwargs) 203 | 204 | 205 | def init_distributed_mode(args): 206 | if args.dist_on_itp: 207 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 208 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 209 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 210 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 211 | os.environ['LOCAL_RANK'] = str(args.gpu) 212 | os.environ['RANK'] = str(args.rank) 213 | os.environ['WORLD_SIZE'] = str(args.world_size) 214 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 215 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 216 | args.rank = int(os.environ["RANK"]) 217 | args.world_size = int(os.environ['WORLD_SIZE']) 218 | args.gpu = int(os.environ['LOCAL_RANK']) 219 | elif 'SLURM_PROCID' in os.environ: 220 | args.rank = int(os.environ['SLURM_PROCID']) 221 | args.gpu = args.rank % torch.cuda.device_count() 222 | else: 223 | print('Not using distributed mode') 224 | setup_for_distributed(is_master=True) # hack 225 | args.distributed = False 226 | return 227 | 228 | args.distributed = True 229 | 230 | torch.cuda.set_device(args.gpu) 231 | args.dist_backend = 'nccl' 232 | print('| distributed init (rank {}): {}, gpu {}'.format( 233 | args.rank, args.dist_url, args.gpu), flush=True) 234 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 235 | world_size=args.world_size, rank=args.rank) 236 | torch.distributed.barrier() 237 | setup_for_distributed(args.rank == 0) 238 | 239 | 240 | class NativeScalerWithGradNormCount: 241 | state_dict_key = "amp_scaler" 242 | 243 | def __init__(self): 244 | self._scaler = torch.cuda.amp.GradScaler() 245 | 246 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 247 | self._scaler.scale(loss).backward(create_graph=create_graph) 248 | if update_grad: 249 | if clip_grad is not None: 250 | assert parameters is not None 251 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 252 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 253 | else: 254 | self._scaler.unscale_(optimizer) 255 | norm = get_grad_norm_(parameters) 256 | self._scaler.step(optimizer) 257 | self._scaler.update() 258 | else: 259 | norm = None 260 | return norm 261 | 262 | def state_dict(self): 263 | return self._scaler.state_dict() 264 | 265 | def load_state_dict(self, state_dict): 266 | self._scaler.load_state_dict(state_dict) 267 | 268 | 269 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 270 | if isinstance(parameters, torch.Tensor): 271 | parameters = [parameters] 272 | parameters = [p for p in parameters if p.grad is not None] 273 | norm_type = float(norm_type) 274 | if len(parameters) == 0: 275 | return torch.tensor(0.) 276 | device = parameters[0].grad.device 277 | if norm_type == inf: 278 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 279 | else: 280 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 281 | return total_norm 282 | 283 | 284 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 285 | output_dir = Path(args.output_dir) 286 | epoch_name = str(epoch) 287 | if loss_scaler is not None: 288 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 289 | for checkpoint_path in checkpoint_paths: 290 | to_save = { 291 | 'model': model_without_ddp.state_dict(), 292 | 'optimizer': optimizer.state_dict(), 293 | 'epoch': epoch, 294 | 'scaler': loss_scaler.state_dict(), 295 | 'args': args, 296 | } 297 | 298 | save_on_master(to_save, checkpoint_path) 299 | else: 300 | client_state = {'epoch': epoch} 301 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 302 | 303 | def save_model_gan(args, epoch, model, discriminator, model_without_ddp, discriminator_without_ddp, 304 | optimizer_g, optimizer_d, loss_scaler): 305 | output_dir = Path(args.output_dir) 306 | epoch_name = str(epoch) 307 | if loss_scaler is not None: 308 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 309 | for checkpoint_path in checkpoint_paths: 310 | to_save = { 311 | 'model': model_without_ddp.state_dict(), 312 | 'discriminator_without_ddp': discriminator_without_ddp.state_dict(), 313 | 'optimizer_g': optimizer_g.state_dict(), 314 | 'optimizer_d': optimizer_d.state_dict(), 315 | 'epoch': epoch, 316 | 'scaler': loss_scaler.state_dict(), 317 | 'args': args, 318 | } 319 | 320 | save_on_master(to_save, checkpoint_path) 321 | else: 322 | client_state = {'epoch': epoch} 323 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 324 | discriminator.save_checkpoint(save_dir=args.output_dir, tag="checkpoint_d-%s" % epoch_name, client_state=client_state) 325 | 326 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 327 | if args.resume: 328 | if args.resume.startswith('https'): 329 | checkpoint = torch.hub.load_state_dict_from_url( 330 | args.resume, map_location='cpu', check_hash=True) 331 | else: 332 | checkpoint = torch.load(args.resume, map_location='cpu') 333 | model_without_ddp.load_state_dict(checkpoint['model']) 334 | print("Resume checkpoint %s" % args.resume) 335 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 336 | optimizer.load_state_dict(checkpoint['optimizer']) 337 | args.start_epoch = checkpoint['epoch'] + 1 338 | if 'scaler' in checkpoint: 339 | loss_scaler.load_state_dict(checkpoint['scaler']) 340 | print("With optim & sched!") 341 | 342 | def load_model_gan(args, model_without_ddp, discriminator_without_ddp, 343 | optimizer_g, optimizer_d, loss_scaler): 344 | if args.resume: 345 | if args.resume.startswith('https'): 346 | checkpoint = torch.hub.load_state_dict_from_url( 347 | args.resume, map_location='cpu', check_hash=True) 348 | else: 349 | checkpoint = torch.load(args.resume, map_location='cpu') 350 | model_without_ddp.load_state_dict(checkpoint['model']) 351 | discriminator_without_ddp.load_state_dict(checkpoint['discriminator']) 352 | print("Resume checkpoint %s" % args.resume) 353 | if 'optimizer_d' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 354 | optimizer_d.load_state_dict(checkpoint['optimizer_d']) 355 | if 'optimizer_g' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 356 | optimizer_g.load_state_dict(checkpoint['optimizer_g']) 357 | args.start_epoch = checkpoint['epoch'] + 1 358 | if 'scaler' in checkpoint: 359 | loss_scaler.load_state_dict(checkpoint['scaler']) 360 | print("With optim & sched!") 361 | 362 | 363 | 364 | def all_reduce_mean(x): 365 | world_size = get_world_size() 366 | if world_size > 1: 367 | x_reduce = torch.tensor(x).cuda() 368 | dist.all_reduce(x_reduce) 369 | x_reduce /= world_size 370 | return x_reduce.item() 371 | else: 372 | return x -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | from timm.models.layers import DropPath, trunc_normal_ 5 | from pointnet2_ops import pointnet2_utils 6 | from knn_cuda import KNN 7 | 8 | class Model(torch.nn.Module): 9 | 10 | def __init__(self, device, rgb_backbone_name='vit_base_patch8_224_dino', out_indices=None, 11 | checkpoint_path='/fuxi_team14_intern/m3dm/checkpoints/dino_vitbase8_pretrain.pth', 12 | pool_last=False, xyz_backbone_name='Point_MAE', group_size=128, num_group=1024): 13 | super().__init__() 14 | # 'vit_base_patch8_224_dino' 15 | # Determine if to output features. 16 | self.device = device 17 | 18 | kwargs = {'features_only': True if out_indices else False} 19 | if out_indices: 20 | kwargs.update({'out_indices': out_indices}) 21 | 22 | ## RGB backbone 23 | self.rgb_backbone = timm.create_model(model_name=rgb_backbone_name, pretrained=False, 24 | checkpoint_path=checkpoint_path, 25 | **kwargs) 26 | 27 | ## XYZ backbone 28 | 29 | if xyz_backbone_name == 'Point_MAE': 30 | self.xyz_backbone = PointTransformer(group_size=group_size, num_group=num_group) 31 | self.xyz_backbone.load_model_from_ckpt("/fuxi_team14_intern/m3dm/checkpoints/pointmae_pretrain.pth") 32 | elif xyz_backbone_name == 'Point-BERT': 33 | self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group, encoder_dims=256) 34 | self.xyz_backbone.load_model_from_pb_ckpt("/fuxi_team14_intern/m3dm/checkpoints/Point-BERT.pth") 35 | elif xyz_backbone_name == 'FPFH': 36 | self.xyz_backbone=FPFH(group_size=group_size, num_group=num_group,voxel_size=0.05) 37 | #self.xyz_backbone.load_model_from_pb_ckpt("/workspace/data2/checkpoints/Point-BERT.pth") 38 | 39 | 40 | 41 | 42 | def forward_rgb_features(self, x): 43 | x = self.rgb_backbone.patch_embed(x) 44 | x = self.rgb_backbone._pos_embed(x) 45 | x = self.rgb_backbone.norm_pre(x) 46 | if self.rgb_backbone.grad_checkpointing and not torch.jit.is_scripting(): 47 | x = checkpoint_seq(self.blocks, x) 48 | else: 49 | x = self.rgb_backbone.blocks(x) 50 | x = self.rgb_backbone.norm(x) 51 | 52 | feat = x[:,1:].permute(0, 2, 1).view(1, -1, 28, 28) 53 | return feat 54 | 55 | 56 | def forward(self, rgb, xyz): 57 | 58 | rgb_features = self.forward_rgb_features(rgb) 59 | 60 | xyz_features, center, ori_idx, center_idx = self.xyz_backbone(xyz) 61 | 62 | xyz_features.permute(0, 2, 1) 63 | 64 | return rgb_features, xyz_features, center, ori_idx, center_idx 65 | 66 | 67 | 68 | def fps(data, number): 69 | ''' 70 | data B N 3 71 | number int 72 | ''' 73 | fps_idx = pointnet2_utils.furthest_point_sample(data, number) 74 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() 75 | return fps_data, fps_idx 76 | 77 | class Group(nn.Module): 78 | def __init__(self, num_group, group_size): 79 | super().__init__() 80 | self.num_group = num_group 81 | self.group_size = group_size 82 | self.knn = KNN(k=self.group_size, transpose_mode=True) 83 | 84 | def forward(self, xyz): 85 | ''' 86 | input: B N 3 87 | --------------------------- 88 | output: B G M 3 89 | center : B G 3 90 | ''' 91 | batch_size, num_points, _ = xyz.shape 92 | # fps the centers out 93 | center, center_idx = fps(xyz.contiguous(), self.num_group) # B G 3 94 | # knn to get the neighborhood 95 | _, idx = self.knn(xyz, center) # B G M 96 | assert idx.size(1) == self.num_group 97 | assert idx.size(2) == self.group_size 98 | ori_idx = idx 99 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 100 | idx = idx + idx_base 101 | idx = idx.view(-1) 102 | neighborhood = xyz.reshape(batch_size * num_points, -1)[idx, :] 103 | neighborhood = neighborhood.reshape(batch_size, self.num_group, self.group_size, 3).contiguous() 104 | # normalize 105 | neighborhood = neighborhood - center.unsqueeze(2) 106 | return neighborhood, center, ori_idx, center_idx 107 | 108 | 109 | class Encoder(nn.Module): 110 | def __init__(self, encoder_channel): 111 | super().__init__() 112 | self.encoder_channel = encoder_channel 113 | self.first_conv = nn.Sequential( 114 | nn.Conv1d(3, 128, 1), 115 | nn.BatchNorm1d(128), 116 | nn.ReLU(inplace=True), 117 | nn.Conv1d(128, 256, 1) 118 | ) 119 | self.second_conv = nn.Sequential( 120 | nn.Conv1d(512, 512, 1), 121 | nn.BatchNorm1d(512), 122 | nn.ReLU(inplace=True), 123 | nn.Conv1d(512, self.encoder_channel, 1) 124 | ) 125 | 126 | def forward(self, point_groups): 127 | ''' 128 | point_groups : B G N 3 129 | ----------------- 130 | feature_global : B G C 131 | ''' 132 | bs, g, n, _ = point_groups.shape 133 | point_groups = point_groups.reshape(bs * g, n, 3) 134 | # encoder 135 | feature = self.first_conv(point_groups.transpose(2, 1)) 136 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] 137 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) 138 | feature = self.second_conv(feature) 139 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] 140 | return feature_global.reshape(bs, g, self.encoder_channel) 141 | 142 | 143 | class Mlp(nn.Module): 144 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 145 | super().__init__() 146 | out_features = out_features or in_features 147 | hidden_features = hidden_features or in_features 148 | self.fc1 = nn.Linear(in_features, hidden_features) 149 | self.act = act_layer() 150 | self.fc2 = nn.Linear(hidden_features, out_features) 151 | self.drop = nn.Dropout(drop) 152 | 153 | def forward(self, x): 154 | x = self.fc1(x) 155 | x = self.act(x) 156 | x = self.drop(x) 157 | x = self.fc2(x) 158 | x = self.drop(x) 159 | return x 160 | 161 | 162 | class Attention(nn.Module): 163 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 164 | super().__init__() 165 | self.num_heads = num_heads 166 | head_dim = dim // num_heads 167 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 168 | self.scale = qk_scale or head_dim ** -0.5 169 | 170 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 171 | self.attn_drop = nn.Dropout(attn_drop) 172 | self.proj = nn.Linear(dim, dim) 173 | self.proj_drop = nn.Dropout(proj_drop) 174 | 175 | def forward(self, x): 176 | B, N, C = x.shape 177 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 178 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 179 | 180 | attn = (q * self.scale) @ k.transpose(-2, -1) 181 | attn = attn.softmax(dim=-1) 182 | attn = self.attn_drop(attn) 183 | 184 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 185 | x = self.proj(x) 186 | x = self.proj_drop(x) 187 | return x 188 | 189 | 190 | class Block(nn.Module): 191 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 192 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 193 | super().__init__() 194 | self.norm1 = norm_layer(dim) 195 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 196 | self.norm2 = norm_layer(dim) 197 | mlp_hidden_dim = int(dim * mlp_ratio) 198 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 199 | 200 | self.attn = Attention( 201 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 202 | 203 | def forward(self, x): 204 | x = x + self.drop_path(self.attn(self.norm1(x))) 205 | x = x + self.drop_path(self.mlp(self.norm2(x))) 206 | return x 207 | 208 | 209 | class TransformerEncoder(nn.Module): 210 | """ Transformer Encoder without hierarchical structure 211 | """ 212 | 213 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 214 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): 215 | super().__init__() 216 | 217 | self.blocks = nn.ModuleList([ 218 | Block( 219 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 220 | drop=drop_rate, attn_drop=attn_drop_rate, 221 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 222 | ) 223 | for i in range(depth)]) 224 | 225 | def forward(self, x, pos): 226 | feature_list = [] 227 | fetch_idx = [3, 7, 11] 228 | for i, block in enumerate(self.blocks): 229 | x = block(x + pos) 230 | if i in fetch_idx: 231 | feature_list.append(x) 232 | return feature_list 233 | 234 | 235 | class PointTransformer(nn.Module): 236 | def __init__(self, group_size=128, num_group=1024, encoder_dims=384): 237 | super().__init__() 238 | 239 | self.trans_dim = 384 240 | self.depth = 12 241 | self.drop_path_rate = 0.1 242 | self.num_heads = 6 243 | 244 | self.group_size = group_size 245 | self.num_group = num_group 246 | # grouper 247 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 248 | # define the encoder 249 | self.encoder_dims = encoder_dims 250 | if self.encoder_dims != self.trans_dim: 251 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 252 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) 253 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim) 254 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 255 | # bridge encoder and transformer 256 | 257 | self.pos_embed = nn.Sequential( 258 | nn.Linear(3, 128), 259 | nn.GELU(), 260 | nn.Linear(128, self.trans_dim) 261 | ) 262 | 263 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 264 | self.blocks = TransformerEncoder( 265 | embed_dim=self.trans_dim, 266 | depth=self.depth, 267 | drop_path_rate=dpr, 268 | num_heads=self.num_heads 269 | ) 270 | 271 | self.norm = nn.LayerNorm(self.trans_dim) 272 | 273 | def load_model_from_ckpt(self, bert_ckpt_path): 274 | if bert_ckpt_path is not None: 275 | ckpt = torch.load(bert_ckpt_path) 276 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} 277 | 278 | for k in list(base_ckpt.keys()): 279 | if k.startswith('MAE_encoder'): 280 | base_ckpt[k[len('MAE_encoder.'):]] = base_ckpt[k] 281 | del base_ckpt[k] 282 | elif k.startswith('base_model'): 283 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k] 284 | del base_ckpt[k] 285 | 286 | incompatible = self.load_state_dict(base_ckpt, strict=False) 287 | 288 | #if incompatible.missing_keys: 289 | # print('missing_keys') 290 | # print( 291 | # incompatible.missing_keys 292 | # ) 293 | #if incompatible.unexpected_keys: 294 | # print('unexpected_keys') 295 | # print( 296 | # incompatible.unexpected_keys 297 | 298 | # ) 299 | 300 | # print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}') 301 | 302 | def load_model_from_pb_ckpt(self, bert_ckpt_path): 303 | ckpt = torch.load(bert_ckpt_path) 304 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} 305 | for k in list(base_ckpt.keys()): 306 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'): 307 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k] 308 | elif k.startswith('base_model'): 309 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k] 310 | del base_ckpt[k] 311 | 312 | incompatible = self.load_state_dict(base_ckpt, strict=False) 313 | 314 | if incompatible.missing_keys: 315 | print('missing_keys') 316 | print( 317 | incompatible.missing_keys 318 | ) 319 | if incompatible.unexpected_keys: 320 | print('unexpected_keys') 321 | print( 322 | incompatible.unexpected_keys 323 | 324 | ) 325 | 326 | print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}') 327 | 328 | 329 | def forward(self, pts): 330 | if self.encoder_dims != self.trans_dim: 331 | B,C,N = pts.shape 332 | pts = pts.transpose(-1, -2) # B N 3 333 | # divide the point clo ud in the same form. This is important 334 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts) 335 | # # generate mask 336 | # bool_masked_pos = self._mask_center(center, no_mask = False) # B G 337 | # encoder the input cloud blocks 338 | group_input_tokens = self.encoder(neighborhood) # B G N 339 | group_input_tokens = self.reduce_dim(group_input_tokens) 340 | # prepare cls 341 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) 342 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) 343 | # add pos embedding 344 | pos = self.pos_embed(center) 345 | # final input 346 | x = torch.cat((cls_tokens, group_input_tokens), dim=1) 347 | pos = torch.cat((cls_pos, pos), dim=1) 348 | # transformer 349 | feature_list = self.blocks(x, pos) 350 | feature_list = [self.norm(x)[:,1:].transpose(-1, -2).contiguous() for x in feature_list] 351 | x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152 352 | return x, center, ori_idx, center_idx 353 | else: 354 | B, C, N = pts.shape 355 | pts = pts.transpose(-1, -2) # B N 3 356 | # divide the point clo ud in the same form. This is important 357 | neighborhood, center, ori_idx, center_idx = self.group_divider(pts) 358 | 359 | group_input_tokens = self.encoder(neighborhood) # B G N 360 | 361 | pos = self.pos_embed(center) 362 | # final input 363 | x = group_input_tokens 364 | # transformer 365 | feature_list = self.blocks(x, pos) 366 | feature_list = [self.norm(x).transpose(-1, -2).contiguous() for x in feature_list] 367 | x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152 368 | return x, center, ori_idx, center_idx 369 | 370 | # class FPFH(nn.Module): 371 | # def __init__(self, group_size=32, num_group=512, voxel_size=0.05): 372 | # super(FPFH, self).__init__() 373 | # self.group_size = group_size 374 | # self.num_group = num_group 375 | # self.voxel_size = voxel_size 376 | # self.resize = nn.AdaptiveAvgPool2d((28, 28)) 377 | # self.average = nn.AvgPool2d(2, 2) 378 | 379 | # def organized_pc_to_unorganized_pc(self, organized_pc): 380 | # return organized_pc.reshape(-1, 3) 381 | 382 | # def get_fpfh_features(self, organized_pc): 383 | # organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() 384 | # unorganized_pc = self.organized_pc_to_unorganized_pc(organized_pc_np) 385 | 386 | # nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] 387 | # unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :] 388 | 389 | # o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros)) 390 | 391 | # radius_normal = self.voxel_size * 2 392 | # o3d_pc.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30)) 393 | 394 | # radius_feature = self.voxel_size * 5 395 | # pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature( 396 | # o3d_pc, 397 | # o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100) 398 | # ) 399 | # fpfh = pcd_fpfh.data.T 400 | 401 | # full_fpfh = np.zeros((unorganized_pc.shape[0], fpfh.shape[1]), dtype=fpfh.dtype) 402 | # full_fpfh[nonzero_indices, :] = fpfh 403 | # full_fpfh_reshaped = full_fpfh.reshape((organized_pc_np.shape[0], organized_pc_np.shape[1], fpfh.shape[1])) 404 | # full_fpfh_tensor = torch.tensor(full_fpfh_reshaped).permute(2, 0, 1).unsqueeze(dim=0) 405 | 406 | # return full_fpfh_tensor 407 | 408 | # def forward(self, xyz): 409 | # batch_size, _, height, width = xyz.shape 410 | 411 | # # Compute FPFH features 412 | # xyz_features = self.get_fpfh_features(xyz) 413 | 414 | # # Resize and average 415 | # xyz_features_resized = self.resize(self.average(xyz_features)) 416 | 417 | # # Randomly sample center points 418 | # center_idx = torch.randperm(height * width)[:self.num_group] 419 | # center = xyz.view(batch_size, 3, -1).permute(0, 2, 1)[:, center_idx, :] 420 | 421 | # # Create original indices 422 | # ori_idx = torch.arange(height * width).view(1, height, width).expand(batch_size, -1, -1) 423 | 424 | # return xyz_features_resized, center, ori_idx, center_idx 425 | 426 | # def add_sample_to_mem_bank(self, sample): 427 | # fpfh_feature_maps = self.get_fpfh_features(sample[1]) 428 | # fpfh_feature_maps_resized = self.resize(self.average(fpfh_feature_maps)) 429 | # fpfh_patch = fpfh_feature_maps_resized.reshape(fpfh_feature_maps_resized.shape[1], -1).T 430 | # return fpfh_patch 431 | 432 | # def predict(self, sample): 433 | # depth_feature_maps = self.get_fpfh_features(sample[1]) 434 | # depth_feature_maps_resized = self.resize(self.average(depth_feature_maps)) 435 | # patch = depth_feature_maps_resized.reshape(depth_feature_maps_resized.shape[1], -1).T 436 | # return patch, depth_feature_maps_resized.shape[-2:] 437 | import numpy as np 438 | import open3d as o3d 439 | class FPFH(nn.Module): 440 | def __init__(self, group_size=32, num_group=512, voxel_size=0.05): 441 | super(FPFH, self).__init__() 442 | self.group_size = group_size 443 | self.num_group = num_group 444 | self.voxel_size = voxel_size 445 | 446 | def get_fpfh_features(self, unorganized_pc): 447 | # 确保 unorganized_pc 是 CPU 上的 numpy 数组 448 | if isinstance(unorganized_pc, torch.Tensor): 449 | unorganized_pc = unorganized_pc.cpu().numpy() 450 | 451 | # 确保形状是 (N, 3) 452 | if unorganized_pc.shape[0] == 3: 453 | unorganized_pc = unorganized_pc.T 454 | 455 | # 移除零点 456 | nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] 457 | unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :] 458 | 459 | # 确保数据类型是 float64 460 | unorganized_pc_no_zeros = unorganized_pc_no_zeros.astype(np.float64) 461 | 462 | o3d_pc = o3d.geometry.PointCloud() 463 | o3d_pc.points = o3d.utility.Vector3dVector(unorganized_pc_no_zeros) 464 | 465 | radius_normal = self.voxel_size * 2 466 | o3d_pc.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30)) 467 | 468 | radius_feature = self.voxel_size * 5 469 | pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature( 470 | o3d_pc, 471 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100) 472 | ) 473 | fpfh = pcd_fpfh.data # 形状为 (33, M),M 是非零点的数量 474 | 475 | # 将 FPFH 特征转换为 torch.Tensor 476 | fpfh_tensor = torch.tensor(fpfh, dtype=torch.float32) 477 | 478 | return fpfh_tensor 479 | 480 | def forward(self, xyz): 481 | # 假设 xyz 是形状为 (B, N, 3) 的 torch.Tensor,其中 B 是批量大小,N 是点的数量 482 | xyz = xyz.permute(0, 2, 1) 483 | batch_size,num_points, _ = xyz.shape 484 | 485 | # 计算 FPFH 特征 486 | fpfh_features = [] 487 | for i in range(batch_size): 488 | fpfh = self.get_fpfh_features(xyz[i]) 489 | fpfh_features.append(fpfh) 490 | 491 | fpfh_features = torch.stack(fpfh_features) 492 | 493 | # 随机采样中心点 494 | center_idx = torch.randperm(num_points)[:self.num_group] 495 | center = xyz[:, center_idx, :] 496 | 497 | ori_idx = torch.arange(num_points) 498 | 499 | return fpfh_features, center, ori_idx,center_idx 500 | 501 | def add_sample_to_mem_bank(self, sample): 502 | print(sample.shape) 503 | # 假设 sample 是形状为 (N, 3) 的 torch.Tensor 504 | fpfh_features = self.get_fpfh_features(sample) 505 | return fpfh_features 506 | 507 | def predict(self, sample): 508 | # 假设 sample 是形状为 (N, 3) 的 torch.Tensor 509 | fpfh_features = self.get_fpfh_features(sample) 510 | return fpfh_features, None # 返回 None 作为形状,因为无序点云没有固定的空间维度 511 | -------------------------------------------------------------------------------- /dataset2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torchvision import transforms 4 | import glob 5 | from torch.utils.data import Dataset 6 | from utils.mvtec3d_util import * 7 | from torch.utils.data import DataLoader 8 | import numpy as np 9 | import math 10 | import cv2 11 | import argparse 12 | from matplotlib import pyplot as plt 13 | 14 | def eyecandies_classes(): 15 | return [ 16 | 'CandyCane', 17 | 'ChocolateCookie', 18 | 'ChocolatePraline', 19 | 'Confetto', 20 | 'GummyBear', 21 | 'HazelnutTruffle', 22 | 'LicoriceSandwich', 23 | 'Lollipop', 24 | 'Marshmallow', 25 | 'PeppermintCandy', 26 | ] 27 | 28 | def mvtec3d_classes(): 29 | return [ 30 | "bagel", 31 | "cable_gland", 32 | "carrot", 33 | "cookie", 34 | "dowel", 35 | "foam", 36 | "peach", 37 | "potato", 38 | "rope", 39 | "tire", 40 | ] 41 | def test_3d_classes(): 42 | return [ 43 | 'audio_jack_socket', 44 | 'common_mode_filter', 45 | 'connector_housing-female', 46 | 'crimp_st_cable_mount_box', 47 | 'dc_power_connector', 48 | 'fork_crimp_terminal', 49 | 'headphone_jack_socket', 50 | 'miniature_lifting_motor', 51 | 'purple-clay-pot', 52 | 'power_jack', 53 | 54 | 'ethernet_connector', 55 | 'ferrite_bead', 56 | 'fuse_holder', 57 | 'humidity_sensor', 58 | 'knob-cap', 59 | 'lattice_block_plug', 60 | 'lego_pin_connector_plate', 61 | 'lego_propeller', 62 | 'limit-switch', 63 | 'telephone_spring_switch', 64 | # "bagel", 65 | # "cable_gland", 66 | # "carrot", 67 | # "cookie", 68 | # "dowel", 69 | # "foam", 70 | # "peach", 71 | # "potato", 72 | # "rope", 73 | # "tire", 74 | ] 75 | 76 | RGB_SIZE = 224 77 | 78 | class BaseAnomalyDetectionDataset(Dataset): 79 | 80 | def __init__(self, split, class_name, img_size,downsampling, angle, small,dataset_path='datasets/eyecandies_preprocessed'): 81 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406] 82 | self.IMAGENET_STD = [0.229, 0.224, 0.225] 83 | self.cls = class_name 84 | self.size = img_size 85 | self.img_path = os.path.join(dataset_path, self.cls, split) 86 | self.downsampling = downsampling 87 | self.angle = angle 88 | self.small = small 89 | self.rgb_transform = transforms.Compose( 90 | [transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)]) 93 | def analyze_depth_importance(self,organized_pc): 94 | """分析深度图的重要性,返回行列的重要性标记""" 95 | depth = organized_pc[:,:,2] # 获取深度通道 96 | 97 | # 计算每行的非零点占比 98 | row_importance = np.mean(depth != 0, axis=1) 99 | col_importance = np.mean(depth != 0, axis=0) 100 | 101 | # 使用中位数作为分界线 102 | row_median = np.median(row_importance[row_importance > 0]) 103 | col_median = np.median(col_importance[col_importance > 0]) 104 | 105 | # 标记重要性(True表示重要行/列) 106 | important_rows = row_importance >= row_median 107 | important_cols = col_importance >= col_median 108 | 109 | return important_rows, important_cols 110 | 111 | def smart_downsample(self, organized_pc, target_factor): 112 | """智能降采样,重要区域降采样更多,保持总体降采样因子不变""" 113 | important_rows, important_cols = self.analyze_depth_importance(organized_pc) 114 | 115 | # 获取重要和不重要的行列索引 116 | important_row_indices = np.where(important_rows)[0] 117 | unimportant_row_indices = np.where(~important_rows)[0] 118 | important_col_indices = np.where(important_cols)[0] 119 | unimportant_col_indices = np.where(~important_cols)[0] 120 | 121 | # 计算原始的重要和不重要区域的比例 122 | total_rows = len(important_row_indices) + len(unimportant_row_indices) 123 | total_cols = len(important_col_indices) + len(unimportant_col_indices) 124 | factor = int(math.sqrt(target_factor)) 125 | # 目标总行列数 126 | target_total_rows = total_rows // factor 127 | target_total_cols = total_cols // factor 128 | 129 | # 设置重要区域的更高降采样率(比如降采样率为1/4) 130 | important_factor = factor*2 131 | 132 | # 计算重要区域的目标数量 133 | n_important_rows = len(important_row_indices) // important_factor 134 | n_important_cols = len(important_col_indices) // important_factor 135 | 136 | # 计算不重要区域需要保留的数量(确保总数符合目标) 137 | n_unimportant_rows = target_total_rows - n_important_rows 138 | n_unimportant_cols = target_total_cols - n_important_cols 139 | 140 | # 确保不重要区域的数量不会超过原始数量 141 | n_unimportant_rows = min(n_unimportant_rows, len(unimportant_row_indices)) 142 | n_unimportant_cols = min(n_unimportant_cols, len(unimportant_col_indices)) 143 | 144 | # 选择行 145 | selected_important_rows = np.linspace(0, len(important_row_indices)-1, n_important_rows, dtype=int) 146 | selected_important_rows = important_row_indices[selected_important_rows] 147 | 148 | selected_unimportant_rows = np.linspace(0, len(unimportant_row_indices)-1, n_unimportant_rows, dtype=int) 149 | selected_unimportant_rows = unimportant_row_indices[selected_unimportant_rows] 150 | 151 | # 选择列 152 | selected_important_cols = np.linspace(0, len(important_col_indices)-1, n_important_cols, dtype=int) 153 | selected_important_cols = important_col_indices[selected_important_cols] 154 | 155 | selected_unimportant_cols = np.linspace(0, len(unimportant_col_indices)-1, n_unimportant_cols, dtype=int) 156 | selected_unimportant_cols = unimportant_col_indices[selected_unimportant_cols] 157 | 158 | # 合并选择的行和列 159 | selected_rows = np.sort(np.concatenate([selected_important_rows, selected_unimportant_rows])) 160 | selected_cols = np.sort(np.concatenate([selected_important_cols, selected_unimportant_cols])) 161 | 162 | # 打印降采样信息 163 | print(f"原始大小: {organized_pc.shape[:2]}") 164 | print(f"重要行: {len(important_row_indices)} -> {len(selected_important_rows)} (1/{important_factor})") 165 | print(f"不重要行: {len(unimportant_row_indices)} -> {len(selected_unimportant_rows)}") 166 | print(f"重要列: {len(important_col_indices)} -> {len(selected_important_cols)} (1/{important_factor})") 167 | print(f"不重要列: {len(unimportant_col_indices)} -> {len(selected_unimportant_cols)}") 168 | print(f"最终大小: {len(selected_rows)}x{len(selected_cols)}") 169 | print(f"实际降采样因子: {(total_rows*total_cols)/(len(selected_rows)*len(selected_cols)):.2f}") 170 | 171 | return organized_pc[selected_rows][:, selected_cols] 172 | 173 | def get_matrix(self, image, angle): 174 | image = self.pillow_to_opencv(image) 175 | (h, w) = image.shape[:2] 176 | src_points = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 177 | dst_points = self.calculate_destination_points(0, w, 0, h, angle) 178 | M = cv2.getPerspectiveTransform(src_points, dst_points) 179 | return M 180 | 181 | def calculate_destination_points(self, left, right, top, bottom, angle): 182 | # 计算中心点 183 | center_x = (left + right) / 2 184 | center_y = (top + bottom) / 2 185 | 186 | # 计算角度的弧度值 187 | angle_rad = math.radians(angle) 188 | 189 | # 计算目标点 190 | dst_points = [] 191 | for x, y in [(left, top), (right, top), (right, bottom), (left, bottom)]: 192 | new_x = center_x + (x - center_x) * math.cos(angle_rad) - (y - center_y) * math.sin(angle_rad) 193 | new_y = center_y + (x - center_x) * math.sin(angle_rad) + (y - center_y) * math.cos(angle_rad) 194 | dst_points.append([new_x, new_y]) 195 | 196 | return np.float32(dst_points) 197 | 198 | def perspective_transform(self, image, matrix): 199 | """ 200 | :param image_path: 输入图像路径 201 | :param angle: 旋转角度 202 | :param save_type: 保存的图片类型 203 | :return: 输出图像 204 | """ 205 | # 读取图像 206 | image = self.pillow_to_opencv(image) 207 | (h, w) = image.shape[:2] 208 | # 执行透视变换 209 | transformed = cv2.warpPerspective(image, matrix, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=image[0, 0].tolist()) 210 | transformed = self.opencv_to_pillow(transformed) 211 | return transformed 212 | 213 | def opencv_to_pillow(self, opencv_image): 214 | """ 215 | 将 OpenCV 图像转换为 Pillow 图像。 216 | 217 | :param opencv_image: OpenCV 图像对象(BGR 格式) 218 | :return: Pillow 图像对象 219 | """ 220 | # 检查是否为彩色图像并转换通道顺序 221 | if len(opencv_image.shape) == 3: # 彩色图像 222 | opencv_image_rgb = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB) 223 | return Image.fromarray(opencv_image_rgb) 224 | else: # 灰度图像 225 | return Image.fromarray(opencv_image) 226 | 227 | def pillow_to_opencv(self, pil_image): 228 | """ 229 | 将 Pillow 图像转换为 OpenCV 图像。 230 | 231 | :param pil_image: Pillow 图像对象 232 | :return: OpenCV 图像对象(BGR 格式) 233 | """ 234 | # 将 Pillow 图像转换为 numpy 数组 235 | opencv_image = np.array(pil_image) 236 | 237 | # 如果是 RGB 图像,转换为 BGR 格式 238 | if opencv_image.ndim == 3 and opencv_image.shape[2] == 3: # 检查是否为 RGB 图像 239 | opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR) 240 | 241 | return opencv_image 242 | 243 | class PreTrainTensorDataset(Dataset): 244 | def __init__(self, root_path): 245 | super().__init__() 246 | self.root_path = root_path 247 | self.tensor_paths = os.listdir(self.root_path) 248 | 249 | 250 | def __len__(self): 251 | return len(self.tensor_paths) 252 | 253 | def __getitem__(self, idx): 254 | tensor_path = self.tensor_paths[idx] 255 | 256 | tensor = torch.load(os.path.join(self.root_path, tensor_path)) 257 | 258 | label = 0 259 | 260 | return tensor, label 261 | 262 | class TrainDataset(BaseAnomalyDetectionDataset): 263 | def __init__(self, class_name, img_size,downsampling,angle, small, dataset_path='datasets/eyecandies_preprocessed'): 264 | super().__init__(split="train", class_name=class_name, img_size=img_size,downsampling=downsampling, angle=angle, small=small, dataset_path=dataset_path) 265 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 266 | 267 | def load_dataset(self): 268 | img_tot_paths = [] 269 | tot_labels = [] 270 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png") 271 | rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb', '*', '*L05*RGB*.jpg')) 272 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff") 273 | ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg") 274 | rgb_paths.sort() 275 | tiff_paths.sort() 276 | ps_paths.sort() 277 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths)) 278 | # sample_paths = list(zip(rgb_paths, tiff_paths)) 279 | img_tot_paths.extend(sample_paths) 280 | tot_labels.extend([0] * len(sample_paths)) 281 | # img_tot_paths = img_tot_paths[0:10] 282 | # tot_labels = tot_labels[0:10] 283 | return img_tot_paths, tot_labels 284 | 285 | def __len__(self): 286 | return len(self.img_paths) 287 | 288 | def __getitem__(self, idx): 289 | img_path, label = self.img_paths[idx], self.labels[idx] 290 | rgb_path = img_path[0] 291 | tiff_path = img_path[1] 292 | ps_path = img_path[2] 293 | img = Image.open(rgb_path).convert('RGB') 294 | # add rotation 295 | # matrix = self.get_matrix(img, self.angle) 296 | # img = self.perspective_transform(img, matrix) 297 | 298 | img = self.rgb_transform(img) 299 | ps= Image.open(ps_path).convert('RGB') 300 | # # add rotation 301 | # ps = self.perspective_transform(ps, matrix) 302 | 303 | ps = self.rgb_transform(ps) 304 | if self.downsampling > 1: 305 | organized_pc = read_tiff_organized_pc(tiff_path) 306 | factor1 = int(math.floor(math.sqrt(self.downsampling))) 307 | factor2 = int(math.ceil(self.downsampling / factor1)) 308 | organized_pc = organized_pc[::factor1, ::factor2] 309 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling) 310 | else: 311 | organized_pc = read_tiff_organized_pc(tiff_path) 312 | 313 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) 314 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 315 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) 316 | resized_organized_pc = resized_organized_pc.clone().detach().float() 317 | 318 | 319 | return (img, resized_organized_pc, resized_depth_map_3channel,ps), label 320 | 321 | 322 | class TestDataset(BaseAnomalyDetectionDataset): 323 | def __init__(self, class_name, img_size,downsampling,angle,small,dataset_path='datasets/eyecandies_preprocessed'): 324 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path) 325 | self.gt_transform = transforms.Compose([ 326 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST), 327 | transforms.ToTensor()]) 328 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 329 | 330 | def load_dataset(self): 331 | img_tot_paths = [] 332 | gt_tot_paths = [] 333 | tot_labels = [] 334 | ps_tot_paths = [] 335 | defect_types = os.listdir(self.img_path) 336 | print(defect_types) 337 | # 如果types不为NONE,只保留GOOD和指定的types 338 | for defect_type in defect_types: 339 | if defect_type == 'good': 340 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png") 341 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 342 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 343 | #gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png") 344 | ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg") 345 | rgb_paths.sort() 346 | tiff_paths.sort() 347 | ps_paths.sort() 348 | # 只保留前5个样本 349 | if self.small: 350 | rgb_paths = rgb_paths[:5] 351 | tiff_paths = tiff_paths[:5] 352 | ps_paths = ps_paths[:5] 353 | sample_paths = list(zip(rgb_paths, tiff_paths,ps_paths)) 354 | img_tot_paths.extend(sample_paths) 355 | gt_tot_paths.extend([0] * len(sample_paths)) 356 | tot_labels.extend([0] * len(sample_paths)) 357 | else: 358 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") 359 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 360 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 361 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png") 362 | ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.png") 363 | rgb_paths.sort() 364 | tiff_paths.sort() 365 | gt_paths.sort() 366 | ps_paths.sort() 367 | # if self.small: 368 | # # 检查每个gt mask中缺陷占比 369 | # valid_indices = [] 370 | # for i, gt_path in enumerate(gt_paths): 371 | # gt_mask = np.array(Image.open(gt_path)) 372 | # total_pixels = gt_mask.shape[0] * gt_mask.shape[1] 373 | # threshold = int(total_pixels * 0.005) # 计算1%像素数量 374 | # defect_pixels = np.sum(gt_mask > 0) 375 | # if defect_pixels <= threshold: # 直接比较像素数量 376 | # valid_indices.append(i) 377 | 378 | # # 只保留缺陷占比<=1%的样本 379 | # rgb_paths = [rgb_paths[i] for i in valid_indices] 380 | # tiff_paths = [tiff_paths[i] for i in valid_indices] 381 | # gt_paths = [gt_paths[i] for i in valid_indices] 382 | # # ps_paths = [ps_paths[i] for i in valid_indices] 383 | sample_paths = list(zip(rgb_paths, tiff_paths,ps_paths)) 384 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}, ps_paths: {len(ps_paths)}") 385 | img_tot_paths.extend(sample_paths) 386 | gt_tot_paths.extend(gt_paths) 387 | tot_labels.extend([1] * len(sample_paths)) 388 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 389 | 390 | return img_tot_paths, gt_tot_paths, tot_labels 391 | 392 | def __len__(self): 393 | return len(self.img_paths) 394 | 395 | def __getitem__(self, idx): 396 | 397 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 398 | 399 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 400 | rgb_path = img_path[0] 401 | tiff_path = img_path[1] 402 | ps_path = img_path[2] 403 | img_original = Image.open(rgb_path).convert('RGB') 404 | # matrix = self.get_matrix(img_original, self.angle) 405 | # img_original = self.perspective_transform(img_original, matrix) 406 | 407 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB)) 408 | 409 | img = self.rgb_transform(img_original) 410 | ps_original = Image.open(ps_path).convert('RGB') 411 | # ps_original = self.perspective_transform(ps_original, matrix) 412 | ps = self.rgb_transform(ps_original) 413 | # organized_pc = read_tiff_organized_pc(tiff_path) 414 | if self.downsampling > 1: 415 | organized_pc = read_tiff_organized_pc(tiff_path) 416 | factor1 = int(math.floor(math.sqrt(self.downsampling))) 417 | factor2 = int(math.ceil(self.downsampling / factor1)) 418 | organized_pc = organized_pc[::factor1, ::factor2] 419 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling) 420 | else: 421 | organized_pc = read_tiff_organized_pc(tiff_path) 422 | 423 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) 424 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 425 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) 426 | resized_organized_pc = resized_organized_pc.clone().detach().float() 427 | 428 | 429 | 430 | 431 | if gt == 0: 432 | gt = torch.zeros( 433 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]]) 434 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 435 | else: 436 | gt = Image.open(gt).convert('L') 437 | # gt = self.perspective_transform(gt, matrix) 438 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 439 | gt = self.gt_transform(gt) 440 | gt = torch.where(gt > 0.5, 1., .0) 441 | 442 | # fig.show() 443 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight') 444 | 445 | return (img, resized_depth_map_3channel,ps), gt[:1], label, rgb_path 446 | 447 | class ValidDataset(BaseAnomalyDetectionDataset): 448 | def __init__(self, class_name, img_size,downsampling,angle,small,defect_name,dataset_path='datasets/eyecandies_preprocessed'): 449 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path) 450 | self.gt_transform = transforms.Compose([ 451 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST), 452 | transforms.ToTensor()]) 453 | self.defect_name = defect_name 454 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 455 | 456 | def load_dataset(self): 457 | img_tot_paths = [] 458 | gt_tot_paths = [] 459 | tot_labels = [] 460 | #ps_tot_paths = [] 461 | defect_types = os.listdir(self.img_path) 462 | # print(defect_types) 463 | # 如果types不为NONE,只保留GOOD和指定的types 464 | for defect_type in defect_types: 465 | # print(defect_type) 466 | if defect_type == 'GOOD': 467 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 468 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 469 | ps_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'ps') + "/*.jpg") 470 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 471 | # print(rgb_paths) 472 | rgb_paths.sort() 473 | tiff_paths.sort() 474 | ps_paths.sort() 475 | # 只保留前5个样本 476 | if self.small: 477 | rgb_paths = rgb_paths[:5] 478 | tiff_paths = tiff_paths[:5] 479 | ps_paths = ps_paths[:5] 480 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths)) 481 | sample_paths = sample_paths[:5] 482 | img_tot_paths.extend(sample_paths) 483 | gt_tot_paths.extend([0] * len(sample_paths)) 484 | tot_labels.extend([0] * len(sample_paths)) 485 | elif defect_type in self.defect_name: 486 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 487 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 488 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt','rgb') + "/*.png") 489 | ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.jpg") 490 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 491 | # print(rgb_paths) 492 | rgb_paths.sort() 493 | tiff_paths.sort() 494 | gt_paths.sort() 495 | ps_paths.sort() 496 | if self.small: 497 | # 检查每个gt mask中缺陷占比 498 | valid_indices = [] 499 | for i, gt_path in enumerate(gt_paths): 500 | gt_mask = np.array(Image.open(gt_path)) 501 | total_pixels = gt_mask.shape[0] * gt_mask.shape[1] 502 | threshold = int(total_pixels * 0.005) # 计算1%像素数量 503 | defect_pixels = np.sum(gt_mask > 0) 504 | if defect_pixels <= threshold: # 直接比较像素数量 505 | valid_indices.append(i) 506 | 507 | # 只保留缺陷占比<=1%的样本 508 | rgb_paths = [rgb_paths[i] for i in valid_indices] 509 | tiff_paths = [tiff_paths[i] for i in valid_indices] 510 | gt_paths = [gt_paths[i] for i in valid_indices] 511 | ps_paths = [ps_paths[i] for i in valid_indices] 512 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths)) 513 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}, ps_paths: {len(ps_paths)}") 514 | img_tot_paths.extend(sample_paths) 515 | gt_tot_paths.extend(gt_paths) 516 | tot_labels.extend([1] * len(sample_paths)) 517 | 518 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 519 | 520 | return img_tot_paths, gt_tot_paths, tot_labels 521 | 522 | def __len__(self): 523 | return len(self.img_paths) 524 | 525 | def __getitem__(self, idx): 526 | 527 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 528 | 529 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 530 | rgb_path = img_path[0] 531 | tiff_path = img_path[1] 532 | ps_path = img_path[2] 533 | img_original = Image.open(rgb_path).convert('RGB') 534 | matrix = self.get_matrix(img_original, self.angle) 535 | img_original = self.perspective_transform(img_original, matrix) 536 | 537 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB)) 538 | 539 | img = self.rgb_transform(img_original) 540 | ps_original = Image.open(ps_path).convert('RGB') 541 | ps_original = self.perspective_transform(ps_original, matrix) 542 | ps = self.rgb_transform(ps_original) 543 | # organized_pc = read_tiff_organized_pc(tiff_path) 544 | if self.downsampling > 1: 545 | organized_pc = read_tiff_organized_pc(tiff_path) 546 | factor1 = int(math.floor(math.sqrt(self.downsampling))) 547 | factor2 = int(math.ceil(self.downsampling / factor1)) 548 | organized_pc = organized_pc[::factor1, ::factor2] 549 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling) 550 | else: 551 | organized_pc = read_tiff_organized_pc(tiff_path) 552 | 553 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) 554 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 555 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) 556 | resized_organized_pc = resized_organized_pc.clone().detach().float() 557 | 558 | 559 | 560 | 561 | if gt == 0: 562 | gt = torch.zeros( 563 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]]) 564 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 565 | else: 566 | gt = Image.open(gt).convert('L') 567 | gt = self.perspective_transform(gt, matrix) 568 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 569 | gt = self.gt_transform(gt) 570 | gt = torch.where(gt > 0.5, 1., .0) 571 | 572 | # fig.show() 573 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight') 574 | 575 | return (img, resized_depth_map_3channel, ps), gt[:1], label, rgb_path 576 | 577 | from torch.utils.data import Subset 578 | def redistribute_dataset(dataset, chunk_size=50): 579 | """ 580 | 将数据集每50个分成一组,并重新分配包含GOOD的组 581 | :param dataset: 原始数据集 582 | :param chunk_size: 每组的大小 583 | :return: 重新分配后的数据集列表 584 | """ 585 | total_size = len(dataset) 586 | num_chunks = (total_size + chunk_size - 1) // chunk_size # 向上取整 587 | 588 | # 初始分组 589 | chunks = [] 590 | for i in range(num_chunks): 591 | start_idx = i * chunk_size 592 | end_idx = min(start_idx + chunk_size, total_size) 593 | chunk_indices = list(range(start_idx, end_idx)) 594 | chunks.append(Subset(dataset, chunk_indices)) 595 | 596 | # 找出包含GOOD的组 597 | good_chunks = [] 598 | normal_chunks = [] 599 | 600 | for i, chunk in enumerate(chunks): 601 | has_good = False 602 | # 检查这个chunk中是否包含GOOD 603 | for idx in chunk.indices: 604 | if 'GOOD' in dataset.img_paths[idx][0]: 605 | has_good = True 606 | break 607 | 608 | if has_good: 609 | good_chunks.append(i) 610 | else: 611 | normal_chunks.append(i) 612 | 613 | # 如果没有GOOD组或没有正常组,直接返回原始分组 614 | if not good_chunks or not normal_chunks: 615 | return chunks 616 | 617 | # 重新分配GOOD组的数据 618 | for good_chunk_idx in good_chunks: 619 | chunk = chunks[good_chunk_idx] 620 | good_indices = [] 621 | normal_indices = [] 622 | 623 | # 分离GOOD和非GOOD数据 624 | for idx in chunk.indices: 625 | if 'GOOD' in dataset.img_paths[idx][0]: 626 | good_indices.append(idx) 627 | else: 628 | normal_indices.append(idx) 629 | 630 | # 将GOOD数据平均分配给其他组 631 | num_good = len(good_indices) 632 | num_normal_chunks = len(normal_chunks) 633 | 634 | if num_normal_chunks > 0: 635 | # 计算每个正常组应该获得多少GOOD数据 636 | indices_per_chunk = num_good // num_normal_chunks 637 | remainder = num_good % num_normal_chunks 638 | 639 | # 分配GOOD数据 640 | current_good_idx = 0 641 | for i, normal_chunk_idx in enumerate(normal_chunks): 642 | extra = 1 if i < remainder else 0 643 | num_to_add = indices_per_chunk + extra 644 | 645 | # 添加GOOD数据到正常组 646 | chunks[normal_chunk_idx] = Subset(dataset, 647 | list(chunks[normal_chunk_idx].indices) + 648 | good_indices[current_good_idx:current_good_idx + num_to_add] 649 | ) 650 | current_good_idx += num_to_add 651 | 652 | # 更新原始GOOD组,只保留非GOOD数据 653 | chunks[good_chunk_idx] = Subset(dataset, normal_indices) 654 | 655 | return chunks 656 | def get_data_loader(split, class_name, img_size, args, defect_name = None): 657 | if split in ['train']: 658 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small) 659 | elif split in ['test']: 660 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small) 661 | elif split in ['validation']: 662 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name) 663 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False, 664 | pin_memory=True) 665 | return data_loader 666 | 667 | def get_data_set(split, class_name, img_size, args, defect_name = None): 668 | if split in ['train']: 669 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle) 670 | elif split in ['test']: 671 | print('test') 672 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle) 673 | elif split in ['validation']: 674 | print('validation') 675 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name) 676 | 677 | return dataset 678 | 679 | # if __name__ == "__main__": 680 | # parser = argparse.ArgumentParser(description='Process some integers.') 681 | # 682 | # parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str, 683 | # choices=['DINO','Point_MAE','Fusion','DINO+Point_MAE','DINO+Point_MAE+Fusion','DINO+Point_MAE+add','DINO+FPFH','DINO+FPFH+Fusion', 684 | # 'DINO+FPFH+Fusion+ps','DINO+Point_MAE+Fusion+ps','DINO+Point_MAE+ps','DINO+FPFH+ps','ours','ours2','ours3','ours_final','ours_final1' 685 | # ,'ours_final1_VS'], 686 | # help='Anomaly detection modal name.') 687 | # parser.add_argument('--max_sample', default=400, type=int, 688 | # help='Max sample number.') 689 | # parser.add_argument('--memory_bank', default='multiple', type=str, 690 | # choices=["multiple", "single"], 691 | # help='memory bank mode: "multiple", "single".') 692 | # parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str, 693 | # choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'], 694 | # help='Timm checkpoints name of RGB backbone.') 695 | # parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert','FPFH'], 696 | # help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert, FPFH].') 697 | # parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str, 698 | # help='Checkpoints for fusion module.') 699 | # parser.add_argument('--save_feature', default=False, action='store_true', 700 | # help='Save feature for training fusion block.') 701 | # parser.add_argument('--use_uff', default=False, action='store_true', 702 | # help='Use UFF module.') 703 | # parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str, 704 | # help='Save feature for training fusion block.') 705 | # parser.add_argument('--save_preds', default=False, action='store_true', 706 | # help='Save predicts results.') 707 | # parser.add_argument('--group_size', default=128, type=int, -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torchvision import transforms 4 | import glob 5 | from torch.utils.data import Dataset 6 | from utils.mvtec3d_util import * 7 | from torch.utils.data import DataLoader 8 | import numpy as np 9 | import math 10 | import cv2 11 | import argparse 12 | from matplotlib import pyplot as plt 13 | 14 | def eyecandies_classes(): 15 | return [ 16 | 'CandyCane', 17 | 'ChocolateCookie', 18 | 'ChocolatePraline', 19 | 'Confetto', 20 | 'GummyBear', 21 | 'HazelnutTruffle', 22 | 'LicoriceSandwich', 23 | 'Lollipop', 24 | 'Marshmallow', 25 | 'PeppermintCandy', 26 | ] 27 | 28 | def mvtec3d_classes(): 29 | return [ 30 | "bagel", 31 | "cable_gland", 32 | "carrot", 33 | "cookie", 34 | "dowel", 35 | "foam", 36 | "peach", 37 | "potato", 38 | "rope", 39 | "tire", 40 | ] 41 | def test_3d_classes(): 42 | return [ 43 | # 'audio_jack_socket', 44 | # 'common_mode_filter', 45 | # 'connector_housing-female', 46 | # 'crimp_st_cable_mount_box', 47 | # 'dc_power_connector', 48 | # 'fork_crimp_terminal', 49 | # 'headphone_jack_socket', 50 | # 'miniature_lifting_motor', 51 | # 'purple-clay-pot', 52 | # 'power_jack', 53 | # 'ethernet_connector', 54 | # 'ferrite_bead', 55 | # 'fuse_holder', 56 | # 'humidity_sensor', 57 | # 'knob-cap', 58 | # 'lattice_block_plug', 59 | 'lego_pin_connector_plate', 60 | 'lego_propeller', 61 | 'limit-switch', 62 | 'telephone_spring_switch', 63 | ] 64 | 65 | RGB_SIZE = 224 66 | 67 | class BaseAnomalyDetectionDataset(Dataset): 68 | 69 | def __init__(self, split, class_name, img_size,downsampling, angle, small,dataset_path='datasets/eyecandies_preprocessed'): 70 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406] 71 | self.IMAGENET_STD = [0.229, 0.224, 0.225] 72 | self.cls = class_name 73 | self.size = img_size 74 | self.img_path = os.path.join(dataset_path, self.cls, split) 75 | self.downsampling = downsampling 76 | self.angle = angle 77 | self.small = small 78 | self.rgb_transform = transforms.Compose( 79 | [transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC), 80 | transforms.ToTensor(), 81 | transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)]) 82 | def analyze_depth_importance(self,organized_pc): 83 | """分析深度图的重要性,返回行列的重要性标记""" 84 | depth = organized_pc[:,:,2] # 获取深度通道 85 | 86 | # 计算每行的非零点占比 87 | row_importance = np.mean(depth != 0, axis=1) 88 | col_importance = np.mean(depth != 0, axis=0) 89 | 90 | # 使用中位数作为分界线 91 | row_median = np.median(row_importance[row_importance > 0]) 92 | col_median = np.median(col_importance[col_importance > 0]) 93 | 94 | # 标记重要性(True表示重要行/列) 95 | important_rows = row_importance >= row_median 96 | important_cols = col_importance >= col_median 97 | 98 | return important_rows, important_cols 99 | 100 | def smart_downsample(self, organized_pc, target_factor): 101 | """智能降采样,重要区域降采样更多,保持总体降采样因子不变""" 102 | important_rows, important_cols = self.analyze_depth_importance(organized_pc) 103 | 104 | # 获取重要和不重要的行列索引 105 | important_row_indices = np.where(important_rows)[0] 106 | unimportant_row_indices = np.where(~important_rows)[0] 107 | important_col_indices = np.where(important_cols)[0] 108 | unimportant_col_indices = np.where(~important_cols)[0] 109 | 110 | # 计算原始的重要和不重要区域的比例 111 | total_rows = len(important_row_indices) + len(unimportant_row_indices) 112 | total_cols = len(important_col_indices) + len(unimportant_col_indices) 113 | factor = int(math.sqrt(target_factor)) 114 | # 目标总行列数 115 | target_total_rows = total_rows // factor 116 | target_total_cols = total_cols // factor 117 | 118 | # 设置重要区域的更高降采样率(比如降采样率为1/4) 119 | important_factor = factor*2 120 | 121 | # 计算重要区域的目标数量 122 | n_important_rows = len(important_row_indices) // important_factor 123 | n_important_cols = len(important_col_indices) // important_factor 124 | 125 | # 计算不重要区域需要保留的数量(确保总数符合目标) 126 | n_unimportant_rows = target_total_rows - n_important_rows 127 | n_unimportant_cols = target_total_cols - n_important_cols 128 | 129 | # 确保不重要区域的数量不会超过原始数量 130 | n_unimportant_rows = min(n_unimportant_rows, len(unimportant_row_indices)) 131 | n_unimportant_cols = min(n_unimportant_cols, len(unimportant_col_indices)) 132 | 133 | # 选择行 134 | selected_important_rows = np.linspace(0, len(important_row_indices)-1, n_important_rows, dtype=int) 135 | selected_important_rows = important_row_indices[selected_important_rows] 136 | 137 | selected_unimportant_rows = np.linspace(0, len(unimportant_row_indices)-1, n_unimportant_rows, dtype=int) 138 | selected_unimportant_rows = unimportant_row_indices[selected_unimportant_rows] 139 | 140 | # 选择列 141 | selected_important_cols = np.linspace(0, len(important_col_indices)-1, n_important_cols, dtype=int) 142 | selected_important_cols = important_col_indices[selected_important_cols] 143 | 144 | selected_unimportant_cols = np.linspace(0, len(unimportant_col_indices)-1, n_unimportant_cols, dtype=int) 145 | selected_unimportant_cols = unimportant_col_indices[selected_unimportant_cols] 146 | 147 | # 合并选择的行和列 148 | selected_rows = np.sort(np.concatenate([selected_important_rows, selected_unimportant_rows])) 149 | selected_cols = np.sort(np.concatenate([selected_important_cols, selected_unimportant_cols])) 150 | 151 | # 打印降采样信息 152 | print(f"原始大小: {organized_pc.shape[:2]}") 153 | print(f"重要行: {len(important_row_indices)} -> {len(selected_important_rows)} (1/{important_factor})") 154 | print(f"不重要行: {len(unimportant_row_indices)} -> {len(selected_unimportant_rows)}") 155 | print(f"重要列: {len(important_col_indices)} -> {len(selected_important_cols)} (1/{important_factor})") 156 | print(f"不重要列: {len(unimportant_col_indices)} -> {len(selected_unimportant_cols)}") 157 | print(f"最终大小: {len(selected_rows)}x{len(selected_cols)}") 158 | print(f"实际降采样因子: {(total_rows*total_cols)/(len(selected_rows)*len(selected_cols)):.2f}") 159 | 160 | return organized_pc[selected_rows][:, selected_cols] 161 | 162 | def get_matrix(self, image, angle): 163 | image = self.pillow_to_opencv(image) 164 | (h, w) = image.shape[:2] 165 | src_points = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 166 | dst_points = self.calculate_destination_points(0, w, 0, h, angle) 167 | M = cv2.getPerspectiveTransform(src_points, dst_points) 168 | return M 169 | 170 | def calculate_destination_points(self, left, right, top, bottom, angle): 171 | # 计算中心点 172 | center_x = (left + right) / 2 173 | center_y = (top + bottom) / 2 174 | 175 | # 计算角度的弧度值 176 | angle_rad = math.radians(angle) 177 | 178 | # 计算目标点 179 | dst_points = [] 180 | for x, y in [(left, top), (right, top), (right, bottom), (left, bottom)]: 181 | new_x = center_x + (x - center_x) * math.cos(angle_rad) - (y - center_y) * math.sin(angle_rad) 182 | new_y = center_y + (x - center_x) * math.sin(angle_rad) + (y - center_y) * math.cos(angle_rad) 183 | dst_points.append([new_x, new_y]) 184 | 185 | return np.float32(dst_points) 186 | 187 | def perspective_transform(self, image, matrix): 188 | """ 189 | :param image_path: 输入图像路径 190 | :param angle: 旋转角度 191 | :param save_type: 保存的图片类型 192 | :return: 输出图像 193 | """ 194 | # 读取图像 195 | image = self.pillow_to_opencv(image) 196 | (h, w) = image.shape[:2] 197 | # 执行透视变换 198 | transformed = cv2.warpPerspective(image, matrix, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=image[0, 0].tolist()) 199 | transformed = self.opencv_to_pillow(transformed) 200 | return transformed 201 | 202 | def opencv_to_pillow(self, opencv_image): 203 | """ 204 | 将 OpenCV 图像转换为 Pillow 图像。 205 | 206 | :param opencv_image: OpenCV 图像对象(BGR 格式) 207 | :return: Pillow 图像对象 208 | """ 209 | # 检查是否为彩色图像并转换通道顺序 210 | if len(opencv_image.shape) == 3: # 彩色图像 211 | opencv_image_rgb = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB) 212 | return Image.fromarray(opencv_image_rgb) 213 | else: # 灰度图像 214 | return Image.fromarray(opencv_image) 215 | 216 | def pillow_to_opencv(self, pil_image): 217 | """ 218 | 将 Pillow 图像转换为 OpenCV 图像。 219 | 220 | :param pil_image: Pillow 图像对象 221 | :return: OpenCV 图像对象(BGR 格式) 222 | """ 223 | # 将 Pillow 图像转换为 numpy 数组 224 | opencv_image = np.array(pil_image) 225 | 226 | # 如果是 RGB 图像,转换为 BGR 格式 227 | if opencv_image.ndim == 3 and opencv_image.shape[2] == 3: # 检查是否为 RGB 图像 228 | opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR) 229 | 230 | return opencv_image 231 | 232 | class PreTrainTensorDataset(Dataset): 233 | def __init__(self, root_path): 234 | super().__init__() 235 | self.root_path = root_path 236 | self.tensor_paths = os.listdir(self.root_path) 237 | 238 | 239 | def __len__(self): 240 | return len(self.tensor_paths) 241 | 242 | def __getitem__(self, idx): 243 | tensor_path = self.tensor_paths[idx] 244 | 245 | tensor = torch.load(os.path.join(self.root_path, tensor_path)) 246 | 247 | label = 0 248 | 249 | return tensor, label 250 | 251 | class TrainDataset(BaseAnomalyDetectionDataset): 252 | def __init__(self, class_name, img_size,downsampling,angle, small, dataset_path='datasets/eyecandies_preprocessed'): 253 | super().__init__(split="train", class_name=class_name, img_size=img_size,downsampling=downsampling, angle=angle, small=small, dataset_path=dataset_path) 254 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 255 | 256 | def load_dataset(self): 257 | img_tot_paths = [] 258 | tot_labels = [] 259 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb') + "/*.png") 260 | rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb', '*', '*L05*RGB*.jpg')) 261 | tiff_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'xyz') + "/*.tiff") 262 | # ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg") 263 | rgb_paths.sort() 264 | tiff_paths.sort() 265 | 266 | # print(len(rgb_paths)) 267 | # ps_paths.sort() 268 | # sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths)) 269 | sample_paths = list(zip(rgb_paths, tiff_paths)) 270 | img_tot_paths.extend(sample_paths) 271 | tot_labels.extend([0] * len(sample_paths)) 272 | # img_tot_paths = img_tot_paths[0:10] 273 | # tot_labels = tot_labels[0:10] 274 | return img_tot_paths, tot_labels 275 | 276 | def __len__(self): 277 | return len(self.img_paths) 278 | 279 | def __getitem__(self, idx): 280 | img_path, label = self.img_paths[idx], self.labels[idx] 281 | rgb_path = img_path[0] 282 | tiff_path = img_path[1] 283 | # ps_path = img_path[2] 284 | img = Image.open(rgb_path).convert('RGB') 285 | # add rotation 286 | # matrix = self.get_matrix(img, self.angle) 287 | # img = self.perspective_transform(img, matrix) 288 | 289 | img = self.rgb_transform(img) 290 | # ps= Image.open(ps_path).convert('RGB') 291 | # # add rotation 292 | # ps = self.perspective_transform(ps, matrix) 293 | 294 | # ps = self.rgb_transform(ps) 295 | if self.downsampling > 1: 296 | organized_pc = read_tiff_organized_pc(tiff_path) 297 | factor1 = int(math.floor(math.sqrt(self.downsampling))) 298 | factor2 = int(math.ceil(self.downsampling / factor1)) 299 | organized_pc = organized_pc[::factor1, ::factor2] 300 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling) 301 | else: 302 | organized_pc = read_tiff_organized_pc(tiff_path) 303 | 304 | for x in range(organized_pc.shape[0]): 305 | for y in range(organized_pc.shape[1]): 306 | organized_pc[x, y, 0] = x 307 | organized_pc[x, y, 1] = y 308 | 309 | z_channel = organized_pc[:, :, 2] 310 | if np.isnan(z_channel).any(): 311 | min_z = np.nanmin(z_channel) # 计算 g 通道的最小值,忽略 NaN 312 | z_channel[np.isnan(z_channel)] = min_z 313 | organized_pc[:, :, 2] = z_channel.astype(np.float32) 314 | print("train z_channel nan filled") 315 | 316 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) 317 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 318 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) 319 | resized_organized_pc = resized_organized_pc.clone().detach().float() 320 | 321 | 322 | return (img, resized_organized_pc, resized_depth_map_3channel), label 323 | 324 | 325 | class TestDataset(BaseAnomalyDetectionDataset): 326 | def __init__(self, class_name, img_size,downsampling,angle,small,dataset_path='datasets/eyecandies_preprocessed'): 327 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path) 328 | self.gt_transform = transforms.Compose([ 329 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST), 330 | transforms.ToTensor()]) 331 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 332 | 333 | def load_dataset(self): 334 | img_tot_paths = [] 335 | gt_tot_paths = [] 336 | tot_labels = [] 337 | #ps_tot_paths = [] 338 | defect_types = os.listdir(self.img_path) 339 | print(defect_types) 340 | # 如果types不为NONE,只保留GOOD和指定的types 341 | for defect_type in defect_types: 342 | if defect_type == 'GOOD': 343 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'rgb') + "/*.png") 344 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 345 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 346 | #gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png") 347 | # ps_paths = glob.glob(os.path.join(self.img_path, 'good', 'ps') + "/*.jpg") 348 | rgb_paths.sort() 349 | tiff_paths.sort() 350 | # ps_paths.sort() 351 | # 只保留前5个样本 352 | if self.small: 353 | rgb_paths = rgb_paths[:5] 354 | tiff_paths = tiff_paths[:5] 355 | # ps_paths = ps_paths[:5] 356 | sample_paths = list(zip(rgb_paths, tiff_paths)) 357 | img_tot_paths.extend(sample_paths) 358 | gt_tot_paths.extend([0] * len(sample_paths)) 359 | tot_labels.extend([0] * len(sample_paths)) 360 | else: 361 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") 362 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 363 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 364 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt', 'rgb') + "/*.png") 365 | # ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.png") 366 | rgb_paths.sort() 367 | tiff_paths.sort() 368 | gt_paths.sort() 369 | # ps_paths.sort() 370 | # if self.small: 371 | # # 检查每个gt mask中缺陷占比 372 | # valid_indices = [] 373 | # for i, gt_path in enumerate(gt_paths): 374 | # gt_mask = np.array(Image.open(gt_path)) 375 | # total_pixels = gt_mask.shape[0] * gt_mask.shape[1] 376 | # threshold = int(total_pixels * 0.005) # 计算1%像素数量 377 | # defect_pixels = np.sum(gt_mask > 0) 378 | # if defect_pixels <= threshold: # 直接比较像素数量 379 | # valid_indices.append(i) 380 | 381 | # # 只保留缺陷占比<=1%的样本 382 | # rgb_paths = [rgb_paths[i] for i in valid_indices] 383 | # tiff_paths = [tiff_paths[i] for i in valid_indices] 384 | # gt_paths = [gt_paths[i] for i in valid_indices] 385 | # # ps_paths = [ps_paths[i] for i in valid_indices] 386 | sample_paths = list(zip(rgb_paths, tiff_paths)) 387 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}") 388 | img_tot_paths.extend(sample_paths) 389 | gt_tot_paths.extend(gt_paths) 390 | tot_labels.extend([1] * len(sample_paths)) 391 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 392 | 393 | return img_tot_paths, gt_tot_paths, tot_labels 394 | 395 | def __len__(self): 396 | return len(self.img_paths) 397 | 398 | def __getitem__(self, idx): 399 | 400 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 401 | 402 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 403 | rgb_path = img_path[0] 404 | tiff_path = img_path[1] 405 | # ps_path = img_path[2] 406 | img_original = Image.open(rgb_path).convert('RGB') 407 | # matrix = self.get_matrix(img_original, self.angle) 408 | # img_original = self.perspective_transform(img_original, matrix) 409 | 410 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB)) 411 | 412 | img = self.rgb_transform(img_original) 413 | # ps_original = Image.open(ps_path).convert('RGB') 414 | # ps_original = self.perspective_transform(ps_original, matrix) 415 | # ps = self.rgb_transform(ps_original) 416 | # organized_pc = read_tiff_organized_pc(tiff_path) 417 | if self.downsampling > 1: 418 | organized_pc = read_tiff_organized_pc(tiff_path) 419 | factor1 = int(math.floor(math.sqrt(self.downsampling))) 420 | factor2 = int(math.ceil(self.downsampling / factor1)) 421 | organized_pc = organized_pc[::factor1, ::factor2] 422 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling) 423 | else: 424 | organized_pc = read_tiff_organized_pc(tiff_path) 425 | 426 | 427 | for x in range(organized_pc.shape[0]): 428 | for y in range(organized_pc.shape[1]): 429 | organized_pc[x, y, 0] = x 430 | organized_pc[x, y, 1] = y 431 | 432 | z_channel = organized_pc[:, :, 2] 433 | if np.isnan(z_channel).any(): 434 | min_z = np.nanmin(z_channel) # 计算 g 通道的最小值,忽略 NaN 435 | z_channel[np.isnan(z_channel)] = min_z 436 | organized_pc[:, :, 2] = z_channel.astype(np.float32) 437 | print("test z_channel nan filled") 438 | 439 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) 440 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 441 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) 442 | resized_organized_pc = resized_organized_pc.clone().detach().float() 443 | 444 | 445 | 446 | 447 | if gt == 0: 448 | gt = torch.zeros( 449 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]]) 450 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 451 | else: 452 | gt = Image.open(gt).convert('L') 453 | # gt = self.perspective_transform(gt, matrix) 454 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 455 | gt = self.gt_transform(gt) 456 | gt = torch.where(gt > 0.5, 1., .0) 457 | 458 | # fig.show() 459 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight') 460 | 461 | return (img, resized_depth_map_3channel), gt[:1], label, rgb_path 462 | 463 | class ValidDataset(BaseAnomalyDetectionDataset): 464 | def __init__(self, class_name, img_size,downsampling,angle,small,defect_name,dataset_path='datasets/eyecandies_preprocessed'): 465 | super().__init__(split="test", class_name=class_name, img_size=img_size,downsampling=downsampling,angle=angle, small=small, dataset_path=dataset_path) 466 | self.gt_transform = transforms.Compose([ 467 | transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST), 468 | transforms.ToTensor()]) 469 | self.defect_name = defect_name 470 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 471 | 472 | def load_dataset(self): 473 | img_tot_paths = [] 474 | gt_tot_paths = [] 475 | tot_labels = [] 476 | #ps_tot_paths = [] 477 | defect_types = os.listdir(self.img_path) 478 | # print(defect_types) 479 | # 如果types不为NONE,只保留GOOD和指定的types 480 | for defect_type in defect_types: 481 | # print(defect_type) 482 | if defect_type == 'GOOD': 483 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 484 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 485 | ps_paths = glob.glob(os.path.join(self.img_path, 'GOOD', 'ps') + "/*.jpg") 486 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 487 | # print(rgb_paths) 488 | rgb_paths.sort() 489 | tiff_paths.sort() 490 | ps_paths.sort() 491 | # 只保留前5个样本 492 | if self.small: 493 | rgb_paths = rgb_paths[:5] 494 | tiff_paths = tiff_paths[:5] 495 | ps_paths = ps_paths[:5] 496 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths)) 497 | sample_paths = sample_paths[:5] 498 | img_tot_paths.extend(sample_paths) 499 | gt_tot_paths.extend([0] * len(sample_paths)) 500 | tot_labels.extend([0] * len(sample_paths)) 501 | elif defect_type in self.defect_name: 502 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 503 | tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") 504 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt','rgb') + "/*.png") 505 | ps_paths = glob.glob(os.path.join(self.img_path, defect_type, 'ps') + "/*.jpg") 506 | # print(os.path.join(self.img_path, defect_type, 'rgb', '*', '*L05*RGB*.jpg')) 507 | # print(rgb_paths) 508 | rgb_paths.sort() 509 | tiff_paths.sort() 510 | gt_paths.sort() 511 | ps_paths.sort() 512 | if self.small: 513 | # 检查每个gt mask中缺陷占比 514 | valid_indices = [] 515 | for i, gt_path in enumerate(gt_paths): 516 | gt_mask = np.array(Image.open(gt_path)) 517 | total_pixels = gt_mask.shape[0] * gt_mask.shape[1] 518 | threshold = int(total_pixels * 0.005) # 计算1%像素数量 519 | defect_pixels = np.sum(gt_mask > 0) 520 | if defect_pixels <= threshold: # 直接比较像素数量 521 | valid_indices.append(i) 522 | 523 | # 只保留缺陷占比<=1%的样本 524 | rgb_paths = [rgb_paths[i] for i in valid_indices] 525 | tiff_paths = [tiff_paths[i] for i in valid_indices] 526 | gt_paths = [gt_paths[i] for i in valid_indices] 527 | ps_paths = [ps_paths[i] for i in valid_indices] 528 | sample_paths = list(zip(rgb_paths, tiff_paths, ps_paths)) 529 | print(f"rgb_paths: {len(rgb_paths)}, tiff_paths: {len(tiff_paths)}, ps_paths: {len(ps_paths)}") 530 | img_tot_paths.extend(sample_paths) 531 | gt_tot_paths.extend(gt_paths) 532 | tot_labels.extend([1] * len(sample_paths)) 533 | 534 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 535 | 536 | return img_tot_paths, gt_tot_paths, tot_labels 537 | 538 | def __len__(self): 539 | return len(self.img_paths) 540 | 541 | def __getitem__(self, idx): 542 | 543 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 544 | 545 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 546 | rgb_path = img_path[0] 547 | tiff_path = img_path[1] 548 | ps_path = img_path[2] 549 | img_original = Image.open(rgb_path).convert('RGB') 550 | matrix = self.get_matrix(img_original, self.angle) 551 | img_original = self.perspective_transform(img_original, matrix) 552 | 553 | # axes[0].imshow(cv2.cvtColor(self.pillow_to_opencv(img_original), cv2.COLOR_BGR2RGB)) 554 | 555 | img = self.rgb_transform(img_original) 556 | ps_original = Image.open(ps_path).convert('RGB') 557 | ps_original = self.perspective_transform(ps_original, matrix) 558 | ps = self.rgb_transform(ps_original) 559 | # organized_pc = read_tiff_organized_pc(tiff_path) 560 | if self.downsampling > 1: 561 | organized_pc = read_tiff_organized_pc(tiff_path) 562 | factor1 = int(math.floor(math.sqrt(self.downsampling))) 563 | factor2 = int(math.ceil(self.downsampling / factor1)) 564 | organized_pc = organized_pc[::factor1, ::factor2] 565 | #organized_pc = self.smart_downsample(organized_pc, self.downsampling) 566 | else: 567 | organized_pc = read_tiff_organized_pc(tiff_path) 568 | 569 | for x in range(organized_pc.shape[0]): 570 | for y in range(organized_pc.shape[1]): 571 | organized_pc[x, y, 0] = x 572 | organized_pc[x, y, 1] = y 573 | 574 | z_channel = organized_pc[:, :, 2] 575 | if np.isnan(z_channel).any(): 576 | min_z = np.nanmin(z_channel) # 计算 g 通道的最小值,忽略 NaN 577 | z_channel[np.isnan(z_channel)] = min_z 578 | organized_pc[:, :, 2] = z_channel.astype(np.float32) 579 | print("val z_channel nan filled") 580 | 581 | depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) 582 | resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) 583 | resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) 584 | resized_organized_pc = resized_organized_pc.clone().detach().float() 585 | 586 | 587 | 588 | 589 | if gt == 0: 590 | gt = torch.zeros( 591 | [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]]) 592 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 593 | else: 594 | gt = Image.open(gt).convert('L') 595 | gt = self.perspective_transform(gt, matrix) 596 | # axes[1].imshow(self.pillow_to_opencv(gt), cmap='gray') 597 | gt = self.gt_transform(gt) 598 | gt = torch.where(gt > 0.5, 1., .0) 599 | 600 | # fig.show() 601 | # fig.savefig("test_image_3.png", dpi=300, bbox_inches='tight') 602 | 603 | return (img, resized_depth_map_3channel, ps), gt[:1], label, rgb_path 604 | 605 | from torch.utils.data import Subset 606 | def redistribute_dataset(dataset, chunk_size=50): 607 | """ 608 | 将数据集每50个分成一组,并重新分配包含GOOD的组 609 | :param dataset: 原始数据集 610 | :param chunk_size: 每组的大小 611 | :return: 重新分配后的数据集列表 612 | """ 613 | total_size = len(dataset) 614 | num_chunks = (total_size + chunk_size - 1) // chunk_size # 向上取整 615 | 616 | # 初始分组 617 | chunks = [] 618 | for i in range(num_chunks): 619 | start_idx = i * chunk_size 620 | end_idx = min(start_idx + chunk_size, total_size) 621 | chunk_indices = list(range(start_idx, end_idx)) 622 | chunks.append(Subset(dataset, chunk_indices)) 623 | 624 | # 找出包含GOOD的组 625 | good_chunks = [] 626 | normal_chunks = [] 627 | 628 | for i, chunk in enumerate(chunks): 629 | has_good = False 630 | # 检查这个chunk中是否包含GOOD 631 | for idx in chunk.indices: 632 | if 'GOOD' in dataset.img_paths[idx][0]: 633 | has_good = True 634 | break 635 | 636 | if has_good: 637 | good_chunks.append(i) 638 | else: 639 | normal_chunks.append(i) 640 | 641 | # 如果没有GOOD组或没有正常组,直接返回原始分组 642 | if not good_chunks or not normal_chunks: 643 | return chunks 644 | 645 | # 重新分配GOOD组的数据 646 | for good_chunk_idx in good_chunks: 647 | chunk = chunks[good_chunk_idx] 648 | good_indices = [] 649 | normal_indices = [] 650 | 651 | # 分离GOOD和非GOOD数据 652 | for idx in chunk.indices: 653 | if 'GOOD' in dataset.img_paths[idx][0]: 654 | good_indices.append(idx) 655 | else: 656 | normal_indices.append(idx) 657 | 658 | # 将GOOD数据平均分配给其他组 659 | num_good = len(good_indices) 660 | num_normal_chunks = len(normal_chunks) 661 | 662 | if num_normal_chunks > 0: 663 | # 计算每个正常组应该获得多少GOOD数据 664 | indices_per_chunk = num_good // num_normal_chunks 665 | remainder = num_good % num_normal_chunks 666 | 667 | # 分配GOOD数据 668 | current_good_idx = 0 669 | for i, normal_chunk_idx in enumerate(normal_chunks): 670 | extra = 1 if i < remainder else 0 671 | num_to_add = indices_per_chunk + extra 672 | 673 | # 添加GOOD数据到正常组 674 | chunks[normal_chunk_idx] = Subset(dataset, 675 | list(chunks[normal_chunk_idx].indices) + 676 | good_indices[current_good_idx:current_good_idx + num_to_add] 677 | ) 678 | current_good_idx += num_to_add 679 | 680 | # 更新原始GOOD组,只保留非GOOD数据 681 | chunks[good_chunk_idx] = Subset(dataset, normal_indices) 682 | 683 | return chunks 684 | def get_data_loader(split, class_name, img_size, args, defect_name = None): 685 | if split in ['train']: 686 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small) 687 | elif split in ['test']: 688 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small) 689 | elif split in ['validation']: 690 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name) 691 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False, 692 | pin_memory=True) 693 | return data_loader 694 | 695 | def get_data_set(split, class_name, img_size, args, defect_name = None): 696 | if split in ['train']: 697 | dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle) 698 | elif split in ['test']: 699 | print('test') 700 | dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle) 701 | elif split in ['validation']: 702 | print('validation') 703 | dataset = ValidDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path, downsampling=args.downsampling, angle=args.rotate_angle,small=args.small, defect_name = defect_name) 704 | 705 | return dataset 706 | 707 | # if __name__ == "__main__": 708 | # parser = argparse.ArgumentParser(description='Process some integers.') 709 | # 710 | # parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str, 711 | # choices=['DINO','Point_MAE','Fusion','DINO+Point_MAE','DINO+Point_MAE+Fusion','DINO+Point_MAE+add','DINO+FPFH','DINO+FPFH+Fusion', 712 | # 'DINO+FPFH+Fusion+ps','DINO+Point_MAE+Fusion+ps','DINO+Point_MAE+ps','DINO+FPFH+ps','ours','ours2','ours3','ours_final','ours_final1' 713 | # ,'ours_final1_VS'], 714 | # help='Anomaly detection modal name.') 715 | # parser.add_argument('--max_sample', default=400, type=int, 716 | # help='Max sample number.') 717 | # parser.add_argument('--memory_bank', default='multiple', type=str, 718 | # choices=["multiple", "single"], 719 | # help='memory bank mode: "multiple", "single".') 720 | # parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str, 721 | # choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'], 722 | # help='Timm checkpoints name of RGB backbone.') 723 | # parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert','FPFH'], 724 | # help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert, FPFH].') 725 | # parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str, 726 | # help='Checkpoints for fusion module.') 727 | # parser.add_argument('--save_feature', default=False, action='store_true', 728 | # help='Save feature for training fusion block.') 729 | # parser.add_argument('--use_uff', default=False, action='store_true', 730 | # help='Use UFF module.') 731 | # parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str, 732 | # help='Save feature for training fusion block.') 733 | # parser.add_argument('--save_preds', default=False, action='store_true', 734 | # help='Save predicts results.') 735 | # parser.add_argument('--group_size', default=128, type=int, 736 | --------------------------------------------------------------------------------