├── models ├── __init__.py ├── necks │ ├── __init__.py │ └── mfcn.py ├── reconstructions │ ├── __init__.py │ └── vis_decoder.py ├── backbones │ ├── __init__.py │ └── efficientnet │ │ └── __init__.py ├── model_helper.py └── initializer.py ├── utils ├── __init__.py ├── lr_helper.py ├── optimizer_helper.py ├── criterion_helper.py ├── dist_helper.py ├── vis_helper.py └── misc_helper.py ├── data ├── CIFAR-10 │ └── .gitkeep ├── Real-IAD │ ├── realiad_1024 │ │ └── .gitkeep │ ├── realiad_jsons │ │ └── .gitkeep │ ├── realiad_raw │ │ └── .gitkeep │ ├── realiad_jsons_sv │ │ └── .gitkeep │ ├── realiad_jsons_fuiad_0.0 │ │ └── .gitkeep │ ├── realiad_jsons_fuiad_0.1 │ │ └── .gitkeep │ ├── realiad_jsons_fuiad_0.2 │ │ └── .gitkeep │ └── realiad_jsons_fuiad_0.4 │ │ └── .gitkeep └── MVTec-AD │ └── json_vis_decoder │ ├── test_toothbrush.json │ ├── test_grid.json │ ├── test_wood.json │ ├── test_bottle.json │ └── test_transistor.json ├── datasets ├── __init__.py ├── image_reader.py ├── data_builder.py ├── base_dataset.py ├── custom_dataset.py ├── explicit_dataset.py ├── cifar_dataset.py └── transforms.py ├── pretrained └── .gitkeep ├── docs ├── setting.jpg ├── res_mvtec.jpg ├── query_bottle.jpg └── query_capsule.jpg ├── experiments ├── MVTec-AD │ ├── train_torch.sh │ ├── eval_torch.sh │ ├── eval.sh │ ├── train.sh │ └── config.yaml ├── RealIAD-C1 │ ├── eval_torch.sh │ ├── train_torch.sh │ └── config.yaml ├── RealIAD-full │ ├── train_torch.sh │ ├── eval_torch.sh │ └── config.yaml ├── RealIAD-fuad-n0 │ ├── eval_torch.sh │ ├── train_torch.sh │ └── config.yaml ├── RealIAD-fuad-n1 │ ├── eval_torch.sh │ ├── train_torch.sh │ └── config.yaml ├── RealIAD-fuad-n2 │ ├── eval_torch.sh │ ├── train_torch.sh │ └── config.yaml ├── RealIAD-fuad-n4 │ ├── eval_torch.sh │ ├── train_torch.sh │ └── config.yaml ├── CIFAR-10 │ ├── 13579 │ │ ├── eval_torch.sh │ │ ├── train_torch.sh │ │ ├── train.sh │ │ ├── eval.sh │ │ └── config.yaml │ ├── 56789 │ │ ├── eval_torch.sh │ │ ├── train_torch.sh │ │ ├── train.sh │ │ ├── eval.sh │ │ └── config.yaml │ ├── 01234 │ │ ├── eval_torch.sh │ │ ├── train_torch.sh │ │ ├── train.sh │ │ ├── eval.sh │ │ └── config.yaml │ └── 02468 │ │ ├── eval_torch.sh │ │ ├── train_torch.sh │ │ ├── train.sh │ │ ├── eval.sh │ │ └── config.yaml ├── vis_query │ ├── vis_query_torch.sh │ ├── vis_query.sh │ └── config.yaml ├── vis_recon │ ├── vis_recon_torch.sh │ ├── vis_recon.sh │ └── config.yaml └── train_vis_decoder │ ├── train_torch.sh │ ├── train.sh │ └── config.yaml ├── prepare.sh ├── requirements.txt ├── .gitignore ├── README.md ├── README_uniad.md └── tools ├── vis_recon.py ├── vis_query.py └── train_vis_decoder.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/CIFAR-10/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_1024/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_jsons/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_jsons_sv/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_jsons_fuiad_0.0/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_jsons_fuiad_0.1/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_jsons_fuiad_0.2/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Real-IAD/realiad_jsons_fuiad_0.4/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .mfcn import * # noqa F401 2 | -------------------------------------------------------------------------------- /docs/setting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/AnomalyDetection_Real-IAD/HEAD/docs/setting.jpg -------------------------------------------------------------------------------- /docs/res_mvtec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/AnomalyDetection_Real-IAD/HEAD/docs/res_mvtec.jpg -------------------------------------------------------------------------------- /docs/query_bottle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/AnomalyDetection_Real-IAD/HEAD/docs/query_bottle.jpg -------------------------------------------------------------------------------- /docs/query_capsule.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/AnomalyDetection_Real-IAD/HEAD/docs/query_capsule.jpg -------------------------------------------------------------------------------- /models/reconstructions/__init__.py: -------------------------------------------------------------------------------- 1 | from .uniad import UniAD # noqa F401 2 | from .vis_decoder import VisDecoder # noqa F401 3 | -------------------------------------------------------------------------------- /experiments/MVTec-AD/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/MVTec-AD/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-C1/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-C1/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-full/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n0/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n0/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n1/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n1/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n2/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n2/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n4/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n4/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/RealIAD-full/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python3 -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/01234/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/01234/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/02468/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/02468/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/13579/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/13579/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/56789/eval_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/56789/train_torch.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../:$PYTHONPATH 2 | CUDA_VISIBLE_DEVICES=$2 3 | python -m torch.distributed.launch --nproc_per_node=$1 ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /prepare.sh: -------------------------------------------------------------------------------- 1 | # startup from a commonly used pytorch 1.13 traning image 2 | python3 -m pip install \ 3 | easydict \ 4 | einops==0.4.1 \ 5 | scikit-learn==0.24.2 \ 6 | tabulate==0.8.10 \ 7 | -------------------------------------------------------------------------------- /experiments/MVTec-AD/eval.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=mvtec \ 3 | python -u ../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/MVTec-AD/train.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=mvtec \ 3 | python -u ../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/01234/train.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=01234 \ 3 | python -u ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/02468/train.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=02468 \ 3 | python -u ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/13579/train.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=13579 \ 3 | python -u ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/56789/train.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=56789 \ 3 | python -u ../../../tools/train_val.py 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/01234/eval.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=01234 \ 3 | python -u ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/02468/eval.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=02468 \ 3 | python -u ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/13579/eval.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=13579 \ 3 | python -u ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/56789/eval.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=$PYTHONPATH:../../../ \ 2 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=56789 \ 3 | python -u ../../../tools/train_val.py -e 4 | -------------------------------------------------------------------------------- /utils/lr_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_scheduler(optimizer, config): 5 | if config.type == "StepLR": 6 | return torch.optim.lr_scheduler.StepLR(optimizer, **config.kwargs) 7 | else: 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /experiments/vis_query/vis_query_torch.sh: -------------------------------------------------------------------------------- 1 | # class_name: bottle cable capsule carpet grid hazelnut leather metal_nut pill screw tile toothbrush transistor wood zipper 2 | export PYTHONPATH=../../:$PYTHONPATH 3 | python ../../tools/vis_query.py --class_name $1 4 | -------------------------------------------------------------------------------- /experiments/vis_recon/vis_recon_torch.sh: -------------------------------------------------------------------------------- 1 | # class_name: bottle cable capsule carpet grid hazelnut leather metal_nut pill screw tile toothbrush transistor wood zipper 2 | export PYTHONPATH=../../:$PYTHONPATH 3 | python -u ../../tools/vis_recon.py --class_name $1 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | einops==0.4.1 3 | numpy 4 | opencv-python 5 | Pillow 6 | pycryptodome 7 | PyYAML 8 | scikit-image 9 | scikit-learn==0.24.2 10 | scipy==1.9.1 11 | tabulate==0.8.10 12 | timm==0.6.12 13 | torch==1.13.1+cu117 14 | torchvision==0.14.1+cu117 15 | tqdm 16 | wheel 17 | -------------------------------------------------------------------------------- /experiments/train_vis_decoder/train_torch.sh: -------------------------------------------------------------------------------- 1 | # class_name: bottle cable capsule carpet grid hazelnut leather metal_nut pill screw tile toothbrush transistor wood zipper 2 | export PYTHONPATH=../../:$PYTHONPATH 3 | CUDA_VISIBLE_DEVICES=$2 4 | python -m torch.distributed.launch --nproc_per_node=$1 ../../tools/train_vis_decoder.py --class_name $3 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/* 2 | *pyc 3 | .vscode/ 4 | 5 | experiments/*/*/checkpoints*/* 6 | experiments/*/*/log*/* 7 | experiments/*/*/arun_log*/* 8 | experiments/*/*/*result*/* 9 | experiments/*/*/*vis*/* 10 | 11 | experiments/*/checkpoints*/* 12 | experiments/*/log*/* 13 | experiments/*/arun_log*/* 14 | experiments/*/*result*/* 15 | experiments/*/*vis*/* 16 | -------------------------------------------------------------------------------- /utils/optimizer_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_optimizer(parameters, config): 5 | if config.type == "AdamW": 6 | return torch.optim.AdamW(parameters, **config.kwargs) 7 | elif config.type == "Adam": 8 | return torch.optim.Adam(parameters, **config.kwargs) 9 | elif config.type == "SGD": 10 | return torch.optim.SGD(parameters, **config.kwargs) 11 | else: 12 | raise NotImplementedError 13 | -------------------------------------------------------------------------------- /experiments/vis_query/vis_query.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT=../../ 3 | export PYTHONPATH=$ROOT:$PYTHONPATH 4 | if [ ! -d "./log_srun/" ];then 5 | mkdir log_srun 6 | fi 7 | 8 | for cls in bottle cable capsule carpet grid hazelnut leather metal_nut pill screw tile toothbrush transistor wood zipper 9 | do 10 | srun --mpi=pmi2 -p$1 -n1 --gres=gpu:1 --ntasks-per-node=1 --cpus-per-task=4 --job-name=$cls \ 11 | python ../../tools/vis_query.py --class_name $cls > log_srun/log_$cls.txt 2>&1 & 12 | done 13 | -------------------------------------------------------------------------------- /experiments/vis_recon/vis_recon.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT=../../ 3 | export PYTHONPATH=$ROOT:$PYTHONPATH 4 | if [ ! -d "./log_srun/" ];then 5 | mkdir log_srun 6 | fi 7 | 8 | for cls in bottle cable capsule carpet grid hazelnut leather metal_nut pill screw tile toothbrush transistor wood zipper 9 | do 10 | srun --mpi=pmi2 -p$1 -n1 --gres=gpu:1 --ntasks-per-node=1 --cpus-per-task=4 --job-name=$cls \ 11 | python -u ../../tools/vis_recon.py --class_name $cls > log_srun/log_$cls.txt 2>&1 & 12 | done 13 | -------------------------------------------------------------------------------- /experiments/train_vis_decoder/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT=../../ 3 | export PYTHONPATH=$ROOT:$PYTHONPATH 4 | if [ ! -d "./log_srun/" ];then 5 | mkdir log_srun 6 | fi 7 | 8 | for cls in bottle cable capsule carpet grid hazelnut leather metal_nut pill screw tile toothbrush transistor wood zipper 9 | do 10 | srun --mpi=pmi2 -p$2 -n$1 --gres=gpu:$1 --ntasks-per-node=$1 --cpus-per-task=4 --job-name=$cls \ 11 | python -u ../../tools/train_vis_decoder.py --class_name $cls > log_srun/log_$cls.txt 2>&1 & 12 | done 13 | -------------------------------------------------------------------------------- /experiments/vis_recon/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | 3 | data: 4 | dataset_dir: ../../data/MVTec-AD/mvtec_anomaly_detection/ 5 | feature_dir: ../MVTec-AD/result_recon/ 6 | input_size: [224,224] # [h,w] 7 | pixel_mean: [0.485, 0.456, 0.406] 8 | pixel_std: [0.229, 0.224, 0.225] 9 | 10 | saver: 11 | load_path: ../train_vis_decoder/{class_name}/checkpoints/ckpt.pth.tar 12 | save_dir: result_vis_recon 13 | log_dir: log 14 | 15 | net: 16 | - name: backbone 17 | type: models.backbones.efficientnet_b4 18 | frozen: True 19 | kwargs: 20 | pretrained: True 21 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 22 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 23 | outlayers: [1,2,3,4] 24 | - name: neck 25 | prev: backbone 26 | type: models.necks.MFCN 27 | kwargs: 28 | outstrides: [16] 29 | - name: reconstruction 30 | prev: neck 31 | type: models.reconstructions.VisDecoder 32 | kwargs: 33 | block_type: basic 34 | -------------------------------------------------------------------------------- /experiments/vis_query/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | 3 | data: 4 | input_size: [224, 224] # [h,w] 5 | pixel_mean: [0.485, 0.456, 0.406] 6 | pixel_std: [0.229, 0.224, 0.225] 7 | 8 | saver: 9 | load_path: ../train_vis_decoder/{class_name}/checkpoints/ckpt.pth.tar 10 | save_dir: result_vis_query 11 | log_dir: log 12 | 13 | vis_query: 14 | model_path: ../MVTec-AD/checkpoints/ckpt.pth.tar 15 | hidden_dim: 256 16 | num_decoder_layers: 4 17 | with_text: True 18 | 19 | net: 20 | - name: backbone 21 | type: models.backbones.efficientnet_b4 22 | frozen: True 23 | kwargs: 24 | pretrained: True 25 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 26 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 27 | outlayers: [1,2,3,4] 28 | - name: neck 29 | prev: backbone 30 | type: models.necks.MFCN 31 | kwargs: 32 | outstrides: [16] 33 | - name: reconstruction 34 | prev: neck 35 | type: models.reconstructions.VisDecoder 36 | kwargs: 37 | block_type: basic 38 | -------------------------------------------------------------------------------- /utils/criterion_helper.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FeatureMSELoss(nn.Module): 5 | def __init__(self, weight): 6 | super().__init__() 7 | self.criterion_mse = nn.MSELoss() 8 | self.weight = weight 9 | 10 | def forward(self, input): 11 | feature_rec = input["feature_rec"] 12 | feature_align = input["feature_align"] 13 | return self.criterion_mse(feature_rec, feature_align) 14 | 15 | 16 | class ImageMSELoss(nn.Module): 17 | """Train a decoder for visualization of reconstructed features""" 18 | 19 | def __init__(self, weight): 20 | super().__init__() 21 | self.criterion_mse = nn.MSELoss() 22 | self.weight = weight 23 | 24 | def forward(self, input): 25 | image = input["image"] 26 | image_rec = input["image_rec"] 27 | return self.criterion_mse(image, image_rec) 28 | 29 | 30 | def build_criterion(config): 31 | loss_dict = {} 32 | for i in range(len(config)): 33 | cfg = config[i] 34 | loss_name = cfg["name"] 35 | loss_dict[loss_name] = globals()[cfg["type"]](**cfg["kwargs"]) 36 | return loss_dict 37 | -------------------------------------------------------------------------------- /datasets/image_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | 5 | 6 | class OpenCVReader: 7 | def __init__(self, image_dir, color_mode): 8 | self.image_dir = image_dir 9 | self.color_mode = color_mode 10 | assert color_mode in ["RGB", "BGR", "GRAY"], f"{color_mode} not supported" 11 | if color_mode != "BGR": 12 | self.cvt_color = getattr(cv2, f"COLOR_BGR2{color_mode}") 13 | else: 14 | self.cvt_color = None 15 | 16 | def __call__(self, filename, is_mask=False): 17 | filename = os.path.join(self.image_dir, filename) 18 | assert os.path.exists(filename), filename 19 | if is_mask: 20 | img = cv2.imread(filename, cv2.IMREAD_GRAYSCALE) 21 | return img 22 | img = cv2.imread(filename, cv2.IMREAD_COLOR) 23 | if self.color_mode != "BGR": 24 | img = cv2.cvtColor(img, self.cvt_color) 25 | return img 26 | 27 | 28 | def build_image_reader(cfg_reader): 29 | if cfg_reader["type"] == "opencv": 30 | return OpenCVReader(**cfg_reader["kwargs"]) 31 | else: 32 | raise TypeError("no supported image reader type: {}".format(cfg_reader["type"])) 33 | -------------------------------------------------------------------------------- /datasets/data_builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from datasets.cifar_dataset import build_cifar10_dataloader 4 | from datasets.custom_dataset import build_custom_dataloader 5 | from datasets.explicit_dataset import build_explicit_dataloader 6 | 7 | logger = logging.getLogger("global") 8 | 9 | 10 | def build(cfg, training, distributed): 11 | if training: 12 | cfg.update(cfg.get("train", {})) 13 | else: 14 | cfg.update(cfg.get("test", {})) 15 | 16 | dataset = cfg["type"] 17 | if dataset == "custom": 18 | data_loader = build_custom_dataloader(cfg, training, distributed) 19 | elif dataset == "cifar10": 20 | data_loader = build_cifar10_dataloader(cfg, training, distributed) 21 | elif dataset == 'explicit': 22 | data_loader = build_explicit_dataloader(cfg, training, distributed) 23 | else: 24 | raise NotImplementedError(f"{dataset} is not supported") 25 | 26 | return data_loader 27 | 28 | 29 | def build_dataloader(cfg_dataset, distributed=True): 30 | train_loader = None 31 | if cfg_dataset.get("train", None): 32 | train_loader = build(cfg_dataset, training=True, distributed=distributed) 33 | 34 | test_loader = None 35 | if cfg_dataset.get("test", None): 36 | test_loader = build(cfg_dataset, training=False, distributed=distributed) 37 | 38 | logger.info("build dataset done") 39 | return train_loader, test_loader 40 | -------------------------------------------------------------------------------- /utils/dist_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def setup_distributed(backend="nccl", port=None): 9 | """Initialize distributed training environment. 10 | support both slurm and torch.distributed.launch 11 | see torch.distributed.init_process_group() for more details 12 | """ 13 | num_gpus = torch.cuda.device_count() 14 | 15 | if "SLURM_JOB_ID" in os.environ: 16 | rank = int(os.environ["SLURM_PROCID"]) 17 | world_size = int(os.environ["SLURM_NTASKS"]) 18 | node_list = os.environ["SLURM_NODELIST"] 19 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 20 | # specify master port 21 | if port is not None: 22 | os.environ["MASTER_PORT"] = str(port) 23 | elif "MASTER_PORT" not in os.environ: 24 | os.environ["MASTER_PORT"] = "29500" 25 | if "MASTER_ADDR" not in os.environ: 26 | os.environ["MASTER_ADDR"] = addr 27 | os.environ["WORLD_SIZE"] = str(world_size) 28 | os.environ["LOCAL_RANK"] = str(rank % num_gpus) 29 | os.environ["RANK"] = str(rank) 30 | else: 31 | rank = int(os.environ["RANK"]) 32 | world_size = int(os.environ["WORLD_SIZE"]) 33 | 34 | torch.cuda.set_device(rank % num_gpus) 35 | 36 | dist.init_process_group( 37 | backend=backend, 38 | world_size=world_size, 39 | rank=rank, 40 | ) 41 | return rank, world_size 42 | -------------------------------------------------------------------------------- /models/necks/mfcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # MFCN: multi-scale feature concat network 5 | __all__ = ["MFCN"] 6 | 7 | 8 | class MFCN(nn.Module): 9 | def __init__(self, inplanes, outplanes, instrides, outstrides): 10 | super(MFCN, self).__init__() 11 | 12 | assert isinstance(inplanes, list) 13 | assert isinstance(outplanes, list) and len(outplanes) == 1 14 | assert isinstance(outstrides, list) and len(outstrides) == 1 15 | assert outplanes[0] == sum(inplanes) # concat 16 | self.inplanes = inplanes 17 | self.outplanes = outplanes 18 | self.instrides = instrides 19 | self.outstrides = outstrides 20 | self.scale_factors = [ 21 | in_stride / outstrides[0] for in_stride in instrides 22 | ] # for resize 23 | self.upsample_list = [ 24 | nn.UpsamplingBilinear2d(scale_factor=scale_factor) 25 | for scale_factor in self.scale_factors 26 | ] 27 | 28 | def forward(self, input): 29 | features = input["features"] 30 | assert len(self.inplanes) == len(features) 31 | 32 | feature_list = [] 33 | # resize & concatenate 34 | for i in range(len(features)): 35 | upsample = self.upsample_list[i] 36 | feature_resize = upsample(features[i]) 37 | feature_list.append(feature_resize) 38 | 39 | feature_align = torch.cat(feature_list, dim=1) 40 | 41 | return {"feature_align": feature_align, "outplane": self.get_outplanes()} 42 | 43 | def get_outplanes(self): 44 | return self.outplanes 45 | 46 | def get_outstrides(self): 47 | return self.outstrides 48 | -------------------------------------------------------------------------------- /experiments/train_vis_decoder/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 131 3 | port: 11111 4 | 5 | dataset: 6 | type: custom 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: ../../data/MVTec-AD/mvtec_anomaly_detection/ 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: ../../data/MVTec-AD/json_vis_decoder/test_{class_name}.json 16 | rebalance: False 17 | hflip: False 18 | vflip: False 19 | rotate: False 20 | 21 | input_size: [224,224] # [h,w] 22 | pixel_mean: [0.485, 0.456, 0.406] 23 | pixel_std: [0.229, 0.224, 0.225] 24 | batch_size: 32 25 | workers: 4 # number of workers of dataloader for each process 26 | 27 | criterion: 28 | - name: ImageMSELoss 29 | type: ImageMSELoss 30 | kwargs: 31 | weight: 1.0 32 | 33 | trainer: 34 | max_epoch: 1000 35 | clip_max_norm: 0.1 36 | print_freq_step: 1 37 | tb_freq_step: 1 38 | lr_scheduler: 39 | type: StepLR 40 | kwargs: 41 | step_size: 800 42 | gamma: 0.1 43 | optimizer: 44 | type: AdamW 45 | kwargs: 46 | lr: 0.0001 47 | betas: [0.9, 0.999] 48 | weight_decay: 0.0001 49 | 50 | # Optional, set False to disable 51 | visualization: 52 | vis_freq_epoch: 10 53 | vis_dir: vis 54 | 55 | saver: 56 | auto_resume: True 57 | always_save: False 58 | # load_path: checkpoints/ckpt_best.pth.tar 59 | save_dir: checkpoints/ 60 | log_dir: log/ 61 | 62 | frozen_layers: [backbone] 63 | 64 | net: 65 | - name: backbone 66 | type: models.backbones.efficientnet_b4 67 | frozen: True 68 | kwargs: 69 | pretrained: True 70 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 71 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 72 | outlayers: [1,2,3,4] 73 | - name: neck 74 | prev: backbone 75 | type: models.necks.MFCN 76 | kwargs: 77 | outstrides: [16] 78 | - name: reconstruction 79 | prev: neck 80 | type: models.reconstructions.VisDecoder 81 | kwargs: 82 | block_type: basic 83 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | 7 | import datasets.transforms as T 8 | 9 | 10 | class BaseDataset(Dataset): 11 | """ 12 | A dataset should implement 13 | 1. __len__ to get size of the dataset, Required 14 | 2. __getitem__ to get a single data, Required 15 | 16 | """ 17 | def __init__(self): 18 | super(BaseDataset, self).__init__() 19 | 20 | def __len__(self): 21 | raise NotImplementedError 22 | 23 | def __getitem__(self, idx): 24 | raise NotImplementedError 25 | 26 | 27 | class TrainBaseTransform(object): 28 | """ 29 | Resize, flip, rotation for image and mask 30 | """ 31 | def __init__(self, input_size, hflip, vflip, rotate): 32 | self.input_size = input_size # h x w 33 | self.hflip = hflip 34 | self.vflip = vflip 35 | self.rotate = rotate 36 | 37 | def __call__(self, image, mask): 38 | transform_fn = transforms.Resize(self.input_size, Image.BILINEAR) 39 | image = transform_fn(image) 40 | transform_fn = transforms.Resize(self.input_size, Image.NEAREST) 41 | mask = transform_fn(mask) 42 | if self.hflip: 43 | transform_fn = T.RandomHFlip() 44 | image, mask = transform_fn(image, mask) 45 | if self.vflip: 46 | transform_fn = T.RandomVFlip() 47 | image, mask = transform_fn(image, mask) 48 | if self.rotate: 49 | transform_fn = T.RandomRotation([0, 90, 180, 270]) 50 | image, mask = transform_fn(image, mask) 51 | return image, mask 52 | 53 | 54 | class TestBaseTransform(object): 55 | """ 56 | Resize for image and mask 57 | """ 58 | def __init__(self, input_size): 59 | self.input_size = input_size # h x w 60 | 61 | def __call__(self, image, mask): 62 | transform_fn = transforms.Resize(self.input_size, Image.BILINEAR) 63 | image = transform_fn(image) 64 | transform_fn = transforms.Resize(self.input_size, Image.NEAREST) 65 | mask = transform_fn(mask) 66 | return image, mask 67 | -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .efficientnet import * # noqa F401 2 | from .resnet import * # noqa F401 3 | 4 | backbone_info = { 5 | "resnet18": { 6 | "layers": [1, 2, 3, 4], 7 | "planes": [64, 128, 256, 512], 8 | "strides": [4, 8, 16, 32], 9 | }, 10 | "resnet34": { 11 | "layers": [1, 2, 3, 4], 12 | "planes": [64, 128, 256, 512], 13 | "strides": [4, 8, 16, 32], 14 | }, 15 | "resnet50": { 16 | "layers": [1, 2, 3, 4], 17 | "planes": [256, 512, 1024, 2048], 18 | "strides": [4, 8, 16, 32], 19 | }, 20 | "resnet101": { 21 | "layers": [1, 2, 3, 4], 22 | "planes": [256, 512, 1024, 2048], 23 | "strides": [4, 8, 16, 32], 24 | }, 25 | "wide_resnet50_2": { 26 | "layers": [1, 2, 3, 4], 27 | "planes": [256, 512, 1024, 2048], 28 | "strides": [4, 8, 16, 32], 29 | }, 30 | "efficientnet_b0": { 31 | "layers": [1, 2, 3, 4, 5], 32 | "blocks": [0, 2, 4, 10, 15], 33 | "planes": [16, 24, 40, 112, 320], 34 | "strides": [2, 4, 8, 16, 32], 35 | }, 36 | "efficientnet_b1": { 37 | "layers": [1, 2, 3, 4, 5], 38 | "blocks": [1, 4, 7, 15, 22], 39 | "planes": [16, 24, 40, 112, 320], 40 | "strides": [2, 4, 8, 16, 32], 41 | }, 42 | "efficientnet_b2": { 43 | "layers": [1, 2, 3, 4, 5], 44 | "blocks": [1, 4, 7, 15, 22], 45 | "planes": [16, 24, 48, 120, 352], 46 | "strides": [2, 4, 8, 16, 32], 47 | }, 48 | "efficientnet_b3": { 49 | "layers": [1, 2, 3, 4, 5], 50 | "blocks": [1, 4, 7, 17, 25], 51 | "planes": [24, 32, 48, 136, 384], 52 | "strides": [2, 4, 8, 16, 32], 53 | }, 54 | "efficientnet_b4": { 55 | "layers": [1, 2, 3, 4, 5], 56 | "blocks": [1, 5, 9, 21, 31], 57 | "planes": [24, 32, 56, 160, 448], 58 | "strides": [2, 4, 8, 16, 32], 59 | }, 60 | "efficientnet_b5": { 61 | "layers": [1, 2, 3, 4, 5], 62 | "blocks": [2, 7, 12, 26, 38], 63 | "planes": [24, 40, 64, 176, 512], 64 | "strides": [2, 4, 8, 16, 32], 65 | }, 66 | "efficientnet_b6": { 67 | "layers": [1, 2, 3, 4, 5], 68 | "blocks": [2, 8, 14, 30, 44], 69 | "planes": [32, 40, 72, 200, 576], 70 | "strides": [2, 4, 8, 16, 32], 71 | }, 72 | } 73 | -------------------------------------------------------------------------------- /models/model_helper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib 3 | 4 | import torch 5 | import torch.nn as nn 6 | from utils.misc_helper import to_device 7 | 8 | 9 | class ModelHelper(nn.Module): 10 | """Build model from cfg""" 11 | 12 | def __init__(self, cfg): 13 | super(ModelHelper, self).__init__() 14 | 15 | self.frozen_layers = [] 16 | for cfg_subnet in cfg: 17 | mname = cfg_subnet["name"] 18 | kwargs = cfg_subnet["kwargs"] 19 | mtype = cfg_subnet["type"] 20 | if cfg_subnet.get("frozen", False): 21 | self.frozen_layers.append(mname) 22 | if cfg_subnet.get("prev", None) is not None: 23 | prev_module = getattr(self, cfg_subnet["prev"]) 24 | kwargs["inplanes"] = prev_module.get_outplanes() 25 | kwargs["instrides"] = prev_module.get_outstrides() 26 | 27 | module = self.build(mtype, kwargs) 28 | self.add_module(mname, module) 29 | 30 | def build(self, mtype, kwargs): 31 | module_name, cls_name = mtype.rsplit(".", 1) 32 | module = importlib.import_module(module_name) 33 | cls = getattr(module, cls_name) 34 | return cls(**kwargs) 35 | 36 | def cuda(self): 37 | self.device = torch.device("cuda") 38 | return super(ModelHelper, self).cuda() 39 | 40 | def cpu(self): 41 | self.device = torch.device("cpu") 42 | return super(ModelHelper, self).cpu() 43 | 44 | def forward(self, input): 45 | input = copy.copy(input) 46 | if input["image"].device != self.device: 47 | input = to_device(input, device=self.device) 48 | for submodule in self.children(): 49 | output = submodule(input) 50 | input.update(output) 51 | return input 52 | 53 | def freeze_layer(self, module): 54 | module.eval() 55 | for param in module.parameters(): 56 | param.requires_grad = False 57 | 58 | def train(self, mode=True): 59 | """ 60 | Sets the module in training mode. 61 | This has any effect only on modules such as Dropout or BatchNorm. 62 | 63 | Returns: 64 | Module: self 65 | """ 66 | self.training = mode 67 | for mname, module in self.named_children(): 68 | if mname in self.frozen_layers: 69 | self.freeze_layer(module) 70 | else: 71 | module.train(mode) 72 | return self 73 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/01234/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 131 3 | port: 11111 4 | 5 | dataset: 6 | type: cifar10 7 | train: 8 | root_dir: ../../../data/CIFAR-10/ 9 | test: 10 | root_dir: ../../../data/CIFAR-10/ 11 | input_size: [224,224] # [h,w] 12 | normals: [0,1,2,3,4] 13 | batch_size: 16 14 | workers: 4 # number of workers of dataloader for each process 15 | 16 | criterion: 17 | - name: FeatureMSELoss 18 | type: FeatureMSELoss 19 | kwargs: 20 | weight: 1.0 21 | 22 | trainer: 23 | max_epoch: 1000 24 | clip_max_norm: 0.1 25 | val_freq_epoch: 10 26 | print_freq_step: 1 27 | tb_freq_step: 1 28 | lr_scheduler: 29 | type: StepLR 30 | kwargs: 31 | step_size: 800 32 | gamma: 0.1 33 | optimizer: 34 | type: AdamW 35 | kwargs: 36 | lr: 0.0001 37 | weight_decay: 0.0001 38 | 39 | saver: 40 | auto_resume: False 41 | always_save: False 42 | load_path: checkpoints/ckpt.pth.tar 43 | save_dir: checkpoints/ 44 | log_dir: log/ 45 | 46 | evaluator: 47 | save_dir: result_eval_temp 48 | key_metric: mean_mean_auc 49 | metrics: 50 | auc: 51 | - name: mean 52 | - name: std 53 | - name: max 54 | kwargs: 55 | avgpool_size: [16, 16] 56 | 57 | frozen_layers: [backbone] 58 | 59 | net: 60 | - name: backbone 61 | type: models.backbones.efficientnet_b4 62 | frozen: True 63 | kwargs: 64 | pretrained: True 65 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 66 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 67 | outlayers: [1,2,3,4,5] 68 | - name: neck 69 | prev: backbone 70 | type: models.necks.MFCN 71 | kwargs: 72 | outstrides: [16] 73 | - name: reconstruction 74 | prev: neck 75 | type: models.reconstructions.UniAD 76 | kwargs: 77 | pos_embed_type: learned 78 | hidden_dim: 256 79 | nhead: 8 80 | num_encoder_layers: 4 81 | num_decoder_layers: 4 82 | dim_feedforward: 1024 83 | dropout: 0.1 84 | activation: relu 85 | normalize_before: False 86 | feature_jitter: 87 | scale: 20.0 88 | prob: 1.0 89 | neighbor_mask: 90 | neighbor_size: [7,7] 91 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 92 | save_recon: False # save time 93 | initializer: 94 | method: xavier_uniform 95 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/02468/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 131 3 | port: 11111 4 | 5 | dataset: 6 | type: cifar10 7 | train: 8 | root_dir: ../../../data/CIFAR-10/ 9 | test: 10 | root_dir: ../../../data/CIFAR-10/ 11 | input_size: [224,224] # [h,w] 12 | normals: [0,2,4,6,8] 13 | batch_size: 16 14 | workers: 4 # number of workers of dataloader for each process 15 | 16 | criterion: 17 | - name: FeatureMSELoss 18 | type: FeatureMSELoss 19 | kwargs: 20 | weight: 1.0 21 | 22 | trainer: 23 | max_epoch: 1000 24 | clip_max_norm: 0.1 25 | val_freq_epoch: 10 26 | print_freq_step: 1 27 | tb_freq_step: 1 28 | lr_scheduler: 29 | type: StepLR 30 | kwargs: 31 | step_size: 800 32 | gamma: 0.1 33 | optimizer: 34 | type: AdamW 35 | kwargs: 36 | lr: 0.0001 37 | weight_decay: 0.0001 38 | 39 | saver: 40 | auto_resume: False 41 | always_save: False 42 | load_path: checkpoints/ckpt.pth.tar 43 | save_dir: checkpoints/ 44 | log_dir: log/ 45 | 46 | evaluator: 47 | save_dir: result_eval_temp 48 | key_metric: mean_mean_auc 49 | metrics: 50 | auc: 51 | - name: mean 52 | - name: std 53 | - name: max 54 | kwargs: 55 | avgpool_size: [16, 16] 56 | 57 | frozen_layers: [backbone] 58 | 59 | net: 60 | - name: backbone 61 | type: models.backbones.efficientnet_b4 62 | frozen: True 63 | kwargs: 64 | pretrained: True 65 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 66 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 67 | outlayers: [1,2,3,4,5] 68 | - name: neck 69 | prev: backbone 70 | type: models.necks.MFCN 71 | kwargs: 72 | outstrides: [16] 73 | - name: reconstruction 74 | prev: neck 75 | type: models.reconstructions.UniAD 76 | kwargs: 77 | pos_embed_type: learned 78 | hidden_dim: 256 79 | nhead: 8 80 | num_encoder_layers: 4 81 | num_decoder_layers: 4 82 | dim_feedforward: 1024 83 | dropout: 0.1 84 | activation: relu 85 | normalize_before: False 86 | feature_jitter: 87 | scale: 20.0 88 | prob: 1.0 89 | neighbor_mask: 90 | neighbor_size: [7,7] 91 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 92 | save_recon: False # save time 93 | initializer: 94 | method: xavier_uniform 95 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/13579/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 131 3 | port: 11111 4 | 5 | dataset: 6 | type: cifar10 7 | train: 8 | root_dir: ../../../data/CIFAR-10/ 9 | test: 10 | root_dir: ../../../data/CIFAR-10/ 11 | input_size: [224,224] # [h,w] 12 | normals: [1,3,5,7,9] 13 | batch_size: 16 14 | workers: 4 # number of workers of dataloader for each process 15 | 16 | criterion: 17 | - name: FeatureMSELoss 18 | type: FeatureMSELoss 19 | kwargs: 20 | weight: 1.0 21 | 22 | trainer: 23 | max_epoch: 1000 24 | clip_max_norm: 0.1 25 | val_freq_epoch: 10 26 | print_freq_step: 1 27 | tb_freq_step: 1 28 | lr_scheduler: 29 | type: StepLR 30 | kwargs: 31 | step_size: 800 32 | gamma: 0.1 33 | optimizer: 34 | type: AdamW 35 | kwargs: 36 | lr: 0.0001 37 | weight_decay: 0.0001 38 | 39 | saver: 40 | auto_resume: False 41 | always_save: False 42 | load_path: checkpoints/ckpt.pth.tar 43 | save_dir: checkpoints/ 44 | log_dir: log/ 45 | 46 | evaluator: 47 | save_dir: result_eval_temp 48 | key_metric: mean_mean_auc 49 | metrics: 50 | auc: 51 | - name: mean 52 | - name: std 53 | - name: max 54 | kwargs: 55 | avgpool_size: [16, 16] 56 | 57 | frozen_layers: [backbone] 58 | 59 | net: 60 | - name: backbone 61 | type: models.backbones.efficientnet_b4 62 | frozen: True 63 | kwargs: 64 | pretrained: True 65 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 66 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 67 | outlayers: [1,2,3,4,5] 68 | - name: neck 69 | prev: backbone 70 | type: models.necks.MFCN 71 | kwargs: 72 | outstrides: [16] 73 | - name: reconstruction 74 | prev: neck 75 | type: models.reconstructions.UniAD 76 | kwargs: 77 | pos_embed_type: learned 78 | hidden_dim: 256 79 | nhead: 8 80 | num_encoder_layers: 4 81 | num_decoder_layers: 4 82 | dim_feedforward: 1024 83 | dropout: 0.1 84 | activation: relu 85 | normalize_before: False 86 | feature_jitter: 87 | scale: 20.0 88 | prob: 1.0 89 | neighbor_mask: 90 | neighbor_size: [7,7] 91 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 92 | save_recon: False # save time 93 | initializer: 94 | method: xavier_uniform 95 | -------------------------------------------------------------------------------- /experiments/CIFAR-10/56789/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 131 3 | port: 11111 4 | 5 | dataset: 6 | type: cifar10 7 | train: 8 | root_dir: ../../../data/CIFAR-10/ 9 | test: 10 | root_dir: ../../../data/CIFAR-10/ 11 | input_size: [224,224] # [h,w] 12 | normals: [5,6,7,8,9] 13 | batch_size: 16 14 | workers: 4 # number of workers of dataloader for each process 15 | 16 | criterion: 17 | - name: FeatureMSELoss 18 | type: FeatureMSELoss 19 | kwargs: 20 | weight: 1.0 21 | 22 | trainer: 23 | max_epoch: 1000 24 | clip_max_norm: 0.1 25 | val_freq_epoch: 10 26 | print_freq_step: 1 27 | tb_freq_step: 1 28 | lr_scheduler: 29 | type: StepLR 30 | kwargs: 31 | step_size: 800 32 | gamma: 0.1 33 | optimizer: 34 | type: AdamW 35 | kwargs: 36 | lr: 0.0001 37 | weight_decay: 0.0001 38 | 39 | saver: 40 | auto_resume: False 41 | always_save: False 42 | load_path: checkpoints/ckpt.pth.tar 43 | save_dir: checkpoints/ 44 | log_dir: log/ 45 | 46 | evaluator: 47 | save_dir: result_eval_temp 48 | key_metric: mean_mean_auc 49 | metrics: 50 | auc: 51 | - name: mean 52 | - name: std 53 | - name: max 54 | kwargs: 55 | avgpool_size: [16, 16] 56 | 57 | frozen_layers: [backbone] 58 | 59 | net: 60 | - name: backbone 61 | type: models.backbones.efficientnet_b4 62 | frozen: True 63 | kwargs: 64 | pretrained: True 65 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 66 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 67 | outlayers: [1,2,3,4,5] 68 | - name: neck 69 | prev: backbone 70 | type: models.necks.MFCN 71 | kwargs: 72 | outstrides: [16] 73 | - name: reconstruction 74 | prev: neck 75 | type: models.reconstructions.UniAD 76 | kwargs: 77 | pos_embed_type: learned 78 | hidden_dim: 256 79 | nhead: 8 80 | num_encoder_layers: 4 81 | num_decoder_layers: 4 82 | dim_feedforward: 1024 83 | dropout: 0.1 84 | activation: relu 85 | normalize_before: False 86 | feature_jitter: 87 | scale: 20.0 88 | prob: 1.0 89 | neighbor_mask: 90 | neighbor_size: [7,7] 91 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 92 | save_recon: False # save time 93 | initializer: 94 | method: xavier_uniform 95 | -------------------------------------------------------------------------------- /models/initializer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from torch import nn 4 | 5 | 6 | def init_weights_normal(module, std=0.01): 7 | for m in module.modules(): 8 | if ( 9 | isinstance(m, nn.Conv2d) 10 | or isinstance(m, nn.Linear) 11 | or isinstance(m, nn.ConvTranspose2d) 12 | ): 13 | nn.init.normal_(m.weight.data, std=std) 14 | if m.bias is not None: 15 | m.bias.data.zero_() 16 | 17 | 18 | def init_weights_xavier(module, method): 19 | for m in module.modules(): 20 | if ( 21 | isinstance(m, nn.Conv2d) 22 | or isinstance(m, nn.Linear) 23 | or isinstance(m, nn.ConvTranspose2d) 24 | ): 25 | if "normal" in method: 26 | nn.init.xavier_normal_(m.weight.data) 27 | elif "uniform" in method: 28 | nn.init.xavier_uniform_(m.weight.data) 29 | else: 30 | raise NotImplementedError(f"{method} not supported") 31 | if m.bias is not None: 32 | m.bias.data.zero_() 33 | 34 | 35 | def init_weights_msra(module, method): 36 | for m in module.modules(): 37 | if ( 38 | isinstance(m, nn.Conv2d) 39 | or isinstance(m, nn.Linear) 40 | or isinstance(m, nn.ConvTranspose2d) 41 | ): 42 | if "normal" in method: 43 | nn.init.kaiming_normal_(m.weight.data, a=1) 44 | elif "uniform" in method: 45 | nn.init.kaiming_uniform_(m.weight.data, a=1) 46 | else: 47 | raise NotImplementedError(f"{method} not supported") 48 | if m.bias is not None: 49 | m.bias.data.zero_() 50 | 51 | 52 | def initialize(model, method, **kwargs): 53 | # initialize BN, Conv, & FC with different methods 54 | # initialize BN 55 | for m in model.modules(): 56 | if isinstance(m, nn.BatchNorm2d): 57 | m.weight.data.fill_(1) 58 | m.bias.data.zero_() 59 | 60 | # initialize Conv & FC 61 | if method == "normal": 62 | init_weights_normal(model, **kwargs) 63 | elif "msra" in method: 64 | init_weights_msra(model, method) 65 | elif "xavier" in method: 66 | init_weights_xavier(model, method) 67 | else: 68 | raise NotImplementedError(f"{method} not supported") 69 | 70 | 71 | def initialize_from_cfg(model, cfg): 72 | if cfg is None: 73 | initialize(model, "normal", std=0.01) 74 | return 75 | 76 | cfg = copy.deepcopy(cfg) 77 | method = cfg.pop("method") 78 | initialize(model, method, **cfg) 79 | -------------------------------------------------------------------------------- /experiments/MVTec-AD/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 133 3 | port: 11111 4 | 5 | dataset: 6 | type: custom 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: ../../data/MVTec-AD/mvtec_anomaly_detection/ 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: ../../data/MVTec-AD/train.json 16 | rebalance: False 17 | hflip: False 18 | vflip: False 19 | rotate: False 20 | 21 | test: 22 | meta_file: ../../data/MVTec-AD/test.json 23 | 24 | input_size: [224,224] # [h,w] 25 | pixel_mean: [0.485, 0.456, 0.406] 26 | pixel_std: [0.229, 0.224, 0.225] 27 | batch_size: 8 28 | workers: 4 # number of workers of dataloader for each process 29 | 30 | criterion: 31 | - name: FeatureMSELoss 32 | type: FeatureMSELoss 33 | kwargs: 34 | weight: 1.0 35 | 36 | trainer: 37 | max_epoch: 1000 38 | clip_max_norm: 0.1 39 | val_freq_epoch: 10 40 | print_freq_step: 1 41 | tb_freq_step: 1 42 | lr_scheduler: 43 | type: StepLR 44 | kwargs: 45 | step_size: 800 46 | gamma: 0.1 47 | optimizer: 48 | type: AdamW 49 | kwargs: 50 | lr: 0.0001 51 | betas: [0.9, 0.999] 52 | weight_decay: 0.0001 53 | 54 | saver: 55 | auto_resume: False 56 | always_save: False 57 | load_path: checkpoints/ckpt.pth.tar 58 | save_dir: checkpoints/ 59 | log_dir: log/ 60 | 61 | evaluator: 62 | save_dir: result_eval_temp 63 | key_metric: mean_pixel_auc 64 | metrics: 65 | auc: 66 | - name: std 67 | - name: max 68 | kwargs: 69 | avgpool_size: [16, 16] 70 | - name: pixel 71 | vis_compound: 72 | save_dir: vis_compound 73 | max_score: null 74 | min_score: null 75 | # vis_single: 76 | # save_dir: vis_single 77 | # max_score: null 78 | # min_score: null 79 | 80 | frozen_layers: [backbone] 81 | 82 | net: 83 | - name: backbone 84 | type: models.backbones.efficientnet_b4 85 | frozen: True 86 | kwargs: 87 | pretrained: True 88 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 89 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 90 | outlayers: [1,2,3,4] 91 | - name: neck 92 | prev: backbone 93 | type: models.necks.MFCN 94 | kwargs: 95 | outstrides: [16] 96 | - name: reconstruction 97 | prev: neck 98 | type: models.reconstructions.UniAD 99 | kwargs: 100 | pos_embed_type: learned 101 | hidden_dim: 256 102 | nhead: 8 103 | num_encoder_layers: 4 104 | num_decoder_layers: 4 105 | dim_feedforward: 1024 106 | dropout: 0.1 107 | activation: relu 108 | normalize_before: False 109 | feature_jitter: 110 | scale: 20.0 111 | prob: 1.0 112 | neighbor_mask: 113 | neighbor_size: [7,7] 114 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 115 | save_recon: 116 | save_dir: result_recon 117 | initializer: 118 | method: xavier_uniform 119 | -------------------------------------------------------------------------------- /models/backbones/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | """__init__.py - all efficientnet models. 2 | """ 3 | 4 | # Author: lukemelas (github username) 5 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 6 | # With adjustments and added comments by workingcoder (github username). 7 | 8 | 9 | __version__ = "0.7.1" 10 | from .model import EfficientNet 11 | 12 | __all__ = [ 13 | "EfficientNet", 14 | "efficientnet_b0", 15 | "efficientnet_b1", 16 | "efficientnet_b2", 17 | "efficientnet_b3", 18 | "efficientnet_b4", 19 | "efficientnet_b5", 20 | "efficientnet_b6", 21 | "efficientnet_b7", 22 | "efficientnet_b8", 23 | "efficientnet_l2", 24 | ] 25 | 26 | 27 | def efficientnet_b0(pretrained, outblocks, outstrides, pretrained_model=""): 28 | return build_efficient( 29 | "efficientnet-b0", pretrained, outblocks, outstrides, pretrained_model 30 | ) 31 | 32 | 33 | def efficientnet_b1(pretrained, outblocks, outstrides, pretrained_model=""): 34 | return build_efficient( 35 | "efficientnet-b1", pretrained, outblocks, outstrides, pretrained_model 36 | ) 37 | 38 | 39 | def efficientnet_b2(pretrained, outblocks, outstrides, pretrained_model=""): 40 | return build_efficient( 41 | "efficientnet-b2", pretrained, outblocks, outstrides, pretrained_model 42 | ) 43 | 44 | 45 | def efficientnet_b3(pretrained, outblocks, outstrides, pretrained_model=""): 46 | return build_efficient( 47 | "efficientnet-b3", pretrained, outblocks, outstrides, pretrained_model 48 | ) 49 | 50 | 51 | def efficientnet_b4(pretrained, outblocks, outstrides, pretrained_model=""): 52 | return build_efficient( 53 | "efficientnet-b4", pretrained, outblocks, outstrides, pretrained_model 54 | ) 55 | 56 | 57 | def efficientnet_b5(pretrained, outblocks, outstrides, pretrained_model=""): 58 | return build_efficient( 59 | "efficientnet-b5", pretrained, outblocks, outstrides, pretrained_model 60 | ) 61 | 62 | 63 | def efficientnet_b6(pretrained, outblocks, outstrides, pretrained_model=""): 64 | return build_efficient( 65 | "efficientnet-b6", pretrained, outblocks, outstrides, pretrained_model 66 | ) 67 | 68 | 69 | def efficientnet_b7(pretrained, outblocks, outstrides, pretrained_model=""): 70 | return build_efficient( 71 | "efficientnet-b7", pretrained, outblocks, outstrides, pretrained_model 72 | ) 73 | 74 | 75 | def efficientnet_b8(pretrained, outblocks, outstrides, pretrained_model=""): 76 | return build_efficient( 77 | "efficientnet-b8", pretrained, outblocks, outstrides, pretrained_model 78 | ) 79 | 80 | 81 | def efficientnet_l2(pretrained, outblocks, outstrides, pretrained_model=""): 82 | return build_efficient( 83 | "efficientnet-l2", pretrained, outblocks, outstrides, pretrained_model 84 | ) 85 | 86 | 87 | def build_efficient(model_name, pretrained, outblocks, outstrides, pretrained_model=""): 88 | if pretrained: 89 | model = EfficientNet.from_pretrained( 90 | model_name, 91 | outblocks=outblocks, 92 | outstrides=outstrides, 93 | pretrained_model=pretrained_model, 94 | ) 95 | else: 96 | model = EfficientNet.from_name( 97 | model_name, outblocks=outblocks, outstrides=outstrides 98 | ) 99 | return model 100 | -------------------------------------------------------------------------------- /utils/vis_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | from datasets.image_reader import build_image_reader 6 | 7 | 8 | def normalize(pred, max_value=None, min_value=None): 9 | if max_value is None or min_value is None: 10 | return (pred - pred.min()) / (pred.max() - pred.min()) 11 | else: 12 | return (pred - min_value) / (max_value - min_value) 13 | 14 | 15 | def apply_ad_scoremap(image, scoremap, alpha=0.5): 16 | np_image = np.asarray(image, dtype=np.float32) 17 | scoremap = (scoremap * 255).astype(np.uint8) 18 | scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) 19 | scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) 20 | return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8) 21 | 22 | 23 | def visualize_compound(fileinfos, preds, masks, cfg_vis, cfg_reader): 24 | vis_dir = cfg_vis.save_dir 25 | max_score = cfg_vis.get("max_score", None) 26 | min_score = cfg_vis.get("min_score", None) 27 | max_score = preds.max() if not max_score else max_score 28 | min_score = preds.min() if not min_score else min_score 29 | 30 | image_reader = build_image_reader(cfg_reader) 31 | 32 | for i, fileinfo in enumerate(fileinfos): 33 | clsname = fileinfo["clsname"] 34 | filename = fileinfo["filename"] 35 | filedir, filename = os.path.split(filename) 36 | _, defename = os.path.split(filedir) 37 | save_dir = os.path.join(vis_dir, clsname, defename) 38 | os.makedirs(save_dir, exist_ok=True) 39 | 40 | # read image 41 | h, w = int(fileinfo["height"]), int(fileinfo["width"]) 42 | image = image_reader(fileinfo["filename"]) 43 | pred = preds[i][:, :, None].repeat(3, 2) 44 | pred = cv2.resize(pred, (w, h)) 45 | 46 | # self normalize just for analysis 47 | scoremap_self = apply_ad_scoremap(image, normalize(pred)) 48 | # global normalize 49 | pred = np.clip(pred, min_score, max_score) 50 | pred = normalize(pred, max_score, min_score) 51 | scoremap_global = apply_ad_scoremap(image, pred) 52 | 53 | if masks is not None: 54 | mask = (masks[i] * 255).astype(np.uint8)[:, :, None].repeat(3, 2) 55 | mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) 56 | save_path = os.path.join(save_dir, filename) 57 | if mask.sum() == 0: 58 | scoremap = np.vstack([image, scoremap_global]) 59 | else: 60 | scoremap = np.vstack([image, mask, scoremap_global, scoremap_self]) 61 | else: 62 | scoremap = np.vstack([image, scoremap_global, scoremap_self]) 63 | 64 | scoremap = cv2.cvtColor(scoremap, cv2.COLOR_RGB2BGR) 65 | cv2.imwrite(save_path, scoremap) 66 | 67 | 68 | def visualize_single(fileinfos, preds, cfg_vis, cfg_reader): 69 | vis_dir = cfg_vis.save_dir 70 | max_score = cfg_vis.get("max_score", None) 71 | min_score = cfg_vis.get("min_score", None) 72 | max_score = preds.max() if not max_score else max_score 73 | min_score = preds.min() if not min_score else min_score 74 | 75 | image_reader = build_image_reader(cfg_reader) 76 | 77 | for i, fileinfo in enumerate(fileinfos): 78 | clsname = fileinfo["clsname"] 79 | filename = fileinfo["filename"] 80 | filedir, filename = os.path.split(filename) 81 | _, defename = os.path.split(filedir) 82 | save_dir = os.path.join(vis_dir, clsname, defename) 83 | os.makedirs(save_dir, exist_ok=True) 84 | 85 | # read image 86 | h, w = int(fileinfo["height"]), int(fileinfo["width"]) 87 | image = image_reader(fileinfo["filename"]) 88 | pred = preds[i][:, :, None].repeat(3, 2) 89 | pred = cv2.resize(pred, (w, h)) 90 | 91 | # write global normalize image 92 | pred = np.clip(pred, min_score, max_score) 93 | pred = normalize(pred, max_score, min_score) 94 | scoremap_global = apply_ad_scoremap(image, pred) 95 | 96 | save_path = os.path.join(save_dir, filename) 97 | scoremap_global = cv2.cvtColor(scoremap_global, cv2.COLOR_RGB2BGR) 98 | cv2.imwrite(save_path, scoremap_global) 99 | -------------------------------------------------------------------------------- /datasets/custom_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import json 4 | import logging 5 | 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data.sampler import RandomSampler 12 | 13 | from datasets.base_dataset import BaseDataset, TestBaseTransform, TrainBaseTransform 14 | from datasets.image_reader import build_image_reader 15 | from datasets.transforms import RandomColorJitter 16 | 17 | logger = logging.getLogger("global_logger") 18 | 19 | 20 | def build_custom_dataloader(cfg, training, distributed=True): 21 | 22 | image_reader = build_image_reader(cfg.image_reader) 23 | 24 | normalize_fn = transforms.Normalize(mean=cfg["pixel_mean"], std=cfg["pixel_std"]) 25 | if training: 26 | transform_fn = TrainBaseTransform( 27 | cfg["input_size"], cfg["hflip"], cfg["vflip"], cfg["rotate"] 28 | ) 29 | else: 30 | transform_fn = TestBaseTransform(cfg["input_size"]) 31 | 32 | colorjitter_fn = None 33 | if cfg.get("colorjitter", None) and training: 34 | colorjitter_fn = RandomColorJitter.from_params(cfg["colorjitter"]) 35 | 36 | logger.info("building CustomDataset from: {}".format(cfg["meta_file"])) 37 | 38 | dataset = CustomDataset( 39 | image_reader, 40 | cfg["meta_file"], 41 | training, 42 | transform_fn=transform_fn, 43 | normalize_fn=normalize_fn, 44 | colorjitter_fn=colorjitter_fn, 45 | ) 46 | 47 | if distributed: 48 | sampler = DistributedSampler(dataset) 49 | else: 50 | sampler = RandomSampler(dataset) 51 | 52 | data_loader = DataLoader( 53 | dataset, 54 | batch_size=cfg["batch_size"], 55 | num_workers=cfg["workers"], 56 | pin_memory=True, 57 | sampler=sampler, 58 | ) 59 | 60 | return data_loader 61 | 62 | 63 | class CustomDataset(BaseDataset): 64 | def __init__( 65 | self, 66 | image_reader, 67 | meta_file, 68 | training, 69 | transform_fn, 70 | normalize_fn, 71 | colorjitter_fn=None, 72 | ): 73 | self.image_reader = image_reader 74 | self.meta_file = meta_file 75 | self.training = training 76 | self.transform_fn = transform_fn 77 | self.normalize_fn = normalize_fn 78 | self.colorjitter_fn = colorjitter_fn 79 | 80 | # construct metas 81 | with open(meta_file, "r") as f_r: 82 | self.metas = [] 83 | for line in f_r: 84 | meta = json.loads(line) 85 | self.metas.append(meta) 86 | 87 | def __len__(self): 88 | return len(self.metas) 89 | 90 | def __getitem__(self, index): 91 | input = {} 92 | meta = self.metas[index] 93 | 94 | # read image 95 | filename = meta["filename"] 96 | label = meta["label"] 97 | image = self.image_reader(meta["filename"]) 98 | input.update( 99 | { 100 | "filename": filename, 101 | "height": image.shape[0], 102 | "width": image.shape[1], 103 | "label": label, 104 | } 105 | ) 106 | 107 | if meta.get("clsname", None): 108 | input["clsname"] = meta["clsname"] 109 | else: 110 | input["clsname"] = filename.split("/")[-4] 111 | 112 | image = Image.fromarray(image, "RGB") 113 | 114 | # read / generate mask 115 | if meta.get("maskname", None): 116 | mask = self.image_reader(meta["maskname"], is_mask=True) 117 | else: 118 | if label == 0: # good 119 | mask = np.zeros((image.height, image.width)).astype(np.uint8) 120 | elif label == 1: # defective 121 | mask = (np.ones((image.height, image.width)) * 255).astype(np.uint8) 122 | else: 123 | raise ValueError("Labels must be [None, 0, 1]!") 124 | 125 | mask = Image.fromarray(mask, "L") 126 | 127 | if self.transform_fn: 128 | image, mask = self.transform_fn(image, mask) 129 | if self.colorjitter_fn: 130 | image = self.colorjitter_fn(image) 131 | image = transforms.ToTensor()(image) 132 | mask = transforms.ToTensor()(mask) 133 | if self.normalize_fn: 134 | image = self.normalize_fn(image) 135 | input.update({"image": image, "mask": mask}) 136 | return input 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-IAD Dataset 2 | Official experiment example of [Real-IAD](https://realiad4ad.github.io/Real-IAD) Dataset using [UniAD](README_uniad.md) 3 | 4 | ## 1. Preparation 5 | 6 | ### 1.1. Download the decompress the dataset 7 | - Download jsons of Real-IAD dataset (named `realiad_jsons.zip`) and extract into `data/Real-IAD/` 8 | - Download images (of resolution 1024 pixels) of Real-IAD dataset (one ZIP archive per object) and extract them into `data/Real-IAD/realiad_1024/` 9 | - [Optional] Download images (original resolution) of Real-IAD dataset (one ZIP archive per object) and extract them into `data/Real-IAD/realiad_raw/` if you want to conduct experiments on the raw images 10 | 11 | The Real-IAD dataset directory should be as follow: (`audiojack` is one of the 30 objects in Real-IAD) 12 | ```shell 13 | data 14 | └── Real-IAD 15 | ├── realiad_1024 16 | │ ├── audiojack 17 | │ │ │── *.jpg 18 | │ │ │── *.png 19 | │ │ ... 20 | │ ... 21 | ├── realiad_jsons 22 | │ ├── audiojack.json 23 | │ ... 24 | ├── realiad_jsons_sv 25 | │ ├── audiojack.json 26 | │ ... 27 | ├── realiad_jsons_fuiad_0.0 28 | │ ├── audiojack.json 29 | │ ... 30 | ├── realiad_jsons_fuiad_0.1 31 | │ ├── audiojack.json 32 | │ ... 33 | ├── realiad_jsons_fuiad_0.2 34 | │ ├── audiojack.json 35 | │ ... 36 | ├── realiad_jsons_fuiad_0.4 37 | │ ├── audiojack.json 38 | │ ... 39 | └── realiad_raw 40 | ├── audiojack 41 | │ │── *.jpg 42 | │ │── *.png 43 | │ ... 44 | ... 45 | ``` 46 | 47 | ### 1.2. Setup environment 48 | Setup `python` environments following `requirements.txt`. We have tested the code under the environment with packages of versions listed below: 49 | ```text 50 | einops==0.4.1 51 | scikit-learn==0.24.2 52 | scipy==1.9.1 53 | tabulate==0.8.10 54 | timm==0.6.12 55 | torch==1.13.1+cu117 56 | torchvision==0.14.1+cu117 57 | ``` 58 | You may change them if you have to and should adjust the code accordingly. 59 | 60 | ## 2. Training 61 | We provide config for Single-View/Multi-View UIAD and FUIAD, they are located under `experiments` directory as follow: 62 | ```shell 63 | experiments 64 | ├── RealIAD-C1 # Single-View UIAD 65 | ├── RealIAD-fuad-n0 # FUIAD (NR=0.0) 66 | ├── RealIAD-fuad-n1 # FUIAD (NR=0.1) 67 | ├── RealIAD-fuad-n2 # FUIAD (NR=0.2) 68 | ├── RealIAD-fuad-n4 # FUIAD (NR=0.4) 69 | ├── RealIAD-full # Multi-View UIAD 70 | ... 71 | ``` 72 | - Single-View UIAD: 73 | ```shell 74 | cd experiments/RealIAD-C1 && train_torch.sh 8 0,1,2,3,4,5,6,7 75 | # run locally with 8 GPUs 76 | ``` 77 | - Multi-View UIAD: 78 | ```shell 79 | cd experiments/RealIAD-full && train_torch.sh 8 0,1,2,3,4,5,6,7 80 | # run locally with 8 GPUs 81 | ``` 82 | - FUIAD: 83 | ```shell 84 | # under bash 85 | pushd experiments/RealIAD-fuad-n0 && train_torch.sh 8 0,1,2,3,4,5,6,7 && popd 86 | pushd experiments/RealIAD-fuad-n1 && train_torch.sh 8 0,1,2,3,4,5,6,7 && popd 87 | pushd experiments/RealIAD-fuad-n2 && train_torch.sh 8 0,1,2,3,4,5,6,7 && popd 88 | pushd experiments/RealIAD-fuad-n4 && train_torch.sh 8 0,1,2,3,4,5,6,7 && popd 89 | # run locally with 8 GPUs 90 | ``` 91 | 92 | - [Optional] Experiments on Images of Original Resolution 93 | 94 | To conduct experiments on images of original resolution, change the config value `dataset.image_reader.kwargs.image_dir` from `data/Real-IAD/realiad_1024` to `data/Real-IAD/realiad_raw` in config file `experiments/{your_setting}/config.yaml` 95 | 96 | ## 3. Evaluating 97 | After training finished, ano-map of evaluation set is generated under `experiments/{your_setting}/checkpoints/` and store in `*.pkl` files, one file per object. Then use [ADEval](https://pypi.org/project/ADEval/) to evaluate the result. 98 | 99 | - Install ADEval 100 | ```shell 101 | python3 -m pip install ADEval 102 | ``` 103 | 104 | - Execute the evaluate command 105 | 106 | Take Multi-View UIAD as an example: 107 | 108 | ```shell 109 | # calculate S-AUROC, I-AUROC and P-AUPRO for each object 110 | find experiments/RealAD-full/checkpoints/ | \ 111 | grep pkl$ | sort | \ 112 | xargs -n 1 python3 -m adeval --sample_key_pat "([a-zA-Z][a-zA-Z0-9_]*_[0-9]{4}_[A-Z][A-Z_]*[A-Z])_C[0-9]_" 113 | ``` 114 | > Note: the argument `--sample_key_pat` is identical for all experiment settings of Real-IAD 115 | 116 | ## Acknowledgement 117 | This repo is built on the top of Offical Implementation of [UniAD](https://github.com/zhiyuanyou/UniAD.git), which use some codes from repositories including [detr](https://github.com/facebookresearch/detr) and [efficientnet](https://github.com/lukemelas/EfficientNet-PyTorch). 118 | 119 | ## Notice 120 | The copyright notice pertaining to the Tencent code in this repo was previously in the name of "THL A29 Limited." That entity has now been de-registered. You should treat all previously distributed copies of the code as if the copyright notice was in the name of "Tencent". 121 | -------------------------------------------------------------------------------- /data/MVTec-AD/json_vis_decoder/test_toothbrush.json: -------------------------------------------------------------------------------- 1 | {"filename": "toothbrush/test/good/004.png", "label": 0, "label_name": "good"} 2 | {"filename": "toothbrush/test/good/007.png", "label": 0, "label_name": "good"} 3 | {"filename": "toothbrush/test/good/006.png", "label": 0, "label_name": "good"} 4 | {"filename": "toothbrush/test/good/002.png", "label": 0, "label_name": "good"} 5 | {"filename": "toothbrush/test/good/010.png", "label": 0, "label_name": "good"} 6 | {"filename": "toothbrush/test/good/003.png", "label": 0, "label_name": "good"} 7 | {"filename": "toothbrush/test/good/009.png", "label": 0, "label_name": "good"} 8 | {"filename": "toothbrush/test/good/000.png", "label": 0, "label_name": "good"} 9 | {"filename": "toothbrush/test/good/011.png", "label": 0, "label_name": "good"} 10 | {"filename": "toothbrush/test/good/008.png", "label": 0, "label_name": "good"} 11 | {"filename": "toothbrush/test/good/005.png", "label": 0, "label_name": "good"} 12 | {"filename": "toothbrush/test/good/001.png", "label": 0, "label_name": "good"} 13 | {"filename": "toothbrush/test/defective/023.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/023_mask.png"} 14 | {"filename": "toothbrush/test/defective/012.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/012_mask.png"} 15 | {"filename": "toothbrush/test/defective/024.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/024_mask.png"} 16 | {"filename": "toothbrush/test/defective/004.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/004_mask.png"} 17 | {"filename": "toothbrush/test/defective/007.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/007_mask.png"} 18 | {"filename": "toothbrush/test/defective/006.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/006_mask.png"} 19 | {"filename": "toothbrush/test/defective/002.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/002_mask.png"} 20 | {"filename": "toothbrush/test/defective/022.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/022_mask.png"} 21 | {"filename": "toothbrush/test/defective/010.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/010_mask.png"} 22 | {"filename": "toothbrush/test/defective/017.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/017_mask.png"} 23 | {"filename": "toothbrush/test/defective/019.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/019_mask.png"} 24 | {"filename": "toothbrush/test/defective/003.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/003_mask.png"} 25 | {"filename": "toothbrush/test/defective/027.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/027_mask.png"} 26 | {"filename": "toothbrush/test/defective/015.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/015_mask.png"} 27 | {"filename": "toothbrush/test/defective/018.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/018_mask.png"} 28 | {"filename": "toothbrush/test/defective/021.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/021_mask.png"} 29 | {"filename": "toothbrush/test/defective/026.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/026_mask.png"} 30 | {"filename": "toothbrush/test/defective/025.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/025_mask.png"} 31 | {"filename": "toothbrush/test/defective/009.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/009_mask.png"} 32 | {"filename": "toothbrush/test/defective/028.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/028_mask.png"} 33 | {"filename": "toothbrush/test/defective/016.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/016_mask.png"} 34 | {"filename": "toothbrush/test/defective/000.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/000_mask.png"} 35 | {"filename": "toothbrush/test/defective/011.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/011_mask.png"} 36 | {"filename": "toothbrush/test/defective/020.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/020_mask.png"} 37 | {"filename": "toothbrush/test/defective/008.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/008_mask.png"} 38 | {"filename": "toothbrush/test/defective/013.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/013_mask.png"} 39 | {"filename": "toothbrush/test/defective/014.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/014_mask.png"} 40 | {"filename": "toothbrush/test/defective/005.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/005_mask.png"} 41 | {"filename": "toothbrush/test/defective/029.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/029_mask.png"} 42 | {"filename": "toothbrush/test/defective/001.png", "label": 1, "label_name": "defective", "maskname": "toothbrush/ground_truth/defective/001_mask.png"} 43 | -------------------------------------------------------------------------------- /README_uniad.md: -------------------------------------------------------------------------------- 1 | # UniAD 2 | Official PyTorch Implementation of [A Unified Model for Multi-class Anomaly Detection](https://arxiv.org/abs/2206.03687), Accepted by NeurIPS 2022 Spotlight. 3 | 4 | ![Image text](docs/setting.jpg) 5 | ![Image text](docs/res_mvtec.jpg) 6 | 7 | ## 1. Quick Start 8 | 9 | ### 1.1 MVTec-AD 10 | 11 | - **Create the MVTec-AD dataset directory**. Download the MVTec-AD dataset from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad). Unzip the file and move some to `./data/MVTec-AD/`. The MVTec-AD dataset directory should be as follows. 12 | 13 | ``` 14 | |-- data 15 | |-- MVTec-AD 16 | |-- mvtec_anomaly_detection 17 | |-- json_vis_decoder 18 | |-- train.json 19 | |-- test.json 20 | ``` 21 | 22 | - **cd the experiment directory** by running `cd ./experiments/MVTec-AD/`. 23 | 24 | - **Train or eval** by running: 25 | 26 | (1) For slurm group: `sh train.sh #NUM_GPUS #PARTITION` or `sh eval.sh #NUM_GPUS #PARTITION`. 27 | 28 | (2) For torch.distributed.launch: `sh train_torch.sh #NUM_GPUS #GPU_IDS` or `sh eval_torch.sh #NUM_GPUS #GPU_IDS`, *e.g.*, train with GPUs 1,3,4,6 (4 GPUs in total): `sh train_torch.sh 4 1,3,4,6`. 29 | 30 | **Note**: During eval, please *set config.saver.load_path* to load the checkpoints. 31 | 32 | - **Results and checkpoints**. 33 | 34 | | Platform | GPU | Detection AUROC | Localization AUROC | Checkpoints | Note | 35 | | ------ | ------ | ------ | ------ | ------ | ------ | 36 | | slurm group | 8 GPUs (NVIDIA Tesla V100 16GB)| 96.7 | 96.8 | [here](https://drive.google.com/file/d/1q03ysv_5VJATlDN-A-c9zvcTuyEeaQHG/view?usp=sharing) | ***A unified model for all categories*** | 37 | | torch.distributed.launch | 1 GPU (NVIDIA GeForce GTX 1080 Ti 11 GB)| 97.6 | 97.0 | [here](https://drive.google.com/file/d/1v282ZlibC-b0H9sjLUlOSCFNzEv-TIuh/view?usp=sharing) | ***A unified model for all categories*** | 38 | 39 | 40 | ### 1.2 CIFAR-10 41 | 42 | - **Create the CIFAR-10 dataset directory**. Download the CIFAR-10 dataset from [here](http://www.cs.toronto.edu/~kriz/cifar.html). Unzip the file and move some to `./data/CIFAR-10/`. The CIFAR-10 dataset directory should be as follows. 43 | 44 | ``` 45 | |-- data 46 | |-- CIFAR-10 47 | |-- cifar-10-batches-py 48 | ``` 49 | 50 | - **cd the experiment directory** by running `cd ./experiments/CIFAR-10/01234/`. Here we take class 0,1,2,3,4 as normal samples, and other settings are similar. 51 | 52 | - **Train or eval** by running: 53 | 54 | (1) For slurm group: `sh train.sh #NUM_GPUS #PARTITION` or `sh eval.sh #NUM_GPUS #PARTITION`. 55 | 56 | (2) For torch.distributed.launch: `sh train_torch.sh #NUM_GPUS #GPU_IDS` or `sh eval_torch.sh #NUM_GPUS #GPU_IDS`. 57 | 58 | **Note**: During eval, please *set config.saver.load_path* to load the checkpoints. 59 | 60 | - **Results and checkpoints**. Training on 8 GPUs (NVIDIA Tesla V100 16GB) results in following performance. 61 | 62 | | Normal Samples | {01234} | {56789} | {02468} | {13579} | Mean | 63 | | ------ | ------ | ------ | ------ | ------ | ------ | 64 | | AUROC | 84.4 | 79.6 | 93.0 | 89.1 | 86.5 | 65 | 66 | 67 | ## 2. Visualize Reconstructed Features 68 | 69 | We **highly recommend** to visualize reconstructed features, since this could directly prove that our UniAD *reconstructs anomalies to their corresponding normal samples*. 70 | 71 | ### 2.1 Train Decoders for Visualization 72 | 73 | - **cd the experiment directory** by running `cd ./experiments/train_vis_decoder/`. 74 | 75 | - **Train** by running: 76 | 77 | (1) For slurm group: `sh train.sh #NUM_GPUS #PARTITION`. 78 | 79 | (2) For torch.distributed.launch: `sh train_torch.sh #NUM_GPUS #GPU_IDS #CLASS_NAME`. 80 | 81 | **Note**: for torch.distributed.launch, you should *train one vis_decoder for a specific class for one time*. 82 | 83 | ### 2.2 Visualize Reconstructed Features 84 | 85 | - **cd the experiment directory** by running `cd ./experiments/vis_recon/`. 86 | 87 | - **Visualize** by running (only support 1 GPU): 88 | 89 | (1) For slurm group: `sh vis_recon.sh #PARTITION`. 90 | 91 | (2) For torch.distributed.launch: `sh vis_recon_torch.sh #CLASS_NAME`. 92 | 93 | **Note**: for torch.distributed.launch, you should *visualize a specific class for one time*. 94 | 95 | ## 3. Questions 96 | 97 | ### 3.1 Explanation of Evaluation Results 98 | 99 | The first line of the evaluation results are shown as follows. 100 | 101 | | clsname | pixel | mean | max | std | 102 | |:----------:|:--------:|:--------:|:--------:|:--------:| 103 | 104 | The *pixel* means anomaly localization results. 105 | 106 | The *mean*, *max*, and *std* mean **post-processing methods** for anomaly detection. That is to say, the anomaly localization result is an anomaly map with the shape of *H x W*. We need to *convert this map to a scalar* as the anomaly score for this whole image. For this convert, you have 3 options: 107 | 108 | - use the *mean* value of the anomaly map. 109 | - use the *max* value of the (averagely pooled) anomaly map. 110 | - use the *std* value of the anomaly map. 111 | 112 | In our paper, we use *max* for MVTec-AD and *mean* for CIFAR-10. 113 | 114 | ### 3.2 Visualize Learned Query Embedding 115 | 116 | If you have finished the training of the main model and decoders (used for visualization) for MVTec-AD, you could also choose to visualize the learned query embedding in the main model. 117 | 118 | - **cd the experiment directory** by running `cd ./experiments/vis_query/`. 119 | 120 | - **Visualize** by running (only support 1 GPU): 121 | 122 | (1) For slurm group: `sh vis_query.sh #PARTITION`. 123 | 124 | (2) For torch.distributed.launch: `sh vis_query_torch.sh #CLASS_NAME`. 125 | 126 | **Note**: for torch.distributed.launch, you should *visualize a specific class for one time*. 127 | 128 | Some results are very interesting. The learned query embedding partly contains some features of normal samples. However, we ***did not*** fully figure out this and this part ***was not*** included in our paper. 129 | 130 | ![Image text](docs/query_bottle.jpg) 131 | ![Image text](docs/query_capsule.jpg) 132 | 133 | ## Acknowledgement 134 | 135 | We use some codes from repositories including [detr](https://github.com/facebookresearch/detr) and [efficientnet](https://github.com/lukemelas/EfficientNet-PyTorch). 136 | -------------------------------------------------------------------------------- /tools/vis_recon.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import importlib 4 | import os 5 | import pprint 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.optim 11 | import yaml 12 | from easydict import EasyDict 13 | from utils.misc_helper import create_logger 14 | 15 | parser = argparse.ArgumentParser(description="UniAD") 16 | parser.add_argument("--config", default="./config.yaml") 17 | parser.add_argument("--class_name", default="") 18 | 19 | 20 | def update_config(config): 21 | # update planes & strides 22 | backbone_path, backbone_type = config.net[0].type.rsplit(".", 1) 23 | module = importlib.import_module(backbone_path) 24 | backbone_info = getattr(module, "backbone_info") 25 | backbone = backbone_info[backbone_type] 26 | outplanes = [] 27 | for layer in config.net[0].kwargs.outlayers: 28 | if layer not in backbone["layers"]: 29 | raise ValueError( 30 | "only layer {} for backbone {} is allowed, but get {}!".format( 31 | backbone["layers"], backbone_type, layer 32 | ) 33 | ) 34 | idx = backbone["layers"].index(layer) 35 | outplanes.append(backbone["planes"][idx]) 36 | 37 | config.net[2].kwargs.instrides = config.net[1].kwargs.outstrides 38 | config.net[2].kwargs.inplanes = [sum(outplanes)] 39 | return config 40 | 41 | 42 | def load_state_decoder(path, model): 43 | def map_func(storage, location): 44 | return storage.cuda() 45 | 46 | if os.path.isfile(path): 47 | print("=> loading checkpoint '{}'".format(path)) 48 | 49 | checkpoint = torch.load(path, map_location=map_func) 50 | state_dict = checkpoint["state_dict"] 51 | 52 | # state_dict of decoder 53 | state_dict_decoder = {} 54 | for k, v in state_dict.items(): 55 | if "module.reconstruction." in k: 56 | k_new = k.replace("module.reconstruction.", "") 57 | state_dict_decoder[k_new] = v 58 | 59 | # fix size mismatch error 60 | ignore_keys = [] 61 | for k, v in state_dict_decoder.items(): 62 | if k in model.state_dict().keys(): 63 | v_dst = model.state_dict()[k] 64 | if v.shape != v_dst.shape: 65 | ignore_keys.append(k) 66 | print( 67 | "caution: size-mismatch key: {} size: {} -> {}".format( 68 | k, v.shape, v_dst.shape 69 | ) 70 | ) 71 | 72 | for k in ignore_keys: 73 | state_dict_decoder.pop(k) 74 | 75 | model.load_state_dict(state_dict_decoder, strict=False) 76 | 77 | ckpt_keys = set(state_dict_decoder.keys()) 78 | own_keys = set(model.state_dict().keys()) 79 | missing_keys = own_keys - ckpt_keys 80 | for k in missing_keys: 81 | print("caution: missing keys from checkpoint {}: {}".format(path, k)) 82 | else: 83 | print("=> no checkpoint found at '{}'".format(path)) 84 | 85 | 86 | def main(): 87 | args = parser.parse_args() 88 | 89 | with open(args.config) as f: 90 | config = EasyDict(yaml.load(f, Loader=yaml.FullLoader)) 91 | 92 | update_config(config) 93 | config.saver.load_path = config.saver.load_path.replace( 94 | "{class_name}", args.class_name 95 | ) 96 | 97 | config.exp_path = os.path.dirname(args.config) 98 | config.save_path = os.path.join(config.exp_path, config.saver.save_dir) 99 | config.log_path = os.path.join(config.exp_path, config.saver.log_dir) 100 | os.makedirs(config.save_path, exist_ok=True) 101 | os.makedirs(config.log_path, exist_ok=True) 102 | 103 | logger = create_logger( 104 | "global_logger", config.log_path + "/dec_{}.log".format(args.class_name) 105 | ) 106 | logger.info("args: {}".format(pprint.pformat(args))) 107 | logger.info("config: {}".format(pprint.pformat(config))) 108 | 109 | # create model 110 | module_name, cls_name = config.net[2].type.rsplit(".", 1) 111 | module = importlib.import_module(module_name) 112 | model = getattr(module, cls_name)(**config.net[2].kwargs) 113 | load_state_decoder(config.saver.load_path, model) 114 | model.cuda() 115 | 116 | mean = ( 117 | torch.tensor(config.data.pixel_mean).cuda().unsqueeze(0).unsqueeze(0) 118 | ) # 1 x 1 x 3 119 | std = ( 120 | torch.tensor(config.data.pixel_std).cuda().unsqueeze(0).unsqueeze(0) 121 | ) # 1 x 1 x 3 122 | 123 | feature_paths = glob.glob( 124 | os.path.join(config.data.feature_dir, args.class_name, "*/*.npy") 125 | ) 126 | for feature_path in feature_paths: 127 | feature = np.load(feature_path) 128 | feature = torch.tensor(feature).cuda().unsqueeze(0) 129 | input = {"feature_align": feature} 130 | with torch.no_grad(): 131 | output = model(input) 132 | image_rec = ( 133 | output["image_rec"].squeeze(0).permute(1, 2, 0) 134 | ) # 1 x 3 x h x w -> h x w x 3 135 | image_rec = (image_rec * std + mean) * 255 136 | image_rec = image_rec.cpu().numpy() 137 | 138 | # image 139 | filedir, filename = os.path.split(feature_path) 140 | filedir, defename = os.path.split(filedir) 141 | _, clsname = os.path.split(filedir) 142 | filename_, _ = os.path.splitext(filename) 143 | imagepath = os.path.join( 144 | config.data.dataset_dir, clsname, "test", defename, filename_ + ".png" 145 | ) 146 | image = cv2.imread(imagepath) 147 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 148 | image = cv2.resize( 149 | image, (config.data.input_size[1], config.data.input_size[0]) 150 | ) # h,w -> w,h 151 | 152 | image = np.concatenate([image, image_rec], axis=0) # 2h x w x 3 153 | vis_recon_dir = os.path.join(config.save_path, clsname, defename) 154 | os.makedirs(vis_recon_dir, exist_ok=True) 155 | savepath = os.path.join(vis_recon_dir, filename_ + ".jpg") 156 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 157 | cv2.imwrite(savepath, image) 158 | 159 | print(f"Success: Feature: {feature_path}\n Saved: {savepath}") 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /experiments/RealIAD-full/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 133 3 | port: 11111 4 | 5 | dataset: 6 | type: explicit 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: data/Real-IAD/realiad_1024 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: 16 | - data/Real-IAD/realiad_jsons/audiojack.json 17 | - data/Real-IAD/realiad_jsons/bottle_cap.json 18 | - data/Real-IAD/realiad_jsons/button_battery.json 19 | - data/Real-IAD/realiad_jsons/end_cap.json 20 | - data/Real-IAD/realiad_jsons/eraser.json 21 | - data/Real-IAD/realiad_jsons/fire_hood.json 22 | - data/Real-IAD/realiad_jsons/mint.json 23 | - data/Real-IAD/realiad_jsons/mounts.json 24 | - data/Real-IAD/realiad_jsons/pcb.json 25 | - data/Real-IAD/realiad_jsons/phone_battery.json 26 | - data/Real-IAD/realiad_jsons/plastic_nut.json 27 | - data/Real-IAD/realiad_jsons/plastic_plug.json 28 | - data/Real-IAD/realiad_jsons/porcelain_doll.json 29 | - data/Real-IAD/realiad_jsons/regulator.json 30 | - data/Real-IAD/realiad_jsons/rolled_strip_base.json 31 | - data/Real-IAD/realiad_jsons/sim_card_set.json 32 | - data/Real-IAD/realiad_jsons/switch.json 33 | - data/Real-IAD/realiad_jsons/tape.json 34 | - data/Real-IAD/realiad_jsons/terminalblock.json 35 | - data/Real-IAD/realiad_jsons/toothbrush.json 36 | - data/Real-IAD/realiad_jsons/toy.json 37 | - data/Real-IAD/realiad_jsons/toy_brick.json 38 | - data/Real-IAD/realiad_jsons/transistor1.json 39 | - data/Real-IAD/realiad_jsons/u_block.json 40 | - data/Real-IAD/realiad_jsons/usb.json 41 | - data/Real-IAD/realiad_jsons/usb_adaptor.json 42 | - data/Real-IAD/realiad_jsons/vcpill.json 43 | - data/Real-IAD/realiad_jsons/wooden_beads.json 44 | - data/Real-IAD/realiad_jsons/woodstick.json 45 | - data/Real-IAD/realiad_jsons/zipper.json 46 | rebalance: False 47 | hflip: False 48 | vflip: False 49 | rotate: False 50 | 51 | test: 52 | meta_file: 53 | - data/Real-IAD/realiad_jsons/audiojack.json 54 | - data/Real-IAD/realiad_jsons/bottle_cap.json 55 | - data/Real-IAD/realiad_jsons/button_battery.json 56 | - data/Real-IAD/realiad_jsons/end_cap.json 57 | - data/Real-IAD/realiad_jsons/eraser.json 58 | - data/Real-IAD/realiad_jsons/fire_hood.json 59 | - data/Real-IAD/realiad_jsons/mint.json 60 | - data/Real-IAD/realiad_jsons/mounts.json 61 | - data/Real-IAD/realiad_jsons/pcb.json 62 | - data/Real-IAD/realiad_jsons/phone_battery.json 63 | - data/Real-IAD/realiad_jsons/plastic_nut.json 64 | - data/Real-IAD/realiad_jsons/plastic_plug.json 65 | - data/Real-IAD/realiad_jsons/porcelain_doll.json 66 | - data/Real-IAD/realiad_jsons/regulator.json 67 | - data/Real-IAD/realiad_jsons/rolled_strip_base.json 68 | - data/Real-IAD/realiad_jsons/sim_card_set.json 69 | - data/Real-IAD/realiad_jsons/switch.json 70 | - data/Real-IAD/realiad_jsons/tape.json 71 | - data/Real-IAD/realiad_jsons/terminalblock.json 72 | - data/Real-IAD/realiad_jsons/toothbrush.json 73 | - data/Real-IAD/realiad_jsons/toy.json 74 | - data/Real-IAD/realiad_jsons/toy_brick.json 75 | - data/Real-IAD/realiad_jsons/transistor1.json 76 | - data/Real-IAD/realiad_jsons/u_block.json 77 | - data/Real-IAD/realiad_jsons/usb.json 78 | - data/Real-IAD/realiad_jsons/usb_adaptor.json 79 | - data/Real-IAD/realiad_jsons/vcpill.json 80 | - data/Real-IAD/realiad_jsons/wooden_beads.json 81 | - data/Real-IAD/realiad_jsons/woodstick.json 82 | - data/Real-IAD/realiad_jsons/zipper.json 83 | 84 | input_size: [224,224] # [h,w] 85 | pixel_mean: [0.485, 0.456, 0.406] 86 | pixel_std: [0.229, 0.224, 0.225] 87 | batch_size: 8 88 | workers: 8 # number of workers of dataloader for each process 89 | 90 | criterion: 91 | - name: FeatureMSELoss 92 | type: FeatureMSELoss 93 | kwargs: 94 | weight: 1.0 95 | 96 | trainer: 97 | max_epoch: 300 98 | clip_max_norm: 0.1 99 | val_freq_epoch: 10 100 | print_freq_step: 1 101 | tb_freq_step: 1 102 | lr_scheduler: 103 | type: StepLR 104 | kwargs: 105 | step_size: 800 106 | gamma: 0.1 107 | optimizer: 108 | type: AdamW 109 | kwargs: 110 | lr: 0.0001 111 | betas: [0.9, 0.999] 112 | weight_decay: 0.0001 113 | 114 | saver: 115 | auto_resume: True 116 | always_save: False 117 | load_path: checkpoints/ckpt.pth.tar 118 | save_dir: checkpoints/ 119 | log_dir: log/ 120 | 121 | evaluator: 122 | save_dir: result_eval_temp 123 | key_metric: mean_pixel_auc 124 | no_metrics: True 125 | metrics: 126 | auc: 127 | - name: std 128 | - name: max 129 | kwargs: 130 | avgpool_size: [16, 16] 131 | - name: pixel 132 | # vis_compound: 133 | # save_dir: vis_compound 134 | # max_score: null 135 | # min_score: null 136 | # vis_single: 137 | # save_dir: vis_single 138 | # max_score: null 139 | # min_score: null 140 | 141 | frozen_layers: [backbone] 142 | 143 | net: 144 | - name: backbone 145 | type: models.backbones.efficientnet_b4 146 | frozen: True 147 | kwargs: 148 | pretrained: True 149 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 150 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 151 | outlayers: [1,2,3,4] 152 | - name: neck 153 | prev: backbone 154 | type: models.necks.MFCN 155 | kwargs: 156 | outstrides: [16] 157 | - name: reconstruction 158 | prev: neck 159 | type: models.reconstructions.UniAD 160 | kwargs: 161 | pos_embed_type: learned 162 | hidden_dim: 256 163 | nhead: 8 164 | num_encoder_layers: 4 165 | num_decoder_layers: 4 166 | dim_feedforward: 1024 167 | dropout: 0.1 168 | activation: relu 169 | normalize_before: False 170 | feature_jitter: 171 | scale: 20.0 172 | prob: 1.0 173 | neighbor_mask: 174 | neighbor_size: [7,7] 175 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 176 | save_recon: False 177 | # save_dir: result_recon 178 | initializer: 179 | method: xavier_uniform 180 | -------------------------------------------------------------------------------- /experiments/RealIAD-C1/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 133 3 | port: 11111 4 | 5 | dataset: 6 | type: explicit 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: data/Real-IAD/realiad_1024 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: 16 | - data/Real-IAD/realiad_jsons_sv/audiojack.json 17 | - data/Real-IAD/realiad_jsons_sv/bottle_cap.json 18 | - data/Real-IAD/realiad_jsons_sv/button_battery.json 19 | - data/Real-IAD/realiad_jsons_sv/end_cap.json 20 | - data/Real-IAD/realiad_jsons_sv/eraser.json 21 | - data/Real-IAD/realiad_jsons_sv/fire_hood.json 22 | - data/Real-IAD/realiad_jsons_sv/mint.json 23 | - data/Real-IAD/realiad_jsons_sv/mounts.json 24 | - data/Real-IAD/realiad_jsons_sv/pcb.json 25 | - data/Real-IAD/realiad_jsons_sv/phone_battery.json 26 | - data/Real-IAD/realiad_jsons_sv/plastic_nut.json 27 | - data/Real-IAD/realiad_jsons_sv/plastic_plug.json 28 | - data/Real-IAD/realiad_jsons_sv/porcelain_doll.json 29 | - data/Real-IAD/realiad_jsons_sv/regulator.json 30 | - data/Real-IAD/realiad_jsons_sv/rolled_strip_base.json 31 | - data/Real-IAD/realiad_jsons_sv/sim_card_set.json 32 | - data/Real-IAD/realiad_jsons_sv/switch.json 33 | - data/Real-IAD/realiad_jsons_sv/tape.json 34 | - data/Real-IAD/realiad_jsons_sv/terminalblock.json 35 | - data/Real-IAD/realiad_jsons_sv/toothbrush.json 36 | - data/Real-IAD/realiad_jsons_sv/toy.json 37 | - data/Real-IAD/realiad_jsons_sv/toy_brick.json 38 | - data/Real-IAD/realiad_jsons_sv/transistor1.json 39 | - data/Real-IAD/realiad_jsons_sv/u_block.json 40 | - data/Real-IAD/realiad_jsons_sv/usb.json 41 | - data/Real-IAD/realiad_jsons_sv/usb_adaptor.json 42 | - data/Real-IAD/realiad_jsons_sv/vcpill.json 43 | - data/Real-IAD/realiad_jsons_sv/wooden_beads.json 44 | - data/Real-IAD/realiad_jsons_sv/woodstick.json 45 | - data/Real-IAD/realiad_jsons_sv/zipper.json 46 | rebalance: False 47 | hflip: False 48 | vflip: False 49 | rotate: False 50 | 51 | test: 52 | meta_file: 53 | - data/Real-IAD/realiad_jsons_sv/audiojack.json 54 | - data/Real-IAD/realiad_jsons_sv/bottle_cap.json 55 | - data/Real-IAD/realiad_jsons_sv/button_battery.json 56 | - data/Real-IAD/realiad_jsons_sv/end_cap.json 57 | - data/Real-IAD/realiad_jsons_sv/eraser.json 58 | - data/Real-IAD/realiad_jsons_sv/fire_hood.json 59 | - data/Real-IAD/realiad_jsons_sv/mint.json 60 | - data/Real-IAD/realiad_jsons_sv/mounts.json 61 | - data/Real-IAD/realiad_jsons_sv/pcb.json 62 | - data/Real-IAD/realiad_jsons_sv/phone_battery.json 63 | - data/Real-IAD/realiad_jsons_sv/plastic_nut.json 64 | - data/Real-IAD/realiad_jsons_sv/plastic_plug.json 65 | - data/Real-IAD/realiad_jsons_sv/porcelain_doll.json 66 | - data/Real-IAD/realiad_jsons_sv/regulator.json 67 | - data/Real-IAD/realiad_jsons_sv/rolled_strip_base.json 68 | - data/Real-IAD/realiad_jsons_sv/sim_card_set.json 69 | - data/Real-IAD/realiad_jsons_sv/switch.json 70 | - data/Real-IAD/realiad_jsons_sv/tape.json 71 | - data/Real-IAD/realiad_jsons_sv/terminalblock.json 72 | - data/Real-IAD/realiad_jsons_sv/toothbrush.json 73 | - data/Real-IAD/realiad_jsons_sv/toy.json 74 | - data/Real-IAD/realiad_jsons_sv/toy_brick.json 75 | - data/Real-IAD/realiad_jsons_sv/transistor1.json 76 | - data/Real-IAD/realiad_jsons_sv/u_block.json 77 | - data/Real-IAD/realiad_jsons_sv/usb.json 78 | - data/Real-IAD/realiad_jsons_sv/usb_adaptor.json 79 | - data/Real-IAD/realiad_jsons_sv/vcpill.json 80 | - data/Real-IAD/realiad_jsons_sv/wooden_beads.json 81 | - data/Real-IAD/realiad_jsons_sv/woodstick.json 82 | - data/Real-IAD/realiad_jsons_sv/zipper.json 83 | 84 | input_size: [224,224] # [h,w] 85 | pixel_mean: [0.485, 0.456, 0.406] 86 | pixel_std: [0.229, 0.224, 0.225] 87 | batch_size: 8 88 | workers: 8 # number of workers of dataloader for each process 89 | 90 | criterion: 91 | - name: FeatureMSELoss 92 | type: FeatureMSELoss 93 | kwargs: 94 | weight: 1.0 95 | 96 | trainer: 97 | max_epoch: 300 98 | clip_max_norm: 0.1 99 | val_freq_epoch: 10 100 | print_freq_step: 1 101 | tb_freq_step: 1 102 | lr_scheduler: 103 | type: StepLR 104 | kwargs: 105 | step_size: 800 106 | gamma: 0.1 107 | optimizer: 108 | type: AdamW 109 | kwargs: 110 | lr: 0.0001 111 | betas: [0.9, 0.999] 112 | weight_decay: 0.0001 113 | 114 | saver: 115 | auto_resume: True 116 | always_save: False 117 | load_path: checkpoints/ckpt.pth.tar 118 | save_dir: checkpoints/ 119 | log_dir: log/ 120 | 121 | evaluator: 122 | save_dir: result_eval_temp 123 | key_metric: mean_pixel_auc 124 | no_metrics: True 125 | metrics: 126 | auc: 127 | - name: std 128 | - name: max 129 | kwargs: 130 | avgpool_size: [16, 16] 131 | - name: pixel 132 | # vis_compound: 133 | # save_dir: vis_compound 134 | # max_score: null 135 | # min_score: null 136 | # vis_single: 137 | # save_dir: vis_single 138 | # max_score: null 139 | # min_score: null 140 | 141 | frozen_layers: [backbone] 142 | 143 | net: 144 | - name: backbone 145 | type: models.backbones.efficientnet_b4 146 | frozen: True 147 | kwargs: 148 | pretrained: True 149 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 150 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 151 | outlayers: [1,2,3,4] 152 | - name: neck 153 | prev: backbone 154 | type: models.necks.MFCN 155 | kwargs: 156 | outstrides: [16] 157 | - name: reconstruction 158 | prev: neck 159 | type: models.reconstructions.UniAD 160 | kwargs: 161 | pos_embed_type: learned 162 | hidden_dim: 256 163 | nhead: 8 164 | num_encoder_layers: 4 165 | num_decoder_layers: 4 166 | dim_feedforward: 1024 167 | dropout: 0.1 168 | activation: relu 169 | normalize_before: False 170 | feature_jitter: 171 | scale: 20.0 172 | prob: 1.0 173 | neighbor_mask: 174 | neighbor_size: [7,7] 175 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 176 | save_recon: False 177 | # save_dir: result_recon 178 | initializer: 179 | method: xavier_uniform 180 | -------------------------------------------------------------------------------- /datasets/explicit_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from typing import List, Dict, Union 3 | 4 | import json 5 | import logging 6 | import os.path as osp 7 | 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.utils.data.sampler import RandomSampler 14 | 15 | from datasets.base_dataset import BaseDataset, TestBaseTransform, TrainBaseTransform 16 | from datasets.image_reader import build_image_reader 17 | from datasets.transforms import RandomColorJitter 18 | 19 | logger = logging.getLogger("global_logger") 20 | 21 | 22 | def build_explicit_dataloader(cfg, training, distributed=True): 23 | 24 | image_reader = build_image_reader(cfg.image_reader) 25 | 26 | normalize_fn = transforms.Normalize(mean=cfg["pixel_mean"], std=cfg["pixel_std"]) 27 | if training: 28 | transform_fn = TrainBaseTransform( 29 | cfg["input_size"], cfg["hflip"], cfg["vflip"], cfg["rotate"] 30 | ) 31 | else: 32 | transform_fn = TestBaseTransform(cfg["input_size"]) 33 | 34 | colorjitter_fn = None 35 | if cfg.get("colorjitter", None) and training: 36 | colorjitter_fn = RandomColorJitter.from_params(cfg["colorjitter"]) 37 | 38 | logger.info("building ExplicitDataset from: {}".format(cfg["meta_file"])) 39 | 40 | dataset = ExplicitDataset( 41 | image_reader, 42 | cfg["meta_file"], 43 | training, 44 | transform_fn=transform_fn, 45 | normalize_fn=normalize_fn, 46 | colorjitter_fn=colorjitter_fn, 47 | ) 48 | 49 | if distributed: 50 | sampler = DistributedSampler(dataset) 51 | else: 52 | sampler = RandomSampler(dataset) 53 | 54 | data_loader = DataLoader( 55 | dataset, 56 | batch_size=cfg["batch_size"], 57 | num_workers=cfg["workers"], 58 | pin_memory=True, 59 | sampler=sampler, 60 | persistent_workers=True, 61 | ) 62 | 63 | return data_loader 64 | 65 | 66 | class ExplicitDataset(BaseDataset): 67 | def __init__( 68 | self, 69 | image_reader, 70 | meta_file, 71 | training, 72 | transform_fn, 73 | normalize_fn, 74 | colorjitter_fn=None, 75 | ): 76 | self.image_reader = image_reader 77 | self.meta_file = meta_file 78 | self.training = training 79 | self.transform_fn = transform_fn 80 | self.normalize_fn = normalize_fn 81 | self.colorjitter_fn = colorjitter_fn 82 | 83 | if isinstance(self.meta_file, str): 84 | self.meta_file = [meta_file] 85 | 86 | # construct metas 87 | self.metas = sum((self.load_explicit(path, self.training) 88 | for path in self.meta_file), []) 89 | 90 | @staticmethod 91 | def load_explicit(path: str, is_training: bool) -> List[Dict[str, Union[str, int]]]: 92 | SAMPLE_KEYS = {'category', 'anomaly_class', 'image_path', 'mask_path'} 93 | 94 | with open(path, 'r') as fp: 95 | info = json.load(fp) 96 | assert isinstance(info, dict) and all( 97 | key in info for key in ('meta', 'train', 'test') 98 | ) 99 | meta = info['meta'] 100 | train = info['train'] 101 | test = info['test'] 102 | raw_samples = train if is_training else test 103 | 104 | assert isinstance(raw_samples, list) and all( 105 | isinstance(sample, dict) and set(sample.keys()) == SAMPLE_KEYS 106 | for sample in raw_samples 107 | ) 108 | assert isinstance(meta, dict) 109 | prefix = meta['prefix'] 110 | normal_class = meta['normal_class'] 111 | 112 | if is_training: 113 | return [dict(filename=osp.join(prefix, sample['image_path']), 114 | label_name=normal_class, label=0, 115 | clsname=sample['category']) 116 | for sample in raw_samples] 117 | else: 118 | def as_normal(sample): 119 | return (sample['mask_path'] is None or 120 | sample['anomaly_class'] == normal_class) 121 | 122 | return [dict( 123 | filename=osp.join(prefix, sample['image_path']), 124 | maskname=None if as_normal(sample) 125 | else osp.join(prefix, sample['mask_path']), 126 | label=0 if as_normal(sample) else 1, 127 | label_name=sample['anomaly_class'], 128 | clsname=sample['category'] 129 | ) for sample in raw_samples] 130 | 131 | def __len__(self): 132 | return len(self.metas) 133 | 134 | def __getitem__(self, index): 135 | input = {} 136 | meta = self.metas[index] 137 | 138 | # read image 139 | filename = meta["filename"] 140 | label = meta["label"] 141 | image = self.image_reader(meta["filename"]) 142 | input.update( 143 | { 144 | "filename": filename, 145 | "height": image.shape[0], 146 | "width": image.shape[1], 147 | "label": label, 148 | } 149 | ) 150 | 151 | if meta.get("clsname", None): 152 | input["clsname"] = meta["clsname"] 153 | else: 154 | input["clsname"] = filename.split("/")[-4] 155 | 156 | image = Image.fromarray(image, "RGB") 157 | 158 | # read / generate mask 159 | if meta.get("maskname", None): 160 | input['maskname'] = meta['maskname'] 161 | mask = self.image_reader(meta["maskname"], is_mask=True) 162 | else: 163 | input['maskname'] = '' 164 | if label == 0: # good 165 | mask = np.zeros((image.height, image.width)).astype(np.uint8) 166 | elif label == 1: # defective 167 | mask = (np.ones((image.height, image.width)) * 255).astype(np.uint8) 168 | else: 169 | raise ValueError("Labels must be [None, 0, 1]!") 170 | 171 | mask = Image.fromarray(mask, "L") 172 | 173 | if self.transform_fn: 174 | image, mask = self.transform_fn(image, mask) 175 | if self.colorjitter_fn: 176 | image = self.colorjitter_fn(image) 177 | image = transforms.ToTensor()(image) 178 | mask = transforms.ToTensor()(mask) 179 | if self.normalize_fn: 180 | image = self.normalize_fn(image) 181 | input.update({"image": image, "mask": mask}) 182 | return input 183 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n0/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 133 3 | port: 11111 4 | 5 | dataset: 6 | type: explicit 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: data/Real-IAD/realiad_1024 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: 16 | - data/Real-IAD/realiad_jsons_fuiad_0.0/audiojack.json 17 | - data/Real-IAD/realiad_jsons_fuiad_0.0/bottle_cap.json 18 | - data/Real-IAD/realiad_jsons_fuiad_0.0/button_battery.json 19 | - data/Real-IAD/realiad_jsons_fuiad_0.0/end_cap.json 20 | - data/Real-IAD/realiad_jsons_fuiad_0.0/eraser.json 21 | - data/Real-IAD/realiad_jsons_fuiad_0.0/fire_hood.json 22 | - data/Real-IAD/realiad_jsons_fuiad_0.0/mint.json 23 | - data/Real-IAD/realiad_jsons_fuiad_0.0/mounts.json 24 | - data/Real-IAD/realiad_jsons_fuiad_0.0/pcb.json 25 | - data/Real-IAD/realiad_jsons_fuiad_0.0/phone_battery.json 26 | - data/Real-IAD/realiad_jsons_fuiad_0.0/plastic_nut.json 27 | - data/Real-IAD/realiad_jsons_fuiad_0.0/plastic_plug.json 28 | - data/Real-IAD/realiad_jsons_fuiad_0.0/porcelain_doll.json 29 | - data/Real-IAD/realiad_jsons_fuiad_0.0/regulator.json 30 | - data/Real-IAD/realiad_jsons_fuiad_0.0/rolled_strip_base.json 31 | - data/Real-IAD/realiad_jsons_fuiad_0.0/sim_card_set.json 32 | - data/Real-IAD/realiad_jsons_fuiad_0.0/switch.json 33 | - data/Real-IAD/realiad_jsons_fuiad_0.0/tape.json 34 | - data/Real-IAD/realiad_jsons_fuiad_0.0/terminalblock.json 35 | - data/Real-IAD/realiad_jsons_fuiad_0.0/toothbrush.json 36 | - data/Real-IAD/realiad_jsons_fuiad_0.0/toy.json 37 | - data/Real-IAD/realiad_jsons_fuiad_0.0/toy_brick.json 38 | - data/Real-IAD/realiad_jsons_fuiad_0.0/transistor1.json 39 | - data/Real-IAD/realiad_jsons_fuiad_0.0/u_block.json 40 | - data/Real-IAD/realiad_jsons_fuiad_0.0/usb.json 41 | - data/Real-IAD/realiad_jsons_fuiad_0.0/usb_adaptor.json 42 | - data/Real-IAD/realiad_jsons_fuiad_0.0/vcpill.json 43 | - data/Real-IAD/realiad_jsons_fuiad_0.0/wooden_beads.json 44 | - data/Real-IAD/realiad_jsons_fuiad_0.0/woodstick.json 45 | - data/Real-IAD/realiad_jsons_fuiad_0.0/zipper.json 46 | rebalance: False 47 | hflip: False 48 | vflip: False 49 | rotate: False 50 | 51 | test: 52 | meta_file: 53 | - data/Real-IAD/realiad_jsons_fuiad_0.0/audiojack.json 54 | - data/Real-IAD/realiad_jsons_fuiad_0.0/bottle_cap.json 55 | - data/Real-IAD/realiad_jsons_fuiad_0.0/button_battery.json 56 | - data/Real-IAD/realiad_jsons_fuiad_0.0/end_cap.json 57 | - data/Real-IAD/realiad_jsons_fuiad_0.0/eraser.json 58 | - data/Real-IAD/realiad_jsons_fuiad_0.0/fire_hood.json 59 | - data/Real-IAD/realiad_jsons_fuiad_0.0/mint.json 60 | - data/Real-IAD/realiad_jsons_fuiad_0.0/mounts.json 61 | - data/Real-IAD/realiad_jsons_fuiad_0.0/pcb.json 62 | - data/Real-IAD/realiad_jsons_fuiad_0.0/phone_battery.json 63 | - data/Real-IAD/realiad_jsons_fuiad_0.0/plastic_nut.json 64 | - data/Real-IAD/realiad_jsons_fuiad_0.0/plastic_plug.json 65 | - data/Real-IAD/realiad_jsons_fuiad_0.0/porcelain_doll.json 66 | - data/Real-IAD/realiad_jsons_fuiad_0.0/regulator.json 67 | - data/Real-IAD/realiad_jsons_fuiad_0.0/rolled_strip_base.json 68 | - data/Real-IAD/realiad_jsons_fuiad_0.0/sim_card_set.json 69 | - data/Real-IAD/realiad_jsons_fuiad_0.0/switch.json 70 | - data/Real-IAD/realiad_jsons_fuiad_0.0/tape.json 71 | - data/Real-IAD/realiad_jsons_fuiad_0.0/terminalblock.json 72 | - data/Real-IAD/realiad_jsons_fuiad_0.0/toothbrush.json 73 | - data/Real-IAD/realiad_jsons_fuiad_0.0/toy.json 74 | - data/Real-IAD/realiad_jsons_fuiad_0.0/toy_brick.json 75 | - data/Real-IAD/realiad_jsons_fuiad_0.0/transistor1.json 76 | - data/Real-IAD/realiad_jsons_fuiad_0.0/u_block.json 77 | - data/Real-IAD/realiad_jsons_fuiad_0.0/usb.json 78 | - data/Real-IAD/realiad_jsons_fuiad_0.0/usb_adaptor.json 79 | - data/Real-IAD/realiad_jsons_fuiad_0.0/vcpill.json 80 | - data/Real-IAD/realiad_jsons_fuiad_0.0/wooden_beads.json 81 | - data/Real-IAD/realiad_jsons_fuiad_0.0/woodstick.json 82 | - data/Real-IAD/realiad_jsons_fuiad_0.0/zipper.json 83 | 84 | input_size: [224,224] # [h,w] 85 | pixel_mean: [0.485, 0.456, 0.406] 86 | pixel_std: [0.229, 0.224, 0.225] 87 | batch_size: 8 88 | workers: 8 # number of workers of dataloader for each process 89 | 90 | criterion: 91 | - name: FeatureMSELoss 92 | type: FeatureMSELoss 93 | kwargs: 94 | weight: 1.0 95 | 96 | trainer: 97 | max_epoch: 300 98 | clip_max_norm: 0.1 99 | val_freq_epoch: 10 100 | print_freq_step: 1 101 | tb_freq_step: 1 102 | lr_scheduler: 103 | type: StepLR 104 | kwargs: 105 | step_size: 800 106 | gamma: 0.1 107 | optimizer: 108 | type: AdamW 109 | kwargs: 110 | lr: 0.0001 111 | betas: [0.9, 0.999] 112 | weight_decay: 0.0001 113 | 114 | saver: 115 | auto_resume: True 116 | always_save: False 117 | load_path: checkpoints/ckpt.pth.tar 118 | save_dir: checkpoints/ 119 | log_dir: log/ 120 | 121 | evaluator: 122 | save_dir: result_eval_temp 123 | key_metric: mean_pixel_auc 124 | no_metrics: True 125 | metrics: 126 | auc: 127 | - name: std 128 | - name: max 129 | kwargs: 130 | avgpool_size: [16, 16] 131 | - name: pixel 132 | # vis_compound: 133 | # save_dir: vis_compound 134 | # max_score: null 135 | # min_score: null 136 | # vis_single: 137 | # save_dir: vis_single 138 | # max_score: null 139 | # min_score: null 140 | 141 | frozen_layers: [backbone] 142 | 143 | net: 144 | - name: backbone 145 | type: models.backbones.efficientnet_b4 146 | frozen: True 147 | kwargs: 148 | pretrained: True 149 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 150 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 151 | outlayers: [1,2,3,4] 152 | - name: neck 153 | prev: backbone 154 | type: models.necks.MFCN 155 | kwargs: 156 | outstrides: [16] 157 | - name: reconstruction 158 | prev: neck 159 | type: models.reconstructions.UniAD 160 | kwargs: 161 | pos_embed_type: learned 162 | hidden_dim: 256 163 | nhead: 8 164 | num_encoder_layers: 4 165 | num_decoder_layers: 4 166 | dim_feedforward: 1024 167 | dropout: 0.1 168 | activation: relu 169 | normalize_before: False 170 | feature_jitter: 171 | scale: 20.0 172 | prob: 1.0 173 | neighbor_mask: 174 | neighbor_size: [7,7] 175 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 176 | save_recon: False 177 | # save_dir: result_recon 178 | initializer: 179 | method: xavier_uniform 180 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n1/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 133 3 | port: 11111 4 | 5 | dataset: 6 | type: explicit 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: data/Real-IAD/realiad_1024 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: 16 | - data/Real-IAD/realiad_jsons_fuiad_0.1/audiojack.json 17 | - data/Real-IAD/realiad_jsons_fuiad_0.1/bottle_cap.json 18 | - data/Real-IAD/realiad_jsons_fuiad_0.1/button_battery.json 19 | - data/Real-IAD/realiad_jsons_fuiad_0.1/end_cap.json 20 | - data/Real-IAD/realiad_jsons_fuiad_0.1/eraser.json 21 | - data/Real-IAD/realiad_jsons_fuiad_0.1/fire_hood.json 22 | - data/Real-IAD/realiad_jsons_fuiad_0.1/mint.json 23 | - data/Real-IAD/realiad_jsons_fuiad_0.1/mounts.json 24 | - data/Real-IAD/realiad_jsons_fuiad_0.1/pcb.json 25 | - data/Real-IAD/realiad_jsons_fuiad_0.1/phone_battery.json 26 | - data/Real-IAD/realiad_jsons_fuiad_0.1/plastic_nut.json 27 | - data/Real-IAD/realiad_jsons_fuiad_0.1/plastic_plug.json 28 | - data/Real-IAD/realiad_jsons_fuiad_0.1/porcelain_doll.json 29 | - data/Real-IAD/realiad_jsons_fuiad_0.1/regulator.json 30 | - data/Real-IAD/realiad_jsons_fuiad_0.1/rolled_strip_base.json 31 | - data/Real-IAD/realiad_jsons_fuiad_0.1/sim_card_set.json 32 | - data/Real-IAD/realiad_jsons_fuiad_0.1/switch.json 33 | - data/Real-IAD/realiad_jsons_fuiad_0.1/tape.json 34 | - data/Real-IAD/realiad_jsons_fuiad_0.1/terminalblock.json 35 | - data/Real-IAD/realiad_jsons_fuiad_0.1/toothbrush.json 36 | - data/Real-IAD/realiad_jsons_fuiad_0.1/toy.json 37 | - data/Real-IAD/realiad_jsons_fuiad_0.1/toy_brick.json 38 | - data/Real-IAD/realiad_jsons_fuiad_0.1/transistor1.json 39 | - data/Real-IAD/realiad_jsons_fuiad_0.1/u_block.json 40 | - data/Real-IAD/realiad_jsons_fuiad_0.1/usb.json 41 | - data/Real-IAD/realiad_jsons_fuiad_0.1/usb_adaptor.json 42 | - data/Real-IAD/realiad_jsons_fuiad_0.1/vcpill.json 43 | - data/Real-IAD/realiad_jsons_fuiad_0.1/wooden_beads.json 44 | - data/Real-IAD/realiad_jsons_fuiad_0.1/woodstick.json 45 | - data/Real-IAD/realiad_jsons_fuiad_0.1/zipper.json 46 | rebalance: False 47 | hflip: False 48 | vflip: False 49 | rotate: False 50 | 51 | test: 52 | meta_file: 53 | - data/Real-IAD/realiad_jsons_fuiad_0.1/audiojack.json 54 | - data/Real-IAD/realiad_jsons_fuiad_0.1/bottle_cap.json 55 | - data/Real-IAD/realiad_jsons_fuiad_0.1/button_battery.json 56 | - data/Real-IAD/realiad_jsons_fuiad_0.1/end_cap.json 57 | - data/Real-IAD/realiad_jsons_fuiad_0.1/eraser.json 58 | - data/Real-IAD/realiad_jsons_fuiad_0.1/fire_hood.json 59 | - data/Real-IAD/realiad_jsons_fuiad_0.1/mint.json 60 | - data/Real-IAD/realiad_jsons_fuiad_0.1/mounts.json 61 | - data/Real-IAD/realiad_jsons_fuiad_0.1/pcb.json 62 | - data/Real-IAD/realiad_jsons_fuiad_0.1/phone_battery.json 63 | - data/Real-IAD/realiad_jsons_fuiad_0.1/plastic_nut.json 64 | - data/Real-IAD/realiad_jsons_fuiad_0.1/plastic_plug.json 65 | - data/Real-IAD/realiad_jsons_fuiad_0.1/porcelain_doll.json 66 | - data/Real-IAD/realiad_jsons_fuiad_0.1/regulator.json 67 | - data/Real-IAD/realiad_jsons_fuiad_0.1/rolled_strip_base.json 68 | - data/Real-IAD/realiad_jsons_fuiad_0.1/sim_card_set.json 69 | - data/Real-IAD/realiad_jsons_fuiad_0.1/switch.json 70 | - data/Real-IAD/realiad_jsons_fuiad_0.1/tape.json 71 | - data/Real-IAD/realiad_jsons_fuiad_0.1/terminalblock.json 72 | - data/Real-IAD/realiad_jsons_fuiad_0.1/toothbrush.json 73 | - data/Real-IAD/realiad_jsons_fuiad_0.1/toy.json 74 | - data/Real-IAD/realiad_jsons_fuiad_0.1/toy_brick.json 75 | - data/Real-IAD/realiad_jsons_fuiad_0.1/transistor1.json 76 | - data/Real-IAD/realiad_jsons_fuiad_0.1/u_block.json 77 | - data/Real-IAD/realiad_jsons_fuiad_0.1/usb.json 78 | - data/Real-IAD/realiad_jsons_fuiad_0.1/usb_adaptor.json 79 | - data/Real-IAD/realiad_jsons_fuiad_0.1/vcpill.json 80 | - data/Real-IAD/realiad_jsons_fuiad_0.1/wooden_beads.json 81 | - data/Real-IAD/realiad_jsons_fuiad_0.1/woodstick.json 82 | - data/Real-IAD/realiad_jsons_fuiad_0.1/zipper.json 83 | 84 | input_size: [224,224] # [h,w] 85 | pixel_mean: [0.485, 0.456, 0.406] 86 | pixel_std: [0.229, 0.224, 0.225] 87 | batch_size: 8 88 | workers: 8 # number of workers of dataloader for each process 89 | 90 | criterion: 91 | - name: FeatureMSELoss 92 | type: FeatureMSELoss 93 | kwargs: 94 | weight: 1.0 95 | 96 | trainer: 97 | max_epoch: 300 98 | clip_max_norm: 0.1 99 | val_freq_epoch: 10 100 | print_freq_step: 1 101 | tb_freq_step: 1 102 | lr_scheduler: 103 | type: StepLR 104 | kwargs: 105 | step_size: 800 106 | gamma: 0.1 107 | optimizer: 108 | type: AdamW 109 | kwargs: 110 | lr: 0.0001 111 | betas: [0.9, 0.999] 112 | weight_decay: 0.0001 113 | 114 | saver: 115 | auto_resume: True 116 | always_save: False 117 | load_path: checkpoints/ckpt.pth.tar 118 | save_dir: checkpoints/ 119 | log_dir: log/ 120 | 121 | evaluator: 122 | save_dir: result_eval_temp 123 | key_metric: mean_pixel_auc 124 | no_metrics: True 125 | metrics: 126 | auc: 127 | - name: std 128 | - name: max 129 | kwargs: 130 | avgpool_size: [16, 16] 131 | - name: pixel 132 | # vis_compound: 133 | # save_dir: vis_compound 134 | # max_score: null 135 | # min_score: null 136 | # vis_single: 137 | # save_dir: vis_single 138 | # max_score: null 139 | # min_score: null 140 | 141 | frozen_layers: [backbone] 142 | 143 | net: 144 | - name: backbone 145 | type: models.backbones.efficientnet_b4 146 | frozen: True 147 | kwargs: 148 | pretrained: True 149 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 150 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 151 | outlayers: [1,2,3,4] 152 | - name: neck 153 | prev: backbone 154 | type: models.necks.MFCN 155 | kwargs: 156 | outstrides: [16] 157 | - name: reconstruction 158 | prev: neck 159 | type: models.reconstructions.UniAD 160 | kwargs: 161 | pos_embed_type: learned 162 | hidden_dim: 256 163 | nhead: 8 164 | num_encoder_layers: 4 165 | num_decoder_layers: 4 166 | dim_feedforward: 1024 167 | dropout: 0.1 168 | activation: relu 169 | normalize_before: False 170 | feature_jitter: 171 | scale: 20.0 172 | prob: 1.0 173 | neighbor_mask: 174 | neighbor_size: [7,7] 175 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 176 | save_recon: False 177 | # save_dir: result_recon 178 | initializer: 179 | method: xavier_uniform 180 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n2/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 133 3 | port: 11111 4 | 5 | dataset: 6 | type: explicit 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: data/Real-IAD/realiad_1024 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: 16 | - data/Real-IAD/realiad_jsons_fuiad_0.2/audiojack.json 17 | - data/Real-IAD/realiad_jsons_fuiad_0.2/bottle_cap.json 18 | - data/Real-IAD/realiad_jsons_fuiad_0.2/button_battery.json 19 | - data/Real-IAD/realiad_jsons_fuiad_0.2/end_cap.json 20 | - data/Real-IAD/realiad_jsons_fuiad_0.2/eraser.json 21 | - data/Real-IAD/realiad_jsons_fuiad_0.2/fire_hood.json 22 | - data/Real-IAD/realiad_jsons_fuiad_0.2/mint.json 23 | - data/Real-IAD/realiad_jsons_fuiad_0.2/mounts.json 24 | - data/Real-IAD/realiad_jsons_fuiad_0.2/pcb.json 25 | - data/Real-IAD/realiad_jsons_fuiad_0.2/phone_battery.json 26 | - data/Real-IAD/realiad_jsons_fuiad_0.2/plastic_nut.json 27 | - data/Real-IAD/realiad_jsons_fuiad_0.2/plastic_plug.json 28 | - data/Real-IAD/realiad_jsons_fuiad_0.2/porcelain_doll.json 29 | - data/Real-IAD/realiad_jsons_fuiad_0.2/regulator.json 30 | - data/Real-IAD/realiad_jsons_fuiad_0.2/rolled_strip_base.json 31 | - data/Real-IAD/realiad_jsons_fuiad_0.2/sim_card_set.json 32 | - data/Real-IAD/realiad_jsons_fuiad_0.2/switch.json 33 | - data/Real-IAD/realiad_jsons_fuiad_0.2/tape.json 34 | - data/Real-IAD/realiad_jsons_fuiad_0.2/terminalblock.json 35 | - data/Real-IAD/realiad_jsons_fuiad_0.2/toothbrush.json 36 | - data/Real-IAD/realiad_jsons_fuiad_0.2/toy.json 37 | - data/Real-IAD/realiad_jsons_fuiad_0.2/toy_brick.json 38 | - data/Real-IAD/realiad_jsons_fuiad_0.2/transistor1.json 39 | - data/Real-IAD/realiad_jsons_fuiad_0.2/u_block.json 40 | - data/Real-IAD/realiad_jsons_fuiad_0.2/usb.json 41 | - data/Real-IAD/realiad_jsons_fuiad_0.2/usb_adaptor.json 42 | - data/Real-IAD/realiad_jsons_fuiad_0.2/vcpill.json 43 | - data/Real-IAD/realiad_jsons_fuiad_0.2/wooden_beads.json 44 | - data/Real-IAD/realiad_jsons_fuiad_0.2/woodstick.json 45 | - data/Real-IAD/realiad_jsons_fuiad_0.2/zipper.json 46 | rebalance: False 47 | hflip: False 48 | vflip: False 49 | rotate: False 50 | 51 | test: 52 | meta_file: 53 | - data/Real-IAD/realiad_jsons_fuiad_0.2/audiojack.json 54 | - data/Real-IAD/realiad_jsons_fuiad_0.2/bottle_cap.json 55 | - data/Real-IAD/realiad_jsons_fuiad_0.2/button_battery.json 56 | - data/Real-IAD/realiad_jsons_fuiad_0.2/end_cap.json 57 | - data/Real-IAD/realiad_jsons_fuiad_0.2/eraser.json 58 | - data/Real-IAD/realiad_jsons_fuiad_0.2/fire_hood.json 59 | - data/Real-IAD/realiad_jsons_fuiad_0.2/mint.json 60 | - data/Real-IAD/realiad_jsons_fuiad_0.2/mounts.json 61 | - data/Real-IAD/realiad_jsons_fuiad_0.2/pcb.json 62 | - data/Real-IAD/realiad_jsons_fuiad_0.2/phone_battery.json 63 | - data/Real-IAD/realiad_jsons_fuiad_0.2/plastic_nut.json 64 | - data/Real-IAD/realiad_jsons_fuiad_0.2/plastic_plug.json 65 | - data/Real-IAD/realiad_jsons_fuiad_0.2/porcelain_doll.json 66 | - data/Real-IAD/realiad_jsons_fuiad_0.2/regulator.json 67 | - data/Real-IAD/realiad_jsons_fuiad_0.2/rolled_strip_base.json 68 | - data/Real-IAD/realiad_jsons_fuiad_0.2/sim_card_set.json 69 | - data/Real-IAD/realiad_jsons_fuiad_0.2/switch.json 70 | - data/Real-IAD/realiad_jsons_fuiad_0.2/tape.json 71 | - data/Real-IAD/realiad_jsons_fuiad_0.2/terminalblock.json 72 | - data/Real-IAD/realiad_jsons_fuiad_0.2/toothbrush.json 73 | - data/Real-IAD/realiad_jsons_fuiad_0.2/toy.json 74 | - data/Real-IAD/realiad_jsons_fuiad_0.2/toy_brick.json 75 | - data/Real-IAD/realiad_jsons_fuiad_0.2/transistor1.json 76 | - data/Real-IAD/realiad_jsons_fuiad_0.2/u_block.json 77 | - data/Real-IAD/realiad_jsons_fuiad_0.2/usb.json 78 | - data/Real-IAD/realiad_jsons_fuiad_0.2/usb_adaptor.json 79 | - data/Real-IAD/realiad_jsons_fuiad_0.2/vcpill.json 80 | - data/Real-IAD/realiad_jsons_fuiad_0.2/wooden_beads.json 81 | - data/Real-IAD/realiad_jsons_fuiad_0.2/woodstick.json 82 | - data/Real-IAD/realiad_jsons_fuiad_0.2/zipper.json 83 | 84 | input_size: [224,224] # [h,w] 85 | pixel_mean: [0.485, 0.456, 0.406] 86 | pixel_std: [0.229, 0.224, 0.225] 87 | batch_size: 8 88 | workers: 8 # number of workers of dataloader for each process 89 | 90 | criterion: 91 | - name: FeatureMSELoss 92 | type: FeatureMSELoss 93 | kwargs: 94 | weight: 1.0 95 | 96 | trainer: 97 | max_epoch: 300 98 | clip_max_norm: 0.1 99 | val_freq_epoch: 10 100 | print_freq_step: 1 101 | tb_freq_step: 1 102 | lr_scheduler: 103 | type: StepLR 104 | kwargs: 105 | step_size: 800 106 | gamma: 0.1 107 | optimizer: 108 | type: AdamW 109 | kwargs: 110 | lr: 0.0001 111 | betas: [0.9, 0.999] 112 | weight_decay: 0.0001 113 | 114 | saver: 115 | auto_resume: True 116 | always_save: False 117 | load_path: checkpoints/ckpt.pth.tar 118 | save_dir: checkpoints/ 119 | log_dir: log/ 120 | 121 | evaluator: 122 | save_dir: result_eval_temp 123 | key_metric: mean_pixel_auc 124 | no_metrics: True 125 | metrics: 126 | auc: 127 | - name: std 128 | - name: max 129 | kwargs: 130 | avgpool_size: [16, 16] 131 | - name: pixel 132 | # vis_compound: 133 | # save_dir: vis_compound 134 | # max_score: null 135 | # min_score: null 136 | # vis_single: 137 | # save_dir: vis_single 138 | # max_score: null 139 | # min_score: null 140 | 141 | frozen_layers: [backbone] 142 | 143 | net: 144 | - name: backbone 145 | type: models.backbones.efficientnet_b4 146 | frozen: True 147 | kwargs: 148 | pretrained: True 149 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 150 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 151 | outlayers: [1,2,3,4] 152 | - name: neck 153 | prev: backbone 154 | type: models.necks.MFCN 155 | kwargs: 156 | outstrides: [16] 157 | - name: reconstruction 158 | prev: neck 159 | type: models.reconstructions.UniAD 160 | kwargs: 161 | pos_embed_type: learned 162 | hidden_dim: 256 163 | nhead: 8 164 | num_encoder_layers: 4 165 | num_decoder_layers: 4 166 | dim_feedforward: 1024 167 | dropout: 0.1 168 | activation: relu 169 | normalize_before: False 170 | feature_jitter: 171 | scale: 20.0 172 | prob: 1.0 173 | neighbor_mask: 174 | neighbor_size: [7,7] 175 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 176 | save_recon: False 177 | # save_dir: result_recon 178 | initializer: 179 | method: xavier_uniform 180 | -------------------------------------------------------------------------------- /experiments/RealIAD-fuad-n4/config.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 133 3 | port: 11111 4 | 5 | dataset: 6 | type: explicit 7 | 8 | image_reader: 9 | type: opencv 10 | kwargs: 11 | image_dir: data/Real-IAD/realiad_1024 12 | color_mode: RGB 13 | 14 | train: 15 | meta_file: 16 | - data/Real-IAD/realiad_jsons_fuiad_0.4/audiojack.json 17 | - data/Real-IAD/realiad_jsons_fuiad_0.4/bottle_cap.json 18 | - data/Real-IAD/realiad_jsons_fuiad_0.4/button_battery.json 19 | - data/Real-IAD/realiad_jsons_fuiad_0.4/end_cap.json 20 | - data/Real-IAD/realiad_jsons_fuiad_0.4/eraser.json 21 | - data/Real-IAD/realiad_jsons_fuiad_0.4/fire_hood.json 22 | - data/Real-IAD/realiad_jsons_fuiad_0.4/mint.json 23 | - data/Real-IAD/realiad_jsons_fuiad_0.4/mounts.json 24 | - data/Real-IAD/realiad_jsons_fuiad_0.4/pcb.json 25 | - data/Real-IAD/realiad_jsons_fuiad_0.4/phone_battery.json 26 | - data/Real-IAD/realiad_jsons_fuiad_0.4/plastic_nut.json 27 | - data/Real-IAD/realiad_jsons_fuiad_0.4/plastic_plug.json 28 | - data/Real-IAD/realiad_jsons_fuiad_0.4/porcelain_doll.json 29 | - data/Real-IAD/realiad_jsons_fuiad_0.4/regulator.json 30 | - data/Real-IAD/realiad_jsons_fuiad_0.4/rolled_strip_base.json 31 | - data/Real-IAD/realiad_jsons_fuiad_0.4/sim_card_set.json 32 | - data/Real-IAD/realiad_jsons_fuiad_0.4/switch.json 33 | - data/Real-IAD/realiad_jsons_fuiad_0.4/tape.json 34 | - data/Real-IAD/realiad_jsons_fuiad_0.4/terminalblock.json 35 | - data/Real-IAD/realiad_jsons_fuiad_0.4/toothbrush.json 36 | - data/Real-IAD/realiad_jsons_fuiad_0.4/toy.json 37 | - data/Real-IAD/realiad_jsons_fuiad_0.4/toy_brick.json 38 | - data/Real-IAD/realiad_jsons_fuiad_0.4/transistor1.json 39 | - data/Real-IAD/realiad_jsons_fuiad_0.4/u_block.json 40 | - data/Real-IAD/realiad_jsons_fuiad_0.4/usb.json 41 | - data/Real-IAD/realiad_jsons_fuiad_0.4/usb_adaptor.json 42 | - data/Real-IAD/realiad_jsons_fuiad_0.4/vcpill.json 43 | - data/Real-IAD/realiad_jsons_fuiad_0.4/wooden_beads.json 44 | - data/Real-IAD/realiad_jsons_fuiad_0.4/woodstick.json 45 | - data/Real-IAD/realiad_jsons_fuiad_0.4/zipper.json 46 | rebalance: False 47 | hflip: False 48 | vflip: False 49 | rotate: False 50 | 51 | test: 52 | meta_file: 53 | - data/Real-IAD/realiad_jsons_fuiad_0.4/audiojack.json 54 | - data/Real-IAD/realiad_jsons_fuiad_0.4/bottle_cap.json 55 | - data/Real-IAD/realiad_jsons_fuiad_0.4/button_battery.json 56 | - data/Real-IAD/realiad_jsons_fuiad_0.4/end_cap.json 57 | - data/Real-IAD/realiad_jsons_fuiad_0.4/eraser.json 58 | - data/Real-IAD/realiad_jsons_fuiad_0.4/fire_hood.json 59 | - data/Real-IAD/realiad_jsons_fuiad_0.4/mint.json 60 | - data/Real-IAD/realiad_jsons_fuiad_0.4/mounts.json 61 | - data/Real-IAD/realiad_jsons_fuiad_0.4/pcb.json 62 | - data/Real-IAD/realiad_jsons_fuiad_0.4/phone_battery.json 63 | - data/Real-IAD/realiad_jsons_fuiad_0.4/plastic_nut.json 64 | - data/Real-IAD/realiad_jsons_fuiad_0.4/plastic_plug.json 65 | - data/Real-IAD/realiad_jsons_fuiad_0.4/porcelain_doll.json 66 | - data/Real-IAD/realiad_jsons_fuiad_0.4/regulator.json 67 | - data/Real-IAD/realiad_jsons_fuiad_0.4/rolled_strip_base.json 68 | - data/Real-IAD/realiad_jsons_fuiad_0.4/sim_card_set.json 69 | - data/Real-IAD/realiad_jsons_fuiad_0.4/switch.json 70 | - data/Real-IAD/realiad_jsons_fuiad_0.4/tape.json 71 | - data/Real-IAD/realiad_jsons_fuiad_0.4/terminalblock.json 72 | - data/Real-IAD/realiad_jsons_fuiad_0.4/toothbrush.json 73 | - data/Real-IAD/realiad_jsons_fuiad_0.4/toy.json 74 | - data/Real-IAD/realiad_jsons_fuiad_0.4/toy_brick.json 75 | - data/Real-IAD/realiad_jsons_fuiad_0.4/transistor1.json 76 | - data/Real-IAD/realiad_jsons_fuiad_0.4/u_block.json 77 | - data/Real-IAD/realiad_jsons_fuiad_0.4/usb.json 78 | - data/Real-IAD/realiad_jsons_fuiad_0.4/usb_adaptor.json 79 | - data/Real-IAD/realiad_jsons_fuiad_0.4/vcpill.json 80 | - data/Real-IAD/realiad_jsons_fuiad_0.4/wooden_beads.json 81 | - data/Real-IAD/realiad_jsons_fuiad_0.4/woodstick.json 82 | - data/Real-IAD/realiad_jsons_fuiad_0.4/zipper.json 83 | 84 | input_size: [224,224] # [h,w] 85 | pixel_mean: [0.485, 0.456, 0.406] 86 | pixel_std: [0.229, 0.224, 0.225] 87 | batch_size: 8 88 | workers: 8 # number of workers of dataloader for each process 89 | 90 | criterion: 91 | - name: FeatureMSELoss 92 | type: FeatureMSELoss 93 | kwargs: 94 | weight: 1.0 95 | 96 | trainer: 97 | max_epoch: 300 98 | clip_max_norm: 0.1 99 | val_freq_epoch: 10 100 | print_freq_step: 1 101 | tb_freq_step: 1 102 | lr_scheduler: 103 | type: StepLR 104 | kwargs: 105 | step_size: 800 106 | gamma: 0.1 107 | optimizer: 108 | type: AdamW 109 | kwargs: 110 | lr: 0.0001 111 | betas: [0.9, 0.999] 112 | weight_decay: 0.0001 113 | 114 | saver: 115 | auto_resume: True 116 | always_save: False 117 | load_path: checkpoints/ckpt.pth.tar 118 | save_dir: checkpoints/ 119 | log_dir: log/ 120 | 121 | evaluator: 122 | save_dir: result_eval_temp 123 | key_metric: mean_pixel_auc 124 | no_metrics: True 125 | metrics: 126 | auc: 127 | - name: std 128 | - name: max 129 | kwargs: 130 | avgpool_size: [16, 16] 131 | - name: pixel 132 | # vis_compound: 133 | # save_dir: vis_compound 134 | # max_score: null 135 | # min_score: null 136 | # vis_single: 137 | # save_dir: vis_single 138 | # max_score: null 139 | # min_score: null 140 | 141 | frozen_layers: [backbone] 142 | 143 | net: 144 | - name: backbone 145 | type: models.backbones.efficientnet_b4 146 | frozen: True 147 | kwargs: 148 | pretrained: True 149 | # select outlayers from: resnet [1,2,3,4], efficientnet [1,2,3,4,5] 150 | # empirically, for industrial: resnet [1,2,3] or [2,3], efficientnet [1,2,3,4] or [2,3,4] 151 | outlayers: [1,2,3,4] 152 | - name: neck 153 | prev: backbone 154 | type: models.necks.MFCN 155 | kwargs: 156 | outstrides: [16] 157 | - name: reconstruction 158 | prev: neck 159 | type: models.reconstructions.UniAD 160 | kwargs: 161 | pos_embed_type: learned 162 | hidden_dim: 256 163 | nhead: 8 164 | num_encoder_layers: 4 165 | num_decoder_layers: 4 166 | dim_feedforward: 1024 167 | dropout: 0.1 168 | activation: relu 169 | normalize_before: False 170 | feature_jitter: 171 | scale: 20.0 172 | prob: 1.0 173 | neighbor_mask: 174 | neighbor_size: [7,7] 175 | mask: [True, True, True] # whether use mask in [enc, dec1, dec2] 176 | save_recon: False 177 | # save_dir: result_recon 178 | initializer: 179 | method: xavier_uniform 180 | -------------------------------------------------------------------------------- /datasets/cifar_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import logging 4 | import os.path 5 | import pickle 6 | import random 7 | from typing import Any, List 8 | 9 | import numpy as np 10 | import torch 11 | import torchvision.transforms as transforms 12 | from PIL import Image 13 | from torch.utils.data import DataLoader, Dataset 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torch.utils.data.sampler import RandomSampler 16 | 17 | logger = logging.getLogger("global_logger") 18 | 19 | classes = [ 20 | "airplane", 21 | "automobile", 22 | "bird", 23 | "cat", 24 | "deer", 25 | "dog", 26 | "frog", 27 | "horse", 28 | "ship", 29 | "truck", 30 | ] 31 | 32 | 33 | def build_cifar10_dataloader(cfg, training, distributed=True): 34 | 35 | logger.info("building CustomDataset from: {}".format(cfg["root_dir"])) 36 | 37 | dataset = CIFAR10( 38 | root=cfg["root_dir"], 39 | train=training, 40 | resize=cfg["input_size"], 41 | normals=cfg["normals"], 42 | ) 43 | 44 | if distributed: 45 | sampler = DistributedSampler(dataset) 46 | else: 47 | sampler = RandomSampler(dataset) 48 | 49 | data_loader = DataLoader( 50 | dataset, 51 | batch_size=cfg["batch_size"], 52 | num_workers=cfg["workers"], 53 | pin_memory=True, 54 | sampler=sampler, 55 | ) 56 | 57 | return data_loader 58 | 59 | 60 | class CIFAR10(Dataset): 61 | """`CIFAR10 `_ Dataset. 62 | Args: 63 | root (string): Root directory of dataset where directory 64 | train (bool, optional): If True, creates dataset from training set, otherwise 65 | creates from test set. 66 | """ 67 | 68 | base_folder = "cifar-10-batches-py" 69 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 70 | filename = "cifar-10-python.tar.gz" 71 | tgz_md5 = "c58f30108f718f92721af3b95e74349a" 72 | train_list = [ 73 | ["data_batch_1", "c99cafc152244af753f735de768cd75f"], 74 | ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"], 75 | ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"], 76 | ["data_batch_4", "634d18415352ddfa80567beed471001a"], 77 | ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"], 78 | ] 79 | 80 | test_list = [ 81 | ["test_batch", "40351d587109b95175f43aff81a1287e"], 82 | ] 83 | meta = { 84 | "filename": "batches.meta", 85 | "key": "label_names", 86 | "md5": "5ff9c542aee3614f3951f8cda6e48888", 87 | } 88 | 89 | def __init__( 90 | self, 91 | root: str, 92 | train: bool, 93 | resize: List[int], 94 | normals: List[int], 95 | ) -> None: 96 | 97 | self.root = root 98 | self.normals = normals 99 | self.train = train # training set or test set 100 | 101 | self.transform = transforms.Compose( 102 | [ 103 | transforms.Resize(resize, Image.ANTIALIAS), 104 | transforms.ToTensor(), 105 | transforms.Normalize( 106 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 107 | ), 108 | ] 109 | ) 110 | 111 | self.mask_transform = transforms.Compose( 112 | [ 113 | transforms.ToTensor(), 114 | ] 115 | ) 116 | 117 | if self.train: 118 | downloaded_list = self.train_list 119 | else: 120 | downloaded_list = self.test_list 121 | 122 | self.data: Any = [] 123 | self.targets = [] 124 | 125 | # now load the picked numpy arrays 126 | for file_name, checksum in downloaded_list: 127 | file_path = os.path.join(self.root, self.base_folder, file_name) 128 | with open(file_path, "rb") as f: 129 | entry = pickle.load(f, encoding="latin1") 130 | self.data.append(entry["data"]) 131 | if "labels" in entry: 132 | self.targets.extend(entry["labels"]) 133 | else: 134 | self.targets.extend(entry["fine_labels"]) 135 | 136 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 137 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 138 | 139 | self._load_meta() 140 | self._select_normal() 141 | 142 | def _select_normal(self) -> None: 143 | assert self.data.shape[0] == len(self.targets) 144 | _data_normal = [] 145 | _data_defect = [] 146 | _targets_normal = [] 147 | _targets_defect = [] 148 | for datum, target in zip(self.data, self.targets): 149 | if target in self.normals: 150 | _data_normal.append(datum) 151 | _targets_normal.append(target) 152 | elif not self.train: 153 | _data_defect.append(datum) 154 | _targets_defect.append(target) 155 | 156 | if not self.train: 157 | ids = random.sample(range(len(_data_defect)), len(_data_normal)) 158 | _data_defect = [_data_defect[idx] for idx in ids] 159 | _targets_defect = [_targets_defect[idx] for idx in ids] 160 | 161 | self.data = _data_normal + _data_defect 162 | self.targets = _targets_normal + _targets_defect 163 | 164 | def _load_meta(self) -> None: 165 | path = os.path.join(self.root, self.base_folder, self.meta["filename"]) 166 | with open(path, "rb") as infile: 167 | data = pickle.load(infile, encoding="latin1") 168 | self.classes = data[self.meta["key"]] 169 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 170 | 171 | def __getitem__(self, index: int): 172 | img, target = self.data[index], self.targets[index] 173 | 174 | label = 0 if target in self.normals else 1 175 | 176 | # doing this so that it is consistent with all other datasets 177 | # to return a PIL Image 178 | img = Image.fromarray(img) 179 | 180 | if self.transform is not None: 181 | img = self.transform(img) 182 | 183 | height = img.shape[1] 184 | width = img.shape[2] 185 | 186 | if label == 0: 187 | mask = torch.zeros((1, height, width)) 188 | else: 189 | mask = torch.ones((1, height, width)) 190 | 191 | input = { 192 | "filename": "{}/{}.jpg".format(classes[target], index), 193 | "image": img, 194 | "mask": mask, 195 | "height": height, 196 | "width": width, 197 | "label": label, 198 | "clsname": "cifar", 199 | } 200 | 201 | return input 202 | 203 | def __len__(self) -> int: 204 | return len(self.data) 205 | 206 | def extra_repr(self) -> str: 207 | split = "Train" if self.train is True else "Test" 208 | return f"Split: {split}" 209 | -------------------------------------------------------------------------------- /tools/vis_query.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import os 4 | import pprint 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.optim 10 | import yaml 11 | from easydict import EasyDict 12 | from einops import rearrange 13 | from torch import nn 14 | from utils.misc_helper import create_logger 15 | 16 | parser = argparse.ArgumentParser(description="UniAD") 17 | parser.add_argument("--config", default="./config.yaml") 18 | parser.add_argument("--class_name", default="") 19 | 20 | 21 | def update_config(config): 22 | # update planes & strides 23 | backbone_path, backbone_type = config.net[0].type.rsplit(".", 1) 24 | module = importlib.import_module(backbone_path) 25 | backbone_info = getattr(module, "backbone_info") 26 | backbone = backbone_info[backbone_type] 27 | outplanes = [] 28 | for layer in config.net[0].kwargs.outlayers: 29 | if layer not in backbone["layers"]: 30 | raise ValueError( 31 | "only layer {} for backbone {} is allowed, but get {}!".format( 32 | backbone["layers"], backbone_type, layer 33 | ) 34 | ) 35 | idx = backbone["layers"].index(layer) 36 | outplanes.append(backbone["planes"][idx]) 37 | 38 | config.net[2].kwargs.instrides = config.net[1].kwargs.outstrides 39 | config.net[2].kwargs.inplanes = [sum(outplanes)] 40 | return config 41 | 42 | 43 | def load_state_decoder(path, model): 44 | def map_func(storage, location): 45 | return storage.cuda() 46 | 47 | if os.path.isfile(path): 48 | print("=> loading checkpoint '{}'".format(path)) 49 | 50 | checkpoint = torch.load(path, map_location=map_func) 51 | state_dict = checkpoint["state_dict"] 52 | 53 | # state_dict of decoder 54 | state_dict_decoder = {} 55 | for k, v in state_dict.items(): 56 | if "module.reconstruction." in k: 57 | k_new = k.replace("module.reconstruction.", "") 58 | state_dict_decoder[k_new] = v 59 | 60 | # fix size mismatch error 61 | ignore_keys = [] 62 | for k, v in state_dict_decoder.items(): 63 | if k in model.state_dict().keys(): 64 | v_dst = model.state_dict()[k] 65 | if v.shape != v_dst.shape: 66 | ignore_keys.append(k) 67 | print( 68 | "caution: size-mismatch key: {} size: {} -> {}".format( 69 | k, v.shape, v_dst.shape 70 | ) 71 | ) 72 | 73 | for k in ignore_keys: 74 | state_dict_decoder.pop(k) 75 | 76 | model.load_state_dict(state_dict_decoder, strict=False) 77 | 78 | ckpt_keys = set(state_dict_decoder.keys()) 79 | own_keys = set(model.state_dict().keys()) 80 | missing_keys = own_keys - ckpt_keys 81 | for k in missing_keys: 82 | print("caution: missing keys from checkpoint {}: {}".format(path, k)) 83 | else: 84 | print("=> no checkpoint found at '{}'".format(path)) 85 | 86 | 87 | def main(): 88 | args = parser.parse_args() 89 | 90 | with open(args.config) as f: 91 | config = EasyDict(yaml.load(f, Loader=yaml.FullLoader)) 92 | 93 | update_config(config) 94 | config.saver.load_path = config.saver.load_path.replace( 95 | "{class_name}", args.class_name 96 | ) 97 | hidden_dim = config.vis_query.hidden_dim 98 | feat_dim = config.net[2].kwargs.inplanes[0] 99 | instride = config.net[2].kwargs.instrides[0] 100 | feat_size = [_ // instride for _ in config.data.input_size] 101 | 102 | config.exp_path = os.path.dirname(args.config) 103 | config.save_path = os.path.join(config.exp_path, config.saver.save_dir) 104 | config.log_path = os.path.join(config.exp_path, config.saver.log_dir) 105 | os.makedirs(config.save_path, exist_ok=True) 106 | os.makedirs(config.log_path, exist_ok=True) 107 | 108 | logger = create_logger( 109 | "global_logger", config.log_path + "/dec_{}.log".format(args.class_name) 110 | ) 111 | logger.info("args: {}".format(pprint.pformat(args))) 112 | logger.info("config: {}".format(pprint.pformat(config))) 113 | 114 | # create model 115 | module_name, cls_name = config.net[2].type.rsplit(".", 1) 116 | module = importlib.import_module(module_name) 117 | model = getattr(module, cls_name)(**config.net[2].kwargs) 118 | load_state_decoder(config.saver.load_path, model) 119 | model.cuda() 120 | 121 | mean = ( 122 | torch.tensor(config.data.pixel_mean).cuda().unsqueeze(0).unsqueeze(0) 123 | ) # 1 x 1 x 3 124 | std = ( 125 | torch.tensor(config.data.pixel_std).cuda().unsqueeze(0).unsqueeze(0) 126 | ) # 1 x 1 x 3 127 | 128 | model_query = torch.load(config.vis_query.model_path)["state_dict"] 129 | 130 | # proj learned_embed from hidden_dim to feat_dim 131 | output_proj = nn.Linear(hidden_dim, feat_dim).cuda() 132 | state_dict_proj = {} 133 | for k in model_query.keys(): 134 | if "module.reconstruction.output_proj." in k: 135 | k_new = k.replace("module.reconstruction.output_proj.", "") 136 | state_dict_proj[k_new] = model_query[k] 137 | output_proj.load_state_dict(state_dict_proj) 138 | 139 | queries = [] 140 | queries.append(torch.rand(feat_size[0] * feat_size[1], hidden_dim).cuda()) 141 | queries.append(torch.ones(feat_size[0] * feat_size[1], hidden_dim).cuda()) 142 | queries.append(torch.zeros(feat_size[0] * feat_size[1], hidden_dim).cuda()) 143 | for idx_layer in range(config.vis_query.num_decoder_layers): 144 | k_query = f"module.reconstruction.transformer.decoder.layers.{idx_layer}.learned_embed.weight" 145 | learned_embed = model_query[k_query].clone().detach() 146 | queries.append(learned_embed) 147 | 148 | images = [] 149 | for learned_embed in queries: 150 | learned_embed = output_proj(learned_embed.cuda()).unsqueeze(0) 151 | learned_embed = rearrange( 152 | learned_embed, "b (h w) c -> b c h w", h=feat_size[0] 153 | ) # b x c X h x w 154 | input = {"feature_align": learned_embed} 155 | with torch.no_grad(): 156 | output = model(input) 157 | image_rec = ( 158 | output["image_rec"].squeeze(0).permute(1, 2, 0) 159 | ) # 1 x 3 x h x w -> h x w x 3 160 | image_rec = (image_rec * std + mean) * 255 161 | image_rec = image_rec.cpu().numpy() 162 | images.append(image_rec) 163 | 164 | # write image 165 | image = np.ascontiguousarray(np.concatenate(images, axis=1)) # h x 4w x 3 166 | if config.vis_query.with_text: 167 | texts = ["baseline rand", "baseline ones", "baseline zeros"] 168 | for idx_layer in range(config.vis_query.num_decoder_layers): 169 | texts.append(f"query layer{idx_layer}") 170 | for idx, text in enumerate(texts): 171 | x = idx * config.data.input_size[1] + 10 172 | y = config.data.input_size[0] - 10 173 | cv2.putText( 174 | image, text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2 175 | ) 176 | savepath = os.path.join(config.save_path, f"query_{args.class_name}.jpg") 177 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 178 | cv2.imwrite(savepath, image) 179 | 180 | print(f"Success: Class: {args.class_name}, Saved: {savepath}") 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /utils/misc_helper.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | from collections.abc import Mapping 7 | from datetime import datetime 8 | 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | 13 | 14 | def basicConfig(*args, **kwargs): 15 | return 16 | 17 | 18 | # To prevent duplicate logs, we mask this baseConfig setting 19 | logging.basicConfig = basicConfig 20 | 21 | 22 | def create_logger(name, log_file, level=logging.INFO): 23 | log = logging.getLogger(name) 24 | formatter = logging.Formatter( 25 | "[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s" 26 | ) 27 | fh = logging.FileHandler(log_file) 28 | fh.setFormatter(formatter) 29 | sh = logging.StreamHandler() 30 | sh.setFormatter(formatter) 31 | log.setLevel(level) 32 | log.addHandler(fh) 33 | log.addHandler(sh) 34 | return log 35 | 36 | 37 | def get_current_time(): 38 | current_time = datetime.now().strftime("%Y%m%d_%H%M%S") 39 | return current_time 40 | 41 | 42 | class AverageMeter(object): 43 | """Computes and stores the average and current value""" 44 | 45 | def __init__(self, length=0): 46 | self.length = length 47 | self.reset() 48 | 49 | def reset(self): 50 | if self.length > 0: 51 | self.history = [] 52 | else: 53 | self.count = 0 54 | self.sum = 0.0 55 | self.val = 0.0 56 | self.avg = 0.0 57 | 58 | def update(self, val, num=1): 59 | if self.length > 0: 60 | # currently assert num==1 to avoid bad usage, refine when there are some explict requirements 61 | assert num == 1 62 | self.history.append(val) 63 | if len(self.history) > self.length: 64 | del self.history[0] 65 | 66 | self.val = self.history[-1] 67 | self.avg = np.mean(self.history) 68 | else: 69 | self.val = val 70 | self.sum += val * num 71 | self.count += num 72 | self.avg = self.sum / self.count 73 | 74 | 75 | def save_checkpoint(state, is_best, config): 76 | folder = config.save_path 77 | 78 | torch.save(state, os.path.join(folder, "ckpt.pth.tar")) 79 | if is_best: 80 | shutil.copyfile( 81 | os.path.join(folder, "ckpt.pth.tar"), 82 | os.path.join(folder, "ckpt_best.pth.tar"), 83 | ) 84 | 85 | if config.saver.get( 86 | "always_save", True 87 | ): # default: save checkpoint after validate() 88 | epoch = state["epoch"] 89 | shutil.copyfile( 90 | os.path.join(folder, "ckpt.pth.tar"), 91 | os.path.join(folder, f"ckpt_{epoch}.pth.tar"), 92 | ) 93 | 94 | 95 | def load_state(path, model, optimizer=None): 96 | 97 | rank = dist.get_rank() 98 | 99 | def map_func(storage, location): 100 | return storage.cuda() 101 | 102 | if os.path.isfile(path): 103 | if rank == 0: 104 | print("=> loading checkpoint '{}'".format(path)) 105 | 106 | checkpoint = torch.load(path, map_location=map_func) 107 | 108 | # fix size mismatch error 109 | ignore_keys = [] 110 | for k, v in checkpoint["state_dict"].items(): 111 | if k in model.state_dict().keys(): 112 | v_dst = model.state_dict()[k] 113 | if v.shape != v_dst.shape: 114 | ignore_keys.append(k) 115 | if rank == 0: 116 | print( 117 | "caution: size-mismatch key: {} size: {} -> {}".format( 118 | k, v.shape, v_dst.shape 119 | ) 120 | ) 121 | 122 | for k in ignore_keys: 123 | checkpoint["state_dict"].pop(k) 124 | 125 | model.load_state_dict(checkpoint["state_dict"], strict=False) 126 | 127 | if rank == 0: 128 | ckpt_keys = set(checkpoint["state_dict"].keys()) 129 | own_keys = set(model.state_dict().keys()) 130 | missing_keys = own_keys - ckpt_keys 131 | for k in missing_keys: 132 | print("caution: missing keys from checkpoint {}: {}".format(path, k)) 133 | 134 | if optimizer is not None: 135 | best_metric = checkpoint["best_metric"] 136 | epoch = checkpoint["epoch"] 137 | # optimizer.load_state_dict(checkpoint["optimizer"]) 138 | if rank == 0: 139 | print( 140 | "=> also loaded optimizer from checkpoint '{}' (Epoch {})".format( 141 | path, epoch 142 | ) 143 | ) 144 | return best_metric, epoch 145 | else: 146 | if rank == 0: 147 | print("=> no checkpoint found at '{}'".format(path)) 148 | 149 | 150 | def set_random_seed(seed=233, reproduce=False): 151 | np.random.seed(seed) 152 | torch.manual_seed(seed ** 2) 153 | torch.cuda.manual_seed(seed ** 3) 154 | random.seed(seed ** 4) 155 | 156 | if reproduce: 157 | torch.backends.cudnn.benchmark = False 158 | torch.backends.cudnn.deterministic = True 159 | else: 160 | torch.backends.cudnn.benchmark = True 161 | 162 | 163 | def to_device(input, device="cuda", dtype=None): 164 | """Transfer data between devidces""" 165 | 166 | if "image" in input: 167 | input["image"] = input["image"].to(dtype=dtype) 168 | 169 | def transfer(x): 170 | if torch.is_tensor(x): 171 | return x.to(device=device) 172 | elif isinstance(x, list): 173 | return [transfer(_) for _ in x] 174 | elif isinstance(x, Mapping): 175 | return type(x)({k: transfer(v) for k, v in x.items()}) 176 | else: 177 | return x 178 | 179 | return {k: transfer(v) for k, v in input.items()} 180 | 181 | 182 | def update_config(config): 183 | # update feature size 184 | _, reconstruction_type = config.net[2].type.rsplit(".", 1) 185 | if reconstruction_type == "UniAD": 186 | input_size = config.dataset.input_size 187 | outstride = config.net[1].kwargs.outstrides[0] 188 | assert ( 189 | input_size[0] % outstride == 0 190 | ), "input_size must could be divided by outstrides exactly!" 191 | assert ( 192 | input_size[1] % outstride == 0 193 | ), "input_size must could be divided by outstrides exactly!" 194 | feature_size = [s // outstride for s in input_size] 195 | config.net[2].kwargs.feature_size = feature_size 196 | 197 | # update planes & strides 198 | backbone_path, backbone_type = config.net[0].type.rsplit(".", 1) 199 | module = importlib.import_module(backbone_path) 200 | backbone_info = getattr(module, "backbone_info") 201 | backbone = backbone_info[backbone_type] 202 | outblocks = None 203 | if "efficientnet" in backbone_type: 204 | outblocks = [] 205 | outstrides = [] 206 | outplanes = [] 207 | for layer in config.net[0].kwargs.outlayers: 208 | if layer not in backbone["layers"]: 209 | raise ValueError( 210 | "only layer {} for backbone {} is allowed, but get {}!".format( 211 | backbone["layers"], backbone_type, layer 212 | ) 213 | ) 214 | idx = backbone["layers"].index(layer) 215 | if "efficientnet" in backbone_type: 216 | outblocks.append(backbone["blocks"][idx]) 217 | outstrides.append(backbone["strides"][idx]) 218 | outplanes.append(backbone["planes"][idx]) 219 | if "efficientnet" in backbone_type: 220 | config.net[0].kwargs.pop("outlayers") 221 | config.net[0].kwargs.outblocks = outblocks 222 | config.net[0].kwargs.outstrides = outstrides 223 | config.net[1].kwargs.outplanes = [sum(outplanes)] 224 | 225 | return config 226 | -------------------------------------------------------------------------------- /models/reconstructions/vis_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.initializer import initialize_from_cfg 3 | 4 | 5 | def conv3x3(inplanes, outplanes, stride=1, groups=1, dilation=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d( 8 | inplanes, 9 | outplanes, 10 | kernel_size=3, 11 | stride=stride, 12 | padding=dilation, 13 | groups=groups, 14 | bias=False, 15 | dilation=dilation, 16 | ) 17 | 18 | 19 | def conv1x1(inplanes, outplanes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__( 28 | self, 29 | inplanes, 30 | planes, 31 | stride=1, 32 | shortcut=None, 33 | groups=1, 34 | base_width=64, 35 | dilation=1, 36 | norm_layer=None, 37 | ): 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | self.conv1 = conv3x3(inplanes, planes, stride=1) 46 | self.upsample = None 47 | if stride != 1: 48 | self.upsample = nn.Upsample(scale_factor=stride, mode="bilinear") 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | 54 | self.shortcut = shortcut 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | if self.upsample is not None: 64 | out = self.upsample(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.shortcut is not None: 70 | identity = self.shortcut(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 82 | # This variant is also known as ResNet V1.5 and improves accuracy according to 83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 84 | 85 | expansion = 4 86 | 87 | def __init__( 88 | self, 89 | inplanes, 90 | planes, 91 | stride=1, 92 | upsample=None, 93 | groups=1, 94 | base_width=64, 95 | dilation=1, 96 | norm_layer=None, 97 | ): 98 | super(Bottleneck, self).__init__() 99 | if norm_layer is None: 100 | norm_layer = nn.BatchNorm2d 101 | width = int(planes * (base_width / 64.0)) * groups 102 | # Both self.conv2 and self.upsample layers upsample the input when stride != 1 103 | self.conv1 = conv1x1(inplanes, width) 104 | self.bn1 = norm_layer(width) 105 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 106 | self.bn2 = norm_layer(width) 107 | self.conv3 = conv1x1(width, planes * self.expansion) 108 | self.bn3 = norm_layer(planes * self.expansion) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.upsample = upsample 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | identity = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.upsample is not None: 128 | identity = self.upsample(x) 129 | 130 | out += identity 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | def __init__( 138 | self, 139 | inplanes, 140 | instrides, 141 | block, 142 | layers, 143 | groups=1, 144 | width_per_group=64, 145 | norm_layer=None, 146 | initializer=None, 147 | ): 148 | super(ResNet, self).__init__() 149 | assert isinstance(inplanes, list) and len(inplanes) == 1 150 | assert isinstance(instrides, list) and len(instrides) == 1 151 | self.inplanes = inplanes[0] 152 | self.instrides = instrides[0] 153 | 154 | if norm_layer is None: 155 | norm_layer = nn.BatchNorm2d 156 | self._norm_layer = norm_layer 157 | self.dilation = 1 158 | layer_planes = [64, 128, 256, 512] 159 | if self.instrides == 32: 160 | layer_strides = [2, 2, 2, 1] 161 | elif self.instrides == 16: 162 | layer_strides = [1, 2, 2, 1] 163 | else: 164 | raise NotImplementedError 165 | 166 | self.groups = groups 167 | self.base_width = width_per_group 168 | self.layer4 = self._make_layer( 169 | block, layer_planes[3], layers[3], stride=layer_strides[3] 170 | ) 171 | self.layer3 = self._make_layer( 172 | block, layer_planes[2], layers[2], stride=layer_strides[2] 173 | ) 174 | self.layer2 = self._make_layer( 175 | block, layer_planes[1], layers[1], stride=layer_strides[1] 176 | ) 177 | self.layer1 = self._make_layer( 178 | block, layer_planes[1], layers[0], stride=layer_strides[0] 179 | ) 180 | self.upsample1 = nn.Upsample(scale_factor=2, mode="bilinear") 181 | self.conv1 = nn.Conv2d( 182 | self.inplanes, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False 183 | ) 184 | self.bn1 = norm_layer(self.inplanes) 185 | self.relu = nn.ReLU(inplace=True) 186 | self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear") 187 | self.conv2 = nn.Conv2d(self.inplanes, 3, kernel_size=1, stride=1, bias=False) 188 | initialize_from_cfg(self, initializer) 189 | 190 | def _make_layer(self, block, planes, blocks, stride=1): 191 | norm_layer = self._norm_layer 192 | shortcut = None 193 | previous_dilation = self.dilation 194 | if stride != 1 or self.inplanes != planes * block.expansion: 195 | shortcut = nn.Sequential( 196 | conv1x1(self.inplanes, planes * block.expansion, stride=1), 197 | nn.Upsample(scale_factor=stride, mode="bilinear"), 198 | norm_layer(planes * block.expansion), 199 | ) 200 | 201 | layers = [] 202 | layers.append( 203 | block( 204 | self.inplanes, 205 | planes, 206 | stride, 207 | shortcut, 208 | self.groups, 209 | self.base_width, 210 | previous_dilation, 211 | norm_layer, 212 | ) 213 | ) 214 | self.inplanes = planes * block.expansion 215 | for _ in range(1, blocks): 216 | layers.append( 217 | block( 218 | self.inplanes, 219 | planes, 220 | groups=self.groups, 221 | base_width=self.base_width, 222 | dilation=self.dilation, 223 | norm_layer=norm_layer, 224 | ) 225 | ) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | @property 230 | def layer0(self): 231 | return nn.Sequential( 232 | self.upsample1, self.conv1, self.bn1, self.relu, self.upsample2, self.conv2 233 | ) 234 | 235 | def forward(self, input): 236 | x = input["feature_align"] 237 | 238 | for layer_idx in range(4, -1, -1): 239 | layer = getattr(self, f"layer{layer_idx}", None) 240 | if layer is not None: 241 | x = layer(x) 242 | 243 | return {"image_rec": x} 244 | 245 | 246 | def VisDecoder(block_type, **kwargs): 247 | if block_type == "basic": 248 | return ResNet(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) 249 | elif block_type == "bottle": 250 | return ResNet(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) 251 | else: 252 | raise NotImplementedError 253 | -------------------------------------------------------------------------------- /data/MVTec-AD/json_vis_decoder/test_grid.json: -------------------------------------------------------------------------------- 1 | {"filename": "grid/test/thread/004.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/004_mask.png"} 2 | {"filename": "grid/test/thread/007.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/007_mask.png"} 3 | {"filename": "grid/test/thread/006.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/006_mask.png"} 4 | {"filename": "grid/test/thread/002.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/002_mask.png"} 5 | {"filename": "grid/test/thread/010.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/010_mask.png"} 6 | {"filename": "grid/test/thread/003.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/003_mask.png"} 7 | {"filename": "grid/test/thread/009.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/009_mask.png"} 8 | {"filename": "grid/test/thread/000.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/000_mask.png"} 9 | {"filename": "grid/test/thread/008.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/008_mask.png"} 10 | {"filename": "grid/test/thread/005.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/005_mask.png"} 11 | {"filename": "grid/test/thread/001.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/thread/001_mask.png"} 12 | {"filename": "grid/test/glue/004.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/004_mask.png"} 13 | {"filename": "grid/test/glue/007.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/007_mask.png"} 14 | {"filename": "grid/test/glue/006.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/006_mask.png"} 15 | {"filename": "grid/test/glue/002.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/002_mask.png"} 16 | {"filename": "grid/test/glue/010.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/010_mask.png"} 17 | {"filename": "grid/test/glue/003.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/003_mask.png"} 18 | {"filename": "grid/test/glue/009.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/009_mask.png"} 19 | {"filename": "grid/test/glue/000.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/000_mask.png"} 20 | {"filename": "grid/test/glue/008.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/008_mask.png"} 21 | {"filename": "grid/test/glue/005.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/005_mask.png"} 22 | {"filename": "grid/test/glue/001.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/glue/001_mask.png"} 23 | {"filename": "grid/test/bent/004.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/004_mask.png"} 24 | {"filename": "grid/test/bent/007.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/007_mask.png"} 25 | {"filename": "grid/test/bent/006.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/006_mask.png"} 26 | {"filename": "grid/test/bent/002.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/002_mask.png"} 27 | {"filename": "grid/test/bent/010.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/010_mask.png"} 28 | {"filename": "grid/test/bent/003.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/003_mask.png"} 29 | {"filename": "grid/test/bent/009.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/009_mask.png"} 30 | {"filename": "grid/test/bent/000.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/000_mask.png"} 31 | {"filename": "grid/test/bent/011.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/011_mask.png"} 32 | {"filename": "grid/test/bent/008.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/008_mask.png"} 33 | {"filename": "grid/test/bent/005.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/005_mask.png"} 34 | {"filename": "grid/test/bent/001.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/bent/001_mask.png"} 35 | {"filename": "grid/test/good/012.png", "label": 0, "label_name": "good"} 36 | {"filename": "grid/test/good/004.png", "label": 0, "label_name": "good"} 37 | {"filename": "grid/test/good/007.png", "label": 0, "label_name": "good"} 38 | {"filename": "grid/test/good/006.png", "label": 0, "label_name": "good"} 39 | {"filename": "grid/test/good/002.png", "label": 0, "label_name": "good"} 40 | {"filename": "grid/test/good/010.png", "label": 0, "label_name": "good"} 41 | {"filename": "grid/test/good/017.png", "label": 0, "label_name": "good"} 42 | {"filename": "grid/test/good/019.png", "label": 0, "label_name": "good"} 43 | {"filename": "grid/test/good/003.png", "label": 0, "label_name": "good"} 44 | {"filename": "grid/test/good/015.png", "label": 0, "label_name": "good"} 45 | {"filename": "grid/test/good/018.png", "label": 0, "label_name": "good"} 46 | {"filename": "grid/test/good/009.png", "label": 0, "label_name": "good"} 47 | {"filename": "grid/test/good/016.png", "label": 0, "label_name": "good"} 48 | {"filename": "grid/test/good/000.png", "label": 0, "label_name": "good"} 49 | {"filename": "grid/test/good/011.png", "label": 0, "label_name": "good"} 50 | {"filename": "grid/test/good/020.png", "label": 0, "label_name": "good"} 51 | {"filename": "grid/test/good/008.png", "label": 0, "label_name": "good"} 52 | {"filename": "grid/test/good/013.png", "label": 0, "label_name": "good"} 53 | {"filename": "grid/test/good/014.png", "label": 0, "label_name": "good"} 54 | {"filename": "grid/test/good/005.png", "label": 0, "label_name": "good"} 55 | {"filename": "grid/test/good/001.png", "label": 0, "label_name": "good"} 56 | {"filename": "grid/test/metal_contamination/004.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/004_mask.png"} 57 | {"filename": "grid/test/metal_contamination/007.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/007_mask.png"} 58 | {"filename": "grid/test/metal_contamination/006.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/006_mask.png"} 59 | {"filename": "grid/test/metal_contamination/002.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/002_mask.png"} 60 | {"filename": "grid/test/metal_contamination/010.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/010_mask.png"} 61 | {"filename": "grid/test/metal_contamination/003.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/003_mask.png"} 62 | {"filename": "grid/test/metal_contamination/009.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/009_mask.png"} 63 | {"filename": "grid/test/metal_contamination/000.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/000_mask.png"} 64 | {"filename": "grid/test/metal_contamination/008.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/008_mask.png"} 65 | {"filename": "grid/test/metal_contamination/005.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/005_mask.png"} 66 | {"filename": "grid/test/metal_contamination/001.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/metal_contamination/001_mask.png"} 67 | {"filename": "grid/test/broken/004.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/004_mask.png"} 68 | {"filename": "grid/test/broken/007.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/007_mask.png"} 69 | {"filename": "grid/test/broken/006.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/006_mask.png"} 70 | {"filename": "grid/test/broken/002.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/002_mask.png"} 71 | {"filename": "grid/test/broken/010.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/010_mask.png"} 72 | {"filename": "grid/test/broken/003.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/003_mask.png"} 73 | {"filename": "grid/test/broken/009.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/009_mask.png"} 74 | {"filename": "grid/test/broken/000.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/000_mask.png"} 75 | {"filename": "grid/test/broken/011.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/011_mask.png"} 76 | {"filename": "grid/test/broken/008.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/008_mask.png"} 77 | {"filename": "grid/test/broken/005.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/005_mask.png"} 78 | {"filename": "grid/test/broken/001.png", "label": 1, "label_name": "defective", "maskname": "grid/ground_truth/broken/001_mask.png"} 79 | -------------------------------------------------------------------------------- /data/MVTec-AD/json_vis_decoder/test_wood.json: -------------------------------------------------------------------------------- 1 | {"filename": "wood/test/liquid/004.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/004_mask.png"} 2 | {"filename": "wood/test/liquid/007.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/007_mask.png"} 3 | {"filename": "wood/test/liquid/006.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/006_mask.png"} 4 | {"filename": "wood/test/liquid/002.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/002_mask.png"} 5 | {"filename": "wood/test/liquid/003.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/003_mask.png"} 6 | {"filename": "wood/test/liquid/009.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/009_mask.png"} 7 | {"filename": "wood/test/liquid/000.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/000_mask.png"} 8 | {"filename": "wood/test/liquid/008.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/008_mask.png"} 9 | {"filename": "wood/test/liquid/005.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/005_mask.png"} 10 | {"filename": "wood/test/liquid/001.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/liquid/001_mask.png"} 11 | {"filename": "wood/test/combined/004.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/004_mask.png"} 12 | {"filename": "wood/test/combined/007.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/007_mask.png"} 13 | {"filename": "wood/test/combined/006.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/006_mask.png"} 14 | {"filename": "wood/test/combined/002.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/002_mask.png"} 15 | {"filename": "wood/test/combined/010.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/010_mask.png"} 16 | {"filename": "wood/test/combined/003.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/003_mask.png"} 17 | {"filename": "wood/test/combined/009.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/009_mask.png"} 18 | {"filename": "wood/test/combined/000.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/000_mask.png"} 19 | {"filename": "wood/test/combined/008.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/008_mask.png"} 20 | {"filename": "wood/test/combined/005.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/005_mask.png"} 21 | {"filename": "wood/test/combined/001.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/combined/001_mask.png"} 22 | {"filename": "wood/test/scratch/012.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/012_mask.png"} 23 | {"filename": "wood/test/scratch/004.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/004_mask.png"} 24 | {"filename": "wood/test/scratch/007.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/007_mask.png"} 25 | {"filename": "wood/test/scratch/006.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/006_mask.png"} 26 | {"filename": "wood/test/scratch/002.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/002_mask.png"} 27 | {"filename": "wood/test/scratch/010.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/010_mask.png"} 28 | {"filename": "wood/test/scratch/017.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/017_mask.png"} 29 | {"filename": "wood/test/scratch/019.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/019_mask.png"} 30 | {"filename": "wood/test/scratch/003.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/003_mask.png"} 31 | {"filename": "wood/test/scratch/015.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/015_mask.png"} 32 | {"filename": "wood/test/scratch/018.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/018_mask.png"} 33 | {"filename": "wood/test/scratch/009.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/009_mask.png"} 34 | {"filename": "wood/test/scratch/016.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/016_mask.png"} 35 | {"filename": "wood/test/scratch/000.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/000_mask.png"} 36 | {"filename": "wood/test/scratch/011.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/011_mask.png"} 37 | {"filename": "wood/test/scratch/020.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/020_mask.png"} 38 | {"filename": "wood/test/scratch/008.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/008_mask.png"} 39 | {"filename": "wood/test/scratch/013.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/013_mask.png"} 40 | {"filename": "wood/test/scratch/014.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/014_mask.png"} 41 | {"filename": "wood/test/scratch/005.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/005_mask.png"} 42 | {"filename": "wood/test/scratch/001.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/scratch/001_mask.png"} 43 | {"filename": "wood/test/good/012.png", "label": 0, "label_name": "good"} 44 | {"filename": "wood/test/good/004.png", "label": 0, "label_name": "good"} 45 | {"filename": "wood/test/good/007.png", "label": 0, "label_name": "good"} 46 | {"filename": "wood/test/good/006.png", "label": 0, "label_name": "good"} 47 | {"filename": "wood/test/good/002.png", "label": 0, "label_name": "good"} 48 | {"filename": "wood/test/good/010.png", "label": 0, "label_name": "good"} 49 | {"filename": "wood/test/good/017.png", "label": 0, "label_name": "good"} 50 | {"filename": "wood/test/good/003.png", "label": 0, "label_name": "good"} 51 | {"filename": "wood/test/good/015.png", "label": 0, "label_name": "good"} 52 | {"filename": "wood/test/good/018.png", "label": 0, "label_name": "good"} 53 | {"filename": "wood/test/good/009.png", "label": 0, "label_name": "good"} 54 | {"filename": "wood/test/good/016.png", "label": 0, "label_name": "good"} 55 | {"filename": "wood/test/good/000.png", "label": 0, "label_name": "good"} 56 | {"filename": "wood/test/good/011.png", "label": 0, "label_name": "good"} 57 | {"filename": "wood/test/good/008.png", "label": 0, "label_name": "good"} 58 | {"filename": "wood/test/good/013.png", "label": 0, "label_name": "good"} 59 | {"filename": "wood/test/good/014.png", "label": 0, "label_name": "good"} 60 | {"filename": "wood/test/good/005.png", "label": 0, "label_name": "good"} 61 | {"filename": "wood/test/good/001.png", "label": 0, "label_name": "good"} 62 | {"filename": "wood/test/color/004.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/004_mask.png"} 63 | {"filename": "wood/test/color/007.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/007_mask.png"} 64 | {"filename": "wood/test/color/006.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/006_mask.png"} 65 | {"filename": "wood/test/color/002.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/002_mask.png"} 66 | {"filename": "wood/test/color/003.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/003_mask.png"} 67 | {"filename": "wood/test/color/000.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/000_mask.png"} 68 | {"filename": "wood/test/color/005.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/005_mask.png"} 69 | {"filename": "wood/test/color/001.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/color/001_mask.png"} 70 | {"filename": "wood/test/hole/004.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/004_mask.png"} 71 | {"filename": "wood/test/hole/007.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/007_mask.png"} 72 | {"filename": "wood/test/hole/006.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/006_mask.png"} 73 | {"filename": "wood/test/hole/002.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/002_mask.png"} 74 | {"filename": "wood/test/hole/003.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/003_mask.png"} 75 | {"filename": "wood/test/hole/009.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/009_mask.png"} 76 | {"filename": "wood/test/hole/000.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/000_mask.png"} 77 | {"filename": "wood/test/hole/008.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/008_mask.png"} 78 | {"filename": "wood/test/hole/005.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/005_mask.png"} 79 | {"filename": "wood/test/hole/001.png", "label": 1, "label_name": "defective", "maskname": "wood/ground_truth/hole/001_mask.png"} 80 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import numbers 4 | import random 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from torchvision.transforms import functional as F 9 | from torchvision.transforms.functional import ( 10 | adjust_brightness, 11 | adjust_contrast, 12 | adjust_hue, 13 | adjust_saturation, 14 | ) 15 | 16 | 17 | # Horizontal Flip 18 | class RandomHFlip(object): 19 | def __init__(self, flip_p=0.5): 20 | self.flip_p = flip_p 21 | 22 | def __call__(self, img, mask): 23 | flip_flag = torch.rand(1)[0].item() < self.flip_p 24 | if flip_flag: 25 | return F.hflip(img), F.hflip(mask) 26 | else: 27 | return img, mask 28 | 29 | 30 | # Vertical Flip 31 | class RandomVFlip(object): 32 | def __init__(self, flip_p=0.5): 33 | self.flip_p = flip_p 34 | 35 | def __call__(self, img, mask): 36 | flip_flag = torch.rand(1)[0].item() < self.flip_p 37 | if flip_flag: 38 | return F.vflip(img), F.vflip(mask) 39 | else: 40 | return img, mask 41 | 42 | 43 | # from POD 44 | class RandomColorJitter(object): 45 | """ 46 | Randomly change the brightness, contrast and saturation of an image. 47 | 48 | Arguments: 49 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 50 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 51 | or the given [min, max]. Should be non negative numbers. 52 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 53 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 54 | or the given [min, max]. Should be non negative numbers. 55 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 56 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 57 | or the given [min, max]. Should be non negative numbers. 58 | hue (float or tuple of float (min, max)): How much to jitter hue. 59 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 60 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 61 | """ 62 | 63 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, prob=0): 64 | self.brightness = self._check_input(brightness, "brightness") 65 | self.contrast = self._check_input(contrast, "contrast") 66 | self.saturation = self._check_input(saturation, "saturation") 67 | self.hue = self._check_input( 68 | hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False 69 | ) 70 | self.prob = prob 71 | 72 | def _check_input( 73 | self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True 74 | ): 75 | if isinstance(value, numbers.Number): 76 | if value < 0: 77 | raise ValueError( 78 | "If {} is a single number, it must be non negative.".format(name) 79 | ) 80 | value = [center - value, center + value] 81 | if clip_first_on_zero: 82 | value[0] = max(value[0], 0) 83 | elif isinstance(value, (tuple, list)) and len(value) == 2: 84 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 85 | raise ValueError("{} values should be between {}".format(name, bound)) 86 | else: 87 | raise TypeError( 88 | "{} should be a single number or a list/tuple with lenght 2.".format( 89 | name 90 | ) 91 | ) 92 | 93 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 94 | # or (0., 0.) for hue, do nothing 95 | if value[0] == value[1] == center: 96 | value = None 97 | return value 98 | 99 | def get_params(self, brightness, contrast, saturation, hue): 100 | """ 101 | Get a randomized transform to be applied on image. 102 | 103 | Arguments are same as that of __init__. 104 | 105 | Returns: 106 | Transform which randomly adjusts brightness, contrast and 107 | saturation in a random order. 108 | """ 109 | img_transforms = [] 110 | 111 | if brightness is not None and random.random() < self.prob: 112 | brightness_factor = random.uniform(brightness[0], brightness[1]) 113 | img_transforms.append( 114 | transforms.Lambda(lambda img: adjust_brightness(img, brightness_factor)) 115 | ) 116 | 117 | if contrast is not None and random.random() < self.prob: 118 | contrast_factor = random.uniform(contrast[0], contrast[1]) 119 | img_transforms.append( 120 | transforms.Lambda(lambda img: adjust_contrast(img, contrast_factor)) 121 | ) 122 | 123 | if saturation is not None and random.random() < self.prob: 124 | saturation_factor = random.uniform(saturation[0], saturation[1]) 125 | img_transforms.append( 126 | transforms.Lambda(lambda img: adjust_saturation(img, saturation_factor)) 127 | ) 128 | 129 | if hue is not None and random.random() < self.prob: 130 | hue_factor = random.uniform(hue[0], hue[1]) 131 | img_transforms.append( 132 | transforms.Lambda(lambda img: adjust_hue(img, hue_factor)) 133 | ) 134 | 135 | random.shuffle(img_transforms) 136 | img_transforms = transforms.Compose(img_transforms) 137 | 138 | return img_transforms 139 | 140 | def __call__(self, img): 141 | """ 142 | Arguments: 143 | img (PIL Image): Input image. 144 | Returns: 145 | img (PIL Image): Color jittered image. 146 | """ 147 | transform = self.get_params( 148 | self.brightness, self.contrast, self.saturation, self.hue 149 | ) 150 | img = transform(img) 151 | return img 152 | 153 | def __repr__(self): 154 | format_string = self.__class__.__name__ + "(" 155 | format_string += "brightness={0}".format(self.brightness) 156 | format_string += ", contrast={0}".format(self.contrast) 157 | format_string += ", saturation={0}".format(self.saturation) 158 | format_string += ", hue={0})".format(self.hue) 159 | return format_string 160 | 161 | @classmethod 162 | def from_params(cls, params): 163 | brightness = params.get("brightness", 0.1) 164 | contrast = params.get("contrast", 0.5) 165 | hue = params.get("hue", 0.07) 166 | saturation = params.get("saturation", 0.5) 167 | prob = params.get("prob", 0.5) 168 | return cls( 169 | brightness=brightness, 170 | contrast=contrast, 171 | hue=hue, 172 | saturation=saturation, 173 | prob=prob, 174 | ) 175 | 176 | 177 | class RandomRotation(object): 178 | """Rotate the image by angle. 179 | 180 | Args: 181 | degrees (sequence or float or int): Range of degrees to select from. 182 | If degrees is a number instead of sequence like (min, max), the range of degrees 183 | will be (-degrees, +degrees). 184 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 185 | An optional resampling filter. 186 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 187 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 188 | expand (bool, optional): Optional expansion flag. 189 | If true, expands the output to make it large enough to hold the entire rotated image. 190 | If false or omitted, make the output image the same size as the input image. 191 | Note that the expand flag assumes rotation around the center and no translation. 192 | center (2-tuple, optional): Optional center of rotation. 193 | Origin is the upper left corner. 194 | Default is the center of the image. 195 | """ 196 | 197 | def __init__(self, degrees, resample=False, expand=False, center=None): 198 | if isinstance(degrees, numbers.Number): 199 | degrees = [degrees] 200 | self.degrees = degrees 201 | self.resample = resample 202 | self.expand = expand 203 | self.center = center 204 | 205 | @staticmethod 206 | def get_params(degrees): 207 | """Get parameters for ``rotate`` for a random rotation. 208 | 209 | Returns: 210 | sequence: params to be passed to ``rotate`` for random rotation. 211 | """ 212 | angle = random.choice(degrees) 213 | 214 | return angle 215 | 216 | def __call__(self, img, mask): 217 | """ 218 | img, mask (PIL Image): Image to be rotated. 219 | Returns: 220 | img, mask (PIL Image): Rotated image. 221 | """ 222 | angle = self.get_params(self.degrees) 223 | img = F.rotate(img, angle, self.resample, self.expand, self.center) 224 | mask = F.rotate(mask, angle, self.resample, self.expand, self.center) 225 | return img, mask 226 | 227 | def __repr__(self): 228 | format_string = self.__class__.__name__ + "(degrees={0}".format(self.degrees) 229 | format_string += ", resample={0}".format(self.resample) 230 | format_string += ", expand={0}".format(self.expand) 231 | if self.center is not None: 232 | format_string += ", center={0}".format(self.center) 233 | format_string += ")" 234 | return format_string 235 | -------------------------------------------------------------------------------- /data/MVTec-AD/json_vis_decoder/test_bottle.json: -------------------------------------------------------------------------------- 1 | {"filename": "bottle/test/contamination/012.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/012_mask.png"} 2 | {"filename": "bottle/test/contamination/004.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/004_mask.png"} 3 | {"filename": "bottle/test/contamination/007.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/007_mask.png"} 4 | {"filename": "bottle/test/contamination/006.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/006_mask.png"} 5 | {"filename": "bottle/test/contamination/002.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/002_mask.png"} 6 | {"filename": "bottle/test/contamination/010.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/010_mask.png"} 7 | {"filename": "bottle/test/contamination/017.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/017_mask.png"} 8 | {"filename": "bottle/test/contamination/019.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/019_mask.png"} 9 | {"filename": "bottle/test/contamination/003.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/003_mask.png"} 10 | {"filename": "bottle/test/contamination/015.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/015_mask.png"} 11 | {"filename": "bottle/test/contamination/018.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/018_mask.png"} 12 | {"filename": "bottle/test/contamination/009.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/009_mask.png"} 13 | {"filename": "bottle/test/contamination/016.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/016_mask.png"} 14 | {"filename": "bottle/test/contamination/000.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/000_mask.png"} 15 | {"filename": "bottle/test/contamination/011.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/011_mask.png"} 16 | {"filename": "bottle/test/contamination/020.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/020_mask.png"} 17 | {"filename": "bottle/test/contamination/008.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/008_mask.png"} 18 | {"filename": "bottle/test/contamination/013.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/013_mask.png"} 19 | {"filename": "bottle/test/contamination/014.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/014_mask.png"} 20 | {"filename": "bottle/test/contamination/005.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/005_mask.png"} 21 | {"filename": "bottle/test/contamination/001.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/contamination/001_mask.png"} 22 | {"filename": "bottle/test/good/012.png", "label": 0, "label_name": "good"} 23 | {"filename": "bottle/test/good/004.png", "label": 0, "label_name": "good"} 24 | {"filename": "bottle/test/good/007.png", "label": 0, "label_name": "good"} 25 | {"filename": "bottle/test/good/006.png", "label": 0, "label_name": "good"} 26 | {"filename": "bottle/test/good/002.png", "label": 0, "label_name": "good"} 27 | {"filename": "bottle/test/good/010.png", "label": 0, "label_name": "good"} 28 | {"filename": "bottle/test/good/017.png", "label": 0, "label_name": "good"} 29 | {"filename": "bottle/test/good/019.png", "label": 0, "label_name": "good"} 30 | {"filename": "bottle/test/good/003.png", "label": 0, "label_name": "good"} 31 | {"filename": "bottle/test/good/015.png", "label": 0, "label_name": "good"} 32 | {"filename": "bottle/test/good/018.png", "label": 0, "label_name": "good"} 33 | {"filename": "bottle/test/good/009.png", "label": 0, "label_name": "good"} 34 | {"filename": "bottle/test/good/016.png", "label": 0, "label_name": "good"} 35 | {"filename": "bottle/test/good/000.png", "label": 0, "label_name": "good"} 36 | {"filename": "bottle/test/good/011.png", "label": 0, "label_name": "good"} 37 | {"filename": "bottle/test/good/008.png", "label": 0, "label_name": "good"} 38 | {"filename": "bottle/test/good/013.png", "label": 0, "label_name": "good"} 39 | {"filename": "bottle/test/good/014.png", "label": 0, "label_name": "good"} 40 | {"filename": "bottle/test/good/005.png", "label": 0, "label_name": "good"} 41 | {"filename": "bottle/test/good/001.png", "label": 0, "label_name": "good"} 42 | {"filename": "bottle/test/broken_large/012.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/012_mask.png"} 43 | {"filename": "bottle/test/broken_large/004.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/004_mask.png"} 44 | {"filename": "bottle/test/broken_large/007.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/007_mask.png"} 45 | {"filename": "bottle/test/broken_large/006.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/006_mask.png"} 46 | {"filename": "bottle/test/broken_large/002.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/002_mask.png"} 47 | {"filename": "bottle/test/broken_large/010.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/010_mask.png"} 48 | {"filename": "bottle/test/broken_large/017.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/017_mask.png"} 49 | {"filename": "bottle/test/broken_large/019.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/019_mask.png"} 50 | {"filename": "bottle/test/broken_large/003.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/003_mask.png"} 51 | {"filename": "bottle/test/broken_large/015.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/015_mask.png"} 52 | {"filename": "bottle/test/broken_large/018.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/018_mask.png"} 53 | {"filename": "bottle/test/broken_large/009.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/009_mask.png"} 54 | {"filename": "bottle/test/broken_large/016.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/016_mask.png"} 55 | {"filename": "bottle/test/broken_large/000.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/000_mask.png"} 56 | {"filename": "bottle/test/broken_large/011.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/011_mask.png"} 57 | {"filename": "bottle/test/broken_large/008.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/008_mask.png"} 58 | {"filename": "bottle/test/broken_large/013.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/013_mask.png"} 59 | {"filename": "bottle/test/broken_large/014.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/014_mask.png"} 60 | {"filename": "bottle/test/broken_large/005.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/005_mask.png"} 61 | {"filename": "bottle/test/broken_large/001.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_large/001_mask.png"} 62 | {"filename": "bottle/test/broken_small/012.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/012_mask.png"} 63 | {"filename": "bottle/test/broken_small/004.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/004_mask.png"} 64 | {"filename": "bottle/test/broken_small/007.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/007_mask.png"} 65 | {"filename": "bottle/test/broken_small/006.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/006_mask.png"} 66 | {"filename": "bottle/test/broken_small/002.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/002_mask.png"} 67 | {"filename": "bottle/test/broken_small/010.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/010_mask.png"} 68 | {"filename": "bottle/test/broken_small/017.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/017_mask.png"} 69 | {"filename": "bottle/test/broken_small/019.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/019_mask.png"} 70 | {"filename": "bottle/test/broken_small/003.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/003_mask.png"} 71 | {"filename": "bottle/test/broken_small/015.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/015_mask.png"} 72 | {"filename": "bottle/test/broken_small/018.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/018_mask.png"} 73 | {"filename": "bottle/test/broken_small/021.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/021_mask.png"} 74 | {"filename": "bottle/test/broken_small/009.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/009_mask.png"} 75 | {"filename": "bottle/test/broken_small/016.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/016_mask.png"} 76 | {"filename": "bottle/test/broken_small/000.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/000_mask.png"} 77 | {"filename": "bottle/test/broken_small/011.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/011_mask.png"} 78 | {"filename": "bottle/test/broken_small/020.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/020_mask.png"} 79 | {"filename": "bottle/test/broken_small/008.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/008_mask.png"} 80 | {"filename": "bottle/test/broken_small/013.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/013_mask.png"} 81 | {"filename": "bottle/test/broken_small/014.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/014_mask.png"} 82 | {"filename": "bottle/test/broken_small/005.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/005_mask.png"} 83 | {"filename": "bottle/test/broken_small/001.png", "label": 1, "label_name": "defective", "maskname": "bottle/ground_truth/broken_small/001_mask.png"} 84 | -------------------------------------------------------------------------------- /data/MVTec-AD/json_vis_decoder/test_transistor.json: -------------------------------------------------------------------------------- 1 | {"filename": "transistor/test/damaged_case/004.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/004_mask.png"} 2 | {"filename": "transistor/test/damaged_case/007.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/007_mask.png"} 3 | {"filename": "transistor/test/damaged_case/006.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/006_mask.png"} 4 | {"filename": "transistor/test/damaged_case/002.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/002_mask.png"} 5 | {"filename": "transistor/test/damaged_case/003.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/003_mask.png"} 6 | {"filename": "transistor/test/damaged_case/009.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/009_mask.png"} 7 | {"filename": "transistor/test/damaged_case/000.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/000_mask.png"} 8 | {"filename": "transistor/test/damaged_case/008.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/008_mask.png"} 9 | {"filename": "transistor/test/damaged_case/005.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/005_mask.png"} 10 | {"filename": "transistor/test/damaged_case/001.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/damaged_case/001_mask.png"} 11 | {"filename": "transistor/test/bent_lead/004.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/004_mask.png"} 12 | {"filename": "transistor/test/bent_lead/007.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/007_mask.png"} 13 | {"filename": "transistor/test/bent_lead/006.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/006_mask.png"} 14 | {"filename": "transistor/test/bent_lead/002.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/002_mask.png"} 15 | {"filename": "transistor/test/bent_lead/003.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/003_mask.png"} 16 | {"filename": "transistor/test/bent_lead/009.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/009_mask.png"} 17 | {"filename": "transistor/test/bent_lead/000.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/000_mask.png"} 18 | {"filename": "transistor/test/bent_lead/008.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/008_mask.png"} 19 | {"filename": "transistor/test/bent_lead/005.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/005_mask.png"} 20 | {"filename": "transistor/test/bent_lead/001.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/bent_lead/001_mask.png"} 21 | {"filename": "transistor/test/cut_lead/004.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/004_mask.png"} 22 | {"filename": "transistor/test/cut_lead/007.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/007_mask.png"} 23 | {"filename": "transistor/test/cut_lead/006.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/006_mask.png"} 24 | {"filename": "transistor/test/cut_lead/002.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/002_mask.png"} 25 | {"filename": "transistor/test/cut_lead/003.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/003_mask.png"} 26 | {"filename": "transistor/test/cut_lead/009.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/009_mask.png"} 27 | {"filename": "transistor/test/cut_lead/000.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/000_mask.png"} 28 | {"filename": "transistor/test/cut_lead/008.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/008_mask.png"} 29 | {"filename": "transistor/test/cut_lead/005.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/005_mask.png"} 30 | {"filename": "transistor/test/cut_lead/001.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/cut_lead/001_mask.png"} 31 | {"filename": "transistor/test/good/023.png", "label": 0, "label_name": "good"} 32 | {"filename": "transistor/test/good/043.png", "label": 0, "label_name": "good"} 33 | {"filename": "transistor/test/good/031.png", "label": 0, "label_name": "good"} 34 | {"filename": "transistor/test/good/054.png", "label": 0, "label_name": "good"} 35 | {"filename": "transistor/test/good/048.png", "label": 0, "label_name": "good"} 36 | {"filename": "transistor/test/good/012.png", "label": 0, "label_name": "good"} 37 | {"filename": "transistor/test/good/050.png", "label": 0, "label_name": "good"} 38 | {"filename": "transistor/test/good/024.png", "label": 0, "label_name": "good"} 39 | {"filename": "transistor/test/good/004.png", "label": 0, "label_name": "good"} 40 | {"filename": "transistor/test/good/007.png", "label": 0, "label_name": "good"} 41 | {"filename": "transistor/test/good/006.png", "label": 0, "label_name": "good"} 42 | {"filename": "transistor/test/good/044.png", "label": 0, "label_name": "good"} 43 | {"filename": "transistor/test/good/051.png", "label": 0, "label_name": "good"} 44 | {"filename": "transistor/test/good/049.png", "label": 0, "label_name": "good"} 45 | {"filename": "transistor/test/good/002.png", "label": 0, "label_name": "good"} 46 | {"filename": "transistor/test/good/022.png", "label": 0, "label_name": "good"} 47 | {"filename": "transistor/test/good/045.png", "label": 0, "label_name": "good"} 48 | {"filename": "transistor/test/good/038.png", "label": 0, "label_name": "good"} 49 | {"filename": "transistor/test/good/010.png", "label": 0, "label_name": "good"} 50 | {"filename": "transistor/test/good/034.png", "label": 0, "label_name": "good"} 51 | {"filename": "transistor/test/good/017.png", "label": 0, "label_name": "good"} 52 | {"filename": "transistor/test/good/019.png", "label": 0, "label_name": "good"} 53 | {"filename": "transistor/test/good/003.png", "label": 0, "label_name": "good"} 54 | {"filename": "transistor/test/good/027.png", "label": 0, "label_name": "good"} 55 | {"filename": "transistor/test/good/015.png", "label": 0, "label_name": "good"} 56 | {"filename": "transistor/test/good/018.png", "label": 0, "label_name": "good"} 57 | {"filename": "transistor/test/good/032.png", "label": 0, "label_name": "good"} 58 | {"filename": "transistor/test/good/021.png", "label": 0, "label_name": "good"} 59 | {"filename": "transistor/test/good/040.png", "label": 0, "label_name": "good"} 60 | {"filename": "transistor/test/good/059.png", "label": 0, "label_name": "good"} 61 | {"filename": "transistor/test/good/039.png", "label": 0, "label_name": "good"} 62 | {"filename": "transistor/test/good/030.png", "label": 0, "label_name": "good"} 63 | {"filename": "transistor/test/good/026.png", "label": 0, "label_name": "good"} 64 | {"filename": "transistor/test/good/037.png", "label": 0, "label_name": "good"} 65 | {"filename": "transistor/test/good/052.png", "label": 0, "label_name": "good"} 66 | {"filename": "transistor/test/good/047.png", "label": 0, "label_name": "good"} 67 | {"filename": "transistor/test/good/025.png", "label": 0, "label_name": "good"} 68 | {"filename": "transistor/test/good/041.png", "label": 0, "label_name": "good"} 69 | {"filename": "transistor/test/good/009.png", "label": 0, "label_name": "good"} 70 | {"filename": "transistor/test/good/033.png", "label": 0, "label_name": "good"} 71 | {"filename": "transistor/test/good/028.png", "label": 0, "label_name": "good"} 72 | {"filename": "transistor/test/good/016.png", "label": 0, "label_name": "good"} 73 | {"filename": "transistor/test/good/000.png", "label": 0, "label_name": "good"} 74 | {"filename": "transistor/test/good/011.png", "label": 0, "label_name": "good"} 75 | {"filename": "transistor/test/good/055.png", "label": 0, "label_name": "good"} 76 | {"filename": "transistor/test/good/058.png", "label": 0, "label_name": "good"} 77 | {"filename": "transistor/test/good/053.png", "label": 0, "label_name": "good"} 78 | {"filename": "transistor/test/good/020.png", "label": 0, "label_name": "good"} 79 | {"filename": "transistor/test/good/008.png", "label": 0, "label_name": "good"} 80 | {"filename": "transistor/test/good/046.png", "label": 0, "label_name": "good"} 81 | {"filename": "transistor/test/good/056.png", "label": 0, "label_name": "good"} 82 | {"filename": "transistor/test/good/036.png", "label": 0, "label_name": "good"} 83 | {"filename": "transistor/test/good/042.png", "label": 0, "label_name": "good"} 84 | {"filename": "transistor/test/good/035.png", "label": 0, "label_name": "good"} 85 | {"filename": "transistor/test/good/013.png", "label": 0, "label_name": "good"} 86 | {"filename": "transistor/test/good/014.png", "label": 0, "label_name": "good"} 87 | {"filename": "transistor/test/good/005.png", "label": 0, "label_name": "good"} 88 | {"filename": "transistor/test/good/029.png", "label": 0, "label_name": "good"} 89 | {"filename": "transistor/test/good/001.png", "label": 0, "label_name": "good"} 90 | {"filename": "transistor/test/good/057.png", "label": 0, "label_name": "good"} 91 | {"filename": "transistor/test/misplaced/004.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/004_mask.png"} 92 | {"filename": "transistor/test/misplaced/007.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/007_mask.png"} 93 | {"filename": "transistor/test/misplaced/006.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/006_mask.png"} 94 | {"filename": "transistor/test/misplaced/002.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/002_mask.png"} 95 | {"filename": "transistor/test/misplaced/003.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/003_mask.png"} 96 | {"filename": "transistor/test/misplaced/009.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/009_mask.png"} 97 | {"filename": "transistor/test/misplaced/000.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/000_mask.png"} 98 | {"filename": "transistor/test/misplaced/008.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/008_mask.png"} 99 | {"filename": "transistor/test/misplaced/005.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/005_mask.png"} 100 | {"filename": "transistor/test/misplaced/001.png", "label": 1, "label_name": "defective", "maskname": "transistor/ground_truth/misplaced/001_mask.png"} 101 | -------------------------------------------------------------------------------- /tools/train_vis_decoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pprint 5 | import time 6 | 7 | import cv2 8 | import torch 9 | import torch.distributed as dist 10 | import torch.optim 11 | import yaml 12 | from datasets.data_builder import build_dataloader 13 | from easydict import EasyDict 14 | from models.model_helper import ModelHelper 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from utils.criterion_helper import build_criterion 17 | from utils.dist_helper import setup_distributed 18 | from utils.lr_helper import get_scheduler 19 | from utils.misc_helper import ( 20 | AverageMeter, 21 | create_logger, 22 | get_current_time, 23 | load_state, 24 | save_checkpoint, 25 | set_random_seed, 26 | update_config, 27 | ) 28 | from utils.optimizer_helper import get_optimizer 29 | 30 | parser = argparse.ArgumentParser(description="UniAD Framework") 31 | parser.add_argument("--config", default="./config.yaml") 32 | parser.add_argument("--class_name", default="") 33 | parser.add_argument("-v", "--visualization", action="store_true") 34 | parser.add_argument("--local_rank", default=None, help="local rank for dist") 35 | 36 | 37 | class_name_list = [ 38 | "bottle", 39 | "cable", 40 | "capsule", 41 | "carpet", 42 | "grid", 43 | "hazelnut", 44 | "leather", 45 | "metal_nut", 46 | "pill", 47 | "screw", 48 | "tile", 49 | "toothbrush", 50 | "transistor", 51 | "wood", 52 | "zipper", 53 | ] 54 | 55 | 56 | def main(): 57 | global args, config, best_metric 58 | args = parser.parse_args() 59 | 60 | with open(args.config) as f: 61 | config = EasyDict(yaml.load(f, Loader=yaml.FullLoader)) 62 | 63 | config.dataset.train.meta_file = config.dataset.train.meta_file.replace( 64 | "{class_name}", args.class_name 65 | ) 66 | config.port = config["port"] + class_name_list.index(args.class_name) 67 | rank, world_size = setup_distributed(port=config.port) 68 | config = update_config(config) 69 | 70 | config.exp_path = os.path.join(os.path.dirname(args.config), args.class_name) 71 | config.save_path = os.path.join(config.exp_path, config.saver.save_dir) 72 | config.log_path = os.path.join(config.exp_path, config.saver.log_dir) 73 | if rank == 0: 74 | os.makedirs(config.save_path, exist_ok=True) 75 | os.makedirs(config.log_path, exist_ok=True) 76 | 77 | current_time = get_current_time() 78 | tb_logger = None # SummaryWriter(config.log_path + "/events_dec/" + current_time) 79 | logger = create_logger( 80 | "global_logger", config.log_path + "/dec_{}.log".format(current_time) 81 | ) 82 | logger.info("args: {}".format(pprint.pformat(args))) 83 | logger.info("config: {}".format(pprint.pformat(config))) 84 | else: 85 | tb_logger = None 86 | 87 | random_seed = config.get("random_seed", None) 88 | reproduce = config.get("reproduce", None) 89 | if random_seed: 90 | set_random_seed(random_seed, reproduce) 91 | # create model 92 | model = ModelHelper(config.net) 93 | model.cuda() 94 | local_rank = int(os.environ["LOCAL_RANK"]) 95 | model = DDP( 96 | model, 97 | device_ids=[local_rank], 98 | output_device=local_rank, 99 | find_unused_parameters=True, 100 | ) 101 | 102 | layers = [] 103 | for module in config.net: 104 | layers.append(module["name"]) 105 | frozen_layers = config.get("frozen_layers", []) 106 | active_layers = list(set(layers) ^ set(frozen_layers)) 107 | if rank == 0: 108 | logger.info("layers: {}".format(layers)) 109 | logger.info("active layers: {}".format(active_layers)) 110 | 111 | # parameters needed to be updated 112 | parameters = [ 113 | {"params": getattr(model.module, layer).parameters()} for layer in active_layers 114 | ] 115 | 116 | optimizer = get_optimizer(parameters, config.trainer.optimizer) 117 | lr_scheduler = get_scheduler(optimizer, config.trainer.lr_scheduler) 118 | 119 | best_metric = float("inf") 120 | last_epoch = 0 121 | 122 | # load model: auto_resume > resume_model > load_path 123 | auto_resume = config.saver.get("auto_resume", True) 124 | resume_model = config.saver.get("resume_model", None) 125 | load_path = config.saver.get("load_path", None) 126 | 127 | if resume_model and not resume_model.startswith("/"): 128 | resume_model = os.path.join(config.exp_path, resume_model) 129 | lastest_model = os.path.join(config.save_path, "ckpt.pth.tar") 130 | if auto_resume and os.path.exists(lastest_model): 131 | resume_model = lastest_model 132 | if resume_model: 133 | best_metric, last_epoch = load_state(resume_model, model, optimizer=optimizer) 134 | elif load_path: 135 | if not load_path.startswith("/"): 136 | load_path = os.path.join(config.exp_path, load_path) 137 | load_state(load_path, model) 138 | 139 | train_loader, _ = build_dataloader(config.dataset, distributed=True) 140 | 141 | if args.visualization: 142 | vis_rec(train_loader, model) 143 | return 144 | 145 | criterion = build_criterion(config.criterion) 146 | 147 | for epoch in range(last_epoch, config.trainer.max_epoch): 148 | train_loader.sampler.set_epoch(epoch) 149 | last_iter = epoch * len(train_loader) 150 | train_loss = train_one_epoch( 151 | train_loader, 152 | model, 153 | optimizer, 154 | lr_scheduler, 155 | epoch, 156 | last_iter, 157 | tb_logger, 158 | criterion, 159 | frozen_layers, 160 | ) 161 | lr_scheduler.step(epoch) 162 | 163 | if rank == 0: 164 | is_best = train_loss <= best_metric 165 | best_metric = min(train_loss, best_metric) 166 | save_checkpoint( 167 | { 168 | "epoch": epoch + 1, 169 | "arch": config.net, 170 | "state_dict": model.state_dict(), 171 | "best_metric": best_metric, 172 | "optimizer": optimizer.state_dict(), 173 | }, 174 | is_best, 175 | config, 176 | ) 177 | 178 | if config.visualization: 179 | if (epoch + 1) % config.visualization.vis_freq_epoch == 0: 180 | vis_rec(train_loader, model) 181 | 182 | 183 | def train_one_epoch( 184 | train_loader, 185 | model, 186 | optimizer, 187 | lr_scheduler, 188 | epoch, 189 | start_iter, 190 | tb_logger, 191 | criterion, 192 | frozen_layers, 193 | ): 194 | 195 | batch_time = AverageMeter(config.trainer.print_freq_step) 196 | data_time = AverageMeter(config.trainer.print_freq_step) 197 | losses = AverageMeter(config.trainer.print_freq_step) 198 | 199 | # switch to train mode 200 | model.train() 201 | # freeze selected layers 202 | for layer in frozen_layers: 203 | module = getattr(model.module, layer) 204 | module.eval() 205 | for param in module.parameters(): 206 | param.requires_grad = False 207 | 208 | world_size = dist.get_world_size() 209 | rank = dist.get_rank() 210 | logger = logging.getLogger("global_logger") 211 | end = time.time() 212 | 213 | train_loss = 0 214 | for i, input in enumerate(train_loader): 215 | curr_step = start_iter + i 216 | current_lr = lr_scheduler.get_lr()[0] 217 | 218 | # measure data loading time 219 | data_time.update(time.time() - end) 220 | 221 | # forward 222 | outputs = model(input) 223 | loss = 0 224 | for name, criterion_loss in criterion.items(): 225 | weight = criterion_loss.weight 226 | loss += weight * criterion_loss(outputs) 227 | reduced_loss = loss.clone() 228 | dist.all_reduce(reduced_loss) 229 | reduced_loss = reduced_loss / world_size 230 | losses.update(reduced_loss.item()) 231 | train_loss += reduced_loss.item() 232 | 233 | # backward 234 | optimizer.zero_grad() 235 | loss.backward() 236 | # update 237 | if config.trainer.get("clip_max_norm", None): 238 | max_norm = config.trainer.clip_max_norm 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 240 | optimizer.step() 241 | # measure elapsed time 242 | batch_time.update(time.time() - end) 243 | 244 | if (curr_step + 1) % config.trainer.print_freq_step == 0 and rank == 0: 245 | # tb_logger.add_scalar("loss_train", losses.avg, curr_step + 1) 246 | # tb_logger.add_scalar("lr", current_lr, curr_step + 1) 247 | # tb_logger.flush() 248 | 249 | logger.info( 250 | "Epoch: [{0}/{1}]\t" 251 | "Iter: [{2}/{3}]\t" 252 | "Time {batch_time.val:.2f} ({batch_time.avg:.2f})\t" 253 | "Data {data_time.val:.2f} ({data_time.avg:.2f})\t" 254 | "Loss {loss.val:.5f} ({loss.avg:.5f})\t" 255 | "LR {lr:.5f}\t".format( 256 | epoch + 1, 257 | config.trainer.max_epoch, 258 | curr_step + 1, 259 | len(train_loader) * config.trainer.max_epoch, 260 | batch_time=batch_time, 261 | data_time=data_time, 262 | loss=losses, 263 | lr=current_lr, 264 | ) 265 | ) 266 | 267 | end = time.time() 268 | 269 | return train_loss / len(train_loader) 270 | 271 | 272 | def vis_rec(loader, model): 273 | model.eval() 274 | 275 | pixel_mean = config.dataset.pixel_mean 276 | pixel_mean = torch.tensor(pixel_mean).cuda().unsqueeze(1).unsqueeze(1) # 3 x 1 x 1 277 | pixel_std = config.dataset.pixel_std 278 | pixel_std = torch.tensor(pixel_std).cuda().unsqueeze(1).unsqueeze(1) # 3 x 1 x 1 279 | 280 | with torch.no_grad(): 281 | for i, input in enumerate(loader): 282 | # forward 283 | outputs = model(input) 284 | filenames = outputs["filename"] 285 | images = outputs["image"] 286 | image_recs = outputs["image_rec"] 287 | clsnames = outputs["clsname"] 288 | 289 | for filename, image, image_rec, clasname in zip( 290 | filenames, images, image_recs, clsnames 291 | ): 292 | filedir, filename = os.path.split(filename) 293 | _, defename = os.path.split(filedir) 294 | filename_, _ = os.path.splitext(filename) 295 | vis_dir = os.path.join(config.visualization.vis_dir, clasname, defename) 296 | os.makedirs(vis_dir, exist_ok=True) 297 | vis_path = os.path.join(vis_dir, filename_ + ".jpg") 298 | 299 | image = (image * pixel_std + pixel_mean) * 255 300 | image_rec = (image_rec * pixel_std + pixel_mean) * 255 301 | image = torch.cat([image, image_rec], dim=1).permute( 302 | 1, 2, 0 303 | ) # 2h x w x 3 304 | image = image.cpu().numpy() 305 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 306 | cv2.imwrite(vis_path, image) 307 | 308 | 309 | if __name__ == "__main__": 310 | main() 311 | --------------------------------------------------------------------------------