├── .gitignore ├── LICENSE ├── MWM_93.44_s40_m0.7.ckpt ├── README.md ├── best_models └── README.md ├── conda_requirements.txt ├── data.py ├── images ├── ProposedASD.jpg ├── ProposedASD.png ├── attentions.pdf ├── spec-heat-ave-std.pdf ├── spec-heat.pdf ├── spec-mel.pdf └── spec-wav.pdf ├── mobilenet.py ├── model.py ├── other_data ├── all_labels.npz ├── allerrors.npz ├── anomalouserrors.npz ├── cleanerrors.npz └── metadata.npz ├── playground.ipynb ├── training_val_test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Michael Neri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MWM_93.44_s40_m0.7.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/MWM_93.44_s40_m0.7.ckpt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Low-complexity Unsupervised Audio Anomaly Detection exploiting Separable Convolutions and Angular Loss 2 | 3 | 4 | 5 | Official repository of the work "Low-complexity Unsupervised Audio Anomaly Detection exploiting Separable Convolutions and Angular Loss" published to IEEE Sensors Letters. 6 | 7 | ## Authors 8 | Michael Neri, Marco Carli 9 | 10 | Department of Industrial, Electronic, and Mechanical Engineering, Roma Tre University, Rome, Italy 11 | 12 | 13 | ---------------------------------------------------------------------------- 14 | 15 | If you use any part of this work please cite the following reference: 16 | 17 | ``` 18 | @ARTICLE{Neri_LSENS_2024, 19 | author={Neri, M. and Carli, M.}, 20 | journal={IEEE Sensors Letters}, 21 | title={{Low-complexity Unsupervised Audio Anomaly Detection exploiting Separable Convolutions and Angular Loss}}, 22 | year={2024}, 23 | volume={}, 24 | number={}, 25 | pages={}, 26 | doi={10.1109/LSENS.2024.3480450} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /best_models/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain all the intermediate models. Remove this file. -------------------------------------------------------------------------------- /conda_requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: win-64 4 | aiohttp=3.8.5=pypi_0 5 | aiosignal=1.3.1=pypi_0 6 | ansicon=1.89.0=pypi_0 7 | anyio=3.7.1=pypi_0 8 | appdirs=1.4.4=pypi_0 9 | arrow=1.2.3=pypi_0 10 | async-timeout=4.0.2=pypi_0 11 | attrs=23.1.0=pypi_0 12 | beautifulsoup4=4.12.2=pypi_0 13 | blas=1.0=mkl 14 | blessed=1.20.0=pypi_0 15 | brotlipy=0.7.0=py39h2bbff1b_1003 16 | ca-certificates=2023.05.30=haa95532_0 17 | certifi=2023.7.22=py39haa95532_0 18 | cffi=1.15.1=py39h2bbff1b_3 19 | charset-normalizer=2.0.4=pyhd3eb1b0_0 20 | click=8.1.6=pypi_0 21 | colorama=0.4.6=pypi_0 22 | contourpy=1.1.0=pypi_0 23 | croniter=1.3.15=pypi_0 24 | cryptography=41.0.2=py39hac1b9e3_0 25 | cuda-cccl=12.2.128=0 26 | cuda-cudart=11.7.99=0 27 | cuda-cudart-dev=11.7.99=0 28 | cuda-cupti=11.7.101=0 29 | cuda-libraries=11.7.1=0 30 | cuda-libraries-dev=11.7.1=0 31 | cuda-nvrtc=11.7.99=0 32 | cuda-nvrtc-dev=11.7.99=0 33 | cuda-nvtx=11.7.91=0 34 | cuda-runtime=11.7.1=0 35 | cycler=0.11.0=pypi_0 36 | dateutils=0.6.12=pypi_0 37 | deepdiff=6.3.1=pypi_0 38 | docker-pycreds=0.4.0=pypi_0 39 | exceptiongroup=1.1.2=pypi_0 40 | fastapi=0.88.0=pypi_0 41 | filelock=3.9.0=py39haa95532_0 42 | fonttools=4.41.1=pypi_0 43 | freetype=2.12.1=ha860e81_0 44 | frozenlist=1.4.0=pypi_0 45 | fsspec=2023.6.0=pypi_0 46 | giflib=5.2.1=h8cc25b3_3 47 | gitdb=4.0.10=pypi_0 48 | gitpython=3.1.32=pypi_0 49 | h11=0.14.0=pypi_0 50 | idna=3.4=py39haa95532_0 51 | importlib-resources=6.0.0=pypi_0 52 | inquirer=3.1.3=pypi_0 53 | intel-openmp=2023.1.0=h59b6b97_46319 54 | itsdangerous=2.1.2=pypi_0 55 | jinja2=3.1.2=py39haa95532_0 56 | jinxed=1.2.0=pypi_0 57 | joblib=1.3.1=pypi_0 58 | jpeg=9e=h2bbff1b_1 59 | kiwisolver=1.4.4=pypi_0 60 | lame=3.100=hcfcfb64_1003 61 | lerc=3.0=hd77b12b_0 62 | libcublas=11.10.3.66=0 63 | libcublas-dev=11.10.3.66=0 64 | libcufft=10.7.2.124=0 65 | libcufft-dev=10.7.2.124=0 66 | libcurand=10.3.3.129=0 67 | libcurand-dev=10.3.3.129=0 68 | libcusolver=11.4.0.1=0 69 | libcusolver-dev=11.4.0.1=0 70 | libcusparse=11.7.4.91=0 71 | libcusparse-dev=11.7.4.91=0 72 | libdeflate=1.17=h2bbff1b_0 73 | libflac=1.4.3=h63175ca_0 74 | libnpp=11.7.4.75=0 75 | libnpp-dev=11.7.4.75=0 76 | libnvjpeg=11.8.0.2=0 77 | libnvjpeg-dev=11.8.0.2=0 78 | libogg=1.3.4=h8ffe710_1 79 | libopus=1.3.1=h8ffe710_1 80 | libpng=1.6.39=h8cc25b3_0 81 | libsndfile=1.2.0=h2628c91_0 82 | libtiff=4.5.0=h6c2663c_2 83 | libuv=1.44.2=h2bbff1b_0 84 | libvorbis=1.3.7=h0e60522_0 85 | libwebp=1.2.4=hbc33d0d_1 86 | libwebp-base=1.2.4=h2bbff1b_1 87 | lightning=1.9.5=pypi_0 88 | lightning-cloud=0.5.37=pypi_0 89 | lightning-utilities=0.9.0=pypi_0 90 | lz4-c=1.9.4=h2bbff1b_0 91 | markdown-it-py=3.0.0=pypi_0 92 | markupsafe=2.1.1=py39h2bbff1b_0 93 | matplotlib=3.7.2=pypi_0 94 | mdurl=0.1.2=pypi_0 95 | mkl=2023.1.0=h8bd8f75_46356 96 | mkl-service=2.4.0=py39h2bbff1b_1 97 | mkl_fft=1.3.6=py39hf11a4ad_1 98 | mkl_random=1.2.2=py39hf11a4ad_1 99 | mpg123=1.31.3=h63175ca_0 100 | mpmath=1.3.0=py39haa95532_0 101 | multidict=6.0.4=pypi_0 102 | networkx=3.1=py39haa95532_0 103 | numpy=1.25.0=py39h055cbcc_0 104 | numpy-base=1.25.0=py39h65a83cf_0 105 | opencv-python=4.8.0.74=pypi_0 106 | openssl=3.0.10=h2bbff1b_2 107 | ordered-set=4.1.0=pypi_0 108 | packaging=23.1=pypi_0 109 | pandas=2.0.3=pypi_0 110 | pathtools=0.1.2=pypi_0 111 | pillow=9.4.0=py39hd77b12b_0 112 | pip=23.2.1=py39haa95532_0 113 | protobuf=4.23.4=pypi_0 114 | psutil=5.9.5=pypi_0 115 | pycparser=2.21=pyhd3eb1b0_0 116 | pydantic=1.10.12=pypi_0 117 | pygments=2.15.1=pypi_0 118 | pyjwt=2.8.0=pypi_0 119 | pyopenssl=23.2.0=py39haa95532_0 120 | pyparsing=3.0.9=pypi_0 121 | pysocks=1.7.1=py39haa95532_0 122 | pysoundfile=0.12.1=pyhd8ed1ab_0 123 | python=3.9.17=h1aa4202_0 124 | python-dateutil=2.8.2=pypi_0 125 | python-editor=1.0.4=pypi_0 126 | python-multipart=0.0.6=pypi_0 127 | pytorch=2.0.1=py3.9_cuda11.7_cudnn8_0 128 | pytorch-cuda=11.7=h16d0643_5 129 | pytorch-lightning=1.9.5=pypi_0 130 | pytorch-mutex=1.0=cuda 131 | pytz=2023.3=pypi_0 132 | pyyaml=6.0.1=pypi_0 133 | readchar=4.0.5=pypi_0 134 | requests=2.31.0=py39haa95532_0 135 | rich=13.5.2=pypi_0 136 | scikit-learn=1.3.0=pypi_0 137 | scipy=1.11.1=pypi_0 138 | seaborn=0.12.2=pypi_0 139 | sentry-sdk=1.29.2=pypi_0 140 | setproctitle=1.3.2=pypi_0 141 | setuptools=68.0.0=py39haa95532_0 142 | six=1.16.0=pypi_0 143 | sklearn=0.0.post7=pypi_0 144 | smmap=5.0.0=pypi_0 145 | sniffio=1.3.0=pypi_0 146 | soupsieve=2.4.1=pypi_0 147 | sqlite=3.41.2=h2bbff1b_0 148 | starlette=0.22.0=pypi_0 149 | starsessions=1.3.0=pypi_0 150 | sympy=1.11.1=py39haa95532_0 151 | tbb=2021.8.0=h59b6b97_0 152 | threadpoolctl=3.2.0=pypi_0 153 | tk=8.6.12=h2bbff1b_0 154 | torchaudio=2.0.2=pypi_0 155 | torchinfo=1.8.0=pypi_0 156 | torchmetrics=1.0.1=pypi_0 157 | torchvision=0.15.2=pypi_0 158 | tqdm=4.65.0=pypi_0 159 | traitlets=5.9.0=pypi_0 160 | typing_extensions=4.7.1=py39haa95532_0 161 | tzdata=2023.3=pypi_0 162 | ucrt=10.0.22621.0=h57928b3_0 163 | urllib3=1.26.16=py39haa95532_0 164 | uvicorn=0.23.2=pypi_0 165 | vc=14.2=h21ff451_1 166 | vc14_runtime=14.36.32532=hfdfe4a8_17 167 | vs2015_runtime=14.36.32532=h05e6639_17 168 | wandb=0.15.8=pypi_0 169 | wcwidth=0.2.6=pypi_0 170 | websocket-client=1.6.1=pypi_0 171 | websockets=11.0.3=pypi_0 172 | wheel=0.38.4=py39haa95532_0 173 | win_inet_pton=1.1.0=py39haa95532_0 174 | xz=5.4.2=h8cc25b3_0 175 | yarl=1.9.2=pypi_0 176 | zipp=3.16.2=pypi_0 177 | zlib=1.2.13=h8cc25b3_0 178 | zstd=1.5.5=hd43e919_0 179 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from pytorch_lightning import LightningDataModule 4 | import librosa as lb 5 | import numpy as np 6 | import os 7 | import random 8 | 9 | TUT_LABELS = ["fan", "pump", "slider", "ToyCar", "ToyConveyor", "valve"] 10 | TUT_LABELS_INT = [0, 7, 13, 20, 27, 34] 11 | 12 | class TUTDataset(Dataset): 13 | 14 | def __init__(self, path_data, list_files, sample_rate, duration): 15 | # Initialization for path, fs, list of files etc. 16 | self.sample_rate = sample_rate 17 | self.path_data = path_data 18 | self.list_files = list_files 19 | self.duration = duration 20 | 21 | def __getitem__(self, index): 22 | # select an audio from the list and return metadata and label 23 | file_name = self.list_files[index] 24 | if ".npy" in file_name: 25 | audio_data = np.load(file_name) 26 | else: 27 | audio_data , _ = lb.load(file_name, sr = self.sample_rate, res_type = "polyphase") 28 | if len(audio_data) > int(self.duration * self.sample_rate): 29 | audio_data = audio_data[:int(self.duration * self.sample_rate)] 30 | 31 | metadata = 0 32 | numerical_label = 0 33 | for i, name_class in enumerate(TUT_LABELS): 34 | if name_class in file_name: 35 | metadata = TUT_LABELS_INT[i] 36 | numerical_label = TUT_LABELS_INT[i] 37 | break 38 | 39 | metadata += int(file_name.split('_')[2]) # in this way I have "name_id", for example fan_01 where 01 is the id of the machine 40 | label = 0 if "normal" in file_name else 1 41 | return audio_data , metadata, label, numerical_label 42 | 43 | def __len__(self): 44 | # return length of the list 45 | return len(self.list_files) 46 | 47 | class TUTDatamodule(LightningDataModule): 48 | 49 | def __init__(self, path_train, path_test, sample_rate, duration, percentage_val, batch_size): 50 | super().__init__() 51 | self.path_train = path_train 52 | self.path_test = path_test 53 | self.sample_rate = sample_rate 54 | self.duration = duration 55 | self.percentage_val = percentage_val 56 | self.batch_size = batch_size 57 | # split of the dataset 58 | self.train_list = self.scan_all_dir(self.path_train) 59 | self.val_list = random.sample(self.train_list, int(len(self.train_list) * percentage_val)) 60 | self.train_list = set(self.train_list) 61 | self.val_list = set(self.val_list) 62 | self.train_list -= self.val_list 63 | self.train_list = list(self.train_list) 64 | self.val_list = list(self.val_list) 65 | 66 | self.test_list = self.scan_all_dir(self.path_test) 67 | 68 | 69 | def scan_all_dir(self, path): 70 | list_all_files = [] 71 | for root, dirs, files in os.walk(path): 72 | for file in files: 73 | list_all_files.append(str(root + "\\" + file)) 74 | return list_all_files 75 | 76 | def setup(self, stage = None): 77 | # Nothing to do 78 | pass 79 | 80 | def prepare_data(self): 81 | # Nothing to do 82 | pass 83 | 84 | def train_dataloader(self): 85 | # return the dataloader containing training data 86 | train_split = TUTDataset(path_data = self.path_train, list_files = self.train_list, sample_rate = self.sample_rate, duration = self.duration) 87 | return DataLoader(train_split, batch_size = self.batch_size, shuffle = True) 88 | 89 | def val_dataloader(self): 90 | # return the dataloader containing validation data 91 | val_split = TUTDataset(path_data = self.path_train, list_files = self.val_list, sample_rate = self.sample_rate, duration = self.duration) 92 | return DataLoader(val_split, batch_size = self.batch_size, shuffle = False) 93 | 94 | def test_dataloader(self): 95 | # return the dataloader containing testing data 96 | test_split = TUTDataset(path_data = self.path_test, list_files = self.test_list, sample_rate = self.sample_rate, duration = self.duration) 97 | return DataLoader(test_split, batch_size = self.batch_size, shuffle = True) 98 | 99 | ## TEST FUNCTION ## 100 | if __name__ == "__main__": 101 | path_train = "TUT Anomaly detection/train" # path for training audio 102 | path_test = "TUT Anomaly detection/test" # path for test audio 103 | percentage_val = 0.2 104 | sample_rate = 16000 105 | batch_size = 64 106 | duration = 10 107 | # create lightning datamodule 108 | datamodule = TUTDatamodule(path_train = path_train, path_test = path_test, sample_rate = sample_rate, duration = duration, 109 | percentage_val = percentage_val, batch_size = batch_size) 110 | dataloader_train = datamodule.train_dataloader() 111 | print(len(dataloader_train)) 112 | dataloader_val = datamodule.val_dataloader() 113 | print(len(dataloader_val)) 114 | dataloader_test = datamodule.test_dataloader() 115 | print(len(dataloader_test)) 116 | print(next(iter(dataloader_test))) -------------------------------------------------------------------------------- /images/ProposedASD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/ProposedASD.jpg -------------------------------------------------------------------------------- /images/ProposedASD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/ProposedASD.png -------------------------------------------------------------------------------- /images/attentions.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/attentions.pdf -------------------------------------------------------------------------------- /images/spec-heat-ave-std.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-heat-ave-std.pdf -------------------------------------------------------------------------------- /images/spec-heat.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-heat.pdf -------------------------------------------------------------------------------- /images/spec-mel.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-mel.pdf -------------------------------------------------------------------------------- /images/spec-wav.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-wav.pdf -------------------------------------------------------------------------------- /mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | class Bottleneck(nn.Module): 6 | def __init__(self, inp, oup, stride, expansion): 7 | super(Bottleneck, self).__init__() 8 | self.connect = stride == 1 and inp == oup 9 | # 10 | self.conv = nn.Sequential( 11 | # pw 12 | nn.Conv2d(inp, inp * expansion, 1, 1, 0, bias=False), 13 | nn.BatchNorm2d(inp * expansion), 14 | nn.PReLU(inp * expansion), 15 | # dw 16 | nn.Conv2d(inp * expansion, inp * expansion, 3, stride, 1, groups=inp * expansion, bias=False), 17 | nn.BatchNorm2d(inp * expansion), 18 | nn.PReLU(inp * expansion), 19 | 20 | # pw-linear 21 | nn.Conv2d(inp * expansion, oup, 1, 1, 0, bias=False), 22 | nn.BatchNorm2d(oup), 23 | ) 24 | 25 | def forward(self, x): 26 | if self.connect: 27 | return x + self.conv(x) 28 | else: 29 | return self.conv(x) 30 | 31 | 32 | class ConvBlock(nn.Module): 33 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False): 34 | super(ConvBlock, self).__init__() 35 | self.linear = linear 36 | if dw: 37 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False) 38 | else: 39 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False) 40 | self.bn = nn.BatchNorm2d(oup) 41 | if not linear: 42 | self.prelu = nn.PReLU(oup) 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.bn(x) 47 | if self.linear: 48 | return x 49 | else: 50 | return self.prelu(x) 51 | 52 | 53 | #https://dcase.community/documents/challenge2022/technical_reports/DCASE2022_Liu_8_t2.pdf 54 | Mobilefacenet_bottleneck_setting = [ 55 | # t, c , n ,s 56 | [2, 128, 2, 2], 57 | [4, 128, 2, 2], 58 | [4, 128, 2, 2], 59 | ] 60 | 61 | 62 | class MobileFaceNet(nn.Module): 63 | def __init__(self, 64 | num_class, 65 | bottleneck_setting=Mobilefacenet_bottleneck_setting): 66 | super(MobileFaceNet, self).__init__() 67 | 68 | self.conv1 = ConvBlock(2, 64, 3, 2, 1) 69 | 70 | self.dw_conv1 = ConvBlock(64, 64, 3, 1, 1, dw=True) 71 | 72 | self.inplanes = 64 73 | block = Bottleneck 74 | self.blocks = self._make_layer(block, bottleneck_setting) 75 | 76 | self.conv2 = ConvBlock(bottleneck_setting[-1][1], 512, 1, 1, 0) 77 | # 20(10), 4(2), 8(4) 78 | self.linear7 = ConvBlock(512, 512, (8, 20), 1, 0, dw=True, linear=True) 79 | 80 | self.linear1 = ConvBlock(512, 128, 1, 1, 0, linear=True) 81 | 82 | self.fc_out = nn.Linear(128, num_class) 83 | # init 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 87 | m.weight.data.normal_(0, math.sqrt(2. / n)) 88 | elif isinstance(m, nn.BatchNorm2d): 89 | m.weight.data.fill_(1) 90 | m.bias.data.zero_() 91 | 92 | def _make_layer(self, block, setting): 93 | layers = [] 94 | for t, c, n, s in setting: 95 | for i in range(n): 96 | if i == 0: 97 | layers.append(block(self.inplanes, c, s, t)) 98 | else: 99 | layers.append(block(self.inplanes, c, 1, t)) 100 | self.inplanes = c 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | x = self.conv1(x) 105 | x = self.dw_conv1(x) 106 | x = self.blocks(x) 107 | x = self.conv2(x) 108 | x = self.linear7(x) 109 | x = self.linear1(x) 110 | feature = x.view(x.size(0), -1) 111 | out = self.fc_out(feature) 112 | return out, feature 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_lightning import LightningModule 4 | from torchmetrics import Accuracy 5 | from torchinfo import summary 6 | import torchaudio.transforms as T 7 | from utils import ArcMarginProduct 8 | import separableconv.nn as sep 9 | import numpy as np 10 | from mobilenet import MobileFaceNet 11 | 12 | class AE_DCASEBaseline(nn.Module): 13 | def __init__(self) -> None: 14 | super(AE_DCASEBaseline, self).__init__() 15 | self.frames = 5 16 | self.n_mels = 128 17 | self.t_bins = 1 + (10 * 16000) // 512 18 | self.vector_array_size = self.t_bins - self.frames + 1 19 | self.transform_tf = T.MelSpectrogram(sample_rate=16000, 20 | n_fft=1024, 21 | win_length=1024, 22 | hop_length=512, 23 | center=True, 24 | pad_mode="reflect", 25 | power=2.0, 26 | norm="slaney", 27 | n_mels=self.n_mels, 28 | mel_scale="htk", 29 | ) 30 | 31 | self.encoder = nn.Sequential( 32 | nn.Linear(in_features = 640, out_features = 128), 33 | nn.BatchNorm1d(self.vector_array_size), 34 | nn.ReLU(), 35 | nn.Linear(in_features = 128, out_features = 128), 36 | nn.BatchNorm1d(self.vector_array_size), 37 | nn.ReLU(), 38 | nn.Linear(in_features = 128, out_features = 128), 39 | nn.BatchNorm1d(self.vector_array_size), 40 | nn.ReLU(), 41 | nn.Linear(in_features = 128, out_features = 128), 42 | nn.BatchNorm1d(self.vector_array_size), 43 | nn.ReLU() 44 | ) 45 | 46 | self.bottleneck = nn.Sequential( 47 | nn.Linear(in_features = 128, out_features = 8), 48 | nn.BatchNorm1d(self.vector_array_size), 49 | nn.ReLU(), 50 | ) 51 | 52 | self.decoder = nn.Sequential( 53 | nn.Linear(in_features = 8, out_features = 128), 54 | nn.BatchNorm1d(self.vector_array_size), 55 | nn.ReLU(), 56 | nn.Linear(in_features = 128, out_features = 128), 57 | nn.BatchNorm1d(self.vector_array_size), 58 | nn.ReLU(), 59 | nn.Linear(in_features = 128, out_features = 128), 60 | nn.BatchNorm1d(self.vector_array_size), 61 | nn.ReLU(), 62 | nn.Linear(in_features = 128, out_features = 128), 63 | nn.BatchNorm1d(self.vector_array_size), 64 | nn.ReLU(), 65 | nn.Linear(in_features = 128, out_features = 640) 66 | ) 67 | 68 | # WEIGHTS INIT LIKE KERAS 69 | nn.init.xavier_uniform_(self.encoder[0].weight) 70 | nn.init.xavier_uniform_(self.encoder[3].weight) 71 | nn.init.xavier_uniform_(self.encoder[6].weight) 72 | nn.init.xavier_uniform_(self.encoder[9].weight) 73 | nn.init.xavier_uniform_(self.bottleneck[0].weight) 74 | nn.init.xavier_uniform_(self.decoder[0].weight) 75 | nn.init.xavier_uniform_(self.decoder[3].weight) 76 | nn.init.xavier_uniform_(self.decoder[6].weight) 77 | nn.init.xavier_uniform_(self.decoder[9].weight) 78 | nn.init.xavier_uniform_(self.decoder[-1].weight) 79 | 80 | # BIAS INIT LIKE KERAS 81 | nn.init.zeros_(self.encoder[0].bias) 82 | nn.init.zeros_(self.encoder[3].bias) 83 | nn.init.zeros_(self.encoder[6].bias) 84 | nn.init.zeros_(self.encoder[9].bias) 85 | nn.init.zeros_(self.bottleneck[0].bias) 86 | nn.init.zeros_(self.decoder[0].bias) 87 | nn.init.zeros_(self.decoder[3].bias) 88 | nn.init.zeros_(self.decoder[6].bias) 89 | nn.init.zeros_(self.decoder[9].bias) 90 | nn.init.zeros_(self.decoder[-1].bias) 91 | 92 | def forward(self, x): 93 | x = self.preprocessing(x) 94 | original = x 95 | x = self.encoder(x) 96 | x = self.bottleneck(x) 97 | x = self.decoder(x) 98 | return original, x 99 | 100 | def preprocessing(self, x): 101 | # compute mel spectrogram 102 | batch_size = x.size()[0] 103 | x = self.transform_tf(x) 104 | x = 10 * torch.log10(x + 1e-8) 105 | vector_dim = self.frames * self.n_mels 106 | 107 | feature_vector = torch.zeros((batch_size, self.vector_array_size, vector_dim)).to(x.device) 108 | for batch in range(batch_size): 109 | for i in range(self.frames): 110 | feature_vector[batch, :, self.n_mels * i: self.n_mels * (i + 1)] = x[batch, :, i: i + self.vector_array_size].T 111 | 112 | return feature_vector 113 | 114 | 115 | ##### WAVEGRAM + MEL SPEC + ATTENTION MODULE + MOBILENETV2 116 | class Wavegram_AttentionModule(nn.Module): 117 | def __init__(self, h): 118 | super(Wavegram_AttentionModule, self).__init__() 119 | self.h = h 120 | # sep-wavegram 121 | self.wavegram = sep.SeparableConv1d(in_channels = 1, out_channels = 128, kernel_size = 1024, stride = 512, padding = 512) 122 | # Mel filterbank 123 | self.transform_tf = T.MelSpectrogram(sample_rate=16000, 124 | n_fft=1024, 125 | win_length=1024, 126 | hop_length=512, 127 | center=True, 128 | pad_mode="reflect", 129 | power=2.0, 130 | norm="slaney", 131 | n_mels=128, 132 | mel_scale="htk", 133 | ) 134 | 135 | # attention module 136 | self.heatmap = nn.Sequential( 137 | sep.SeparableConv2d(in_channels = 2, out_channels = 16, 138 | kernel_size = (3,3), padding = "same", bias = False), 139 | nn.BatchNorm2d(16), 140 | nn.ELU(), 141 | sep.SeparableConv2d(in_channels = 16, out_channels = 64, 142 | kernel_size = (3,3), padding = "same", bias = False), 143 | nn.BatchNorm2d(64), 144 | nn.ELU(), 145 | sep.SeparableConv2d(in_channels = 64, out_channels = 2, kernel_size = 1, padding = "same"), 146 | nn.Sigmoid() 147 | ) 148 | # classifier 149 | self.classifier = MobileFaceNet(num_class = 41) 150 | self.arcface = ArcMarginProduct(in_features = self.h, out_features = 41, s = 40, m = 0.7) 151 | 152 | def forward(self, x, metadata): 153 | # compute mel spectrogram 154 | x_spec = self.transform_tf(x) 155 | x_spec = 10*torch.log10(x_spec + 1e-8) 156 | # compute wavegram 157 | x = x.unsqueeze(1) 158 | x = self.wavegram(x) 159 | x = torch.stack((x_spec, x), dim = 1) 160 | reppr = x 161 | heatmap = self.heatmap(x) 162 | x = x * heatmap 163 | out, features = self.classifier(x) 164 | x = self.arcface(features, metadata) 165 | return x, out, reppr, features, heatmap 166 | 167 | class Wavegram_AttentionMap(LightningModule): 168 | 169 | def __init__(self, h, lr): 170 | super().__init__() 171 | self.h = h 172 | self.model = Wavegram_AttentionModule(self.h) 173 | self.lr = lr 174 | 175 | self.accuracy_training = Accuracy(task="multiclass", num_classes=41) 176 | self.accuracy_val = Accuracy(task="multiclass", num_classes=41) 177 | self.accuracy_test = Accuracy(task="multiclass", num_classes=41) 178 | self.criterion = nn.CrossEntropyLoss() 179 | # to save threshold and errors at init 180 | self.errors_list = [] 181 | self.clean_errors = [] 182 | self.anomaly_errors = [] 183 | self.labels = [] 184 | self.classes = [] 185 | 186 | def mixup_data(self, x, y, alpha=0.2): 187 | y = torch.nn.functional.one_hot(y, num_classes = 41) 188 | if alpha > 0: 189 | lam = np.random.beta(alpha, alpha) 190 | else: 191 | lam = 1 192 | batch_size = x.size()[0] 193 | index = torch.randperm(batch_size) 194 | mixed_x = lam * x + (1 - lam) * x[index, :] 195 | y_a, y_b = y, y[index] 196 | return mixed_x, y_a.float(), y_b.float(), lam 197 | 198 | def forward(self, x, labels): 199 | return self.model(x, labels) 200 | 201 | def mixup_criterion_arcmix(self, pred, y_a, y_b, lam): 202 | loss1 = lam * self.criterion(pred, y_a) 203 | loss2 = (1 - lam) * self.criterion(pred, y_b) 204 | return loss1+loss2 205 | 206 | def training_step(self, batch, batch_idx): 207 | x, metadata, _, _ = batch 208 | # for training step 209 | mixed_x, y_a, y_b, lam = self.mixup_data(x, metadata) 210 | predicted, _, _, _, _ = self.forward(mixed_x, metadata) 211 | loss = self.mixup_criterion_arcmix(predicted, y_a, y_b, lam) 212 | self.log("train/loss_class", loss, on_epoch = True, on_step = True, prog_bar = True) 213 | self.accuracy_training(predicted, metadata) 214 | self.log("train/acc", self.accuracy_training, on_epoch = True, on_step = False) 215 | return loss 216 | 217 | def validation_step(self, batch, batch_idx): 218 | x, metadata, _, _ = batch 219 | predicted, _, _, _, _ = self.forward(x, metadata) 220 | loss = torch.nn.functional.cross_entropy(predicted, metadata, reduction = "mean") 221 | self.log("val/loss_class", loss, on_epoch = True, on_step = False, prog_bar = True) 222 | self.accuracy_val(predicted, metadata) 223 | self.log("val/acc", self.accuracy_val, on_epoch = True, on_step = False) 224 | return loss 225 | 226 | def test_step(self, batch, batch_idx): 227 | x, metadata, label, _ = batch 228 | predicted, _, _, _, _ = self.forward(x, metadata) 229 | loss = torch.nn.functional.cross_entropy(predicted, metadata, reduction = "mean") 230 | self.log("test/loss_class", loss, on_epoch = True, on_step = False, prog_bar = True) 231 | self.accuracy_test(predicted, metadata) 232 | self.log("test/acc", self.accuracy_test, on_epoch = True, on_step = False) 233 | class_loss_batchwise = nn.functional.cross_entropy(predicted, metadata, reduction = "none") 234 | errors = class_loss_batchwise 235 | self.errors_list.append(errors) 236 | self.clean_errors.append(errors[label == 0]) 237 | self.anomaly_errors.append(errors[label == 1]) 238 | self.labels.append(label) 239 | self.classes.append(metadata) 240 | return loss 241 | 242 | def configure_optimizers(self): 243 | opt = torch.optim.AdamW(self.parameters(), lr = self.lr) 244 | # return opt 245 | return { 246 | "optimizer": opt, 247 | "lr_scheduler": { 248 | "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=300, eta_min=0.1*float(self.lr)) 249 | }, 250 | } 251 | 252 | 253 | # TEST FUNCTION 254 | if __name__ == "__main__": 255 | example_input = torch.rand(16, 160000) # dummy audio 256 | model = Wavegram_AttentionModule() 257 | metadata = torch.nn.functional.one_hot(torch.randint(low = 0, high = 41, size =(16,)), num_classes=41) 258 | output = model(example_input, metadata) 259 | print(output) 260 | summary(model, input_data = example_input) 261 | 262 | -------------------------------------------------------------------------------- /other_data/all_labels.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/all_labels.npz -------------------------------------------------------------------------------- /other_data/allerrors.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/allerrors.npz -------------------------------------------------------------------------------- /other_data/anomalouserrors.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/anomalouserrors.npz -------------------------------------------------------------------------------- /other_data/cleanerrors.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/cleanerrors.npz -------------------------------------------------------------------------------- /other_data/metadata.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/metadata.npz -------------------------------------------------------------------------------- /training_val_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data import TUTDatamodule 3 | from pytorch_lightning.loggers import WandbLogger 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | import wandb 7 | import torch 8 | from optparse import OptionParser 9 | from sklearn import metrics 10 | from model import Wavegram_AttentionMap 11 | 12 | 13 | def transform_string_list(option, opt, value, parser): 14 | list_of_int = [int(i) for i in value.split(",")] 15 | setattr(parser.values, option.dest, list_of_int) 16 | 17 | 18 | def train(configs): 19 | # definition of the logger 20 | tags = ["MobileNetWavegram"] 21 | 22 | # Machine type if we want to train and test on a single machine 23 | if "valve" in configs.path_train: 24 | tags.append("Valve") 25 | elif "pump" in configs.path_train: 26 | tags.append("Pump") 27 | elif "slider" in configs.path_train: 28 | tags.append("Slider") 29 | elif "ToyCar" in configs.path_train: 30 | tags.append("ToyCar") 31 | elif "ToyConveyor" in configs.path_train: 32 | tags.append("ToyConveyor") 33 | elif "fan" in configs.path_train: 34 | tags.append("Fan") 35 | else: 36 | tags.append("All") 37 | wandb_logger = WandbLogger(project="Unsupervised Audio Anomaly Detection", config = configs, name=configs.name, tags = tags) 38 | 39 | checkpoint_callback = ModelCheckpoint( 40 | dirpath = "best_models/", 41 | filename = "MobileNetWavegram-Mixup-{epoch:02d}-{val/loss_class:.4f}", # <--- note epoch suffix here 42 | save_last = True, 43 | every_n_epochs = 20, 44 | save_top_k = -1, 45 | auto_insert_metric_name=False 46 | ) 47 | 48 | # definition of the trainer 49 | trainer = Trainer(accelerator="gpu", devices = 1 , max_epochs = configs.epochs, logger=wandb_logger, callbacks = [checkpoint_callback]) 50 | 51 | # definition of the datamodule 52 | datamodule = TUTDatamodule(path_train = configs.path_train, path_test = configs.path_test, sample_rate = configs.sr, duration = configs.duration, 53 | percentage_val = configs.percentage, batch_size = configs.batch_size) 54 | 55 | # definition of the model 56 | model = Wavegram_AttentionMap(lr = configs.lr, h = 128) 57 | 58 | wandb_logger.watch(model, log_graph=False) 59 | trainer.fit(model, datamodule) 60 | return model, trainer, datamodule 61 | 62 | def val(trained_model:Wavegram_AttentionMap, trainer:Trainer, datamodule:TUTDatamodule): 63 | # inference on the validation set 64 | trainer.validate(trained_model, datamodule) 65 | return trained_model, trainer, datamodule 66 | 67 | def test(trained_model:Wavegram_AttentionMap, trainer:Trainer, datamodule:TUTDatamodule): 68 | # inference for tests 69 | # make list of files insider "best_models" folder 70 | list_checkpoints = datamodule.scan_all_dir("best_models/") 71 | best_performance = 0 72 | best_checkpoint = None 73 | for i, checkpoint in enumerate(list_checkpoints): 74 | print("{}) {}".format(i, checkpoint)) 75 | trained_model = Wavegram_AttentionMap.load_from_checkpoint(checkpoint, lr = 0.001, h = 128) # dummy lr 76 | trained_model.errors_list = [] 77 | trained_model.anomaly_errors = [] 78 | trained_model.clean_errors = [] 79 | trainer.test(trained_model, datamodule) 80 | 81 | all_errors = torch.cat(trained_model.errors_list).flatten().cpu().numpy() 82 | all_labels = torch.cat(trained_model.labels).flatten().cpu().numpy() 83 | metadata = np.array(torch.cat(trained_model.classes).tolist()) 84 | performance = metrics.roc_auc_score(all_labels, all_errors) 85 | if performance > best_performance: 86 | best_performance = performance 87 | best_checkpoint = checkpoint 88 | print("Best : {} %".format(best_performance)) 89 | 90 | # here we select the best model 91 | trained_model = Wavegram_AttentionMap.load_from_checkpoint(best_checkpoint, lr = 0.001, h = 128) # dummy lr 92 | trained_model.errors_list = [] 93 | trained_model.anomaly_errors = [] 94 | trained_model.clean_errors = [] 95 | trainer.test(trained_model, datamodule) 96 | 97 | all_errors = torch.cat(trained_model.errors_list).flatten().cpu().numpy() 98 | all_labels = torch.cat(trained_model.labels).flatten().cpu().numpy() 99 | metadata = np.array(torch.cat(trained_model.classes).tolist()) 100 | 101 | np.savez("metadata", metadata) 102 | np.savez("allerrors", all_errors) 103 | np.savez("all_labels", all_labels) 104 | wandb.log({"test/AUROC": metrics.roc_auc_score(all_labels, all_errors)}) 105 | wandb.log({"test/pAUROC": metrics.roc_auc_score(all_labels, all_errors, max_fpr = 0.1)}) 106 | 107 | print("Global AUC : {}".format(metrics.roc_auc_score(all_labels, all_errors))) 108 | print("Global pAUC: {}".format(metrics.roc_auc_score(all_labels, all_errors, max_fpr = 0.1))) 109 | 110 | # fan class 111 | mask = metadata < 7 112 | selected_label = all_labels[mask] 113 | selected_errors = all_errors[mask] 114 | 115 | print("Fan AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors))) 116 | print("Fan pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1))) 117 | 118 | # pump class 119 | mask = np.logical_and(metadata >= 7, metadata < 13) 120 | selected_label = all_labels[mask] 121 | selected_errors = all_errors[mask] 122 | 123 | print("Pump AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors))) 124 | print("Pump pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1))) 125 | 126 | 127 | # slider class 128 | mask = np.logical_and(metadata >= 13, metadata < 20) 129 | selected_label = all_labels[mask] 130 | selected_errors = all_errors[mask] 131 | 132 | print("Slider AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors))) 133 | print("Slider pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1))) 134 | 135 | # ToyCar class 136 | mask = np.logical_and(metadata >= 20, metadata < 27) 137 | selected_label = all_labels[mask] 138 | selected_errors = all_errors[mask] 139 | 140 | print("ToyCar AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors))) 141 | print("ToyCar pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1))) 142 | 143 | 144 | # ToyConveyor class 145 | mask = np.logical_and(metadata >= 27, metadata < 34) 146 | selected_label = all_labels[mask] 147 | selected_errors = all_errors[mask] 148 | 149 | print("ToyConveyor AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors))) 150 | print("ToyConveyor pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1))) 151 | 152 | # valve class 153 | mask = metadata >= 34 154 | selected_label = all_labels[mask] 155 | selected_errors = all_errors[mask] 156 | 157 | print("Valve AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors))) 158 | print("Valve pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1))) 159 | 160 | clean_errors = np.array(torch.cat(trained_model.clean_errors).tolist()) 161 | anomalous_errors = np.array(torch.cat(trained_model.anomaly_errors).tolist()) 162 | np.savez("cleanerrors", clean_errors) 163 | np.savez("anomalouserrors", anomalous_errors) 164 | wandb.finish() 165 | 166 | import matplotlib.pyplot as plt 167 | plt.rcParams["font.family"] = "Times New Roman" 168 | plt.rcParams['font.size'] = 28 169 | # plt.rcParams['text.usetex'] = True <--- If you have latex installed in your machine, you can use it on matplotlib 170 | errors_clean = np.load("cleanerrors.npz")['arr_0'] 171 | errors_anomalous = np.load("anomalouserrors.npz")['arr_0'] 172 | print(errors_clean.shape) 173 | print(errors_anomalous.shape) 174 | plt.figure(figsize = (8, 12)) 175 | plt.hist(errors_clean, bins = 200, alpha = 0.4, label = "Clear", color = "g") 176 | plt.hist(errors_anomalous, bins = 200, alpha = 0.4, label = "Anomalous", color = "r") 177 | plt.grid() 178 | plt.legend() 179 | plt.title("Histograms of errors") 180 | plt.xlabel("Error") 181 | plt.ylabel("Occurrences") 182 | plt.show() 183 | 184 | 185 | 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = OptionParser() 190 | # dataset parameters 191 | parser.add_option("--pathtrain", dest="path_train", 192 | help="path containing training files", default = "TUT Anomaly detection/train") 193 | parser.add_option("--pathtest", dest="path_test", 194 | help="path containing training files", default = "TUT Anomaly detection/test") 195 | parser.add_option("--percentage", dest = "percentage", 196 | help = "percentage of validation samples", default = 0.05, type = float) 197 | parser.add_option("--sr", dest = "sr", 198 | help = "target sample rate", default = 16000, type = int) 199 | parser.add_option("--duration", dest = "duration", 200 | help = "duration, in seconds, of audios", default = 10, type = float) 201 | 202 | 203 | # training parameters 204 | parser.add_option("--lr", dest = "lr", 205 | help = "learning rate", default = 0.0001, type = float) 206 | parser.add_option("--epochs", dest = "epochs", 207 | help = "number of epochs", default = 300, type = int) 208 | parser.add_option("--name", dest = "name", 209 | help = "name of the run on wandb", default = "MobileFaceNet + separable wavegram + attention module + mixup 0.2 Noisy-ArcMix s 40 m 0.7") 210 | parser.add_option("--batch_size", dest = "batch_size", 211 | help = "batch size for dataloader", default = 64, type = int) 212 | 213 | 214 | options, remainder = parser.parse_args() 215 | print(options) 216 | trained_model, trainer, datamodule = train(options) 217 | trained_model, trainer, datamodule = val(trained_model, trainer, datamodule) 218 | test(trained_model, trainer, datamodule) 219 | 220 | 221 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import separableconv.nn as sep 4 | import math 5 | 6 | 7 | class ArcMarginProduct(nn.Module): 8 | def __init__(self, in_features=128, out_features=200, s=40.0, m=0.3, sub=1, easy_margin=False): 9 | super(ArcMarginProduct, self).__init__() 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.s = s 13 | self.m = m 14 | self.sub = sub 15 | self.weight = torch.nn.Parameter(torch.Tensor(out_features * sub, in_features)) 16 | nn.init.xavier_uniform_(self.weight) 17 | 18 | self.easy_margin = easy_margin 19 | self.cos_m = math.cos(m) 20 | self.sin_m = math.sin(m) 21 | 22 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] 23 | self.th = math.cos(math.pi - m) 24 | self.mm = math.sin(math.pi - m) * m 25 | 26 | def forward(self, x, label): 27 | # label = torch.nn.functional.one_hot(label, num_classes=self.out_features) 28 | cosine = torch.nn.functional.linear(torch.nn.functional.normalize(x), torch.nn.functional.normalize(self.weight)) 29 | 30 | if self.sub > 1: 31 | cosine = cosine.view(-1, self.out_features, self.sub) 32 | cosine, _ = torch.max(cosine, dim=2) 33 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 34 | phi = cosine * self.cos_m - sine * self.sin_m 35 | 36 | if self.easy_margin: 37 | phi = torch.where(cosine > 0, phi, cosine) 38 | else: 39 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 40 | 41 | if len(label.shape) == 1: 42 | one_hot = torch.zeros(cosine.size(), device=x.device) 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | else: 45 | one_hot = label 46 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 47 | output = output * self.s 48 | return output --------------------------------------------------------------------------------