├── .gitignore ├── LICENSE ├── README.md ├── data ├── convert_AITEX.py ├── convert_BrainMRI.py ├── convert_HeadCT.py ├── convert_MastCam.py ├── convert_SDD.py ├── convert_elpv.py ├── convert_hyperkvasir.py └── convert_optical.py ├── dataloaders ├── dataloader.py └── utlis.py ├── datasets ├── base_dataset.py ├── cutmix.py └── mvtecad.py ├── loss └── deviation_loss.py ├── modules └── sb.py ├── networks ├── backbone.py ├── resnet.py └── resnet18.py └── utils └── quantize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fuyunwang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPDL 2 | > Distribution Prototype Diffusion Learning for Open-set Supervised Anomaly Detection 3 | -------------------------------------------------------------------------------- /data/convert_AITEX.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import cv2 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset_root', type=str, help="dataset root") 9 | args = parser.parse_args() 10 | 11 | DEFEAT_CLASS = {'002': "Broken_end", '006': "Broken_yarn", '010': "Broken_pick", 12 | '016': "Weft_curling", '019': "Fuzzyball", '022': "Cut_selvage", 13 | '023': "Crease", '025': "Warp_ball", '027': "Knots", 14 | '029': "Contamination", '030': "Nep", '036': "Weft_crack"} 15 | 16 | normal_images = list() 17 | normal_fname = list() 18 | outlier_images = list() 19 | outlier_labels = list() 20 | outlier_fname = list() 21 | 22 | 23 | normal_root = os.path.join(args.dataset_root, 'NODefect_images') 24 | normal_dirs = os.listdir(normal_root) 25 | for dir in normal_dirs: 26 | files = os.listdir(os.path.join(normal_root, dir)) 27 | for image in files: 28 | image_name = image.split('.')[0] 29 | image_data = cv2.imread(os.path.join(normal_root, dir, image)) 30 | for i in range(16): 31 | normal_images.append(image_data[:, i*256:(i+1)*256 ,:]) 32 | normal_fname.append(dir + '_' + image_name + '_' + str(i)) 33 | 34 | outlier_root = os.path.join(args.dataset_root, 'Defect_images/Defect_images') 35 | label_root = os.path.join(args.dataset_root, 'Mask_images/Mask_images') 36 | files = os.listdir(os.path.join(outlier_root)) 37 | for image in files: 38 | split_images = list() 39 | split_labels = list() 40 | image_name = image.split('.')[0] 41 | image_data = cv2.imread(os.path.join(outlier_root, image)) 42 | label_data = cv2.imread(os.path.join(label_root, image_name + '_mask.png')) 43 | if image_data.shape[1] % image_data.shape[0] == 0: 44 | count = image_data.shape[1]//image_data.shape[0] 45 | else: 46 | count = image_data.shape[1] // image_data.shape[0] + 1 47 | for i in range(count): 48 | split_images.append(image_data[:, i * 256:(i + 1) * 256, :]) 49 | split_labels.append(label_data[:, i * 256:(i + 1) * 256, :]) 50 | for i, (im, la) in enumerate(zip(split_images, split_labels)): 51 | if np.max(la) != 0: 52 | outlier_images.append(im) 53 | outlier_labels.append(la) 54 | outlier_fname.append(image_name + '_' + str(i)) 55 | 56 | normal_train, normal_test, normal_name_train, normal_name_test = train_test_split(normal_images, normal_fname, test_size=0.25, random_state=42) 57 | 58 | target_root = './AITEX_anomaly_detection/AITEX' 59 | train_root = os.path.join(target_root, 'train/good') 60 | if not os.path.exists(train_root): 61 | os.makedirs(train_root) 62 | for image, name in zip(normal_train, normal_name_train): 63 | cv2.imwrite(os.path.join(train_root, name + '.png'), image) 64 | 65 | test_root = os.path.join(target_root, 'test/good') 66 | if not os.path.exists(test_root): 67 | os.makedirs(test_root) 68 | for image, name in zip(normal_test, normal_name_test): 69 | cv2.imwrite(os.path.join(test_root, name + '.png'), image) 70 | 71 | for image, label, name in zip(outlier_images, outlier_labels, outlier_fname): 72 | defect_class = DEFEAT_CLASS[name.split('_')[1]] 73 | defect_root = os.path.join(target_root, 'test', defect_class) 74 | label_root = os.path.join(target_root, 'ground_truth', defect_class) 75 | if not os.path.exists(defect_root): 76 | os.makedirs(defect_root) 77 | if not os.path.exists(label_root): 78 | os.makedirs(label_root) 79 | cv2.imwrite(os.path.join(defect_root, name + '.png'), image) 80 | cv2.imwrite(os.path.join(label_root, name + '_mask.png'), label) 81 | 82 | print("Done") -------------------------------------------------------------------------------- /data/convert_BrainMRI.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sklearn.model_selection import train_test_split 3 | import shutil 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--dataset_root', type=str, help="dataset root") 8 | args = parser.parse_args() 9 | 10 | normal_root = os.path.join(args.dataset_root, 'no') 11 | outlier_root = os.path.join(args.dataset_root, 'yes') 12 | 13 | normal_fnames = os.listdir(normal_root) 14 | outlier_fnames = os.listdir(outlier_root) 15 | 16 | normal_train, normal_test, _, _ = train_test_split(normal_fnames, normal_fnames, test_size=0.25, random_state=42) 17 | 18 | target_root = './BrainMRI_anomaly_detection/brainmri' 19 | train_root = os.path.join(target_root, 'train/good') 20 | if not os.path.exists(train_root): 21 | os.makedirs(train_root) 22 | for f in normal_train: 23 | source = os.path.join(normal_root, f) 24 | shutil.copy(source, train_root) 25 | 26 | test_normal_root = os.path.join(target_root, 'test/good') 27 | if not os.path.exists(test_normal_root): 28 | os.makedirs(test_normal_root) 29 | for f in normal_test: 30 | source = os.path.join(normal_root, f) 31 | shutil.copy(source, test_normal_root) 32 | 33 | test_outlier_root = os.path.join(target_root, 'test/defect') 34 | if not os.path.exists(test_outlier_root): 35 | os.makedirs(test_outlier_root) 36 | for f in outlier_fnames: 37 | source = os.path.join(outlier_root, f) 38 | shutil.copy(source, test_outlier_root) 39 | 40 | print("Done") -------------------------------------------------------------------------------- /data/convert_HeadCT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import shutil 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset_root', type=str, help="dataset root") 9 | args = parser.parse_args() 10 | 11 | label_file = os.path.join(args.dataset_root, 'labels.csv') 12 | 13 | data = np.loadtxt(label_file, dtype=int, delimiter=',', skiprows=1) 14 | 15 | fnames = data[:, 0] 16 | label = data[:, 1] 17 | 18 | normal_fnames = fnames[label==0] 19 | outlier_fnames = fnames[label==1] 20 | 21 | normal_train, normal_test, _, _ = train_test_split(normal_fnames, normal_fnames, test_size=0.25, random_state=42) 22 | 23 | target_root = './HeadCT_anomaly_detection/headct' 24 | train_root = os.path.join(target_root, 'train/good') 25 | if not os.path.exists(train_root): 26 | os.makedirs(train_root) 27 | for f in normal_train: 28 | source = os.path.join(args.dataset_root, 'head_ct/head_ct/', '{:0>3d}.png'.format(f)) 29 | shutil.copy(source, train_root) 30 | 31 | test_normal_root = os.path.join(target_root, 'test/good') 32 | if not os.path.exists(test_normal_root): 33 | os.makedirs(test_normal_root) 34 | for f in normal_test: 35 | source = os.path.join(args.dataset_root, 'head_ct/head_ct/', '{:0>3d}.png'.format(f)) 36 | shutil.copy(source, test_normal_root) 37 | 38 | test_outlier_root = os.path.join(target_root, 'test/defect') 39 | if not os.path.exists(test_outlier_root): 40 | os.makedirs(test_outlier_root) 41 | for f in outlier_fnames: 42 | source = os.path.join(args.dataset_root, 'head_ct/head_ct/', '{:0>3d}.png'.format(f)) 43 | shutil.copy(source, test_outlier_root) 44 | 45 | print('Done') -------------------------------------------------------------------------------- /data/convert_MastCam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--dataset_root', type=str, help="dataset root") 7 | args = parser.parse_args() 8 | 9 | normal_train = list() 10 | normal_test = list() 11 | outlier_fnames = list() 12 | normal_root = os.path.join(args.dataset_root, 'train_typical') 13 | for file in os.listdir(normal_root): 14 | normal_train.append(os.path.join(normal_root, file)) 15 | 16 | test_normal_root = os.path.join(args.dataset_root, 'test_typical') 17 | for file in os.listdir(test_normal_root): 18 | normal_test.append(os.path.join(test_normal_root, file)) 19 | 20 | outlier_root = os.path.join(args.dataset_root, 'test_novel') 21 | for dir in os.listdir(outlier_root): 22 | class_root = os.path.join(outlier_root, dir) 23 | for file in os.listdir(class_root): 24 | outlier_fnames.append(os.path.join(class_root, file)) 25 | 26 | target_root = './MastCam_anomaly_detection/mastcam' 27 | train_root = os.path.join(target_root, 'train/good') 28 | if not os.path.exists(train_root): 29 | os.makedirs(train_root) 30 | for f in normal_train: 31 | shutil.copy(f, train_root) 32 | 33 | test_normal_root = os.path.join(target_root, 'test/good') 34 | if not os.path.exists(test_normal_root): 35 | os.makedirs(test_normal_root) 36 | for f in normal_test: 37 | shutil.copy(f, test_normal_root) 38 | 39 | test_outlier_root = os.path.join(target_root, 'test') 40 | if not os.path.exists(test_outlier_root): 41 | os.makedirs(test_outlier_root) 42 | for f in outlier_fnames: 43 | class_name = f.split('/')[-2] 44 | target_root = os.path.join(test_outlier_root,class_name) 45 | if not os.path.exists(target_root): 46 | os.makedirs(target_root) 47 | shutil.copy(f, target_root) 48 | 49 | print('Done') -------------------------------------------------------------------------------- /data/convert_SDD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import cv2 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset_root', type=str, help="dataset root") 9 | args = parser.parse_args() 10 | 11 | dirs = os.listdir(args.dataset_root) 12 | normal_images = list() 13 | normal_labels = list() 14 | normal_fname = list() 15 | outlier_images = list() 16 | outlier_labels = list() 17 | outlier_fname = list() 18 | for d in dirs: 19 | files = os.listdir(os.path.join(args.dataset_root, d)) 20 | images = list() 21 | for f in files: 22 | if 'jpg' in f[-3:]: 23 | images.append(f) 24 | 25 | for image in images: 26 | split_images = list() 27 | split_labels = list() 28 | image_name = image.split('.')[0] 29 | image_data = cv2.imread(os.path.join(args.dataset_root, d, image)) 30 | label_data = cv2.imread(os.path.join(args.dataset_root, d, image_name + '_label.bmp')) 31 | if image_data.shape != label_data.shape: 32 | raise ValueError 33 | image_length = image_data.shape[0] 34 | split_images.append(image_data[:image_length // 3, :, :]) 35 | split_images.append(image_data[image_length // 3:image_length * 2 // 3, :, :]) 36 | split_images.append(image_data[image_length * 2 // 3:, :, :]) 37 | split_labels.append(label_data[:image_length // 3, :, :]) 38 | split_labels.append(label_data[image_length // 3:image_length * 2 // 3, :, :]) 39 | split_labels.append(label_data[image_length * 2 // 3:, :, :]) 40 | for i, (im, la) in enumerate(zip(split_images, split_labels)): 41 | if np.max(la) != 0: 42 | outlier_images.append(im) 43 | outlier_labels.append(la) 44 | outlier_fname.append(d + '_' + image_name + '_' + str(i)) 45 | else: 46 | normal_images.append(im) 47 | normal_labels.append(la) 48 | normal_fname.append(d + '_' + image_name + '_' + str(i)) 49 | 50 | normal_train, normal_test, normal_name_train, normal_name_test = train_test_split(normal_images, normal_fname, test_size=0.25, random_state=42) 51 | 52 | target_root = './SDD_anomaly_detection/SDD' 53 | train_root = os.path.join(target_root, 'train/good') 54 | if not os.path.exists(train_root): 55 | os.makedirs(train_root) 56 | for image, name in zip(normal_train, normal_name_train): 57 | cv2.imwrite(os.path.join(train_root, name + '.png'), image) 58 | 59 | test_root = os.path.join(target_root, 'test/good') 60 | if not os.path.exists(test_root): 61 | os.makedirs(test_root) 62 | for image, name in zip(normal_test, normal_name_test): 63 | cv2.imwrite(os.path.join(test_root, name + '.png'), image) 64 | 65 | defect_root = os.path.join(target_root, 'test/defect') 66 | label_root = os.path.join(target_root, 'ground_truth/defect') 67 | if not os.path.exists(defect_root): 68 | os.makedirs(defect_root) 69 | if not os.path.exists(label_root): 70 | os.makedirs(label_root) 71 | for image, label, name in zip(outlier_images, outlier_labels, outlier_fname): 72 | cv2.imwrite(os.path.join(defect_root, name + '.png'), image) 73 | cv2.imwrite(os.path.join(label_root, name + '_mask.png'), label) 74 | 75 | print("Done") 76 | -------------------------------------------------------------------------------- /data/convert_elpv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import shutil 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset_root', type=str, help="dataset root") 9 | args = parser.parse_args() 10 | 11 | label_file = os.path.join(args.dataset_root, 'labels.csv') 12 | 13 | data = np.genfromtxt(label_file, dtype=['|S19', '0: 36 | self.test_threshold = int((len(normal_data)/(1-self.args.test_rate)) * self.args.test_rate) + self.args.nAnomaly 37 | 38 | self.ood_data = self.get_ood_data() 39 | 40 | if self.train is False: 41 | normal_data = list() 42 | split = 'test' 43 | normal_files = os.listdir(os.path.join(self.root, split, 'good')) 44 | for file in normal_files: 45 | if 'png' in file[-3:] or 'PNG' in file[-3:] or 'jpg' in file[-3:] or 'npy' in file[-3:]: 46 | normal_data.append(split + '/good/' + file) 47 | 48 | outlier_data, pollution_data = self.split_outlier() 49 | outlier_data.sort() 50 | 51 | normal_data = normal_data + pollution_data 52 | 53 | normal_label = np.zeros(len(normal_data)).tolist() 54 | outlier_label = np.ones(len(outlier_data)).tolist() 55 | 56 | self.images = normal_data + outlier_data 57 | self.labels = np.array(normal_label + outlier_label) 58 | self.normal_idx = np.argwhere(self.labels == 0).flatten() 59 | self.outlier_idx = np.argwhere(self.labels == 1).flatten() 60 | 61 | def get_ood_data(self): 62 | ood_data = list() 63 | if self.args.outlier_root is None: 64 | return None 65 | dataset_classes = os.listdir(self.args.outlier_root) 66 | for cl in dataset_classes: 67 | if cl == self.args.classname: 68 | continue 69 | cl_root = os.path.join(self.args.outlier_root, cl, 'train', 'good') 70 | ood_file = os.listdir(cl_root) 71 | for file in ood_file: 72 | if 'png' in file[-3:] or 'PNG' in file[-3:] or 'jpg' in file[-3:] or 'npy' in file[-3:]: 73 | ood_data.append(os.path.join(cl_root, file)) 74 | return ood_data 75 | 76 | def split_outlier(self): 77 | outlier_data_dir = os.path.join(self.root, 'test') 78 | outlier_classes = os.listdir(outlier_data_dir) 79 | if self.know_class in outlier_classes: 80 | print("Know outlier class: " + self.know_class) 81 | outlier_data = list() 82 | know_class_data = list() 83 | for cl in outlier_classes: 84 | if cl == 'good': 85 | continue 86 | outlier_file = os.listdir(os.path.join(outlier_data_dir, cl)) 87 | for file in outlier_file: 88 | if 'png' in file[-3:] or 'PNG' in file[-3:] or 'jpg' in file[-3:] or 'npy' in file[-3:]: 89 | if cl == self.know_class: 90 | know_class_data.append('test/' + cl + '/' + file) 91 | else: 92 | outlier_data.append('test/' + cl + '/' + file) 93 | np.random.RandomState(self.args.ramdn_seed).shuffle(know_class_data) 94 | know_outlier = know_class_data[0:self.args.nAnomaly] 95 | unknow_outlier = outlier_data 96 | if self.train: 97 | return know_outlier, list() 98 | else: 99 | return unknow_outlier, list() 100 | 101 | 102 | outlier_data = list() 103 | for cl in outlier_classes: 104 | if cl == 'good': 105 | continue 106 | outlier_file = os.listdir(os.path.join(outlier_data_dir, cl)) 107 | for file in outlier_file: 108 | if 'png' in file[-3:] or 'PNG' in file[-3:] or 'jpg' in file[-3:] or 'npy' in file[-3:]: 109 | outlier_data.append('test/' + cl + '/' + file) 110 | np.random.RandomState(self.args.ramdn_seed).shuffle(outlier_data) 111 | if self.train: 112 | return outlier_data[0:self.args.nAnomaly], outlier_data[self.args.nAnomaly:self.args.nAnomaly + self.nPollution] 113 | else: 114 | return outlier_data[self.test_threshold:], list() 115 | 116 | def load_image(self, path): 117 | if 'npy' in path[-3:]: 118 | img = np.load(path).astype(np.uint8) 119 | img = img[:, :, :3] 120 | return Image.fromarray(img) 121 | return Image.open(path).convert('RGB') 122 | 123 | def transform_train(self): 124 | composed_transforms = transforms.Compose([ 125 | transforms.Resize((self.args.img_size,self.args.img_size)), 126 | transforms.RandomRotation(180), 127 | transforms.ToTensor(), 128 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 129 | return composed_transforms 130 | 131 | def transform_pseudo(self): 132 | composed_transforms = transforms.Compose([ 133 | transforms.Resize((self.args.img_size,self.args.img_size)), 134 | CutMix(), 135 | transforms.RandomRotation(180), 136 | transforms.ToTensor(), 137 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 138 | return composed_transforms 139 | 140 | def transform_test(self): 141 | composed_transforms = transforms.Compose([ 142 | transforms.Resize((self.args.img_size, self.args.img_size)), 143 | transforms.ToTensor(), 144 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 145 | return composed_transforms 146 | 147 | def __len__(self): 148 | return len(self.images) 149 | 150 | def __getitem__(self, index): 151 | rnd = random.randint(0, 1) 152 | if index in self.normal_idx and rnd == 0 and self.train: 153 | if self.ood_data is None: 154 | index = random.choice(self.normal_idx) 155 | image = self.load_image(os.path.join(self.root, self.images[index])) 156 | transform = self.transform_pseudo 157 | else: 158 | image = self.load_image(random.choice(self.ood_data)) 159 | transform = self.transform 160 | label = 2 161 | else: 162 | image = self.load_image(os.path.join(self.root, self.images[index])) 163 | transform = self.transform 164 | label = self.labels[index] 165 | sample = {'image': transform(image), 'label': label} 166 | return sample 167 | -------------------------------------------------------------------------------- /loss/deviation_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DeviationLoss(nn.Module): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, y_pred, y_true): 10 | confidence_margin = 5. 11 | ref = torch.normal(mean=0., std=torch.full([5000], 1.)).cuda() 12 | dev = (y_pred - torch.mean(ref)) / torch.std(ref) 13 | inlier_loss = torch.abs(dev) 14 | outlier_loss = torch.abs((confidence_margin - dev).clamp_(min=0.)) 15 | dev_loss = (1 - y_true) * inlier_loss + y_true * outlier_loss 16 | return torch.mean(dev_loss) 17 | -------------------------------------------------------------------------------- /modules/sb.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import math 3 | 4 | from torch import nn 5 | 6 | from torch.nn.functional import softmax, log_softmax 7 | import torch 8 | import geotorch 9 | 10 | from modeling.sb_modules.MyMixtureSameFamily import MixtureSameFamily 11 | from torch.distributions.categorical import Categorical 12 | from torch.distributions.multivariate_normal import MultivariateNormal 13 | from torch.distributions.independent import Independent 14 | from torch.distributions.normal import Normal 15 | 16 | from tqdm import tqdm 17 | 18 | 19 | class LightSB(nn.Module): 20 | def __init__(self, dim=512, n_potentials=10, epsilon=0.1, is_diagonal=True, 21 | sampling_batch_size=1, S_diagonal_init=0.1): 22 | super().__init__() 23 | self.is_diagonal = is_diagonal 24 | self.dim = dim 25 | self.n_potentials = n_potentials 26 | self.register_buffer("epsilon", torch.tensor(epsilon)) 27 | self.sampling_batch_size = sampling_batch_size 28 | 29 | self.log_alpha_raw = nn.Parameter(self.epsilon * torch.log(torch.ones(n_potentials) / n_potentials)) 30 | self.r = nn.Parameter(torch.randn(n_potentials, dim)) 31 | 32 | self.S_log_diagonal_matrix = nn.Parameter(torch.log(S_diagonal_init * torch.ones(n_potentials, self.dim))) 33 | self.S_rotation_matrix = nn.Parameter( 34 | torch.randn(n_potentials, self.dim, self.dim) 35 | ) 36 | geotorch.orthogonal(self, "S_rotation_matrix") 37 | 38 | def init_r_by_samples(self, samples): 39 | assert samples.shape[0] == self.r.shape[0] 40 | 41 | self.r.data = torch.clone(samples.to(self.r.device)) 42 | 43 | def get_S(self): 44 | if self.is_diagonal: 45 | S = torch.exp(self.S_log_diagonal_matrix) 46 | else: 47 | S = (self.S_rotation_matrix * (torch.exp(self.S_log_diagonal_matrix))[:, None, :]) @ torch.permute( 48 | self.S_rotation_matrix, (0, 2, 1)) 49 | return S 50 | 51 | def get_r(self): 52 | return self.r 53 | 54 | def get_log_alpha(self): 55 | return (1 / self.epsilon) * self.log_alpha_raw 56 | 57 | @torch.no_grad() 58 | def forward(self, x): 59 | S = self.get_S() 60 | r = self.get_r() 61 | epsilon = self.epsilon 62 | 63 | log_alpha = self.get_log_alpha() 64 | 65 | eps_S = epsilon * S 66 | 67 | samples = [] 68 | batch_size = x.shape[0] 69 | sampling_batch_size = self.sampling_batch_size 70 | 71 | num_sampling_iterations = ( 72 | batch_size // sampling_batch_size if batch_size % sampling_batch_size == 0 else ( 73 | batch_size // sampling_batch_size) + 1 74 | ) 75 | 76 | for i in range(num_sampling_iterations): 77 | sub_batch_x = x[sampling_batch_size * i:sampling_batch_size * (i + 1)] 78 | 79 | if self.is_diagonal: 80 | x_S_x = (sub_batch_x[:, None, :] * S[None, :, :] * sub_batch_x[:, None, :]).sum(dim=-1) 81 | x_r = (sub_batch_x[:, None, :] * r[None, :, :]).sum(dim=-1) 82 | r_x = r[None, :, :] + S[None, :] * sub_batch_x[:, None, :] 83 | else: 84 | x_S_x = (sub_batch_x[:, None, None, :] @ (S[None, :, :, :] @ sub_batch_x[:, None, :, None]))[:, :, 0, 0] 85 | x_r = (sub_batch_x[:, None, :] * r[None, :, :]).sum(dim=-1) 86 | r_x = r[None, :, :] + (S[None, :, :, :] @ sub_batch_x[:, None, :, None])[:, :, :, 0] 87 | 88 | exp_argument = (x_S_x + 2 * x_r) / (2 * epsilon) + log_alpha[None, :] 89 | 90 | if self.is_diagonal: 91 | mix = Categorical(logits=exp_argument) 92 | comp = Independent(Normal(loc=r_x, scale=torch.sqrt(epsilon * S)[None, :, :]), 1) 93 | gmm = MixtureSameFamily(mix, comp) 94 | 95 | else: 96 | mix = Categorical(logits=exp_argument) 97 | comp = MultivariateNormal(loc=r_x, covariance_matrix=epsilon * S) 98 | gmm = MixtureSameFamily(mix, comp) 99 | 100 | samples.append(gmm.sample()) 101 | 102 | samples = torch.cat(samples, dim=0) 103 | 104 | return samples 105 | 106 | def get_drift(self, x, t): 107 | x = torch.clone(x) 108 | x.requires_grad = True 109 | 110 | epsilon = self.epsilon 111 | r = self.get_r() 112 | 113 | S_diagonal = torch.exp(self.S_log_diagonal_matrix) # shape: potential*dim 114 | A_diagonal = (t / (epsilon * (1 - t)))[:, None, None] + 1 / (epsilon * S_diagonal)[None, :, 115 | :] # shape: batch*potential*dim 116 | 117 | S_log_det = torch.sum(self.S_log_diagonal_matrix, dim=-1) # shape: potential 118 | A_log_det = torch.sum(torch.log(A_diagonal), dim=-1) # shape: batch*potential 119 | 120 | log_alpha = self.get_log_alpha() # shape: potential 121 | 122 | if self.is_diagonal: 123 | S = S_diagonal # shape: potential*dim 124 | A = A_diagonal # shape: batch*potential*dim 125 | 126 | S_inv = 1 / S # shape: potential*dim 127 | A_inv = 1 / A # shape: batch*potential*dim 128 | 129 | c = ((1 / (epsilon * (1 - t)))[:, None] * x)[:, None, :] + (r / (epsilon * S_diagonal))[None, :, 130 | :] # shape: batch*potential*dim 131 | 132 | exp_arg = ( 133 | log_alpha[None, :] - 0.5 * S_log_det[None, :] - 0.5 * A_log_det 134 | - 0.5 * ((r * S_inv * r) / epsilon).sum(dim=-1)[None, :] + 0.5 * (c * A_inv * c).sum(dim=-1) 135 | ) 136 | else: 137 | S = (self.S_rotation_matrix * S_diagonal[:, None, :]) @ torch.permute(self.S_rotation_matrix, (0, 2, 1)) 138 | A = (self.S_rotation_matrix[None, :, :, :] * A_diagonal[:, :, None, :]) @ torch.permute( 139 | self.S_rotation_matrix, (0, 2, 1))[None, :, :, :] 140 | 141 | S_inv = (self.S_rotation_matrix * (1 / S_diagonal[:, None, :])) @ torch.permute(self.S_rotation_matrix, 142 | (0, 2, 1)) 143 | A_inv = (self.S_rotation_matrix[None, :, :, :] * (1 / A_diagonal[:, :, None, :])) @ torch.permute( 144 | self.S_rotation_matrix, (0, 2, 1))[None, :, :, :] 145 | 146 | c = ((1 / (epsilon * (1 - t)))[:, None] * x)[:, None, :] + (S_inv @ (r[:, :, None]))[None, :, :, 147 | 0] / epsilon # shape: batch*potential*dim 148 | 149 | c_A_inv_c = (c[:, :, None, :] @ A_inv @ c[:, :, :, None])[:, :, 0, 0] 150 | r_S_inv_r = (r[:, None, :] @ S_inv @ r[:, :, None])[None, :, 0, 0] 151 | 152 | exp_arg = ( 153 | log_alpha[None, :] - 0.5 * S_log_det[None, 154 | :] - 0.5 * A_log_det - 0.5 * r_S_inv_r / epsilon + 0.5 * c_A_inv_c 155 | ) 156 | 157 | lse = torch.logsumexp(exp_arg, dim=-1) 158 | drift = (-x / (1 - t[:, None]) + epsilon * 159 | torch.autograd.grad(lse, x, grad_outputs=torch.ones_like(lse, device=lse.device))[0]).detach() 160 | 161 | return drift 162 | 163 | def sample_euler_maruyama(self, x, n_steps): 164 | epsilon = self.epsilon 165 | t = torch.zeros(x.shape[0], device=x.device) 166 | dt = 1 / n_steps 167 | trajectory = [x] 168 | 169 | for i in range(n_steps): 170 | x = x + self.get_drift(x, t) * dt + math.sqrt(dt) * torch.sqrt(epsilon) * torch.randn_like(x, 171 | device=x.device) 172 | t += dt 173 | trajectory.append(x) 174 | 175 | return torch.stack(trajectory, dim=1) 176 | 177 | def sample_at_time_moment(self, x, t): 178 | t = t.to(x.device) 179 | y = self(x) 180 | 181 | return t * y + (1 - t) * x + torch.sqrt(t * (1 - t) * self.epsilon) * torch.randn_like(x) 182 | 183 | def get_log_potential(self, x): 184 | S = self.get_S() 185 | r = self.get_r() 186 | log_alpha = self.get_log_alpha() 187 | D = self.dim 188 | 189 | epsilon = self.epsilon 190 | 191 | if self.is_diagonal: 192 | mix = Categorical(logits=log_alpha) 193 | comp = Independent(Normal(loc=r, scale=torch.sqrt(self.epsilon * S)), 1) 194 | gmm = MixtureSameFamily(mix, comp) 195 | 196 | potential = gmm.log_prob(x) + torch.logsumexp(log_alpha, dim=-1) 197 | else: 198 | mix = Categorical(logits=log_alpha) 199 | comp = MultivariateNormal(loc=r, covariance_matrix=self.epsilon * S) 200 | gmm = MixtureSameFamily(mix, comp) 201 | 202 | potential = gmm.log_prob(x) + torch.logsumexp(log_alpha, dim=-1) 203 | 204 | return potential 205 | 206 | def get_log_C(self, x): 207 | S = self.get_S() 208 | r = self.get_r() 209 | epsilon = self.epsilon 210 | log_alpha = self.get_log_alpha() 211 | 212 | eps_S = epsilon * S 213 | 214 | if self.is_diagonal: 215 | x_S_x = (x[:, None, :] * S[None, :, :] * x[:, None, :]).sum(dim=-1) 216 | x_r = (x[:, None, :] * r[None, :, :]).sum(dim=-1) 217 | else: 218 | x_S_x = (x[:, None, None, :] @ (S[None, :, :, :] @ x[:, None, :, None]))[:, :, 0, 0] 219 | x_r = (x[:, None, :] * r[None, :, :]).sum(dim=-1) 220 | 221 | exp_argument = (x_S_x + 2 * x_r) / (2 * epsilon) + log_alpha[None, :] 222 | log_norm_const = torch.logsumexp(exp_argument, dim=-1) 223 | 224 | return log_norm_const 225 | 226 | def set_epsilon(self, new_epsilon): 227 | self.epsilon = torch.tensor(new_epsilon, device=self.epsilon.device) 228 | -------------------------------------------------------------------------------- /networks/backbone.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import alexnet 2 | from modeling.networks.resnet18 import FeatureRESNET18 3 | 4 | NET_OUT_DIM = {'alexnet': 256, 'resnet18': 512} 5 | 6 | def build_feature_extractor(backbone, cfg): 7 | if backbone == "alexnet": 8 | print("Feature extractor: AlexNet") 9 | return alexnet(pretrained=True).features 10 | elif backbone == "resnet18": 11 | print("Feature extractor: ResNet-18") 12 | return FeatureRESNET18() 13 | else: 14 | raise NotImplementedError -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from torchvision.models.utils import load_state_dict_from_url 5 | from typing import Type, Any, Callable, Union, List, Optional 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion: int = 1 39 | 40 | def __init__( 41 | self, 42 | inplanes: int, 43 | planes: int, 44 | stride: int = 1, 45 | downsample: Optional[nn.Module] = None, 46 | groups: int = 1, 47 | base_width: int = 64, 48 | dilation: int = 1, 49 | norm_layer: Optional[Callable[..., nn.Module]] = None 50 | ) -> None: 51 | super(BasicBlock, self).__init__() 52 | if norm_layer is None: 53 | norm_layer = nn.BatchNorm2d 54 | if groups != 1 or base_width != 64: 55 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 56 | if dilation > 1: 57 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 59 | self.conv1 = conv3x3(inplanes, planes, stride) 60 | self.bn1 = norm_layer(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = conv3x3(planes, planes) 63 | self.bn2 = norm_layer(planes) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x: Tensor) -> Tensor: 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 88 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 89 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 90 | # This variant is also known as ResNet V1.5 and improves accuracy according to 91 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 92 | 93 | expansion: int = 4 94 | 95 | def __init__( 96 | self, 97 | inplanes: int, 98 | planes: int, 99 | stride: int = 1, 100 | downsample: Optional[nn.Module] = None, 101 | groups: int = 1, 102 | base_width: int = 64, 103 | dilation: int = 1, 104 | norm_layer: Optional[Callable[..., nn.Module]] = None 105 | ) -> None: 106 | super(Bottleneck, self).__init__() 107 | if norm_layer is None: 108 | norm_layer = nn.BatchNorm2d 109 | width = int(planes * (base_width / 64.)) * groups 110 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 111 | self.conv1 = conv1x1(inplanes, width) 112 | self.bn1 = norm_layer(width) 113 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 114 | self.bn2 = norm_layer(width) 115 | self.conv3 = conv1x1(width, planes * self.expansion) 116 | self.bn3 = norm_layer(planes * self.expansion) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.downsample = downsample 119 | self.stride = stride 120 | 121 | def forward(self, x: Tensor) -> Tensor: 122 | identity = x 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv3(out) 133 | out = self.bn3(out) 134 | 135 | if self.downsample is not None: 136 | identity = self.downsample(x) 137 | 138 | out += identity 139 | out = self.relu(out) 140 | 141 | return out 142 | 143 | 144 | class ResNet(nn.Module): 145 | 146 | def __init__( 147 | self, 148 | block: Type[Union[BasicBlock, Bottleneck]], 149 | layers: List[int], 150 | num_classes: int = 1000, 151 | zero_init_residual: bool = False, 152 | groups: int = 1, 153 | width_per_group: int = 64, 154 | replace_stride_with_dilation: Optional[List[bool]] = None, 155 | norm_layer: Optional[Callable[..., nn.Module]] = None 156 | ) -> None: 157 | super(ResNet, self).__init__() 158 | if norm_layer is None: 159 | norm_layer = nn.BatchNorm2d 160 | self._norm_layer = norm_layer 161 | 162 | self.inplanes = 64 163 | self.dilation = 1 164 | if replace_stride_with_dilation is None: 165 | # each element in the tuple indicates if we should replace 166 | # the 2x2 stride with a dilated convolution instead 167 | replace_stride_with_dilation = [False, False, False] 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError("replace_stride_with_dilation should be None " 170 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 171 | self.groups = groups 172 | self.base_width = width_per_group 173 | self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3, 174 | bias=False) 175 | self.bn1 = norm_layer(self.inplanes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 178 | self.layer1 = self._make_layer(block, 64, layers[0]) 179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 180 | dilate=replace_stride_with_dilation[0]) 181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 182 | dilate=replace_stride_with_dilation[1]) 183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 184 | dilate=replace_stride_with_dilation[2]) 185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 186 | self.fc = nn.Linear(512 * block.expansion, num_classes) 187 | 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 204 | 205 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 206 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 207 | norm_layer = self._norm_layer 208 | downsample = None 209 | previous_dilation = self.dilation 210 | if dilate: 211 | self.dilation *= stride 212 | stride = 1 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | conv1x1(self.inplanes, planes * block.expansion, stride), 216 | norm_layer(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 221 | self.base_width, previous_dilation, norm_layer)) 222 | self.inplanes = planes * block.expansion 223 | for _ in range(1, blocks): 224 | layers.append(block(self.inplanes, planes, groups=self.groups, 225 | base_width=self.base_width, dilation=self.dilation, 226 | norm_layer=norm_layer)) 227 | 228 | return nn.Sequential(*layers) 229 | 230 | def _forward_impl(self, x: Tensor) -> Tensor: 231 | # See note [TorchScript super()] 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.avgpool(x) 243 | x = torch.flatten(x, 1) 244 | x = self.fc(x) 245 | 246 | return x 247 | 248 | def forward(self, x: Tensor) -> Tensor: 249 | return self._forward_impl(x) 250 | 251 | 252 | def _resnet( 253 | arch: str, 254 | block: Type[Union[BasicBlock, Bottleneck]], 255 | layers: List[int], 256 | pretrained: bool, 257 | progress: bool, 258 | **kwargs: Any 259 | ) -> ResNet: 260 | model = ResNet(block, layers, **kwargs) 261 | if pretrained: 262 | state_dict = load_state_dict_from_url(model_urls[arch], 263 | progress=progress) 264 | model.load_state_dict(state_dict) 265 | return model 266 | 267 | 268 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 269 | r"""ResNet-18 model from 270 | `"Deep Residual Learning for Image Recognition" `_. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 276 | **kwargs) 277 | 278 | 279 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 280 | r"""ResNet-34 model from 281 | `"Deep Residual Learning for Image Recognition" `_. 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | progress (bool): If True, displays a progress bar of the download to stderr 285 | """ 286 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 287 | **kwargs) 288 | 289 | 290 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 291 | r"""ResNet-50 model from 292 | `"Deep Residual Learning for Image Recognition" `_. 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | progress (bool): If True, displays a progress bar of the download to stderr 296 | """ 297 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 298 | **kwargs) 299 | 300 | 301 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 302 | r"""ResNet-101 model from 303 | `"Deep Residual Learning for Image Recognition" `_. 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | progress (bool): If True, displays a progress bar of the download to stderr 307 | """ 308 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 309 | **kwargs) 310 | 311 | 312 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 313 | r"""ResNet-152 model from 314 | `"Deep Residual Learning for Image Recognition" `_. 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | progress (bool): If True, displays a progress bar of the download to stderr 318 | """ 319 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 320 | **kwargs) 321 | 322 | 323 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 324 | r"""ResNeXt-50 32x4d model from 325 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 326 | Args: 327 | pretrained (bool): If True, returns a model pre-trained on ImageNet 328 | progress (bool): If True, displays a progress bar of the download to stderr 329 | """ 330 | kwargs['groups'] = 32 331 | kwargs['width_per_group'] = 4 332 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 333 | pretrained, progress, **kwargs) 334 | 335 | 336 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 337 | r"""ResNeXt-101 32x8d model from 338 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 339 | Args: 340 | pretrained (bool): If True, returns a model pre-trained on ImageNet 341 | progress (bool): If True, displays a progress bar of the download to stderr 342 | """ 343 | kwargs['groups'] = 32 344 | kwargs['width_per_group'] = 8 345 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 346 | pretrained, progress, **kwargs) 347 | 348 | 349 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 350 | r"""Wide ResNet-50-2 model from 351 | `"Wide Residual Networks" `_. 352 | The model is the same as ResNet except for the bottleneck number of channels 353 | which is twice larger in every block. The number of channels in outer 1x1 354 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 355 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 356 | Args: 357 | pretrained (bool): If True, returns a model pre-trained on ImageNet 358 | progress (bool): If True, displays a progress bar of the download to stderr 359 | """ 360 | kwargs['width_per_group'] = 64 * 2 361 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 362 | pretrained, progress, **kwargs) 363 | 364 | 365 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 366 | r"""Wide ResNet-101-2 model from 367 | `"Wide Residual Networks" `_. 368 | The model is the same as ResNet except for the bottleneck number of channels 369 | which is twice larger in every block. The number of channels in outer 1x1 370 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 371 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 372 | Args: 373 | pretrained (bool): If True, returns a model pre-trained on ImageNet 374 | progress (bool): If True, displays a progress bar of the download to stderr 375 | """ 376 | kwargs['width_per_group'] = 64 * 2 377 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 378 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /networks/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class FeatureRESNET18(nn.Module): 6 | def __init__(self): 7 | super(FeatureRESNET18, self).__init__() 8 | self.net = models.resnet18(pretrained=True) 9 | 10 | def forward(self, x): 11 | x = self.net.conv1(x) 12 | x = self.net.bn1(x) 13 | x = self.net.relu(x) 14 | x = self.net.maxpool(x) 15 | x = self.net.layer1(x) 16 | x = self.net.layer2(x) 17 | x = self.net.layer3(x) 18 | x = self.net.layer4(x) 19 | return x 20 | 21 | -------------------------------------------------------------------------------- /utils/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Quantize(nn.Module): 7 | def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, thresh=1e-6): 8 | super().__init__() 9 | 10 | self.dim = dim 11 | self.n_embed = n_embed 12 | self.decay = decay 13 | self.eps = eps 14 | self.thresh = thresh 15 | 16 | embed = torch.randn(dim, n_embed) 17 | self.register_buffer("embed", embed) 18 | self.register_buffer("cluster_size", torch.zeros(n_embed)) 19 | self.register_buffer("embed_avg", embed.clone()) 20 | 21 | def forward(self, input): # [15 1536] 22 | # input = input.permute(0, 2, 3, 1).contiguous() 23 | flatten = input.reshape(-1, self.dim) # [15 1536] 24 | dist = ( 25 | flatten.pow(2).sum(1, keepdim=True) 26 | - 2 * flatten @ self.embed 27 | + self.embed.pow(2).sum(0, keepdim=True) 28 | ) # [15] [1536] 29 | _, embed_ind = (-dist).max(1) # 取最小的距离 30 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) # [15 1536] 在哪个位置有值就标为1,one hot向量的长度为n_embed 共有emb_ind行 31 | embed_ind = embed_ind.view(*input.shape[:-1]) 32 | quantize = self.embed_code(embed_ind) 33 | 34 | if self.training: 35 | embed_onehot_sum = embed_onehot.sum(0) 36 | embed_sum = flatten.transpose(0, 1) @ embed_onehot 37 | 38 | # dist_fn.all_reduce(embed_onehot_sum) 39 | # dist_fn.all_reduce(embed_sum) 40 | 41 | self.cluster_size.data.mul_(self.decay).add_( 42 | embed_onehot_sum, alpha=1 - self.decay 43 | ) 44 | self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) 45 | n = self.cluster_size.sum() 46 | cluster_size = ( 47 | (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n 48 | ) 49 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) 50 | self.embed.data.copy_(embed_normalized) 51 | 52 | diff = (quantize.detach() - input).pow(2).mean() 53 | quantize = input + (quantize - input).detach() 54 | # quantize = quantize.permute(0, 3, 1, 2).contiguous() 55 | return quantize, diff, embed_ind # 嵌入离原型的距离矩阵,可以得到每一个哪个原型用的最多而且距离最近 56 | 57 | def embed_code(self, embed_id): 58 | return F.embedding(embed_id, self.embed.transpose(0, 1)) # 从self.embed中找到索引对应的表征并返回 59 | def reAssign(self, dist): 60 | _embed = self.embed.transpose(0, 1).clone().detach() 61 | dist = (dist / dist.sum()).detach() 62 | 63 | neverAssignedLoc = dist < self.thresh 64 | totalNeverAssigned = int(neverAssignedLoc.sum()) 65 | # More than half are never assigned 66 | if totalNeverAssigned > self.n_embed // 2: 67 | mask = torch.zeros((totalNeverAssigned, ), device=self.embed.device) 68 | maskIdx = torch.randperm(len(mask))[self.n_embed // 2:] 69 | # Random pick some never assigned loc and drop them. 70 | mask[maskIdx] = 1. 71 | dist[neverAssignedLoc] = mask 72 | # Update 73 | neverAssignedLoc = dist < self.thresh 74 | totalNeverAssigned = int(neverAssignedLoc.sum()) 75 | argIdx = torch.argsort(dist, descending=True)[:(self.n_embed - totalNeverAssigned)] 76 | mostAssigned = _embed[argIdx] 77 | selectedIdx = torch.randperm(len(mostAssigned))[:totalNeverAssigned] 78 | _embed.data[neverAssignedLoc] = mostAssigned[selectedIdx] 79 | 80 | self.embed.data.copy_(_embed.transpose(0, 1)) 81 | 82 | return 83 | 84 | 85 | --------------------------------------------------------------------------------