├── .gitignore ├── README.md ├── cutpaste.py ├── dataset.py ├── density.py ├── doc └── imgs │ ├── 3way_acc.png │ ├── 3way_eval_auc.png │ ├── 3way_loss.png │ ├── author_vs_thisimpl_CutPaste.png │ ├── author_vs_thisimpl_CutPaste_3way.png │ ├── author_vs_thisimpl_CutPaste_scar.png │ ├── compare_all.png │ ├── normal_acc.png │ ├── normal_eval_auc.png │ ├── normal_loss.png │ ├── scar_acc.png │ ├── scar_eval_auc.png │ └── scar_loss.png ├── eval.py ├── model.py ├── requirements.txt ├── run_training.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of CutPaste 2 | 3 | This is an **unofficial** work in progress PyTorch reimplementation of [CutPaste: Self-Supervised Learning for Anomaly Detection and Localization](https://arxiv.org/abs/2104.04015) and in no way affiliated with the original authors. Use at own risk. Pull requests and feedback is appreciated. 4 | 5 | ## Setup 6 | Download the MVTec Anomaly detection Dataset from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad) and extract it into a new folder named `Data`. 7 | 8 | Install the following requirements: 9 | 1. Pytorch and torchvision 10 | 2. sklearn 11 | 3. pandas 12 | 4. seaborn 13 | 5. tqdm 14 | 6. tensorboard 15 | 16 | For example with [Anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/download.html): 17 | ``` 18 | conda create -n cutpaste pytorch torchvision torchaudio cudatoolkit=10.2 seaborn pandas tqdm tensorboard scikit-learn -c pytorch 19 | conda activate cutpaste 20 | ``` 21 | 22 | ## Run Training 23 | ``` 24 | python run_training.py --model_dir models --head_layer 2 25 | ``` 26 | The Script will train a model for each defect type and save it in the `model_dir` Folder. 27 | 28 | To enable training on an Nvidia GPU use the `--cuda 1` flag. 29 | ``` 30 | python run_training.py --model_dir models --head_layer 2 --cuda 1 31 | ``` 32 | 33 | One can track the training progress of the models with tensorboard: 34 | ``` 35 | tensorboard --logdir logdirs 36 | ``` 37 | 38 | ## Run Evaluation 39 | ``` 40 | python eval.py --model_dir models --head_layer 2 41 | ``` 42 | This will create a new directory `Eval` with plots for each defect type/model. 43 | 44 | # Implementation details 45 | 46 | ### CutPaste Location 47 | The pasted image patch always origins from the same image it is pasted to. I'm not sure if this is a Problem and if this is also the case in the original paper/code. 48 | 49 | ### Epochs 50 | Li et al. define "256 parameter update steps" as one epoch. The `--epoch` parameter takes the number of update steps and not their definition of epochs. 51 | 52 | ### Batch Size 53 | Li et al. use a "batch size of 64 (or 96 for 3-way)". Because the number of images feed into the model changes from the normal to the 3-way variant I suspect that they always start with 32 images that get augmented. The `--batch_size` parameter specifies the number of images read from disk. So for the all variants `--batch_size=32` should correspond with the batch size used by Li et al. 54 | 55 | ### Projection head 56 | I did not find a model description of the projection head Li et al. use. 57 | The `--head_layer` parameter is used to vary the number of layers used in this implementation. 58 | Actually `head_layer + 2` fully connected layers are used. 59 | Starting with `head_layer` layers with 512 neurons, followed by a layer with 128 neurons and the output layer with 2 or 3 neurons. The number of neurons depends on the variant. 2 for `normal` and `scar` and 3 for `3way`. 60 | 61 | ### Augmentations used before CutPaste 62 | Li et al. "apply random translation and 63 | color jitters for data augmentation". 64 | This implementation only applies color jitter before the CutPaste augmentation. I tried to use [torchvision.transforms.RandomResizedCrop](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.RandomResizedCrop) as translation, but in a brief test I did not find that it improves performance. 65 | 66 | ### Tensorflow vs PyTorch 67 | Li et al. use tensorflow for their implementation. This implementation is using PyTorch. 68 | 69 | ### Kernel Density Estimation 70 | I implemented two Kernel Density Estimation and mahalanobis distance pipelines. 71 | Li et al. use sklearn for the density estimation but [Ripple et al.](https://github.com/ORippler/gaussian-ad-mvtec) have their own. 72 | The `eval.py` has a `--density` flag that can be toggled between `torch` for the Ripple et al. implementation and `sklearn` for my sklearn implementation. 73 | In my limited testing both implementations have small differences between the resulting ROC AUCs: 74 | ``` 75 | > python eval.py --density torch --cuda 1 --head_layer 2 --save_plots 0| grep AUC 76 | bottle AUC: 0.9944444444444445 77 | cable AUC: 0.8549475262368815 78 | capsule AUC: 0.8232947746310331 79 | carpet AUC: 0.9329855537720706 80 | grid AUC: 0.982456140350877 81 | hazelnut AUC: 0.9160714285714285 82 | leather AUC: 1.0 83 | metal_nut AUC: 0.9403714565004888 84 | pill AUC: 0.8046917621385706 85 | screw AUC: 0.701988112318098 86 | tile AUC: 0.9430014430014431 87 | toothbrush AUC: 0.8972222222222221 88 | transistor AUC: 0.9008333333333334 89 | wood AUC: 0.9815789473684211 90 | zipper AUC: 0.9997373949579832 91 | 92 | > python eval.py --density sklearn --cuda 1 --head_layer 2 --save_plots 0| grep AUC 93 | bottle AUC: 0.9944444444444445 94 | cable AUC: 0.8549475262368815 95 | capsule AUC: 0.8232947746310331 96 | carpet AUC: 0.9329855537720706 97 | grid AUC: 0.982456140350877 98 | hazelnut AUC: 0.9160714285714285 99 | leather AUC: 1.0 100 | metal_nut AUC: 0.9403714565004888 101 | pill AUC: 0.8046917621385706 102 | screw AUC: 0.701988112318098 103 | tile AUC: 0.9430014430014431 104 | toothbrush AUC: 0.8972222222222221 105 | transistor AUC: 0.9008333333333334 106 | wood AUC: 0.9815789473684211 107 | zipper AUC: 0.9997373949579832 108 | ``` 109 | 110 | 111 | # Results 112 | This implementation only tries to recreate the main results from section 4.1 and shown in table 1. 113 | ## CutPaste 114 | ``` 115 | python run_training.py --epochs 10000 --test_epochs 32 --no-pretrained --cuda 1 --head_layer 1 --batch_size 32 --variant normal 116 | ``` 117 | ![training loss](doc/imgs/normal_loss.png) 118 | The blue line is the real value and the orange line is an average over 100 epochs. 119 | ![training accuracy](doc/imgs/normal_acc.png) 120 | ![validation accuracy](doc/imgs/normal_eval_auc.png) 121 | We only run the ROC AUC every 32nd update step, here the orange line is an average over 320 update steps (10 ROC AUC values). 122 | Note: The validation accuracy (named test set ROC AUC) is using the Mahalanobis distance as anomaly score. It can not be directly compared with the accuracy during training. 123 | 124 | ![comparison with Li et al.](doc/imgs/author_vs_thisimpl_CutPaste.png) 125 | 126 | Note that for readability, the y-axis starts at 40% AUC ROC. 127 | ## CutPaste (scar) 128 | ``` 129 | python run_training.py --epochs 10000 --test_epochs 32 --no-pretrained --cuda 1 --head_layer 1 --batch_size 32 --variant scar 130 | ``` 131 | ![training loss](doc/imgs/scar_loss.png) 132 | ![training accuracy](doc/imgs/scar_acc.png) 133 | ![validation accuracy](doc/imgs/scar_eval_auc.png) 134 | 135 | ![comparision with Li et al.](doc/imgs/author_vs_thisimpl_CutPaste_scar.png) 136 | ## CutPaste (3-way) 137 | Due to limited computing resources, the evaluation during training is disabled. 138 | ``` 139 | python run_training.py --epochs 10000 --test_epochs -1 --no-pretrained --cuda 1 --head_layer 1 --batch_size 32 --variant 3way 140 | ``` 141 | ![training loss](doc/imgs/3way_loss.png) 142 | ![training accuracy](doc/imgs/3way_acc.png) 143 | ![comparison with Li et al.](doc/imgs/author_vs_thisimpl_CutPaste_3way.png) 144 | 145 | # Comparison to Li et al. 146 | | defect_type | CutPaste | Li et al. CutPaste | CutPaste (scar) | Li et al. CutPaste (scar) | CutPaste (3-way) | Li et al. CutPaste (3-way) | 147 | |:--------------|-----------:|---------------------:|------------------:|----------------------------:|-------------------:|-----------------------------:| 148 | | bottle | 99.7 | 99.2 | 97.9 | 98.0 | 99.6 | 98.3 | 149 | | cable | 92.3 | 87.1 | 75.0 | 78.8 | 77.2 | 80.6 | 150 | | capsule | 86.2 | 87.9 | 84.5 | 95.3 | 92.4 | 96.2 | 151 | | carpet | 59.8 | 67.9 | 88.6 | 94.6 | 60.1 | 93.1 | 152 | | grid | 100.0 | 99.9 | 99.9 | 95.5 | 100.0 | 99.9 | 153 | | hazelnut | 83.7 | 91.3 | 87.5 | 96.7 | 86.8 | 97.3 | 154 | | leather | 99.5 | 99.7 | 99.5 | 100.0 | 100.0 | 100.0 | 155 | | metal_nut | 91.5 | 96.8 | 80.6 | 97.9 | 87.8 | 99.3 | 156 | | pill | 89.4 | 93.4 | 78.4 | 85.8 | 91.7 | 92.4 | 157 | | screw | 44.1 | 54.4 | 80.7 | 83.7 | 86.8 | 86.3 | 158 | | tile | 88.7 | 95.9 | 95.3 | 89.4 | 97.2 | 93.4 | 159 | | toothbrush | 96.7 | 99.2 | 88.3 | 96.7 | 94.7 | 98.3 | 160 | | transistor | 95.1 | 96.4 | 86.8 | 91.1 | 93.0 | 95.5 | 161 | | wood | 98.6 | 94.9 | 98.0 | 98.7 | 99.4 | 98.6 | 162 | | zipper | 99.6 | 99.4 | 95.9 | 99.5 | 98.8 | 99.4 | 163 | | average | 88.3 | 90.9 | 89.1 | 93.4 | 91.0 | 95.2 | 164 | 165 | ![comparison with Li et al.](doc/imgs/compare_all.png) 166 | # TODOs 167 | - [x] implement Cut-Paste Scar 168 | - [ ] implement gradCam 169 | - [ ] implement localization variant 170 | - [ ] add option to finetune on EfficientNet(B4) 171 | - [ ] clean up parameters and move them into the arguments of the scripts 172 | - [ ] compare results of this reimplementation with the results of the paper 173 | -------------------------------------------------------------------------------- /cutpaste.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | from torchvision import transforms 4 | import torch 5 | 6 | def cut_paste_collate_fn(batch): 7 | # cutPaste return 2 tuples of tuples we convert them into a list of tuples 8 | img_types = list(zip(*batch)) 9 | # print(list(zip(*batch))) 10 | return [torch.stack(imgs) for imgs in img_types] 11 | 12 | 13 | class CutPaste(object): 14 | """Base class for both cutpaste variants with common operations""" 15 | def __init__(self, colorJitter=0.1, transform=None): 16 | self.transform = transform 17 | 18 | if colorJitter is None: 19 | self.colorJitter = None 20 | else: 21 | self.colorJitter = transforms.ColorJitter(brightness = colorJitter, 22 | contrast = colorJitter, 23 | saturation = colorJitter, 24 | hue = colorJitter) 25 | def __call__(self, org_img, img): 26 | # apply transforms to both images 27 | if self.transform: 28 | img = self.transform(img) 29 | org_img = self.transform(org_img) 30 | return org_img, img 31 | 32 | class CutPasteNormal(CutPaste): 33 | """Randomly copy one patche from the image and paste it somewere else. 34 | Args: 35 | area_ratio (list): list with 2 floats for maximum and minimum area to cut out 36 | aspect_ratio (float): minimum area ration. Ration is sampled between aspect_ratio and 1/aspect_ratio. 37 | """ 38 | def __init__(self, area_ratio=[0.02,0.15], aspect_ratio=0.3, **kwags): 39 | super(CutPasteNormal, self).__init__(**kwags) 40 | self.area_ratio = area_ratio 41 | self.aspect_ratio = aspect_ratio 42 | 43 | def __call__(self, img): 44 | #TODO: we might want to use the pytorch implementation to calculate the patches from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomErasing 45 | h = img.size[0] 46 | w = img.size[1] 47 | 48 | # ratio between area_ratio[0] and area_ratio[1] 49 | ratio_area = random.uniform(self.area_ratio[0], self.area_ratio[1]) * w * h 50 | 51 | # sample in log space 52 | log_ratio = torch.log(torch.tensor((self.aspect_ratio, 1/self.aspect_ratio))) 53 | aspect = torch.exp( 54 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 55 | ).item() 56 | 57 | cut_w = int(round(math.sqrt(ratio_area * aspect))) 58 | cut_h = int(round(math.sqrt(ratio_area / aspect))) 59 | 60 | # one might also want to sample from other images. currently we only sample from the image itself 61 | from_location_h = int(random.uniform(0, h - cut_h)) 62 | from_location_w = int(random.uniform(0, w - cut_w)) 63 | 64 | box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h] 65 | patch = img.crop(box) 66 | 67 | if self.colorJitter: 68 | patch = self.colorJitter(patch) 69 | 70 | to_location_h = int(random.uniform(0, h - cut_h)) 71 | to_location_w = int(random.uniform(0, w - cut_w)) 72 | 73 | insert_box = [to_location_w, to_location_h, to_location_w + cut_w, to_location_h + cut_h] 74 | augmented = img.copy() 75 | augmented.paste(patch, insert_box) 76 | 77 | return super().__call__(img, augmented) 78 | 79 | class CutPasteScar(CutPaste): 80 | """Randomly copy one patche from the image and paste it somewere else. 81 | Args: 82 | width (list): width to sample from. List of [min, max] 83 | height (list): height to sample from. List of [min, max] 84 | rotation (list): rotation to sample from. List of [min, max] 85 | """ 86 | def __init__(self, width=[2,16], height=[10,25], rotation=[-45,45], **kwags): 87 | super(CutPasteScar, self).__init__(**kwags) 88 | self.width = width 89 | self.height = height 90 | self.rotation = rotation 91 | 92 | def __call__(self, img): 93 | h = img.size[0] 94 | w = img.size[1] 95 | 96 | # cut region 97 | cut_w = random.uniform(*self.width) 98 | cut_h = random.uniform(*self.height) 99 | 100 | from_location_h = int(random.uniform(0, h - cut_h)) 101 | from_location_w = int(random.uniform(0, w - cut_w)) 102 | 103 | box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h] 104 | patch = img.crop(box) 105 | 106 | if self.colorJitter: 107 | patch = self.colorJitter(patch) 108 | 109 | # rotate 110 | rot_deg = random.uniform(*self.rotation) 111 | patch = patch.convert("RGBA").rotate(rot_deg,expand=True) 112 | 113 | #paste 114 | to_location_h = int(random.uniform(0, h - patch.size[0])) 115 | to_location_w = int(random.uniform(0, w - patch.size[1])) 116 | 117 | mask = patch.split()[-1] 118 | patch = patch.convert("RGB") 119 | 120 | augmented = img.copy() 121 | augmented.paste(patch, (to_location_w, to_location_h), mask=mask) 122 | 123 | return super().__call__(img, augmented) 124 | 125 | class CutPasteUnion(object): 126 | def __init__(self, **kwags): 127 | self.normal = CutPasteNormal(**kwags) 128 | self.scar = CutPasteScar(**kwags) 129 | 130 | def __call__(self, img): 131 | r = random.uniform(0, 1) 132 | if r < 0.5: 133 | return self.normal(img) 134 | else: 135 | return self.scar(img) 136 | 137 | class CutPaste3Way(object): 138 | def __init__(self, **kwags): 139 | self.normal = CutPasteNormal(**kwags) 140 | self.scar = CutPasteScar(**kwags) 141 | 142 | def __call__(self, img): 143 | org, cutpaste_normal = self.normal(img) 144 | _, cutpaste_scar = self.scar(img) 145 | 146 | return org, cutpaste_normal, cutpaste_scar 147 | 148 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from PIL import Image 4 | from joblib import Parallel, delayed 5 | 6 | class Repeat(Dataset): 7 | def __init__(self, org_dataset, new_length): 8 | self.org_dataset = org_dataset 9 | self.org_length = len(self.org_dataset) 10 | self.new_length = new_length 11 | 12 | def __len__(self): 13 | return self.new_length 14 | 15 | def __getitem__(self, idx): 16 | return self.org_dataset[idx % self.org_length] 17 | 18 | class MVTecAT(Dataset): 19 | """MVTec anomaly detection dataset. 20 | Link: https://www.mvtec.com/company/research/datasets/mvtec-ad 21 | """ 22 | 23 | def __init__(self, root_dir, defect_name, size, transform=None, mode="train"): 24 | """ 25 | Args: 26 | root_dir (string): Directory with the MVTec AD dataset. 27 | defect_name (string): defect to load. 28 | transform: Transform to apply to data 29 | mode: "train" loads training samples "test" test samples default "train" 30 | """ 31 | self.root_dir = Path(root_dir) 32 | self.defect_name = defect_name 33 | self.transform = transform 34 | self.mode = mode 35 | self.size = size 36 | 37 | # find test images 38 | if self.mode == "train": 39 | self.image_names = list((self.root_dir / defect_name / "train" / "good").glob("*.png")) 40 | print("loading images") 41 | # during training we cache the smaller images for performance reasons (not a good coding style) 42 | #self.imgs = [Image.open(file).resize((size,size)).convert("RGB") for file in self.image_names] 43 | self.imgs = Parallel(n_jobs=10)(delayed(lambda file: Image.open(file).resize((size,size)).convert("RGB"))(file) for file in self.image_names) 44 | print(f"loaded {len(self.imgs)} images") 45 | else: 46 | #test mode 47 | self.image_names = list((self.root_dir / defect_name / "test").glob(str(Path("*") / "*.png"))) 48 | 49 | def __len__(self): 50 | return len(self.image_names) 51 | 52 | def __getitem__(self, idx): 53 | if self.mode == "train": 54 | # img = Image.open(self.image_names[idx]) 55 | # img = img.convert("RGB") 56 | img = self.imgs[idx].copy() 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | return img 60 | else: 61 | filename = self.image_names[idx] 62 | label = filename.parts[-2] 63 | img = Image.open(filename) 64 | img = img.resize((self.size,self.size)).convert("RGB") 65 | if self.transform is not None: 66 | img = self.transform(img) 67 | return img, label != "good" 68 | -------------------------------------------------------------------------------- /density.py: -------------------------------------------------------------------------------- 1 | 2 | from sklearn.covariance import LedoitWolf 3 | from sklearn.neighbors import KernelDensity 4 | import torch 5 | 6 | 7 | class Density(object): 8 | def fit(self, embeddings): 9 | raise NotImplementedError 10 | 11 | def predict(self, embeddings): 12 | raise NotImplementedError 13 | 14 | 15 | class GaussianDensityTorch(object): 16 | """Gaussian Density estimation similar to the implementation used by Ripple et al. 17 | The code of Ripple et al. can be found here: https://github.com/ORippler/gaussian-ad-mvtec. 18 | """ 19 | def fit(self, embeddings): 20 | self.mean = torch.mean(embeddings, axis=0) 21 | self.inv_cov = torch.Tensor(LedoitWolf().fit(embeddings.cpu()).precision_,device="cpu") 22 | 23 | def predict(self, embeddings): 24 | distances = self.mahalanobis_distance(embeddings, self.mean, self.inv_cov) 25 | return distances 26 | 27 | @staticmethod 28 | def mahalanobis_distance( 29 | values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor 30 | ) -> torch.Tensor: 31 | """Compute the batched mahalanobis distance. 32 | values is a batch of feature vectors. 33 | mean is either the mean of the distribution to compare, or a second 34 | batch of feature vectors. 35 | inv_covariance is the inverse covariance of the target distribution. 36 | 37 | from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308 38 | """ 39 | assert values.dim() == 2 40 | assert 1 <= mean.dim() <= 2 41 | assert len(inv_covariance.shape) == 2 42 | assert values.shape[1] == mean.shape[-1] 43 | assert mean.shape[-1] == inv_covariance.shape[0] 44 | assert inv_covariance.shape[0] == inv_covariance.shape[1] 45 | 46 | if mean.dim() == 1: # Distribution mean. 47 | mean = mean.unsqueeze(0) 48 | x_mu = values - mean # batch x features 49 | # Same as dist = x_mu.t() * inv_covariance * x_mu batch wise 50 | dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu) 51 | return dist.sqrt() 52 | 53 | class GaussianDensitySklearn(): 54 | """Li et al. use sklearn for density estimation. 55 | This implementation uses sklearn KernelDensity module for fitting and predicting. 56 | """ 57 | def fit(self, embeddings): 58 | # estimate KDE parameters 59 | # use grid search cross-validation to optimize the bandwidth 60 | self.kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(embeddings) 61 | 62 | def predict(self, embeddings): 63 | scores = self.kde.score_samples(embeddings) 64 | 65 | # invert scores, so they fit to the class labels for the auc calculation 66 | scores = -scores 67 | 68 | return scores 69 | -------------------------------------------------------------------------------- /doc/imgs/3way_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/3way_acc.png -------------------------------------------------------------------------------- /doc/imgs/3way_eval_auc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/3way_eval_auc.png -------------------------------------------------------------------------------- /doc/imgs/3way_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/3way_loss.png -------------------------------------------------------------------------------- /doc/imgs/author_vs_thisimpl_CutPaste.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/author_vs_thisimpl_CutPaste.png -------------------------------------------------------------------------------- /doc/imgs/author_vs_thisimpl_CutPaste_3way.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/author_vs_thisimpl_CutPaste_3way.png -------------------------------------------------------------------------------- /doc/imgs/author_vs_thisimpl_CutPaste_scar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/author_vs_thisimpl_CutPaste_scar.png -------------------------------------------------------------------------------- /doc/imgs/compare_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/compare_all.png -------------------------------------------------------------------------------- /doc/imgs/normal_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/normal_acc.png -------------------------------------------------------------------------------- /doc/imgs/normal_eval_auc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/normal_eval_auc.png -------------------------------------------------------------------------------- /doc/imgs/normal_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/normal_loss.png -------------------------------------------------------------------------------- /doc/imgs/scar_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/scar_acc.png -------------------------------------------------------------------------------- /doc/imgs/scar_eval_auc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/scar_eval_auc.png -------------------------------------------------------------------------------- /doc/imgs/scar_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runinho/pytorch-cutpaste/10d8bf71df76d3a97f0106efee1d76f81d983149/doc/imgs/scar_loss.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve, auc 2 | from sklearn.manifold import TSNE 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | import torch 6 | from dataset import MVTecAT 7 | from cutpaste import CutPaste 8 | from model import ProjectionNet 9 | import matplotlib.pyplot as plt 10 | import argparse 11 | from pathlib import Path 12 | from cutpaste import CutPaste, cut_paste_collate_fn 13 | from sklearn.utils import shuffle 14 | from sklearn.model_selection import GridSearchCV 15 | import numpy as np 16 | from collections import defaultdict 17 | from density import GaussianDensitySklearn, GaussianDensityTorch 18 | import pandas as pd 19 | from utils import str2bool 20 | 21 | test_data_eval = None 22 | test_transform = None 23 | cached_type = None 24 | 25 | def get_train_embeds(model, size, defect_type, transform, device): 26 | # train data / train kde 27 | test_data = MVTecAT("Data", defect_type, size, transform=transform, mode="train") 28 | 29 | dataloader_train = DataLoader(test_data, batch_size=64, 30 | shuffle=False, num_workers=0) 31 | train_embed = [] 32 | with torch.no_grad(): 33 | for x in dataloader_train: 34 | embed, logit = model(x.to(device)) 35 | 36 | train_embed.append(embed.cpu()) 37 | train_embed = torch.cat(train_embed) 38 | return train_embed 39 | 40 | def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, show_training_data=True, model=None, train_embed=None, head_layer=8, density=GaussianDensityTorch()): 41 | # create test dataset 42 | global test_data_eval,test_transform, cached_type 43 | 44 | # TODO: cache is only nice during training. do we need it? 45 | if test_data_eval is None or cached_type != defect_type: 46 | cached_type = defect_type 47 | test_transform = transforms.Compose([]) 48 | test_transform.transforms.append(transforms.Resize((size,size))) 49 | test_transform.transforms.append(transforms.ToTensor()) 50 | test_transform.transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225])) 52 | test_data_eval = MVTecAT("Data", defect_type, size, transform = test_transform, mode="test") 53 | 54 | dataloader_test = DataLoader(test_data_eval, batch_size=64, 55 | shuffle=False, num_workers=0) 56 | 57 | # create model 58 | if model is None: 59 | print(f"loading model {modelname}") 60 | head_layers = [512]*head_layer+[128] 61 | print(head_layers) 62 | weights = torch.load(modelname) 63 | classes = weights["out.weight"].shape[0] 64 | model = ProjectionNet(pretrained=False, head_layers=head_layers, num_classes=classes) 65 | model.load_state_dict(weights) 66 | model.to(device) 67 | model.eval() 68 | 69 | #get embeddings for test data 70 | labels = [] 71 | embeds = [] 72 | with torch.no_grad(): 73 | for x, label in dataloader_test: 74 | embed, logit = model(x.to(device)) 75 | 76 | # save 77 | embeds.append(embed.cpu()) 78 | labels.append(label.cpu()) 79 | labels = torch.cat(labels) 80 | embeds = torch.cat(embeds) 81 | 82 | if train_embed is None: 83 | train_embed = get_train_embeds(model, size, defect_type, test_transform, device) 84 | 85 | # norm embeds 86 | embeds = torch.nn.functional.normalize(embeds, p=2, dim=1) 87 | train_embed = torch.nn.functional.normalize(train_embed, p=2, dim=1) 88 | 89 | #create eval plot dir 90 | if save_plots: 91 | eval_dir = Path("eval") / modelname 92 | eval_dir.mkdir(parents=True, exist_ok=True) 93 | 94 | # plot tsne 95 | # also show some of the training data 96 | show_training_data = False 97 | if show_training_data: 98 | #augmentation setting 99 | # TODO: do all of this in a separate function that we can call in training and evaluation. 100 | # very ugly to just copy the code lol 101 | min_scale = 0.5 102 | 103 | # create Training Dataset and Dataloader 104 | after_cutpaste_transform = transforms.Compose([]) 105 | after_cutpaste_transform.transforms.append(transforms.ToTensor()) 106 | after_cutpaste_transform.transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], 107 | std=[0.229, 0.224, 0.225])) 108 | 109 | train_transform = transforms.Compose([]) 110 | #train_transform.transforms.append(transforms.RandomResizedCrop(size, scale=(min_scale,1))) 111 | #train_transform.transforms.append(transforms.GaussianBlur(int(size/10), sigma=(0.1,2.0))) 112 | train_transform.transforms.append(CutPaste(transform=after_cutpaste_transform)) 113 | # train_transform.transforms.append(transforms.ToTensor()) 114 | 115 | train_data = MVTecAT("Data", defect_type, transform=train_transform, size=size) 116 | dataloader_train = DataLoader(train_data, batch_size=32, 117 | shuffle=True, num_workers=8, collate_fn=cut_paste_collate_fn, 118 | persistent_workers=True) 119 | # inference training data 120 | train_labels = [] 121 | train_embeds = [] 122 | with torch.no_grad(): 123 | for x1, x2 in dataloader_train: 124 | x = torch.cat([x1,x2], axis=0) 125 | embed, logit = model(x.to(device)) 126 | 127 | # generate labels: 128 | y = torch.tensor([0, 1]) 129 | y = y.repeat_interleave(x1.size(0)) 130 | 131 | # save 132 | train_embeds.append(embed.cpu()) 133 | train_labels.append(y) 134 | # only less data 135 | break 136 | train_labels = torch.cat(train_labels) 137 | train_embeds = torch.cat(train_embeds) 138 | 139 | # for tsne we encode training data as 2, and augmentet data as 3 140 | tsne_labels = torch.cat([labels, train_labels + 2]) 141 | tsne_embeds = torch.cat([embeds, train_embeds]) 142 | else: 143 | tsne_labels = labels 144 | tsne_embeds = embeds 145 | plot_tsne(tsne_labels, tsne_embeds, eval_dir / "tsne.png") 146 | else: 147 | eval_dir = Path("unused") 148 | 149 | print(f"using density estimation {density.__class__.__name__}") 150 | density.fit(train_embed) 151 | distances = density.predict(embeds) 152 | #TODO: set threshold on mahalanobis distances and use "real" probabilities 153 | 154 | roc_auc = plot_roc(labels, distances, eval_dir / "roc_plot.png", modelname=modelname, save_plots=save_plots) 155 | 156 | 157 | return roc_auc 158 | 159 | 160 | def plot_roc(labels, scores, filename, modelname="", save_plots=False): 161 | 162 | fpr, tpr, _ = roc_curve(labels, scores) 163 | roc_auc = auc(fpr, tpr) 164 | 165 | #plot roc 166 | if save_plots: 167 | plt.figure() 168 | lw = 2 169 | plt.plot(fpr, tpr, color='darkorange', 170 | lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) 171 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') 172 | plt.xlim([0.0, 1.0]) 173 | plt.ylim([0.0, 1.05]) 174 | plt.xlabel('False Positive Rate') 175 | plt.ylabel('True Positive Rate') 176 | plt.title(f'Receiver operating characteristic {modelname}') 177 | plt.legend(loc="lower right") 178 | # plt.show() 179 | plt.savefig(filename) 180 | plt.close() 181 | 182 | return roc_auc 183 | 184 | def plot_tsne(labels, embeds, filename): 185 | tsne = TSNE(n_components=2, verbose=1, perplexity=30, n_iter=500) 186 | embeds, labels = shuffle(embeds, labels) 187 | tsne_results = tsne.fit_transform(embeds) 188 | fig, ax = plt.subplots(1) 189 | colormap = ["b", "r", "c", "y"] 190 | 191 | ax.scatter(tsne_results[:,0], tsne_results[:,1], color=[colormap[l] for l in labels]) 192 | fig.savefig(filename) 193 | plt.close() 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser(description='eval models') 197 | parser.add_argument('--type', default="all", 198 | help='MVTec defection dataset type to train seperated by , (default: "all": train all defect types)') 199 | 200 | parser.add_argument('--model_dir', default="models", 201 | help=' directory contating models to evaluate (default: models)') 202 | 203 | parser.add_argument('--cuda', default=False, type=str2bool, 204 | help='use cuda for model predictions (default: False)') 205 | 206 | parser.add_argument('--head_layer', default=8, type=int, 207 | help='number of layers in the projection head (default: 8)') 208 | 209 | parser.add_argument('--density', default="torch", choices=["torch", "sklearn"], 210 | help='density implementation to use. See `density.py` for both implementations. (default: torch)') 211 | 212 | parser.add_argument('--save_plots', default=True, type=str2bool, 213 | help='save TSNE and roc plots') 214 | 215 | 216 | args = parser.parse_args() 217 | 218 | args = parser.parse_args() 219 | print(args) 220 | all_types = ['bottle', 221 | 'cable', 222 | 'capsule', 223 | 'carpet', 224 | 'grid', 225 | 'hazelnut', 226 | 'leather', 227 | 'metal_nut', 228 | 'pill', 229 | 'screw', 230 | 'tile', 231 | 'toothbrush', 232 | 'transistor', 233 | 'wood', 234 | 'zipper'] 235 | 236 | if args.type == "all": 237 | types = all_types 238 | else: 239 | types = args.type.split(",") 240 | 241 | device = "cuda" if args.cuda else "cpu" 242 | 243 | density_mapping = { 244 | "torch": GaussianDensityTorch, 245 | "sklearn": GaussianDensitySklearn 246 | } 247 | density = density_mapping[args.density] 248 | 249 | # find models 250 | model_names = [list(Path(args.model_dir).glob(f"model-{data_type}*"))[0] for data_type in types if len(list(Path(args.model_dir).glob(f"model-{data_type}*"))) > 0] 251 | if len(model_names) < len(all_types): 252 | print("warning: not all types present in folder") 253 | 254 | obj = defaultdict(list) 255 | for model_name, data_type in zip(model_names, types): 256 | print(f"evaluating {data_type}") 257 | 258 | roc_auc = eval_model(model_name, data_type, save_plots=args.save_plots, device=device, head_layer=args.head_layer, density=density()) 259 | print(f"{data_type} AUC: {roc_auc}") 260 | obj["defect_type"].append(data_type) 261 | obj["roc_auc"].append(roc_auc) 262 | 263 | # save pandas dataframe 264 | eval_dir = Path("eval") / args.model_dir 265 | eval_dir.mkdir(parents=True, exist_ok=True) 266 | df = pd.DataFrame(obj) 267 | df.to_csv(str(eval_dir) + "_perf.csv") 268 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet18 5 | 6 | 7 | class ProjectionNet(nn.Module): 8 | def __init__(self, pretrained=True, head_layers=[512,512,512,512,512,512,512,512,128], num_classes=2): 9 | super(ProjectionNet, self).__init__() 10 | #self.resnet18 = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=pretrained) 11 | self.resnet18 = resnet18(pretrained=pretrained) 12 | 13 | # create MLP head as seen in the code in: https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 14 | # TODO: check if this is really the right architecture 15 | last_layer = 512 16 | sequential_layers = [] 17 | for num_neurons in head_layers: 18 | sequential_layers.append(nn.Linear(last_layer, num_neurons)) 19 | sequential_layers.append(nn.BatchNorm1d(num_neurons)) 20 | sequential_layers.append(nn.ReLU(inplace=True)) 21 | last_layer = num_neurons 22 | 23 | #the last layer without activation 24 | 25 | head = nn.Sequential( 26 | *sequential_layers 27 | ) 28 | self.resnet18.fc = nn.Identity() 29 | self.head = head 30 | self.out = nn.Linear(last_layer, num_classes) 31 | 32 | def forward(self, x): 33 | embeds = self.resnet18(x) 34 | tmp = self.head(embeds) 35 | logits = self.out(tmp) 36 | return embeds, logits 37 | 38 | def freeze_resnet(self): 39 | # freez full resnet18 40 | for param in self.resnet18.parameters(): 41 | param.requires_grad = False 42 | 43 | #unfreeze head: 44 | for param in self.resnet18.fc.parameters(): 45 | param.requires_grad = True 46 | 47 | def unfreeze(self): 48 | #unfreeze all: 49 | for param in self.parameters(): 50 | param.requires_grad = True 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | sklearn 4 | pandas 5 | seaborn 6 | tqdm 7 | tensorboard 8 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | # head dims:512,512,512,512,512,512,512,512,128 2 | # code is basicly:https://github.com/google-research/deep_representation_one_class 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | import datetime 6 | import argparse 7 | 8 | import torch 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torchvision import transforms 14 | 15 | 16 | from dataset import MVTecAT, Repeat 17 | from cutpaste import CutPasteNormal,CutPasteScar, CutPaste3Way, CutPasteUnion, cut_paste_collate_fn 18 | from model import ProjectionNet 19 | from eval import eval_model 20 | from utils import str2bool 21 | 22 | def run_training(data_type="screw", 23 | model_dir="models", 24 | epochs=256, 25 | pretrained=True, 26 | test_epochs=10, 27 | freeze_resnet=20, 28 | learninig_rate=0.03, 29 | optim_name="SGD", 30 | batch_size=64, 31 | head_layer=8, 32 | cutpate_type=CutPasteNormal, 33 | device = "cuda", 34 | workers=8, 35 | size = 256): 36 | torch.multiprocessing.freeze_support() 37 | # TODO: use script params for hyperparameter 38 | # Temperature Hyperparameter currently not used 39 | temperature = 0.2 40 | 41 | weight_decay = 0.00003 42 | momentum = 0.9 43 | #TODO: use f strings also for the date LOL 44 | model_name = f"model-{data_type}" + '-{date:%Y-%m-%d_%H_%M_%S}'.format(date=datetime.datetime.now() ) 45 | 46 | #augmentation: 47 | min_scale = 1 48 | 49 | # create Training Dataset and Dataloader 50 | after_cutpaste_transform = transforms.Compose([]) 51 | after_cutpaste_transform.transforms.append(transforms.ToTensor()) 52 | after_cutpaste_transform.transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225])) 54 | 55 | train_transform = transforms.Compose([]) 56 | #train_transform.transforms.append(transforms.RandomResizedCrop(size, scale=(min_scale,1))) 57 | train_transform.transforms.append(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)) 58 | # train_transform.transforms.append(transforms.GaussianBlur(int(size/10), sigma=(0.1,2.0))) 59 | train_transform.transforms.append(transforms.Resize((size,size))) 60 | train_transform.transforms.append(cutpate_type(transform = after_cutpaste_transform)) 61 | # train_transform.transforms.append(transforms.ToTensor()) 62 | 63 | train_data = MVTecAT("Data", data_type, transform = train_transform, size=int(size * (1/min_scale))) 64 | dataloader = DataLoader(Repeat(train_data, 3000), batch_size=batch_size, drop_last=True, 65 | shuffle=True, num_workers=workers, collate_fn=cut_paste_collate_fn, 66 | persistent_workers=True, pin_memory=True, prefetch_factor=5) 67 | 68 | # Writer will output to ./runs/ directory by default 69 | writer = SummaryWriter(Path("logdirs") / model_name) 70 | 71 | # create Model: 72 | head_layers = [512]*head_layer+[128] 73 | num_classes = 2 if cutpate_type is not CutPaste3Way else 3 74 | model = ProjectionNet(pretrained=pretrained, head_layers=head_layers, num_classes=num_classes) 75 | model.to(device) 76 | 77 | if freeze_resnet > 0 and pretrained: 78 | model.freeze_resnet() 79 | 80 | loss_fn = torch.nn.CrossEntropyLoss() 81 | if optim_name == "sgd": 82 | optimizer = optim.SGD(model.parameters(), lr=learninig_rate, momentum=momentum, weight_decay=weight_decay) 83 | scheduler = CosineAnnealingWarmRestarts(optimizer, epochs) 84 | #scheduler = None 85 | elif optim_name == "adam": 86 | optimizer = optim.Adam(model.parameters(), lr=learninig_rate, weight_decay=weight_decay) 87 | scheduler = None 88 | else: 89 | print(f"ERROR unkown optimizer: {optim_name}") 90 | 91 | step = 0 92 | num_batches = len(dataloader) 93 | def get_data_inf(): 94 | while True: 95 | for out in enumerate(dataloader): 96 | yield out 97 | dataloader_inf = get_data_inf() 98 | # From paper: "Note that, unlike conventional definition for an epoch, 99 | # we define 256 parameter update steps as one epoch. 100 | for step in tqdm(range(epochs)): 101 | epoch = int(step / 1) 102 | if epoch == freeze_resnet: 103 | model.unfreeze() 104 | 105 | batch_embeds = [] 106 | batch_idx, data = next(dataloader_inf) 107 | xs = [x.to(device) for x in data] 108 | 109 | # zero the parameter gradients 110 | optimizer.zero_grad() 111 | 112 | xc = torch.cat(xs, axis=0) 113 | embeds, logits = model(xc) 114 | 115 | # embeds = F.normalize(embeds, p=2, dim=1) 116 | # embeds1, embeds2 = torch.split(embeds,x1.size(0),dim=0) 117 | # ip = torch.matmul(embeds1, embeds2.T) 118 | # ip = ip / temperature 119 | 120 | # y = torch.arange(0,x1.size(0), device=device) 121 | # loss = loss_fn(ip, torch.arange(0,x1.size(0), device=device)) 122 | 123 | # calculate label 124 | y = torch.arange(len(xs), device=device) 125 | y = y.repeat_interleave(xs[0].size(0)) 126 | loss = loss_fn(logits, y) 127 | 128 | 129 | # regulize weights: 130 | loss.backward() 131 | optimizer.step() 132 | if scheduler is not None: 133 | scheduler.step(epoch) 134 | 135 | writer.add_scalar('loss', loss.item(), step) 136 | 137 | # predicted = torch.argmax(ip,axis=0) 138 | predicted = torch.argmax(logits,axis=1) 139 | # print(logits) 140 | # print(predicted) 141 | # print(y) 142 | accuracy = torch.true_divide(torch.sum(predicted==y), predicted.size(0)) 143 | writer.add_scalar('acc', accuracy, step) 144 | if scheduler is not None: 145 | writer.add_scalar('lr', scheduler.get_last_lr()[0], step) 146 | 147 | # save embed for validation: 148 | if test_epochs > 0 and epoch % test_epochs == 0: 149 | batch_embeds.append(embeds.cpu().detach()) 150 | 151 | writer.add_scalar('epoch', epoch, step) 152 | 153 | # run tests 154 | if test_epochs > 0 and epoch % test_epochs == 0: 155 | # run auc calculation 156 | #TODO: create dataset only once. 157 | #TODO: train predictor here or in the model class itself. Should not be in the eval part 158 | #TODO: we might not want to use the training datat because of droupout etc. but it should give a indecation of the model performance??? 159 | # batch_embeds = torch.cat(batch_embeds) 160 | # print(batch_embeds.shape) 161 | model.eval() 162 | roc_auc= eval_model(model_name, data_type, device=device, 163 | save_plots=False, 164 | size=size, 165 | show_training_data=False, 166 | model=model) 167 | #train_embed=batch_embeds) 168 | model.train() 169 | writer.add_scalar('eval_auc', roc_auc, step) 170 | 171 | 172 | torch.save(model.state_dict(), model_dir / f"{model_name}.tch") 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser(description='Training defect detection as described in the CutPaste Paper.') 176 | parser.add_argument('--type', default="all", 177 | help='MVTec defection dataset type to train seperated by , (default: "all": train all defect types)') 178 | 179 | parser.add_argument('--epochs', default=256, type=int, 180 | help='number of epochs to train the model , (default: 256)') 181 | 182 | parser.add_argument('--model_dir', default="models", 183 | help='output folder of the models , (default: models)') 184 | 185 | parser.add_argument('--no-pretrained', dest='pretrained', default=True, action='store_false', 186 | help='use pretrained values to initalize ResNet18 , (default: True)') 187 | 188 | parser.add_argument('--test_epochs', default=10, type=int, 189 | help='interval to calculate the auc during trainig, if -1 do not calculate test scores, (default: 10)') 190 | 191 | parser.add_argument('--freeze_resnet', default=20, type=int, 192 | help='number of epochs to freeze resnet (default: 20)') 193 | 194 | parser.add_argument('--lr', default=0.03, type=float, 195 | help='learning rate (default: 0.03)') 196 | 197 | parser.add_argument('--optim', default="sgd", 198 | help='optimizing algorithm values:[sgd, adam] (dafault: "sgd")') 199 | 200 | parser.add_argument('--batch_size', default=64, type=int, 201 | help='batch size, real batchsize is depending on cut paste config normal cutaout has effective batchsize of 2x batchsize (dafault: "64")') 202 | 203 | parser.add_argument('--head_layer', default=1, type=int, 204 | help='number of layers in the projection head (default: 1)') 205 | 206 | parser.add_argument('--variant', default="3way", choices=['normal', 'scar', '3way', 'union'], help='cutpaste variant to use (dafault: "3way")') 207 | 208 | parser.add_argument('--cuda', default=False, type=str2bool, 209 | help='use cuda for training (default: False)') 210 | 211 | parser.add_argument('--workers', default=8, type=int, help="number of workers to use for data loading (default:8)") 212 | 213 | 214 | args = parser.parse_args() 215 | print(args) 216 | all_types = ['bottle', 217 | 'cable', 218 | 'capsule', 219 | 'carpet', 220 | 'grid', 221 | 'hazelnut', 222 | 'leather', 223 | 'metal_nut', 224 | 'pill', 225 | 'screw', 226 | 'tile', 227 | 'toothbrush', 228 | 'transistor', 229 | 'wood', 230 | 'zipper'] 231 | 232 | if args.type == "all": 233 | types = all_types 234 | else: 235 | types = args.type.split(",") 236 | 237 | variant_map = {'normal':CutPasteNormal, 'scar':CutPasteScar, '3way':CutPaste3Way, 'union':CutPasteUnion} 238 | variant = variant_map[args.variant] 239 | 240 | device = "cuda" if args.cuda else "cpu" 241 | print(f"using device: {device}") 242 | 243 | # create modle dir 244 | Path(args.model_dir).mkdir(exist_ok=True, parents=True) 245 | # save config. 246 | with open(Path(args.model_dir) / "run_config.txt", "w") as f: 247 | f.write(str(args)) 248 | 249 | for data_type in types: 250 | print(f"training {data_type}") 251 | run_training(data_type, 252 | model_dir=Path(args.model_dir), 253 | epochs=args.epochs, 254 | pretrained=args.pretrained, 255 | test_epochs=args.test_epochs, 256 | freeze_resnet=args.freeze_resnet, 257 | learninig_rate=args.lr, 258 | optim_name=args.optim, 259 | batch_size=args.batch_size, 260 | head_layer=args.head_layer, 261 | device=device, 262 | cutpate_type=variant, 263 | workers=args.workers) 264 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def str2bool(v): 2 | """argparse handels type=bool in a weird way. 3 | See this stack overflow: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 4 | we can use this function as type converter for boolean values 5 | """ 6 | if isinstance(v, bool): 7 | return v 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') --------------------------------------------------------------------------------