├── src └── teaser.png ├── requirements.txt ├── ksdd2_preprocess.py ├── README.md ├── .gitignore ├── generate_augmented_images.py ├── data └── ksdd2.py └── train_ResNet50.py /src/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligolabs/DIAG/HEAD/src/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2024.6.2 2 | charset-normalizer==3.3.2 3 | click==8.1.7 4 | diffusers==0.29.0 5 | docker-pycreds==0.4.0 6 | filelock==3.13.1 7 | fsspec==2024.2.0 8 | gitdb==4.0.11 9 | GitPython==3.1.43 10 | huggingface-hub==0.23.4 11 | idna==3.7 12 | importlib_metadata==7.1.0 13 | Jinja2==3.1.3 14 | joblib==1.4.2 15 | MarkupSafe==2.1.5 16 | mpmath==1.3.0 17 | networkx==3.2.1 18 | numpy==1.26.3 19 | opencv-python==4.10.0.84 20 | packaging==24.1 21 | pandas==2.2.2 22 | pillow==10.2.0 23 | platformdirs==4.2.2 24 | protobuf==5.27.1 25 | psutil==6.0.0 26 | python-dateutil==2.9.0.post0 27 | pytz==2024.1 28 | PyYAML==6.0.1 29 | regex==2024.5.15 30 | requests==2.32.3 31 | safetensors==0.4.3 32 | scikit-learn==1.5.0 33 | scipy==1.13.1 34 | sentry-sdk==2.6.0 35 | setproctitle==1.3.3 36 | six==1.16.0 37 | smmap==5.0.1 38 | sympy==1.12 39 | threadpoolctl==3.5.0 40 | tokenizers==0.19.1 41 | tqdm==4.66.4 42 | transformers==4.41.2 43 | triton==2.3.1 44 | typing_extensions==4.9.0 45 | tzdata==2024.1 46 | urllib3==2.2.2 47 | wandb==0.17.2 48 | zipp==3.19.2 49 | -------------------------------------------------------------------------------- /ksdd2_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from glob import glob 4 | import pandas as pd 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | 9 | def reshape_ksdd2(src_dir, dst_dir, RES=(224, 632)): 10 | # make dest directory 11 | splits = ['train', 'test'] 12 | for split in splits: 13 | src_split_dir = os.path.join(src_dir, split) 14 | dst_split_dir = os.path.join(dst_dir, split) 15 | os.makedirs(dst_split_dir, exist_ok=True) 16 | all_imgs = os.listdir(src_split_dir) 17 | for img in tqdm(all_imgs, desc=f"Reshaping {split} images", unit="file", total=len(all_imgs)): 18 | img_path = os.path.join(src_split_dir, img) 19 | img_out_path = os.path.join(dst_split_dir, img) 20 | img = cv2.imread(img_path) 21 | img = cv2.resize(img, RES) 22 | cv2.imwrite(img_out_path, img) 23 | 24 | def copy_files(src_dir, dst_dir): 25 | target_files = glob(os.path.join(src_dir, '*.pyb')) 26 | for file in tqdm(target_files, desc="Copying .pyb files", unit="file", total=len(target_files)): 27 | file_name = os.path.basename(file) 28 | dst_file = os.path.join(dst_dir, file_name) 29 | os.system(f'cp {file} {dst_file}') 30 | 31 | def make_csv(dst_dir): 32 | splits = ['train', 'test'] 33 | for split in splits: 34 | img_dir = os.path.join(dst_dir, split) 35 | all_imgs = os.listdir(img_dir) 36 | all_masks = [img for img in all_imgs if "GT" in img] 37 | imgs_dict = {"path": [], "label": []} 38 | for img in tqdm(all_masks, desc=f"Creating {split}.csv", unit="file", total=len(all_masks)): 39 | imgs_dict["path"].append(img.replace("_GT.png", ".png")) 40 | img_path = os.path.join(img_dir, img) 41 | loaded = cv2.imread(img_path) 42 | # if there is a 1, it is positive, else negative 43 | if max(loaded.flatten()) == 0: 44 | imgs_dict["label"].append("negative") 45 | else: 46 | imgs_dict["label"].append("positive") 47 | df = pd.DataFrame(imgs_dict) 48 | df.to_csv(os.path.join(dst_dir, f"{split}.csv"), index=False) 49 | 50 | def main(args): 51 | src_dir = args.src_dir 52 | dst_dir = args.dst_dir 53 | RES = (224,632) # w x h 54 | # make directory 55 | print(f"Copying files from {src_dir} to {dst_dir}") 56 | os.makedirs(dst_dir, exist_ok=True) 57 | # copy .pyb files 58 | copy_files(src_dir, dst_dir) 59 | # reshape images (needed for batching) 60 | reshape_ksdd2(src_dir, dst_dir, RES=RES) 61 | # make csv files 62 | make_csv(dst_dir) 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--src_dir", type=str, required=True, help="Path to the KSDD2 dataset root") 68 | parser.add_argument("--dst_dir", type=str, required=True, help="Path to the destination directory") 69 | args = parser.parse_args() 70 | main(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Leveraging Latent Diffusion Models for Training-Free In-Distribution Data Augmentation for Surface Defect Detection # 2 | 3 | Official implementation of the paper [Leveraging Latent Diffusion Models for Training-Free In-Distribution Data Augmentation for Surface Defect Detection](https://intelligolabs.github.io/DIAG/) accepted at the 21st International Conference on Content-Based Multimedia Indexing (CBMI 2024). 4 |

