├── src
└── teaser.png
├── requirements.txt
├── ksdd2_preprocess.py
├── README.md
├── .gitignore
├── generate_augmented_images.py
├── data
└── ksdd2.py
└── train_ResNet50.py
/src/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelligolabs/DIAG/HEAD/src/teaser.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2024.6.2
2 | charset-normalizer==3.3.2
3 | click==8.1.7
4 | diffusers==0.29.0
5 | docker-pycreds==0.4.0
6 | filelock==3.13.1
7 | fsspec==2024.2.0
8 | gitdb==4.0.11
9 | GitPython==3.1.43
10 | huggingface-hub==0.23.4
11 | idna==3.7
12 | importlib_metadata==7.1.0
13 | Jinja2==3.1.3
14 | joblib==1.4.2
15 | MarkupSafe==2.1.5
16 | mpmath==1.3.0
17 | networkx==3.2.1
18 | numpy==1.26.3
19 | opencv-python==4.10.0.84
20 | packaging==24.1
21 | pandas==2.2.2
22 | pillow==10.2.0
23 | platformdirs==4.2.2
24 | protobuf==5.27.1
25 | psutil==6.0.0
26 | python-dateutil==2.9.0.post0
27 | pytz==2024.1
28 | PyYAML==6.0.1
29 | regex==2024.5.15
30 | requests==2.32.3
31 | safetensors==0.4.3
32 | scikit-learn==1.5.0
33 | scipy==1.13.1
34 | sentry-sdk==2.6.0
35 | setproctitle==1.3.3
36 | six==1.16.0
37 | smmap==5.0.1
38 | sympy==1.12
39 | threadpoolctl==3.5.0
40 | tokenizers==0.19.1
41 | tqdm==4.66.4
42 | transformers==4.41.2
43 | triton==2.3.1
44 | typing_extensions==4.9.0
45 | tzdata==2024.1
46 | urllib3==2.2.2
47 | wandb==0.17.2
48 | zipp==3.19.2
49 |
--------------------------------------------------------------------------------
/ksdd2_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from glob import glob
4 | import pandas as pd
5 | import argparse
6 | from tqdm import tqdm
7 |
8 |
9 | def reshape_ksdd2(src_dir, dst_dir, RES=(224, 632)):
10 | # make dest directory
11 | splits = ['train', 'test']
12 | for split in splits:
13 | src_split_dir = os.path.join(src_dir, split)
14 | dst_split_dir = os.path.join(dst_dir, split)
15 | os.makedirs(dst_split_dir, exist_ok=True)
16 | all_imgs = os.listdir(src_split_dir)
17 | for img in tqdm(all_imgs, desc=f"Reshaping {split} images", unit="file", total=len(all_imgs)):
18 | img_path = os.path.join(src_split_dir, img)
19 | img_out_path = os.path.join(dst_split_dir, img)
20 | img = cv2.imread(img_path)
21 | img = cv2.resize(img, RES)
22 | cv2.imwrite(img_out_path, img)
23 |
24 | def copy_files(src_dir, dst_dir):
25 | target_files = glob(os.path.join(src_dir, '*.pyb'))
26 | for file in tqdm(target_files, desc="Copying .pyb files", unit="file", total=len(target_files)):
27 | file_name = os.path.basename(file)
28 | dst_file = os.path.join(dst_dir, file_name)
29 | os.system(f'cp {file} {dst_file}')
30 |
31 | def make_csv(dst_dir):
32 | splits = ['train', 'test']
33 | for split in splits:
34 | img_dir = os.path.join(dst_dir, split)
35 | all_imgs = os.listdir(img_dir)
36 | all_masks = [img for img in all_imgs if "GT" in img]
37 | imgs_dict = {"path": [], "label": []}
38 | for img in tqdm(all_masks, desc=f"Creating {split}.csv", unit="file", total=len(all_masks)):
39 | imgs_dict["path"].append(img.replace("_GT.png", ".png"))
40 | img_path = os.path.join(img_dir, img)
41 | loaded = cv2.imread(img_path)
42 | # if there is a 1, it is positive, else negative
43 | if max(loaded.flatten()) == 0:
44 | imgs_dict["label"].append("negative")
45 | else:
46 | imgs_dict["label"].append("positive")
47 | df = pd.DataFrame(imgs_dict)
48 | df.to_csv(os.path.join(dst_dir, f"{split}.csv"), index=False)
49 |
50 | def main(args):
51 | src_dir = args.src_dir
52 | dst_dir = args.dst_dir
53 | RES = (224,632) # w x h
54 | # make directory
55 | print(f"Copying files from {src_dir} to {dst_dir}")
56 | os.makedirs(dst_dir, exist_ok=True)
57 | # copy .pyb files
58 | copy_files(src_dir, dst_dir)
59 | # reshape images (needed for batching)
60 | reshape_ksdd2(src_dir, dst_dir, RES=RES)
61 | # make csv files
62 | make_csv(dst_dir)
63 |
64 |
65 | if __name__ == "__main__":
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument("--src_dir", type=str, required=True, help="Path to the KSDD2 dataset root")
68 | parser.add_argument("--dst_dir", type=str, required=True, help="Path to the destination directory")
69 | args = parser.parse_args()
70 | main(args)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Leveraging Latent Diffusion Models for Training-Free In-Distribution Data Augmentation for Surface Defect Detection #
2 |
3 | Official implementation of the paper [Leveraging Latent Diffusion Models for Training-Free In-Distribution Data Augmentation for Surface Defect Detection](https://intelligolabs.github.io/DIAG/) accepted at the 21st International Conference on Content-Based Multimedia Indexing (CBMI 2024).
4 |
5 |
6 |
7 |
8 | ## Installation ##
9 | **1. Clone the repository**
10 | ```bash
11 | git clone https://github.com/intelligolabs/DIAG.git
12 | cd DIAG
13 | ```
14 |
15 | **2. Create an environment with dependencies**
16 | ```bash
17 | conda create -n DIAG python=3.10
18 | conda activate DIAG
19 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
20 | ```
21 | Note: the `pip install` command installs PyTorch 1.12 with CUDA 11.3. Change it accordingly to your own version of CUDA.
22 |
23 | Then, install the requirements:
24 | ```bash
25 | pip install -r requirements.txt
26 | ```
27 |
28 | **3. Data preparation**
29 | Download KSDD2 from the [official KSDD2 website](https://www.vicos.si/resources/kolektorsdd2/).
30 |
31 | Run the `ksdd2_preprocess.py` script. This will create a pre-processed copy of KSDD2 in the `--dst_dir`.
32 | We will use this pre-processed copy for our augmentation and evaluation.
33 | ```bash
34 | python ksdd2_preprocess.py --src_dir="/ksdd2" --dst_dir="/ksdd2_preprocessed"
35 | ```
36 |
37 | **4. [Optional] Set up wandb for logging**
38 | Optionally, you can log the training and evaluation for wandb.
39 | ```
40 | wandb init
41 | ```
42 |
43 | ## Part 1: Data augmentation ##
44 | This step generates the augmented positive images.
45 | The images are generated using the same prompts used in the original paper.
46 | The `src_dir` argument should point to the preprocess data root (see **3. Data preparation** step).
47 |
48 | By default, this script will generate the augmented images in an `augmented_` folder inside `src_dir`.
49 | This is needed for the dataloader during training.
50 | ```bash
51 | python generate_augmented_images.py --src_dir="/ksdd2_preprocessed" --imgs_per_prompt=50 --seed=0
52 | ```
53 |
54 | ## Part 2: Training and evaluation ##
55 | This step will fine-tune a pre-trained ResNet-50 on the (augmented) KSDD2.
56 | Different arguments handle the policy for training:
57 | - `--zero_shot` trains the model without GT positive images
58 | - `--add_augmented` adds the augmented images to the dataset (can be both zero or full shot)
59 | - `--num_augented` selects how many augmented images to add to the training set.
60 | Required if using `--add_audmented`.
61 | Note that this MUST be the TOTAL (`imgs_per_prompt * prompts`) number of images generated in the previous step.
62 |
63 | Example of DIAG training (zero-shot with augmentations):
64 | ```bash
65 | python train_ResNet50.py --seed=0 --epochs=30 --batch_size=32 --num_workers=8 --dataset_path="/ksdd2_preprocessed" --zero_shot --add_augmented --num_augmented=100 --logging
66 | ```
67 |
68 | ## Credits ##
69 | - Kolector Surface-Defect Dataset 2. More info [here](https://www.vicos.si/resources/kolektorsdd2/).
70 | - [Diffusers](https://huggingface.co/docs/diffusers/en/using-diffusers/sdxl) and [StabilityAI](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1) for their SDXL implementation and weights.
71 |
72 | ## Authors ##
73 | Federico Girella, Ziyue Liu, Franco Fummi, Francesco Setti, Marco Cristani, Luigi Capogrosso
74 |
75 | *Department of Engineering for Innovation Medicine, University of Verona, Italy*
76 |
77 | `name.surname@univr.it`
78 |
79 | ## Citation ##
80 | If you use [**DIAG**](https://ieeexplore.ieee.org/abstract/document/10858875), please, cite the following paper:
81 | ```
82 | @InProceedings{girella2024leveraging,
83 | author = {Girella, Federico and Liu, Ziyue and Fummi, Franco and Setti, Francesco and Cristani, Marco and Capogrosso, Luigi},
84 | booktitle = {International Conference on Content-Based Multimedia Indexing (CBMI)},
85 | title = {{Leveraging Latent Diffusion Models for Training-Free in-Distribution Data Augmentation for Surface Defect Detection}},
86 | year = {2024},
87 | doi = {10.1109/cbmi62980.2024.10858875},
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
2 |
3 | ### Python ###
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
165 | ### Python Patch ###
166 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
167 | poetry.toml
168 |
169 | # ruff
170 | .ruff_cache/
171 |
172 | # LSP config files
173 | pyrightconfig.json
174 |
175 | ### VisualStudioCode ###
176 | .vscode/
177 | !.vscode/settings.json
178 | !.vscode/tasks.json
179 | !.vscode/launch.json
180 | !.vscode/extensions.json
181 | !.vscode/*.code-snippets
182 |
183 | # Local History for Visual Studio Code
184 | .history/
185 |
186 | # Built Visual Studio Code Extensions
187 | *.vsix
188 |
189 | ### VisualStudioCode Patch ###
190 | # Ignore all local history of files
191 | .history
192 | .ionide
193 |
194 | ### Custom ###
195 | wandb/
--------------------------------------------------------------------------------
/generate_augmented_images.py:
--------------------------------------------------------------------------------
1 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
2 | from diffusers.utils import load_image
3 | import torch
4 | import pandas as pd
5 | import os
6 | import random
7 | import numpy as np
8 | import argparse
9 |
10 | def get_negative_images(df):
11 | return list(df[df["label"] == "negative"]["path"])
12 |
13 | def get_positive_images_maks(df):
14 | return [p.replace('.png', '_GT.png') for p in list(df[df["label"] == "positive"]["path"])]
15 |
16 | def main(args):
17 |
18 | ### ARGUMENTS
19 | src_dir = args.src_dir
20 | imgs_per_prompt = args.imgs_per_prompt
21 | seed = args.seed
22 |
23 | # prompts used in the paper
24 | prompts = ["white marks on the wall", "copper metal scratches"]
25 | negative_prompt="smooth, plain, black, dark, shadow"
26 |
27 | dst_dir = os.path.join(src_dir, f"augmented_{imgs_per_prompt*len(prompts)}")
28 | os.makedirs(dst_dir, exist_ok=True)
29 |
30 | # hyperparameters used in the paper
31 | num_inference_steps = 30
32 | guidance_scale = 20.0
33 | strength = 0.99
34 | padding_mask_crop = 2
35 | RES = (224, 632)
36 | # needed to ovecome the sdxl shape bias
37 | TARGET = (1024, 1024)
38 |
39 |
40 | ### MAIN
41 | # cuda if available
42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
43 | print("Running on device: ", device)
44 |
45 | # seed everything
46 | random.seed(seed)
47 | torch.manual_seed(seed)
48 | np.random.seed(seed)
49 | generator = torch.Generator(device=device).manual_seed(seed)
50 |
51 | # create directories
52 | os.makedirs(os.path.join(dst_dir, 'imgs'), exist_ok=True)
53 | os.makedirs(os.path.join(dst_dir, 'masks'), exist_ok=True)
54 |
55 | df_path = os.path.join(src_dir, 'train.csv')
56 | df = pd.read_csv(df_path)
57 | negative_imgs = get_negative_images(df)
58 | positive_masks = get_positive_images_maks(df)
59 | print(f'Num negative images: {len(negative_imgs)}')
60 | print(f'Num positive masks: {len(positive_masks)}')
61 |
62 | model = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
63 | print(f'Loading model {model}')
64 | pipe = StableDiffusionXLInpaintPipeline.from_pretrained(model, torch_dtype=torch.float16, variant="fp16").to(device)
65 |
66 | img_idx = 0
67 | for prompt in prompts:
68 | print(f'Generating images for prompt: {prompt}')
69 | cnt = 0
70 | for cnt in range(imgs_per_prompt):
71 | # by sampling 1 by 1 we can generate more anomalies than what we have in the dataset (246)
72 | mask_name = random.sample(positive_masks, 1)[0]
73 | mask_path = os.path.join(src_dir, 'train', mask_name)
74 | neg_img_name = random.sample(negative_imgs, 1)[0]
75 | neg_img_path = os.path.join(src_dir, 'train', neg_img_name)
76 |
77 | neg_img = load_image(neg_img_path).resize(TARGET)
78 | mask = load_image(mask_path).resize(TARGET)
79 |
80 | out_image = pipe(
81 | prompt=prompt,
82 | negative_prompt=negative_prompt,
83 | image=neg_img,
84 | mask_image=mask,
85 | guidance_scale=guidance_scale,
86 | num_inference_steps=num_inference_steps, # steps between 15 and 30 work well for us
87 | strength=strength, # make sure to use `strength` below 1.0
88 | generator=generator,
89 | height=TARGET[1],
90 | width=TARGET[0],
91 | original_size = TARGET,
92 | target_size = TARGET,
93 | padding_mask_crop = padding_mask_crop
94 | ).images[0]
95 | out_image = out_image.resize(RES)
96 | mask = mask.resize(RES)
97 | # save the image with progressive name
98 | out_img_path = f'{dst_dir}/imgs/{str(img_idx + cnt).zfill(5)}.png'
99 | out_image.save(out_img_path)
100 | # save the mask with progressive name
101 | out_mask_path = f'{dst_dir}/masks/{str(img_idx + cnt).zfill(5)}.png'
102 | mask.save(out_mask_path)
103 | cnt += 1
104 | img_idx += cnt
105 |
106 | if __name__ == "__main__":
107 | # ARGUMENTS
108 |
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument("--src_dir", type=str, required=True, help="Directory containing the preprocessed dataset")
111 | parser.add_argument("--imgs_per_prompt", type=int, default=50, help="Number of images to generate per prompt")
112 | parser.add_argument("--seed", type=int, default=0, help="Seed for random generation")
113 | args = parser.parse_args()
114 |
115 | main(args)
--------------------------------------------------------------------------------
/data/ksdd2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import re
6 | import torch
7 |
8 | import torchvision.transforms as T
9 |
10 | from PIL import Image
11 | from torch.utils.data import Dataset
12 |
13 |
14 | def c2chw(x):
15 | return x.unsqueeze(1).unsqueeze(2)
16 |
17 |
18 | def inverse_list(list):
19 | """
20 | List to dict: index -> element
21 | """
22 | dict = {}
23 |
24 | for idx, x in enumerate(list):
25 | dict[x] = idx
26 |
27 | return dict
28 |
29 |
30 | class KolektorSDD2(Dataset):
31 | """"
32 | Kolektor Surface-Defect 2 dataset
33 |
34 | Args:
35 | dataroot (string): path to the root directory of the dataset
36 | split (string): data split ['train', 'test']
37 | scale (string): input image scale
38 | debug (bool) : debug mode
39 | """
40 |
41 | labels = ['ok', 'defect']
42 |
43 | # ImageNet.
44 | mean = (0.485, 0.456, 0.406)
45 | std = (0.229, 0.224, 0.225)
46 |
47 | def __init__(self,
48 | dataroot='/path/to/dataset/'
49 | 'KolektorSDD2',
50 | split='train', negative_only=False,
51 | add_augmented=False, num_augmented=0, zero_shot=False):
52 | super(KolektorSDD2, self).__init__()
53 |
54 | self.fold = None
55 | self.dataroot = dataroot
56 |
57 | self.split_path = None
58 | self.split = 'train' if 'val' == split else split
59 |
60 | self.output_size = (704, 256)
61 | self.negative_only = negative_only
62 | self.add_augmented = add_augmented
63 | if self.add_augmented:
64 | assert self.split == 'train', 'Augmented images are only for the training set!'
65 | self.num_augmented = num_augmented
66 | self.zero_shot = zero_shot
67 | if self.zero_shot:
68 | assert self.add_augmented, 'Zero-shot learning requires augmented images!'
69 | assert self.split == 'train', 'Zero-shot learning is only for the training set!'
70 |
71 | self.class_to_idx = inverse_list(self.labels)
72 | self.classes = self.labels
73 | self.transform = KolektorSDD2.get_transform(output_size=self.output_size)
74 | self.normalize = T.Normalize(KolektorSDD2.mean, KolektorSDD2.std)
75 |
76 | self.load_imgs()
77 | if negative_only:
78 | m = self.masks.sum(-1).sum(-1) == 0
79 | self.samples = self.samples[m]
80 | self.masks = self.masks[m]
81 | self.product_ids = [pid for flag, pid in zip(m, self.product_ids)
82 | if flag]
83 |
84 |
85 | def load_imgs(self):
86 | # Please remove this duplicated files in the official dataset:
87 | # -- 10301_GT (copy).png
88 | # -- 10301 (copy).png
89 | if self.num_augmented > 0:
90 | augmented_imgs_path = os.path.join(self.dataroot, f'augmented_{self.num_augmented}', 'imgs')
91 | augmented_masks_path = os.path.join(self.dataroot, f'augmented_{self.num_augmented}', 'masks')
92 | else:
93 | augmented_imgs_path = os.path.join(self.dataroot, f'augmented', 'imgs')
94 | augmented_masks_path = os.path.join(self.dataroot, f'augmented', 'masks')
95 |
96 | if self.split == 'test':
97 | N = 1004
98 | elif self.split == 'train' and self.zero_shot:
99 | # only augmented positives and original negatives
100 | N = 2085 # number of original negatives
101 | else:
102 | # all original data + augmented
103 | N = 2331 # number of original negatives and positives
104 |
105 | if self.add_augmented:
106 | N += len(os.listdir(augmented_imgs_path))
107 | if self.num_augmented > 0:
108 | assert len(os.listdir(augmented_imgs_path)) == self.num_augmented, f'Number of augmented images requested ({self.num_augmented}) does not match with number found ({len(os.listdir(augmented_imgs_path))})!'
109 |
110 | self.samples = torch.Tensor(N, 3, *self.output_size).zero_()
111 | self.masks = torch.LongTensor(N, *self.output_size).zero_()
112 | self.product_ids = []
113 |
114 | cnt = 0
115 | path = "%s/%s/" % (self.dataroot, self.split)
116 | image_list = [f for f in os.listdir(path)
117 | if re.search(r'[0-9]+\.png$', f)]
118 | assert 0 < len(image_list), self.dataroot
119 |
120 | for img_name in image_list:
121 | product_id = img_name[:-4]
122 | img = self.transform(Image.open(path + img_name))
123 | lab = self.transform(
124 | Image.open(path + product_id + '_GT.png').convert('L'))
125 | if self.zero_shot:
126 | # check that the mask is negative
127 | if lab.sum() == 0:
128 | self.samples[cnt] = img
129 | self.masks[cnt] = lab
130 | self.product_ids.append(product_id)
131 | cnt += 1
132 | else:
133 | # default
134 | self.samples[cnt] = img
135 | self.masks[cnt] = lab
136 | self.product_ids.append(product_id)
137 | cnt += 1
138 |
139 | # Add the augmented images.
140 | if self.add_augmented:
141 | if 'train' == self.split:
142 | image_list = os.listdir(augmented_imgs_path)
143 |
144 | for img_name in image_list:
145 | product_id = img_name[:-4]
146 | img = self.transform(Image.open(os.path.join(augmented_imgs_path, img_name)))
147 | lab = self.transform(
148 | Image.open(os.path.join(augmented_masks_path, img_name)).convert('L'))
149 | self.samples[cnt] = img
150 | self.masks[cnt] = lab
151 | self.product_ids.append(product_id)
152 | cnt += 1
153 |
154 | assert N == cnt, '{} should be {}!'.format(cnt, N)
155 |
156 |
157 | def __getitem__(self, index):
158 | x = self.samples[index]
159 | a = self.masks[index] > 0
160 | if self.normalize is not None:
161 | x = self.normalize(x)
162 |
163 | if 0 == a.sum():
164 | y = self.class_to_idx['ok']
165 | else:
166 | y = self.class_to_idx['defect']
167 |
168 | return x, y, a, 0
169 |
170 |
171 | def __len__(self):
172 | return self.samples.size(0)
173 |
174 |
175 | @staticmethod
176 | def get_transform(output_size=(704, 256)):
177 | transform = [
178 | T.Resize(output_size),
179 | T.ToTensor()
180 | ]
181 | transform = T.Compose(transform)
182 | return transform
183 |
184 |
185 | @staticmethod
186 | def denorm(x):
187 | return x * c2chw(torch.Tensor(KolektorSDD2.std)) + c2chw(torch.Tensor(KolektorSDD2.mean))
188 |
--------------------------------------------------------------------------------
/train_ResNet50.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import wandb
3 | import argparse
4 |
5 | import torch.nn as nn
6 | import torchvision.models as models
7 |
8 | from tqdm import tqdm
9 | from torch.utils.data import DataLoader
10 | from torchvision.models import ResNet50_Weights
11 | from sklearn.metrics import average_precision_score, precision_score, recall_score, precision_recall_curve, roc_curve, auc
12 | import numpy as np
13 |
14 | from data.ksdd2 import KolektorSDD2
15 |
16 | class KSDD2ResNet50(nn.Module):
17 | def __init__(self):
18 | super(KSDD2ResNet50, self).__init__()
19 |
20 | # Load the pre-trained ResNet-50 model from torchvision.models.
21 | self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
22 |
23 | # Change the output layer to output 1 class score instead of 1000 classes.
24 | num_ftrs = self.model.fc.in_features
25 | self.model.fc = nn.Linear(num_ftrs, 1)
26 |
27 | def forward(self, x):
28 | return self.model(x)
29 |
30 |
31 | def evaluate(model, criterion, test_loader, device, log_dict):
32 | t_loss = 0
33 | correct = 0
34 | targets = []
35 | predictions = []
36 |
37 | model.eval()
38 | with torch.no_grad():
39 | for _, data in (tepoch := tqdm(enumerate(test_loader), unit='batch',
40 | total=len(test_loader), desc='Validation')):
41 | x, y = data[0].to(device), data[1].to(device)
42 |
43 | # This gets the prediction from the network.
44 | output = model(x)
45 | output = output.squeeze(1)
46 | # Sum up batch loss.
47 | t_loss += criterion(output, y.float()).item()
48 |
49 | # Get the prediction
50 | pred = output
51 |
52 | predictions.extend(pred.cpu().numpy())
53 | targets.extend(y.cpu().numpy())
54 |
55 | t_loss /= len(test_loader)
56 |
57 | precision_, recall_, thresholds = precision_recall_curve(targets, predictions)
58 | f_measures = 2 * (precision_ * recall_) / (precision_ + recall_ + 0.0000000001)
59 |
60 | # Select best threshold based on F2 score. Following previous works procedure.
61 | ix_best = np.argmax(f_measures)
62 | if ix_best > 0:
63 | best_threshold = (thresholds[ix_best] + thresholds[ix_best - 1]) / 2
64 | else:
65 | best_threshold = thresholds[ix_best]
66 | precision = precision_[ix_best]
67 | recall = recall_[ix_best]
68 |
69 | classifications = predictions > best_threshold
70 |
71 | FPR, TPR, _ = roc_curve(targets, predictions)
72 | AUC = auc(FPR, TPR)
73 | AP = average_precision_score(targets, predictions)
74 |
75 | # Calculate predictions based on best threshold.
76 | correct = np.sum(classifications == targets)
77 | accuracy = 100. * correct / len(classifications)
78 |
79 | print('AVG loss: {:.4f}, ACC: {}/{} ({:.0f}%), Precision: {:.4f}, Recall: {:.4f}, AP: {:.4f}'.format(
80 | t_loss, correct, len(test_loader.dataset), accuracy, precision, recall, AP))
81 |
82 | # log metrics
83 | log_dict['val_ACC'] = accuracy
84 | log_dict['val_PRECISION'] = precision
85 | log_dict['val_RECALL'] = recall
86 | log_dict['val_AP'] = AP
87 |
88 | return log_dict
89 |
90 |
91 | def main(args):
92 | # Set the seed for reproducibility.
93 | torch.manual_seed(args.seed)
94 | # Set the device.
95 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
96 | add_augmented = args.add_augmented
97 | num_augmented = args.num_augmented
98 | zero_shot = args.zero_shot
99 | logging = args.logging
100 |
101 | run_name = f'KSDD2ResNet50-zero_shot_{zero_shot}-add_augmented_{add_augmented}-num_augmented_{num_augmented}-bs_{args.batch_size}-epochs_{args.epochs}'
102 | tags = [f'{args.epochs}epochs', f'{num_augmented}augmented']
103 | if args.zero_shot:
104 | tags.append('zero_shot')
105 | else:
106 | tags.append('full_shot')
107 | if args.add_augmented:
108 | tags.append('augmented')
109 | else:
110 | tags.append('not_augmented')
111 |
112 | if logging:
113 | # Start a new wandb run to track this script.
114 | wandb.init(
115 | name=run_name,
116 | config=args,
117 | tags=tags
118 | )
119 |
120 | # Dataset.
121 | print('Loading KolektorSDD2 training set...')
122 | train_data = KolektorSDD2(dataroot=args.dataset_path, split='train', add_augmented=add_augmented, num_augmented=num_augmented, zero_shot=zero_shot)
123 | print('Number of samples:', len(train_data))
124 |
125 | print('Loading KolektorSDD2 test set...')
126 | test_data = KolektorSDD2(dataroot=args.dataset_path, split='test')
127 | print('Number of samples:', len(test_data))
128 |
129 | # DataLoaders.
130 | train_loader = DataLoader(train_data, batch_size=args.batch_size,
131 | shuffle=True, num_workers=args.num_workers)
132 | test_loader = DataLoader(test_data, batch_size=args.batch_size,
133 | shuffle=False, num_workers=args.num_workers)
134 |
135 | # Define the model.
136 | model = KSDD2ResNet50()
137 | model.to(device)
138 |
139 | # Define the loss function and the optimizer
140 | criterion = nn.BCEWithLogitsLoss(reduction='mean')
141 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
142 |
143 | # Training step.
144 | print(f'Start training on {device} [...]')
145 | model.train()
146 | log_dict = {'train_loss': 0, 'val_ACC': 0, 'val_PRECISION': 0, 'val_RECALL': 0, 'val_AP': 0, 'epoch': 0}
147 | for e in range(args.epochs):
148 | epoch_loss = 0
149 | for _, data in (tepoch := tqdm(enumerate(train_loader), unit='batch',
150 | total=len(train_loader))):
151 | tepoch.set_description(f'Epoch {e}')
152 | x, y = data[0].to(device), data[1].to(device)
153 |
154 | # Training step for the single batch.
155 | model.zero_grad()
156 | outputs = model(x)
157 | outputs = outputs.squeeze(1)
158 | loss = criterion(outputs, y.float())
159 | epoch_loss += loss.item()
160 | loss.backward()
161 | optimizer.step()
162 |
163 | # Print statistics.
164 | tepoch.set_postfix(loss=loss.item())
165 | if logging:
166 | wandb.log({'train_loss':loss.item()})
167 | epoch_loss /= len(train_loader)
168 | log_dict['epoch_loss'] = epoch_loss
169 | log_dict['epoch'] = e
170 |
171 | # Evaluation step after each epoch.
172 | eval_dict = evaluate(model, criterion, test_loader, device, log_dict)
173 | if logging:
174 | wandb.log(eval_dict)
175 |
176 | if logging:
177 | wandb.finish()
178 | print('Training finished.')
179 |
180 | if __name__ == '__main__':
181 | parser = argparse.ArgumentParser(description='DIAG training')
182 | parser.add_argument('--seed', type=int, default=1234)
183 | parser.add_argument('--epochs', type=int, default=30)
184 | parser.add_argument('--batch_size', type=int, default=32)
185 | parser.add_argument('--num_workers', type=int, default=8)
186 | parser.add_argument('--dataset_path', type=str, required=True)
187 | parser.add_argument('--add_augmented', action='store_true', help='Add augmented images to the training set')
188 | parser.add_argument('--num_augmented', type=int, default=120)
189 | parser.add_argument('--zero_shot', action='store_true', help='Train the model without true positives in the training set')
190 | parser.add_argument('--logging', action='store_true', help='Log the stats to wandb')
191 |
192 |
193 | args = parser.parse_args()
194 | main(args)
--------------------------------------------------------------------------------