├── .gitignore
├── LICENSE
├── MWM_93.44_s40_m0.7.ckpt
├── README.md
├── best_models
└── README.md
├── conda_requirements.txt
├── data.py
├── images
├── ProposedASD.jpg
├── ProposedASD.png
├── attentions.pdf
├── spec-heat-ave-std.pdf
├── spec-heat.pdf
├── spec-mel.pdf
└── spec-wav.pdf
├── mobilenet.py
├── model.py
├── other_data
├── all_labels.npz
├── allerrors.npz
├── anomalouserrors.npz
├── cleanerrors.npz
└── metadata.npz
├── playground.ipynb
├── training_val_test.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Michael Neri
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MWM_93.44_s40_m0.7.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/MWM_93.44_s40_m0.7.ckpt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Low-complexity Unsupervised Audio Anomaly Detection exploiting Separable Convolutions and Angular Loss
2 |
3 |
4 |
5 | Official repository of the work "Low-complexity Unsupervised Audio Anomaly Detection exploiting Separable Convolutions and Angular Loss" published to IEEE Sensors Letters.
6 |
7 | ## Authors
8 | Michael Neri, Marco Carli
9 |
10 | Department of Industrial, Electronic, and Mechanical Engineering, Roma Tre University, Rome, Italy
11 |
12 |
13 | ----------------------------------------------------------------------------
14 |
15 | If you use any part of this work please cite the following reference:
16 |
17 | ```
18 | @ARTICLE{Neri_LSENS_2024,
19 | author={Neri, M. and Carli, M.},
20 | journal={IEEE Sensors Letters},
21 | title={{Low-complexity Unsupervised Audio Anomaly Detection exploiting Separable Convolutions and Angular Loss}},
22 | year={2024},
23 | volume={},
24 | number={},
25 | pages={},
26 | doi={10.1109/LSENS.2024.3480450}
27 | }
28 | ```
29 |
--------------------------------------------------------------------------------
/best_models/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain all the intermediate models. Remove this file.
--------------------------------------------------------------------------------
/conda_requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: win-64
4 | aiohttp=3.8.5=pypi_0
5 | aiosignal=1.3.1=pypi_0
6 | ansicon=1.89.0=pypi_0
7 | anyio=3.7.1=pypi_0
8 | appdirs=1.4.4=pypi_0
9 | arrow=1.2.3=pypi_0
10 | async-timeout=4.0.2=pypi_0
11 | attrs=23.1.0=pypi_0
12 | beautifulsoup4=4.12.2=pypi_0
13 | blas=1.0=mkl
14 | blessed=1.20.0=pypi_0
15 | brotlipy=0.7.0=py39h2bbff1b_1003
16 | ca-certificates=2023.05.30=haa95532_0
17 | certifi=2023.7.22=py39haa95532_0
18 | cffi=1.15.1=py39h2bbff1b_3
19 | charset-normalizer=2.0.4=pyhd3eb1b0_0
20 | click=8.1.6=pypi_0
21 | colorama=0.4.6=pypi_0
22 | contourpy=1.1.0=pypi_0
23 | croniter=1.3.15=pypi_0
24 | cryptography=41.0.2=py39hac1b9e3_0
25 | cuda-cccl=12.2.128=0
26 | cuda-cudart=11.7.99=0
27 | cuda-cudart-dev=11.7.99=0
28 | cuda-cupti=11.7.101=0
29 | cuda-libraries=11.7.1=0
30 | cuda-libraries-dev=11.7.1=0
31 | cuda-nvrtc=11.7.99=0
32 | cuda-nvrtc-dev=11.7.99=0
33 | cuda-nvtx=11.7.91=0
34 | cuda-runtime=11.7.1=0
35 | cycler=0.11.0=pypi_0
36 | dateutils=0.6.12=pypi_0
37 | deepdiff=6.3.1=pypi_0
38 | docker-pycreds=0.4.0=pypi_0
39 | exceptiongroup=1.1.2=pypi_0
40 | fastapi=0.88.0=pypi_0
41 | filelock=3.9.0=py39haa95532_0
42 | fonttools=4.41.1=pypi_0
43 | freetype=2.12.1=ha860e81_0
44 | frozenlist=1.4.0=pypi_0
45 | fsspec=2023.6.0=pypi_0
46 | giflib=5.2.1=h8cc25b3_3
47 | gitdb=4.0.10=pypi_0
48 | gitpython=3.1.32=pypi_0
49 | h11=0.14.0=pypi_0
50 | idna=3.4=py39haa95532_0
51 | importlib-resources=6.0.0=pypi_0
52 | inquirer=3.1.3=pypi_0
53 | intel-openmp=2023.1.0=h59b6b97_46319
54 | itsdangerous=2.1.2=pypi_0
55 | jinja2=3.1.2=py39haa95532_0
56 | jinxed=1.2.0=pypi_0
57 | joblib=1.3.1=pypi_0
58 | jpeg=9e=h2bbff1b_1
59 | kiwisolver=1.4.4=pypi_0
60 | lame=3.100=hcfcfb64_1003
61 | lerc=3.0=hd77b12b_0
62 | libcublas=11.10.3.66=0
63 | libcublas-dev=11.10.3.66=0
64 | libcufft=10.7.2.124=0
65 | libcufft-dev=10.7.2.124=0
66 | libcurand=10.3.3.129=0
67 | libcurand-dev=10.3.3.129=0
68 | libcusolver=11.4.0.1=0
69 | libcusolver-dev=11.4.0.1=0
70 | libcusparse=11.7.4.91=0
71 | libcusparse-dev=11.7.4.91=0
72 | libdeflate=1.17=h2bbff1b_0
73 | libflac=1.4.3=h63175ca_0
74 | libnpp=11.7.4.75=0
75 | libnpp-dev=11.7.4.75=0
76 | libnvjpeg=11.8.0.2=0
77 | libnvjpeg-dev=11.8.0.2=0
78 | libogg=1.3.4=h8ffe710_1
79 | libopus=1.3.1=h8ffe710_1
80 | libpng=1.6.39=h8cc25b3_0
81 | libsndfile=1.2.0=h2628c91_0
82 | libtiff=4.5.0=h6c2663c_2
83 | libuv=1.44.2=h2bbff1b_0
84 | libvorbis=1.3.7=h0e60522_0
85 | libwebp=1.2.4=hbc33d0d_1
86 | libwebp-base=1.2.4=h2bbff1b_1
87 | lightning=1.9.5=pypi_0
88 | lightning-cloud=0.5.37=pypi_0
89 | lightning-utilities=0.9.0=pypi_0
90 | lz4-c=1.9.4=h2bbff1b_0
91 | markdown-it-py=3.0.0=pypi_0
92 | markupsafe=2.1.1=py39h2bbff1b_0
93 | matplotlib=3.7.2=pypi_0
94 | mdurl=0.1.2=pypi_0
95 | mkl=2023.1.0=h8bd8f75_46356
96 | mkl-service=2.4.0=py39h2bbff1b_1
97 | mkl_fft=1.3.6=py39hf11a4ad_1
98 | mkl_random=1.2.2=py39hf11a4ad_1
99 | mpg123=1.31.3=h63175ca_0
100 | mpmath=1.3.0=py39haa95532_0
101 | multidict=6.0.4=pypi_0
102 | networkx=3.1=py39haa95532_0
103 | numpy=1.25.0=py39h055cbcc_0
104 | numpy-base=1.25.0=py39h65a83cf_0
105 | opencv-python=4.8.0.74=pypi_0
106 | openssl=3.0.10=h2bbff1b_2
107 | ordered-set=4.1.0=pypi_0
108 | packaging=23.1=pypi_0
109 | pandas=2.0.3=pypi_0
110 | pathtools=0.1.2=pypi_0
111 | pillow=9.4.0=py39hd77b12b_0
112 | pip=23.2.1=py39haa95532_0
113 | protobuf=4.23.4=pypi_0
114 | psutil=5.9.5=pypi_0
115 | pycparser=2.21=pyhd3eb1b0_0
116 | pydantic=1.10.12=pypi_0
117 | pygments=2.15.1=pypi_0
118 | pyjwt=2.8.0=pypi_0
119 | pyopenssl=23.2.0=py39haa95532_0
120 | pyparsing=3.0.9=pypi_0
121 | pysocks=1.7.1=py39haa95532_0
122 | pysoundfile=0.12.1=pyhd8ed1ab_0
123 | python=3.9.17=h1aa4202_0
124 | python-dateutil=2.8.2=pypi_0
125 | python-editor=1.0.4=pypi_0
126 | python-multipart=0.0.6=pypi_0
127 | pytorch=2.0.1=py3.9_cuda11.7_cudnn8_0
128 | pytorch-cuda=11.7=h16d0643_5
129 | pytorch-lightning=1.9.5=pypi_0
130 | pytorch-mutex=1.0=cuda
131 | pytz=2023.3=pypi_0
132 | pyyaml=6.0.1=pypi_0
133 | readchar=4.0.5=pypi_0
134 | requests=2.31.0=py39haa95532_0
135 | rich=13.5.2=pypi_0
136 | scikit-learn=1.3.0=pypi_0
137 | scipy=1.11.1=pypi_0
138 | seaborn=0.12.2=pypi_0
139 | sentry-sdk=1.29.2=pypi_0
140 | setproctitle=1.3.2=pypi_0
141 | setuptools=68.0.0=py39haa95532_0
142 | six=1.16.0=pypi_0
143 | sklearn=0.0.post7=pypi_0
144 | smmap=5.0.0=pypi_0
145 | sniffio=1.3.0=pypi_0
146 | soupsieve=2.4.1=pypi_0
147 | sqlite=3.41.2=h2bbff1b_0
148 | starlette=0.22.0=pypi_0
149 | starsessions=1.3.0=pypi_0
150 | sympy=1.11.1=py39haa95532_0
151 | tbb=2021.8.0=h59b6b97_0
152 | threadpoolctl=3.2.0=pypi_0
153 | tk=8.6.12=h2bbff1b_0
154 | torchaudio=2.0.2=pypi_0
155 | torchinfo=1.8.0=pypi_0
156 | torchmetrics=1.0.1=pypi_0
157 | torchvision=0.15.2=pypi_0
158 | tqdm=4.65.0=pypi_0
159 | traitlets=5.9.0=pypi_0
160 | typing_extensions=4.7.1=py39haa95532_0
161 | tzdata=2023.3=pypi_0
162 | ucrt=10.0.22621.0=h57928b3_0
163 | urllib3=1.26.16=py39haa95532_0
164 | uvicorn=0.23.2=pypi_0
165 | vc=14.2=h21ff451_1
166 | vc14_runtime=14.36.32532=hfdfe4a8_17
167 | vs2015_runtime=14.36.32532=h05e6639_17
168 | wandb=0.15.8=pypi_0
169 | wcwidth=0.2.6=pypi_0
170 | websocket-client=1.6.1=pypi_0
171 | websockets=11.0.3=pypi_0
172 | wheel=0.38.4=py39haa95532_0
173 | win_inet_pton=1.1.0=py39haa95532_0
174 | xz=5.4.2=h8cc25b3_0
175 | yarl=1.9.2=pypi_0
176 | zipp=3.16.2=pypi_0
177 | zlib=1.2.13=h8cc25b3_0
178 | zstd=1.5.5=hd43e919_0
179 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | from pytorch_lightning import LightningDataModule
4 | import librosa as lb
5 | import numpy as np
6 | import os
7 | import random
8 |
9 | TUT_LABELS = ["fan", "pump", "slider", "ToyCar", "ToyConveyor", "valve"]
10 | TUT_LABELS_INT = [0, 7, 13, 20, 27, 34]
11 |
12 | class TUTDataset(Dataset):
13 |
14 | def __init__(self, path_data, list_files, sample_rate, duration):
15 | # Initialization for path, fs, list of files etc.
16 | self.sample_rate = sample_rate
17 | self.path_data = path_data
18 | self.list_files = list_files
19 | self.duration = duration
20 |
21 | def __getitem__(self, index):
22 | # select an audio from the list and return metadata and label
23 | file_name = self.list_files[index]
24 | if ".npy" in file_name:
25 | audio_data = np.load(file_name)
26 | else:
27 | audio_data , _ = lb.load(file_name, sr = self.sample_rate, res_type = "polyphase")
28 | if len(audio_data) > int(self.duration * self.sample_rate):
29 | audio_data = audio_data[:int(self.duration * self.sample_rate)]
30 |
31 | metadata = 0
32 | numerical_label = 0
33 | for i, name_class in enumerate(TUT_LABELS):
34 | if name_class in file_name:
35 | metadata = TUT_LABELS_INT[i]
36 | numerical_label = TUT_LABELS_INT[i]
37 | break
38 |
39 | metadata += int(file_name.split('_')[2]) # in this way I have "name_id", for example fan_01 where 01 is the id of the machine
40 | label = 0 if "normal" in file_name else 1
41 | return audio_data , metadata, label, numerical_label
42 |
43 | def __len__(self):
44 | # return length of the list
45 | return len(self.list_files)
46 |
47 | class TUTDatamodule(LightningDataModule):
48 |
49 | def __init__(self, path_train, path_test, sample_rate, duration, percentage_val, batch_size):
50 | super().__init__()
51 | self.path_train = path_train
52 | self.path_test = path_test
53 | self.sample_rate = sample_rate
54 | self.duration = duration
55 | self.percentage_val = percentage_val
56 | self.batch_size = batch_size
57 | # split of the dataset
58 | self.train_list = self.scan_all_dir(self.path_train)
59 | self.val_list = random.sample(self.train_list, int(len(self.train_list) * percentage_val))
60 | self.train_list = set(self.train_list)
61 | self.val_list = set(self.val_list)
62 | self.train_list -= self.val_list
63 | self.train_list = list(self.train_list)
64 | self.val_list = list(self.val_list)
65 |
66 | self.test_list = self.scan_all_dir(self.path_test)
67 |
68 |
69 | def scan_all_dir(self, path):
70 | list_all_files = []
71 | for root, dirs, files in os.walk(path):
72 | for file in files:
73 | list_all_files.append(str(root + "\\" + file))
74 | return list_all_files
75 |
76 | def setup(self, stage = None):
77 | # Nothing to do
78 | pass
79 |
80 | def prepare_data(self):
81 | # Nothing to do
82 | pass
83 |
84 | def train_dataloader(self):
85 | # return the dataloader containing training data
86 | train_split = TUTDataset(path_data = self.path_train, list_files = self.train_list, sample_rate = self.sample_rate, duration = self.duration)
87 | return DataLoader(train_split, batch_size = self.batch_size, shuffle = True)
88 |
89 | def val_dataloader(self):
90 | # return the dataloader containing validation data
91 | val_split = TUTDataset(path_data = self.path_train, list_files = self.val_list, sample_rate = self.sample_rate, duration = self.duration)
92 | return DataLoader(val_split, batch_size = self.batch_size, shuffle = False)
93 |
94 | def test_dataloader(self):
95 | # return the dataloader containing testing data
96 | test_split = TUTDataset(path_data = self.path_test, list_files = self.test_list, sample_rate = self.sample_rate, duration = self.duration)
97 | return DataLoader(test_split, batch_size = self.batch_size, shuffle = True)
98 |
99 | ## TEST FUNCTION ##
100 | if __name__ == "__main__":
101 | path_train = "TUT Anomaly detection/train" # path for training audio
102 | path_test = "TUT Anomaly detection/test" # path for test audio
103 | percentage_val = 0.2
104 | sample_rate = 16000
105 | batch_size = 64
106 | duration = 10
107 | # create lightning datamodule
108 | datamodule = TUTDatamodule(path_train = path_train, path_test = path_test, sample_rate = sample_rate, duration = duration,
109 | percentage_val = percentage_val, batch_size = batch_size)
110 | dataloader_train = datamodule.train_dataloader()
111 | print(len(dataloader_train))
112 | dataloader_val = datamodule.val_dataloader()
113 | print(len(dataloader_val))
114 | dataloader_test = datamodule.test_dataloader()
115 | print(len(dataloader_test))
116 | print(next(iter(dataloader_test)))
--------------------------------------------------------------------------------
/images/ProposedASD.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/ProposedASD.jpg
--------------------------------------------------------------------------------
/images/ProposedASD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/ProposedASD.png
--------------------------------------------------------------------------------
/images/attentions.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/attentions.pdf
--------------------------------------------------------------------------------
/images/spec-heat-ave-std.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-heat-ave-std.pdf
--------------------------------------------------------------------------------
/images/spec-heat.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-heat.pdf
--------------------------------------------------------------------------------
/images/spec-mel.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-mel.pdf
--------------------------------------------------------------------------------
/images/spec-wav.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/images/spec-wav.pdf
--------------------------------------------------------------------------------
/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | class Bottleneck(nn.Module):
6 | def __init__(self, inp, oup, stride, expansion):
7 | super(Bottleneck, self).__init__()
8 | self.connect = stride == 1 and inp == oup
9 | #
10 | self.conv = nn.Sequential(
11 | # pw
12 | nn.Conv2d(inp, inp * expansion, 1, 1, 0, bias=False),
13 | nn.BatchNorm2d(inp * expansion),
14 | nn.PReLU(inp * expansion),
15 | # dw
16 | nn.Conv2d(inp * expansion, inp * expansion, 3, stride, 1, groups=inp * expansion, bias=False),
17 | nn.BatchNorm2d(inp * expansion),
18 | nn.PReLU(inp * expansion),
19 |
20 | # pw-linear
21 | nn.Conv2d(inp * expansion, oup, 1, 1, 0, bias=False),
22 | nn.BatchNorm2d(oup),
23 | )
24 |
25 | def forward(self, x):
26 | if self.connect:
27 | return x + self.conv(x)
28 | else:
29 | return self.conv(x)
30 |
31 |
32 | class ConvBlock(nn.Module):
33 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False):
34 | super(ConvBlock, self).__init__()
35 | self.linear = linear
36 | if dw:
37 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False)
38 | else:
39 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False)
40 | self.bn = nn.BatchNorm2d(oup)
41 | if not linear:
42 | self.prelu = nn.PReLU(oup)
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 | x = self.bn(x)
47 | if self.linear:
48 | return x
49 | else:
50 | return self.prelu(x)
51 |
52 |
53 | #https://dcase.community/documents/challenge2022/technical_reports/DCASE2022_Liu_8_t2.pdf
54 | Mobilefacenet_bottleneck_setting = [
55 | # t, c , n ,s
56 | [2, 128, 2, 2],
57 | [4, 128, 2, 2],
58 | [4, 128, 2, 2],
59 | ]
60 |
61 |
62 | class MobileFaceNet(nn.Module):
63 | def __init__(self,
64 | num_class,
65 | bottleneck_setting=Mobilefacenet_bottleneck_setting):
66 | super(MobileFaceNet, self).__init__()
67 |
68 | self.conv1 = ConvBlock(2, 64, 3, 2, 1)
69 |
70 | self.dw_conv1 = ConvBlock(64, 64, 3, 1, 1, dw=True)
71 |
72 | self.inplanes = 64
73 | block = Bottleneck
74 | self.blocks = self._make_layer(block, bottleneck_setting)
75 |
76 | self.conv2 = ConvBlock(bottleneck_setting[-1][1], 512, 1, 1, 0)
77 | # 20(10), 4(2), 8(4)
78 | self.linear7 = ConvBlock(512, 512, (8, 20), 1, 0, dw=True, linear=True)
79 |
80 | self.linear1 = ConvBlock(512, 128, 1, 1, 0, linear=True)
81 |
82 | self.fc_out = nn.Linear(128, num_class)
83 | # init
84 | for m in self.modules():
85 | if isinstance(m, nn.Conv2d):
86 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
87 | m.weight.data.normal_(0, math.sqrt(2. / n))
88 | elif isinstance(m, nn.BatchNorm2d):
89 | m.weight.data.fill_(1)
90 | m.bias.data.zero_()
91 |
92 | def _make_layer(self, block, setting):
93 | layers = []
94 | for t, c, n, s in setting:
95 | for i in range(n):
96 | if i == 0:
97 | layers.append(block(self.inplanes, c, s, t))
98 | else:
99 | layers.append(block(self.inplanes, c, 1, t))
100 | self.inplanes = c
101 | return nn.Sequential(*layers)
102 |
103 | def forward(self, x):
104 | x = self.conv1(x)
105 | x = self.dw_conv1(x)
106 | x = self.blocks(x)
107 | x = self.conv2(x)
108 | x = self.linear7(x)
109 | x = self.linear1(x)
110 | feature = x.view(x.size(0), -1)
111 | out = self.fc_out(feature)
112 | return out, feature
113 |
114 |
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_lightning import LightningModule
4 | from torchmetrics import Accuracy
5 | from torchinfo import summary
6 | import torchaudio.transforms as T
7 | from utils import ArcMarginProduct
8 | import separableconv.nn as sep
9 | import numpy as np
10 | from mobilenet import MobileFaceNet
11 |
12 | class AE_DCASEBaseline(nn.Module):
13 | def __init__(self) -> None:
14 | super(AE_DCASEBaseline, self).__init__()
15 | self.frames = 5
16 | self.n_mels = 128
17 | self.t_bins = 1 + (10 * 16000) // 512
18 | self.vector_array_size = self.t_bins - self.frames + 1
19 | self.transform_tf = T.MelSpectrogram(sample_rate=16000,
20 | n_fft=1024,
21 | win_length=1024,
22 | hop_length=512,
23 | center=True,
24 | pad_mode="reflect",
25 | power=2.0,
26 | norm="slaney",
27 | n_mels=self.n_mels,
28 | mel_scale="htk",
29 | )
30 |
31 | self.encoder = nn.Sequential(
32 | nn.Linear(in_features = 640, out_features = 128),
33 | nn.BatchNorm1d(self.vector_array_size),
34 | nn.ReLU(),
35 | nn.Linear(in_features = 128, out_features = 128),
36 | nn.BatchNorm1d(self.vector_array_size),
37 | nn.ReLU(),
38 | nn.Linear(in_features = 128, out_features = 128),
39 | nn.BatchNorm1d(self.vector_array_size),
40 | nn.ReLU(),
41 | nn.Linear(in_features = 128, out_features = 128),
42 | nn.BatchNorm1d(self.vector_array_size),
43 | nn.ReLU()
44 | )
45 |
46 | self.bottleneck = nn.Sequential(
47 | nn.Linear(in_features = 128, out_features = 8),
48 | nn.BatchNorm1d(self.vector_array_size),
49 | nn.ReLU(),
50 | )
51 |
52 | self.decoder = nn.Sequential(
53 | nn.Linear(in_features = 8, out_features = 128),
54 | nn.BatchNorm1d(self.vector_array_size),
55 | nn.ReLU(),
56 | nn.Linear(in_features = 128, out_features = 128),
57 | nn.BatchNorm1d(self.vector_array_size),
58 | nn.ReLU(),
59 | nn.Linear(in_features = 128, out_features = 128),
60 | nn.BatchNorm1d(self.vector_array_size),
61 | nn.ReLU(),
62 | nn.Linear(in_features = 128, out_features = 128),
63 | nn.BatchNorm1d(self.vector_array_size),
64 | nn.ReLU(),
65 | nn.Linear(in_features = 128, out_features = 640)
66 | )
67 |
68 | # WEIGHTS INIT LIKE KERAS
69 | nn.init.xavier_uniform_(self.encoder[0].weight)
70 | nn.init.xavier_uniform_(self.encoder[3].weight)
71 | nn.init.xavier_uniform_(self.encoder[6].weight)
72 | nn.init.xavier_uniform_(self.encoder[9].weight)
73 | nn.init.xavier_uniform_(self.bottleneck[0].weight)
74 | nn.init.xavier_uniform_(self.decoder[0].weight)
75 | nn.init.xavier_uniform_(self.decoder[3].weight)
76 | nn.init.xavier_uniform_(self.decoder[6].weight)
77 | nn.init.xavier_uniform_(self.decoder[9].weight)
78 | nn.init.xavier_uniform_(self.decoder[-1].weight)
79 |
80 | # BIAS INIT LIKE KERAS
81 | nn.init.zeros_(self.encoder[0].bias)
82 | nn.init.zeros_(self.encoder[3].bias)
83 | nn.init.zeros_(self.encoder[6].bias)
84 | nn.init.zeros_(self.encoder[9].bias)
85 | nn.init.zeros_(self.bottleneck[0].bias)
86 | nn.init.zeros_(self.decoder[0].bias)
87 | nn.init.zeros_(self.decoder[3].bias)
88 | nn.init.zeros_(self.decoder[6].bias)
89 | nn.init.zeros_(self.decoder[9].bias)
90 | nn.init.zeros_(self.decoder[-1].bias)
91 |
92 | def forward(self, x):
93 | x = self.preprocessing(x)
94 | original = x
95 | x = self.encoder(x)
96 | x = self.bottleneck(x)
97 | x = self.decoder(x)
98 | return original, x
99 |
100 | def preprocessing(self, x):
101 | # compute mel spectrogram
102 | batch_size = x.size()[0]
103 | x = self.transform_tf(x)
104 | x = 10 * torch.log10(x + 1e-8)
105 | vector_dim = self.frames * self.n_mels
106 |
107 | feature_vector = torch.zeros((batch_size, self.vector_array_size, vector_dim)).to(x.device)
108 | for batch in range(batch_size):
109 | for i in range(self.frames):
110 | feature_vector[batch, :, self.n_mels * i: self.n_mels * (i + 1)] = x[batch, :, i: i + self.vector_array_size].T
111 |
112 | return feature_vector
113 |
114 |
115 | ##### WAVEGRAM + MEL SPEC + ATTENTION MODULE + MOBILENETV2
116 | class Wavegram_AttentionModule(nn.Module):
117 | def __init__(self, h):
118 | super(Wavegram_AttentionModule, self).__init__()
119 | self.h = h
120 | # sep-wavegram
121 | self.wavegram = sep.SeparableConv1d(in_channels = 1, out_channels = 128, kernel_size = 1024, stride = 512, padding = 512)
122 | # Mel filterbank
123 | self.transform_tf = T.MelSpectrogram(sample_rate=16000,
124 | n_fft=1024,
125 | win_length=1024,
126 | hop_length=512,
127 | center=True,
128 | pad_mode="reflect",
129 | power=2.0,
130 | norm="slaney",
131 | n_mels=128,
132 | mel_scale="htk",
133 | )
134 |
135 | # attention module
136 | self.heatmap = nn.Sequential(
137 | sep.SeparableConv2d(in_channels = 2, out_channels = 16,
138 | kernel_size = (3,3), padding = "same", bias = False),
139 | nn.BatchNorm2d(16),
140 | nn.ELU(),
141 | sep.SeparableConv2d(in_channels = 16, out_channels = 64,
142 | kernel_size = (3,3), padding = "same", bias = False),
143 | nn.BatchNorm2d(64),
144 | nn.ELU(),
145 | sep.SeparableConv2d(in_channels = 64, out_channels = 2, kernel_size = 1, padding = "same"),
146 | nn.Sigmoid()
147 | )
148 | # classifier
149 | self.classifier = MobileFaceNet(num_class = 41)
150 | self.arcface = ArcMarginProduct(in_features = self.h, out_features = 41, s = 40, m = 0.7)
151 |
152 | def forward(self, x, metadata):
153 | # compute mel spectrogram
154 | x_spec = self.transform_tf(x)
155 | x_spec = 10*torch.log10(x_spec + 1e-8)
156 | # compute wavegram
157 | x = x.unsqueeze(1)
158 | x = self.wavegram(x)
159 | x = torch.stack((x_spec, x), dim = 1)
160 | reppr = x
161 | heatmap = self.heatmap(x)
162 | x = x * heatmap
163 | out, features = self.classifier(x)
164 | x = self.arcface(features, metadata)
165 | return x, out, reppr, features, heatmap
166 |
167 | class Wavegram_AttentionMap(LightningModule):
168 |
169 | def __init__(self, h, lr):
170 | super().__init__()
171 | self.h = h
172 | self.model = Wavegram_AttentionModule(self.h)
173 | self.lr = lr
174 |
175 | self.accuracy_training = Accuracy(task="multiclass", num_classes=41)
176 | self.accuracy_val = Accuracy(task="multiclass", num_classes=41)
177 | self.accuracy_test = Accuracy(task="multiclass", num_classes=41)
178 | self.criterion = nn.CrossEntropyLoss()
179 | # to save threshold and errors at init
180 | self.errors_list = []
181 | self.clean_errors = []
182 | self.anomaly_errors = []
183 | self.labels = []
184 | self.classes = []
185 |
186 | def mixup_data(self, x, y, alpha=0.2):
187 | y = torch.nn.functional.one_hot(y, num_classes = 41)
188 | if alpha > 0:
189 | lam = np.random.beta(alpha, alpha)
190 | else:
191 | lam = 1
192 | batch_size = x.size()[0]
193 | index = torch.randperm(batch_size)
194 | mixed_x = lam * x + (1 - lam) * x[index, :]
195 | y_a, y_b = y, y[index]
196 | return mixed_x, y_a.float(), y_b.float(), lam
197 |
198 | def forward(self, x, labels):
199 | return self.model(x, labels)
200 |
201 | def mixup_criterion_arcmix(self, pred, y_a, y_b, lam):
202 | loss1 = lam * self.criterion(pred, y_a)
203 | loss2 = (1 - lam) * self.criterion(pred, y_b)
204 | return loss1+loss2
205 |
206 | def training_step(self, batch, batch_idx):
207 | x, metadata, _, _ = batch
208 | # for training step
209 | mixed_x, y_a, y_b, lam = self.mixup_data(x, metadata)
210 | predicted, _, _, _, _ = self.forward(mixed_x, metadata)
211 | loss = self.mixup_criterion_arcmix(predicted, y_a, y_b, lam)
212 | self.log("train/loss_class", loss, on_epoch = True, on_step = True, prog_bar = True)
213 | self.accuracy_training(predicted, metadata)
214 | self.log("train/acc", self.accuracy_training, on_epoch = True, on_step = False)
215 | return loss
216 |
217 | def validation_step(self, batch, batch_idx):
218 | x, metadata, _, _ = batch
219 | predicted, _, _, _, _ = self.forward(x, metadata)
220 | loss = torch.nn.functional.cross_entropy(predicted, metadata, reduction = "mean")
221 | self.log("val/loss_class", loss, on_epoch = True, on_step = False, prog_bar = True)
222 | self.accuracy_val(predicted, metadata)
223 | self.log("val/acc", self.accuracy_val, on_epoch = True, on_step = False)
224 | return loss
225 |
226 | def test_step(self, batch, batch_idx):
227 | x, metadata, label, _ = batch
228 | predicted, _, _, _, _ = self.forward(x, metadata)
229 | loss = torch.nn.functional.cross_entropy(predicted, metadata, reduction = "mean")
230 | self.log("test/loss_class", loss, on_epoch = True, on_step = False, prog_bar = True)
231 | self.accuracy_test(predicted, metadata)
232 | self.log("test/acc", self.accuracy_test, on_epoch = True, on_step = False)
233 | class_loss_batchwise = nn.functional.cross_entropy(predicted, metadata, reduction = "none")
234 | errors = class_loss_batchwise
235 | self.errors_list.append(errors)
236 | self.clean_errors.append(errors[label == 0])
237 | self.anomaly_errors.append(errors[label == 1])
238 | self.labels.append(label)
239 | self.classes.append(metadata)
240 | return loss
241 |
242 | def configure_optimizers(self):
243 | opt = torch.optim.AdamW(self.parameters(), lr = self.lr)
244 | # return opt
245 | return {
246 | "optimizer": opt,
247 | "lr_scheduler": {
248 | "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=300, eta_min=0.1*float(self.lr))
249 | },
250 | }
251 |
252 |
253 | # TEST FUNCTION
254 | if __name__ == "__main__":
255 | example_input = torch.rand(16, 160000) # dummy audio
256 | model = Wavegram_AttentionModule()
257 | metadata = torch.nn.functional.one_hot(torch.randint(low = 0, high = 41, size =(16,)), num_classes=41)
258 | output = model(example_input, metadata)
259 | print(output)
260 | summary(model, input_data = example_input)
261 |
262 |
--------------------------------------------------------------------------------
/other_data/all_labels.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/all_labels.npz
--------------------------------------------------------------------------------
/other_data/allerrors.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/allerrors.npz
--------------------------------------------------------------------------------
/other_data/anomalouserrors.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/anomalouserrors.npz
--------------------------------------------------------------------------------
/other_data/cleanerrors.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/cleanerrors.npz
--------------------------------------------------------------------------------
/other_data/metadata.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaelneri/unsupervised-audio-anomaly-detection/f8eb19635d2d727c242ed16cf5b01030d885158d/other_data/metadata.npz
--------------------------------------------------------------------------------
/training_val_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from data import TUTDatamodule
3 | from pytorch_lightning.loggers import WandbLogger
4 | from pytorch_lightning import Trainer
5 | from pytorch_lightning.callbacks import ModelCheckpoint
6 | import wandb
7 | import torch
8 | from optparse import OptionParser
9 | from sklearn import metrics
10 | from model import Wavegram_AttentionMap
11 |
12 |
13 | def transform_string_list(option, opt, value, parser):
14 | list_of_int = [int(i) for i in value.split(",")]
15 | setattr(parser.values, option.dest, list_of_int)
16 |
17 |
18 | def train(configs):
19 | # definition of the logger
20 | tags = ["MobileNetWavegram"]
21 |
22 | # Machine type if we want to train and test on a single machine
23 | if "valve" in configs.path_train:
24 | tags.append("Valve")
25 | elif "pump" in configs.path_train:
26 | tags.append("Pump")
27 | elif "slider" in configs.path_train:
28 | tags.append("Slider")
29 | elif "ToyCar" in configs.path_train:
30 | tags.append("ToyCar")
31 | elif "ToyConveyor" in configs.path_train:
32 | tags.append("ToyConveyor")
33 | elif "fan" in configs.path_train:
34 | tags.append("Fan")
35 | else:
36 | tags.append("All")
37 | wandb_logger = WandbLogger(project="Unsupervised Audio Anomaly Detection", config = configs, name=configs.name, tags = tags)
38 |
39 | checkpoint_callback = ModelCheckpoint(
40 | dirpath = "best_models/",
41 | filename = "MobileNetWavegram-Mixup-{epoch:02d}-{val/loss_class:.4f}", # <--- note epoch suffix here
42 | save_last = True,
43 | every_n_epochs = 20,
44 | save_top_k = -1,
45 | auto_insert_metric_name=False
46 | )
47 |
48 | # definition of the trainer
49 | trainer = Trainer(accelerator="gpu", devices = 1 , max_epochs = configs.epochs, logger=wandb_logger, callbacks = [checkpoint_callback])
50 |
51 | # definition of the datamodule
52 | datamodule = TUTDatamodule(path_train = configs.path_train, path_test = configs.path_test, sample_rate = configs.sr, duration = configs.duration,
53 | percentage_val = configs.percentage, batch_size = configs.batch_size)
54 |
55 | # definition of the model
56 | model = Wavegram_AttentionMap(lr = configs.lr, h = 128)
57 |
58 | wandb_logger.watch(model, log_graph=False)
59 | trainer.fit(model, datamodule)
60 | return model, trainer, datamodule
61 |
62 | def val(trained_model:Wavegram_AttentionMap, trainer:Trainer, datamodule:TUTDatamodule):
63 | # inference on the validation set
64 | trainer.validate(trained_model, datamodule)
65 | return trained_model, trainer, datamodule
66 |
67 | def test(trained_model:Wavegram_AttentionMap, trainer:Trainer, datamodule:TUTDatamodule):
68 | # inference for tests
69 | # make list of files insider "best_models" folder
70 | list_checkpoints = datamodule.scan_all_dir("best_models/")
71 | best_performance = 0
72 | best_checkpoint = None
73 | for i, checkpoint in enumerate(list_checkpoints):
74 | print("{}) {}".format(i, checkpoint))
75 | trained_model = Wavegram_AttentionMap.load_from_checkpoint(checkpoint, lr = 0.001, h = 128) # dummy lr
76 | trained_model.errors_list = []
77 | trained_model.anomaly_errors = []
78 | trained_model.clean_errors = []
79 | trainer.test(trained_model, datamodule)
80 |
81 | all_errors = torch.cat(trained_model.errors_list).flatten().cpu().numpy()
82 | all_labels = torch.cat(trained_model.labels).flatten().cpu().numpy()
83 | metadata = np.array(torch.cat(trained_model.classes).tolist())
84 | performance = metrics.roc_auc_score(all_labels, all_errors)
85 | if performance > best_performance:
86 | best_performance = performance
87 | best_checkpoint = checkpoint
88 | print("Best : {} %".format(best_performance))
89 |
90 | # here we select the best model
91 | trained_model = Wavegram_AttentionMap.load_from_checkpoint(best_checkpoint, lr = 0.001, h = 128) # dummy lr
92 | trained_model.errors_list = []
93 | trained_model.anomaly_errors = []
94 | trained_model.clean_errors = []
95 | trainer.test(trained_model, datamodule)
96 |
97 | all_errors = torch.cat(trained_model.errors_list).flatten().cpu().numpy()
98 | all_labels = torch.cat(trained_model.labels).flatten().cpu().numpy()
99 | metadata = np.array(torch.cat(trained_model.classes).tolist())
100 |
101 | np.savez("metadata", metadata)
102 | np.savez("allerrors", all_errors)
103 | np.savez("all_labels", all_labels)
104 | wandb.log({"test/AUROC": metrics.roc_auc_score(all_labels, all_errors)})
105 | wandb.log({"test/pAUROC": metrics.roc_auc_score(all_labels, all_errors, max_fpr = 0.1)})
106 |
107 | print("Global AUC : {}".format(metrics.roc_auc_score(all_labels, all_errors)))
108 | print("Global pAUC: {}".format(metrics.roc_auc_score(all_labels, all_errors, max_fpr = 0.1)))
109 |
110 | # fan class
111 | mask = metadata < 7
112 | selected_label = all_labels[mask]
113 | selected_errors = all_errors[mask]
114 |
115 | print("Fan AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors)))
116 | print("Fan pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1)))
117 |
118 | # pump class
119 | mask = np.logical_and(metadata >= 7, metadata < 13)
120 | selected_label = all_labels[mask]
121 | selected_errors = all_errors[mask]
122 |
123 | print("Pump AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors)))
124 | print("Pump pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1)))
125 |
126 |
127 | # slider class
128 | mask = np.logical_and(metadata >= 13, metadata < 20)
129 | selected_label = all_labels[mask]
130 | selected_errors = all_errors[mask]
131 |
132 | print("Slider AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors)))
133 | print("Slider pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1)))
134 |
135 | # ToyCar class
136 | mask = np.logical_and(metadata >= 20, metadata < 27)
137 | selected_label = all_labels[mask]
138 | selected_errors = all_errors[mask]
139 |
140 | print("ToyCar AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors)))
141 | print("ToyCar pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1)))
142 |
143 |
144 | # ToyConveyor class
145 | mask = np.logical_and(metadata >= 27, metadata < 34)
146 | selected_label = all_labels[mask]
147 | selected_errors = all_errors[mask]
148 |
149 | print("ToyConveyor AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors)))
150 | print("ToyConveyor pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1)))
151 |
152 | # valve class
153 | mask = metadata >= 34
154 | selected_label = all_labels[mask]
155 | selected_errors = all_errors[mask]
156 |
157 | print("Valve AUC : {}".format(metrics.roc_auc_score(selected_label, selected_errors)))
158 | print("Valve pAUC: {}".format(metrics.roc_auc_score(selected_label, selected_errors, max_fpr = 0.1)))
159 |
160 | clean_errors = np.array(torch.cat(trained_model.clean_errors).tolist())
161 | anomalous_errors = np.array(torch.cat(trained_model.anomaly_errors).tolist())
162 | np.savez("cleanerrors", clean_errors)
163 | np.savez("anomalouserrors", anomalous_errors)
164 | wandb.finish()
165 |
166 | import matplotlib.pyplot as plt
167 | plt.rcParams["font.family"] = "Times New Roman"
168 | plt.rcParams['font.size'] = 28
169 | # plt.rcParams['text.usetex'] = True <--- If you have latex installed in your machine, you can use it on matplotlib
170 | errors_clean = np.load("cleanerrors.npz")['arr_0']
171 | errors_anomalous = np.load("anomalouserrors.npz")['arr_0']
172 | print(errors_clean.shape)
173 | print(errors_anomalous.shape)
174 | plt.figure(figsize = (8, 12))
175 | plt.hist(errors_clean, bins = 200, alpha = 0.4, label = "Clear", color = "g")
176 | plt.hist(errors_anomalous, bins = 200, alpha = 0.4, label = "Anomalous", color = "r")
177 | plt.grid()
178 | plt.legend()
179 | plt.title("Histograms of errors")
180 | plt.xlabel("Error")
181 | plt.ylabel("Occurrences")
182 | plt.show()
183 |
184 |
185 |
186 |
187 |
188 | if __name__ == "__main__":
189 | parser = OptionParser()
190 | # dataset parameters
191 | parser.add_option("--pathtrain", dest="path_train",
192 | help="path containing training files", default = "TUT Anomaly detection/train")
193 | parser.add_option("--pathtest", dest="path_test",
194 | help="path containing training files", default = "TUT Anomaly detection/test")
195 | parser.add_option("--percentage", dest = "percentage",
196 | help = "percentage of validation samples", default = 0.05, type = float)
197 | parser.add_option("--sr", dest = "sr",
198 | help = "target sample rate", default = 16000, type = int)
199 | parser.add_option("--duration", dest = "duration",
200 | help = "duration, in seconds, of audios", default = 10, type = float)
201 |
202 |
203 | # training parameters
204 | parser.add_option("--lr", dest = "lr",
205 | help = "learning rate", default = 0.0001, type = float)
206 | parser.add_option("--epochs", dest = "epochs",
207 | help = "number of epochs", default = 300, type = int)
208 | parser.add_option("--name", dest = "name",
209 | help = "name of the run on wandb", default = "MobileFaceNet + separable wavegram + attention module + mixup 0.2 Noisy-ArcMix s 40 m 0.7")
210 | parser.add_option("--batch_size", dest = "batch_size",
211 | help = "batch size for dataloader", default = 64, type = int)
212 |
213 |
214 | options, remainder = parser.parse_args()
215 | print(options)
216 | trained_model, trainer, datamodule = train(options)
217 | trained_model, trainer, datamodule = val(trained_model, trainer, datamodule)
218 | test(trained_model, trainer, datamodule)
219 |
220 |
221 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import separableconv.nn as sep
4 | import math
5 |
6 |
7 | class ArcMarginProduct(nn.Module):
8 | def __init__(self, in_features=128, out_features=200, s=40.0, m=0.3, sub=1, easy_margin=False):
9 | super(ArcMarginProduct, self).__init__()
10 | self.in_features = in_features
11 | self.out_features = out_features
12 | self.s = s
13 | self.m = m
14 | self.sub = sub
15 | self.weight = torch.nn.Parameter(torch.Tensor(out_features * sub, in_features))
16 | nn.init.xavier_uniform_(self.weight)
17 |
18 | self.easy_margin = easy_margin
19 | self.cos_m = math.cos(m)
20 | self.sin_m = math.sin(m)
21 |
22 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
23 | self.th = math.cos(math.pi - m)
24 | self.mm = math.sin(math.pi - m) * m
25 |
26 | def forward(self, x, label):
27 | # label = torch.nn.functional.one_hot(label, num_classes=self.out_features)
28 | cosine = torch.nn.functional.linear(torch.nn.functional.normalize(x), torch.nn.functional.normalize(self.weight))
29 |
30 | if self.sub > 1:
31 | cosine = cosine.view(-1, self.out_features, self.sub)
32 | cosine, _ = torch.max(cosine, dim=2)
33 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
34 | phi = cosine * self.cos_m - sine * self.sin_m
35 |
36 | if self.easy_margin:
37 | phi = torch.where(cosine > 0, phi, cosine)
38 | else:
39 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
40 |
41 | if len(label.shape) == 1:
42 | one_hot = torch.zeros(cosine.size(), device=x.device)
43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
44 | else:
45 | one_hot = label
46 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
47 | output = output * self.s
48 | return output
--------------------------------------------------------------------------------