5 | Teaser for DIAG 6 |

7 | 8 | ## Installation ## 9 | **1. Clone the repository** 10 | ```bash 11 | git clone https://github.com/intelligolabs/DIAG.git 12 | cd DIAG 13 | ``` 14 | 15 | **2. Create an environment with dependencies** 16 | ```bash 17 | conda create -n DIAG python=3.10 18 | conda activate DIAG 19 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 20 | ``` 21 | Note: the `pip install` command installs PyTorch 1.12 with CUDA 11.3. Change it accordingly to your own version of CUDA. 22 | 23 | Then, install the requirements: 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | **3. Data preparation** 29 | Download KSDD2 from the [official KSDD2 website](https://www.vicos.si/resources/kolektorsdd2/). 30 | 31 | Run the `ksdd2_preprocess.py` script. This will create a pre-processed copy of KSDD2 in the `--dst_dir`. 32 | We will use this pre-processed copy for our augmentation and evaluation. 33 | ```bash 34 | python ksdd2_preprocess.py --src_dir="/ksdd2" --dst_dir="/ksdd2_preprocessed" 35 | ``` 36 | 37 | **4. [Optional] Set up wandb for logging** 38 | Optionally, you can log the training and evaluation for wandb. 39 | ``` 40 | wandb init 41 | ``` 42 | 43 | ## Part 1: Data augmentation ## 44 | This step generates the augmented positive images. 45 | The images are generated using the same prompts used in the original paper. 46 | The `src_dir` argument should point to the preprocess data root (see **3. Data preparation** step). 47 | 48 | By default, this script will generate the augmented images in an `augmented_` folder inside `src_dir`. 49 | This is needed for the dataloader during training. 50 | ```bash 51 | python generate_augmented_images.py --src_dir="/ksdd2_preprocessed" --imgs_per_prompt=50 --seed=0 52 | ``` 53 | 54 | ## Part 2: Training and evaluation ## 55 | This step will fine-tune a pre-trained ResNet-50 on the (augmented) KSDD2. 56 | Different arguments handle the policy for training: 57 | - `--zero_shot` trains the model without GT positive images 58 | - `--add_augmented` adds the augmented images to the dataset (can be both zero or full shot) 59 | - `--num_augented` selects how many augmented images to add to the training set. 60 | Required if using `--add_audmented`. 61 | Note that this MUST be the TOTAL (`imgs_per_prompt * prompts`) number of images generated in the previous step. 62 | 63 | Example of DIAG training (zero-shot with augmentations): 64 | ```bash 65 | python train_ResNet50.py --seed=0 --epochs=30 --batch_size=32 --num_workers=8 --dataset_path="/ksdd2_preprocessed" --zero_shot --add_augmented --num_augmented=100 --logging 66 | ``` 67 | 68 | ## Credits ## 69 | - Kolector Surface-Defect Dataset 2. More info [here](https://www.vicos.si/resources/kolektorsdd2/). 70 | - [Diffusers](https://huggingface.co/docs/diffusers/en/using-diffusers/sdxl) and [StabilityAI](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1) for their SDXL implementation and weights. 71 | 72 | ## Authors ## 73 | Federico Girella, Ziyue Liu, Franco Fummi, Francesco Setti, Marco Cristani, Luigi Capogrosso 74 | 75 | *Department of Engineering for Innovation Medicine, University of Verona, Italy* 76 | 77 | `name.surname@univr.it` 78 | 79 | ## Citation ## 80 | If you use [**DIAG**](https://ieeexplore.ieee.org/abstract/document/10858875), please, cite the following paper: 81 | ``` 82 | @InProceedings{girella2024leveraging, 83 | author = {Girella, Federico and Liu, Ziyue and Fummi, Franco and Setti, Francesco and Cristani, Marco and Capogrosso, Luigi}, 84 | booktitle = {International Conference on Content-Based Multimedia Indexing (CBMI)}, 85 | title = {{Leveraging Latent Diffusion Models for Training-Free in-Distribution Data Augmentation for Surface Defect Detection}}, 86 | year = {2024}, 87 | doi = {10.1109/cbmi62980.2024.10858875}, 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | ### Python Patch ### 166 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 167 | poetry.toml 168 | 169 | # ruff 170 | .ruff_cache/ 171 | 172 | # LSP config files 173 | pyrightconfig.json 174 | 175 | ### VisualStudioCode ### 176 | .vscode/ 177 | !.vscode/settings.json 178 | !.vscode/tasks.json 179 | !.vscode/launch.json 180 | !.vscode/extensions.json 181 | !.vscode/*.code-snippets 182 | 183 | # Local History for Visual Studio Code 184 | .history/ 185 | 186 | # Built Visual Studio Code Extensions 187 | *.vsix 188 | 189 | ### VisualStudioCode Patch ### 190 | # Ignore all local history of files 191 | .history 192 | .ionide 193 | 194 | ### Custom ### 195 | wandb/ -------------------------------------------------------------------------------- /generate_augmented_images.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline 2 | from diffusers.utils import load_image 3 | import torch 4 | import pandas as pd 5 | import os 6 | import random 7 | import numpy as np 8 | import argparse 9 | 10 | def get_negative_images(df): 11 | return list(df[df["label"] == "negative"]["path"]) 12 | 13 | def get_positive_images_maks(df): 14 | return [p.replace('.png', '_GT.png') for p in list(df[df["label"] == "positive"]["path"])] 15 | 16 | def main(args): 17 | 18 | ### ARGUMENTS 19 | src_dir = args.src_dir 20 | imgs_per_prompt = args.imgs_per_prompt 21 | seed = args.seed 22 | 23 | # prompts used in the paper 24 | prompts = ["white marks on the wall", "copper metal scratches"] 25 | negative_prompt="smooth, plain, black, dark, shadow" 26 | 27 | dst_dir = os.path.join(src_dir, f"augmented_{imgs_per_prompt*len(prompts)}") 28 | os.makedirs(dst_dir, exist_ok=True) 29 | 30 | # hyperparameters used in the paper 31 | num_inference_steps = 30 32 | guidance_scale = 20.0 33 | strength = 0.99 34 | padding_mask_crop = 2 35 | RES = (224, 632) 36 | # needed to ovecome the sdxl shape bias 37 | TARGET = (1024, 1024) 38 | 39 | 40 | ### MAIN 41 | # cuda if available 42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 43 | print("Running on device: ", device) 44 | 45 | # seed everything 46 | random.seed(seed) 47 | torch.manual_seed(seed) 48 | np.random.seed(seed) 49 | generator = torch.Generator(device=device).manual_seed(seed) 50 | 51 | # create directories 52 | os.makedirs(os.path.join(dst_dir, 'imgs'), exist_ok=True) 53 | os.makedirs(os.path.join(dst_dir, 'masks'), exist_ok=True) 54 | 55 | df_path = os.path.join(src_dir, 'train.csv') 56 | df = pd.read_csv(df_path) 57 | negative_imgs = get_negative_images(df) 58 | positive_masks = get_positive_images_maks(df) 59 | print(f'Num negative images: {len(negative_imgs)}') 60 | print(f'Num positive masks: {len(positive_masks)}') 61 | 62 | model = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" 63 | print(f'Loading model {model}') 64 | pipe = StableDiffusionXLInpaintPipeline.from_pretrained(model, torch_dtype=torch.float16, variant="fp16").to(device) 65 | 66 | img_idx = 0 67 | for prompt in prompts: 68 | print(f'Generating images for prompt: {prompt}') 69 | cnt = 0 70 | for cnt in range(imgs_per_prompt): 71 | # by sampling 1 by 1 we can generate more anomalies than what we have in the dataset (246) 72 | mask_name = random.sample(positive_masks, 1)[0] 73 | mask_path = os.path.join(src_dir, 'train', mask_name) 74 | neg_img_name = random.sample(negative_imgs, 1)[0] 75 | neg_img_path = os.path.join(src_dir, 'train', neg_img_name) 76 | 77 | neg_img = load_image(neg_img_path).resize(TARGET) 78 | mask = load_image(mask_path).resize(TARGET) 79 | 80 | out_image = pipe( 81 | prompt=prompt, 82 | negative_prompt=negative_prompt, 83 | image=neg_img, 84 | mask_image=mask, 85 | guidance_scale=guidance_scale, 86 | num_inference_steps=num_inference_steps, # steps between 15 and 30 work well for us 87 | strength=strength, # make sure to use `strength` below 1.0 88 | generator=generator, 89 | height=TARGET[1], 90 | width=TARGET[0], 91 | original_size = TARGET, 92 | target_size = TARGET, 93 | padding_mask_crop = padding_mask_crop 94 | ).images[0] 95 | out_image = out_image.resize(RES) 96 | mask = mask.resize(RES) 97 | # save the image with progressive name 98 | out_img_path = f'{dst_dir}/imgs/{str(img_idx + cnt).zfill(5)}.png' 99 | out_image.save(out_img_path) 100 | # save the mask with progressive name 101 | out_mask_path = f'{dst_dir}/masks/{str(img_idx + cnt).zfill(5)}.png' 102 | mask.save(out_mask_path) 103 | cnt += 1 104 | img_idx += cnt 105 | 106 | if __name__ == "__main__": 107 | # ARGUMENTS 108 | 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--src_dir", type=str, required=True, help="Directory containing the preprocessed dataset") 111 | parser.add_argument("--imgs_per_prompt", type=int, default=50, help="Number of images to generate per prompt") 112 | parser.add_argument("--seed", type=int, default=0, help="Seed for random generation") 113 | args = parser.parse_args() 114 | 115 | main(args) -------------------------------------------------------------------------------- /data/ksdd2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import re 6 | import torch 7 | 8 | import torchvision.transforms as T 9 | 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | 13 | 14 | def c2chw(x): 15 | return x.unsqueeze(1).unsqueeze(2) 16 | 17 | 18 | def inverse_list(list): 19 | """ 20 | List to dict: index -> element 21 | """ 22 | dict = {} 23 | 24 | for idx, x in enumerate(list): 25 | dict[x] = idx 26 | 27 | return dict 28 | 29 | 30 | class KolektorSDD2(Dataset): 31 | """" 32 | Kolektor Surface-Defect 2 dataset 33 | 34 | Args: 35 | dataroot (string): path to the root directory of the dataset 36 | split (string): data split ['train', 'test'] 37 | scale (string): input image scale 38 | debug (bool) : debug mode 39 | """ 40 | 41 | labels = ['ok', 'defect'] 42 | 43 | # ImageNet. 44 | mean = (0.485, 0.456, 0.406) 45 | std = (0.229, 0.224, 0.225) 46 | 47 | def __init__(self, 48 | dataroot='/path/to/dataset/' 49 | 'KolektorSDD2', 50 | split='train', negative_only=False, 51 | add_augmented=False, num_augmented=0, zero_shot=False): 52 | super(KolektorSDD2, self).__init__() 53 | 54 | self.fold = None 55 | self.dataroot = dataroot 56 | 57 | self.split_path = None 58 | self.split = 'train' if 'val' == split else split 59 | 60 | self.output_size = (704, 256) 61 | self.negative_only = negative_only 62 | self.add_augmented = add_augmented 63 | if self.add_augmented: 64 | assert self.split == 'train', 'Augmented images are only for the training set!' 65 | self.num_augmented = num_augmented 66 | self.zero_shot = zero_shot 67 | if self.zero_shot: 68 | assert self.add_augmented, 'Zero-shot learning requires augmented images!' 69 | assert self.split == 'train', 'Zero-shot learning is only for the training set!' 70 | 71 | self.class_to_idx = inverse_list(self.labels) 72 | self.classes = self.labels 73 | self.transform = KolektorSDD2.get_transform(output_size=self.output_size) 74 | self.normalize = T.Normalize(KolektorSDD2.mean, KolektorSDD2.std) 75 | 76 | self.load_imgs() 77 | if negative_only: 78 | m = self.masks.sum(-1).sum(-1) == 0 79 | self.samples = self.samples[m] 80 | self.masks = self.masks[m] 81 | self.product_ids = [pid for flag, pid in zip(m, self.product_ids) 82 | if flag] 83 | 84 | 85 | def load_imgs(self): 86 | # Please remove this duplicated files in the official dataset: 87 | # -- 10301_GT (copy).png 88 | # -- 10301 (copy).png 89 | if self.num_augmented > 0: 90 | augmented_imgs_path = os.path.join(self.dataroot, f'augmented_{self.num_augmented}', 'imgs') 91 | augmented_masks_path = os.path.join(self.dataroot, f'augmented_{self.num_augmented}', 'masks') 92 | else: 93 | augmented_imgs_path = os.path.join(self.dataroot, f'augmented', 'imgs') 94 | augmented_masks_path = os.path.join(self.dataroot, f'augmented', 'masks') 95 | 96 | if self.split == 'test': 97 | N = 1004 98 | elif self.split == 'train' and self.zero_shot: 99 | # only augmented positives and original negatives 100 | N = 2085 # number of original negatives 101 | else: 102 | # all original data + augmented 103 | N = 2331 # number of original negatives and positives 104 | 105 | if self.add_augmented: 106 | N += len(os.listdir(augmented_imgs_path)) 107 | if self.num_augmented > 0: 108 | assert len(os.listdir(augmented_imgs_path)) == self.num_augmented, f'Number of augmented images requested ({self.num_augmented}) does not match with number found ({len(os.listdir(augmented_imgs_path))})!' 109 | 110 | self.samples = torch.Tensor(N, 3, *self.output_size).zero_() 111 | self.masks = torch.LongTensor(N, *self.output_size).zero_() 112 | self.product_ids = [] 113 | 114 | cnt = 0 115 | path = "%s/%s/" % (self.dataroot, self.split) 116 | image_list = [f for f in os.listdir(path) 117 | if re.search(r'[0-9]+\.png$', f)] 118 | assert 0 < len(image_list), self.dataroot 119 | 120 | for img_name in image_list: 121 | product_id = img_name[:-4] 122 | img = self.transform(Image.open(path + img_name)) 123 | lab = self.transform( 124 | Image.open(path + product_id + '_GT.png').convert('L')) 125 | if self.zero_shot: 126 | # check that the mask is negative 127 | if lab.sum() == 0: 128 | self.samples[cnt] = img 129 | self.masks[cnt] = lab 130 | self.product_ids.append(product_id) 131 | cnt += 1 132 | else: 133 | # default 134 | self.samples[cnt] = img 135 | self.masks[cnt] = lab 136 | self.product_ids.append(product_id) 137 | cnt += 1 138 | 139 | # Add the augmented images. 140 | if self.add_augmented: 141 | if 'train' == self.split: 142 | image_list = os.listdir(augmented_imgs_path) 143 | 144 | for img_name in image_list: 145 | product_id = img_name[:-4] 146 | img = self.transform(Image.open(os.path.join(augmented_imgs_path, img_name))) 147 | lab = self.transform( 148 | Image.open(os.path.join(augmented_masks_path, img_name)).convert('L')) 149 | self.samples[cnt] = img 150 | self.masks[cnt] = lab 151 | self.product_ids.append(product_id) 152 | cnt += 1 153 | 154 | assert N == cnt, '{} should be {}!'.format(cnt, N) 155 | 156 | 157 | def __getitem__(self, index): 158 | x = self.samples[index] 159 | a = self.masks[index] > 0 160 | if self.normalize is not None: 161 | x = self.normalize(x) 162 | 163 | if 0 == a.sum(): 164 | y = self.class_to_idx['ok'] 165 | else: 166 | y = self.class_to_idx['defect'] 167 | 168 | return x, y, a, 0 169 | 170 | 171 | def __len__(self): 172 | return self.samples.size(0) 173 | 174 | 175 | @staticmethod 176 | def get_transform(output_size=(704, 256)): 177 | transform = [ 178 | T.Resize(output_size), 179 | T.ToTensor() 180 | ] 181 | transform = T.Compose(transform) 182 | return transform 183 | 184 | 185 | @staticmethod 186 | def denorm(x): 187 | return x * c2chw(torch.Tensor(KolektorSDD2.std)) + c2chw(torch.Tensor(KolektorSDD2.mean)) 188 | -------------------------------------------------------------------------------- /train_ResNet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import argparse 4 | 5 | import torch.nn as nn 6 | import torchvision.models as models 7 | 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | from torchvision.models import ResNet50_Weights 11 | from sklearn.metrics import average_precision_score, precision_score, recall_score, precision_recall_curve, roc_curve, auc 12 | import numpy as np 13 | 14 | from data.ksdd2 import KolektorSDD2 15 | 16 | class KSDD2ResNet50(nn.Module): 17 | def __init__(self): 18 | super(KSDD2ResNet50, self).__init__() 19 | 20 | # Load the pre-trained ResNet-50 model from torchvision.models. 21 | self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) 22 | 23 | # Change the output layer to output 1 class score instead of 1000 classes. 24 | num_ftrs = self.model.fc.in_features 25 | self.model.fc = nn.Linear(num_ftrs, 1) 26 | 27 | def forward(self, x): 28 | return self.model(x) 29 | 30 | 31 | def evaluate(model, criterion, test_loader, device, log_dict): 32 | t_loss = 0 33 | correct = 0 34 | targets = [] 35 | predictions = [] 36 | 37 | model.eval() 38 | with torch.no_grad(): 39 | for _, data in (tepoch := tqdm(enumerate(test_loader), unit='batch', 40 | total=len(test_loader), desc='Validation')): 41 | x, y = data[0].to(device), data[1].to(device) 42 | 43 | # This gets the prediction from the network. 44 | output = model(x) 45 | output = output.squeeze(1) 46 | # Sum up batch loss. 47 | t_loss += criterion(output, y.float()).item() 48 | 49 | # Get the prediction 50 | pred = output 51 | 52 | predictions.extend(pred.cpu().numpy()) 53 | targets.extend(y.cpu().numpy()) 54 | 55 | t_loss /= len(test_loader) 56 | 57 | precision_, recall_, thresholds = precision_recall_curve(targets, predictions) 58 | f_measures = 2 * (precision_ * recall_) / (precision_ + recall_ + 0.0000000001) 59 | 60 | # Select best threshold based on F2 score. Following previous works procedure. 61 | ix_best = np.argmax(f_measures) 62 | if ix_best > 0: 63 | best_threshold = (thresholds[ix_best] + thresholds[ix_best - 1]) / 2 64 | else: 65 | best_threshold = thresholds[ix_best] 66 | precision = precision_[ix_best] 67 | recall = recall_[ix_best] 68 | 69 | classifications = predictions > best_threshold 70 | 71 | FPR, TPR, _ = roc_curve(targets, predictions) 72 | AUC = auc(FPR, TPR) 73 | AP = average_precision_score(targets, predictions) 74 | 75 | # Calculate predictions based on best threshold. 76 | correct = np.sum(classifications == targets) 77 | accuracy = 100. * correct / len(classifications) 78 | 79 | print('AVG loss: {:.4f}, ACC: {}/{} ({:.0f}%), Precision: {:.4f}, Recall: {:.4f}, AP: {:.4f}'.format( 80 | t_loss, correct, len(test_loader.dataset), accuracy, precision, recall, AP)) 81 | 82 | # log metrics 83 | log_dict['val_ACC'] = accuracy 84 | log_dict['val_PRECISION'] = precision 85 | log_dict['val_RECALL'] = recall 86 | log_dict['val_AP'] = AP 87 | 88 | return log_dict 89 | 90 | 91 | def main(args): 92 | # Set the seed for reproducibility. 93 | torch.manual_seed(args.seed) 94 | # Set the device. 95 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 96 | add_augmented = args.add_augmented 97 | num_augmented = args.num_augmented 98 | zero_shot = args.zero_shot 99 | logging = args.logging 100 | 101 | run_name = f'KSDD2ResNet50-zero_shot_{zero_shot}-add_augmented_{add_augmented}-num_augmented_{num_augmented}-bs_{args.batch_size}-epochs_{args.epochs}' 102 | tags = [f'{args.epochs}epochs', f'{num_augmented}augmented'] 103 | if args.zero_shot: 104 | tags.append('zero_shot') 105 | else: 106 | tags.append('full_shot') 107 | if args.add_augmented: 108 | tags.append('augmented') 109 | else: 110 | tags.append('not_augmented') 111 | 112 | if logging: 113 | # Start a new wandb run to track this script. 114 | wandb.init( 115 | name=run_name, 116 | config=args, 117 | tags=tags 118 | ) 119 | 120 | # Dataset. 121 | print('Loading KolektorSDD2 training set...') 122 | train_data = KolektorSDD2(dataroot=args.dataset_path, split='train', add_augmented=add_augmented, num_augmented=num_augmented, zero_shot=zero_shot) 123 | print('Number of samples:', len(train_data)) 124 | 125 | print('Loading KolektorSDD2 test set...') 126 | test_data = KolektorSDD2(dataroot=args.dataset_path, split='test') 127 | print('Number of samples:', len(test_data)) 128 | 129 | # DataLoaders. 130 | train_loader = DataLoader(train_data, batch_size=args.batch_size, 131 | shuffle=True, num_workers=args.num_workers) 132 | test_loader = DataLoader(test_data, batch_size=args.batch_size, 133 | shuffle=False, num_workers=args.num_workers) 134 | 135 | # Define the model. 136 | model = KSDD2ResNet50() 137 | model.to(device) 138 | 139 | # Define the loss function and the optimizer 140 | criterion = nn.BCEWithLogitsLoss(reduction='mean') 141 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 142 | 143 | # Training step. 144 | print(f'Start training on {device} [...]') 145 | model.train() 146 | log_dict = {'train_loss': 0, 'val_ACC': 0, 'val_PRECISION': 0, 'val_RECALL': 0, 'val_AP': 0, 'epoch': 0} 147 | for e in range(args.epochs): 148 | epoch_loss = 0 149 | for _, data in (tepoch := tqdm(enumerate(train_loader), unit='batch', 150 | total=len(train_loader))): 151 | tepoch.set_description(f'Epoch {e}') 152 | x, y = data[0].to(device), data[1].to(device) 153 | 154 | # Training step for the single batch. 155 | model.zero_grad() 156 | outputs = model(x) 157 | outputs = outputs.squeeze(1) 158 | loss = criterion(outputs, y.float()) 159 | epoch_loss += loss.item() 160 | loss.backward() 161 | optimizer.step() 162 | 163 | # Print statistics. 164 | tepoch.set_postfix(loss=loss.item()) 165 | if logging: 166 | wandb.log({'train_loss':loss.item()}) 167 | epoch_loss /= len(train_loader) 168 | log_dict['epoch_loss'] = epoch_loss 169 | log_dict['epoch'] = e 170 | 171 | # Evaluation step after each epoch. 172 | eval_dict = evaluate(model, criterion, test_loader, device, log_dict) 173 | if logging: 174 | wandb.log(eval_dict) 175 | 176 | if logging: 177 | wandb.finish() 178 | print('Training finished.') 179 | 180 | if __name__ == '__main__': 181 | parser = argparse.ArgumentParser(description='DIAG training') 182 | parser.add_argument('--seed', type=int, default=1234) 183 | parser.add_argument('--epochs', type=int, default=30) 184 | parser.add_argument('--batch_size', type=int, default=32) 185 | parser.add_argument('--num_workers', type=int, default=8) 186 | parser.add_argument('--dataset_path', type=str, required=True) 187 | parser.add_argument('--add_augmented', action='store_true', help='Add augmented images to the training set') 188 | parser.add_argument('--num_augmented', type=int, default=120) 189 | parser.add_argument('--zero_shot', action='store_true', help='Train the model without true positives in the training set') 190 | parser.add_argument('--logging', action='store_true', help='Log the stats to wandb') 191 | 192 | 193 | args = parser.parse_args() 194 | main(args) --------------------------------------------------------------------------------