├── lib ├── __init__.py ├── eval.py ├── dataset.py ├── coverhunter │ ├── ch_layers.py │ └── ch_losses.py ├── layers.py ├── augmentations.py └── tensor_ops.py ├── models ├── __init__.py ├── coverhunterc.py ├── cqtnet.py ├── bytecover2x.py ├── bytecover3x.py ├── clews.py └── dvinetp.py ├── utils ├── __init__.py ├── file_utils.py ├── print_utils.py ├── audio_utils.py └── pytorch_utils.py ├── .gitignore ├── install_requirements.sh ├── LICENSE ├── config ├── dvi-cqtnet.yaml ├── shs-cqtnet.yaml ├── shs-dvinetp.yaml ├── dvi-dvinetp.yaml ├── shs-coverhunterc.yaml ├── dvi-coverhunterc.yaml ├── dvi-bytecover2x.yaml ├── shs-bytecover2x.yaml ├── shs-bytecover3x.yaml ├── dvi-bytecover3x.yaml ├── shs-clews.yaml └── dvi-clews.yaml ├── inference.py ├── README.md ├── test.py ├── data_preproc.py └── train.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | 4 | .vscode/ 5 | .venv/ 6 | myslurm/ 7 | 8 | cache 9 | data 10 | logs 11 | 12 | *.wav 13 | *.flac 14 | *.pt 15 | *.pdf -------------------------------------------------------------------------------- /install_requirements.sh: -------------------------------------------------------------------------------- 1 | # We assume python>=3.10 2 | 3 | # --- Create environment and activate --- 4 | python3 -m venv .venv 5 | source .venv/bin/activate 6 | 7 | # --- Pytorch --- 8 | pip install torch==2.3.1 torchaudio==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121 9 | pip install lightning==2.3.0 tensorboard==2.17.0 einops==0.8.0 torchinfo==1.8.0 10 | 11 | # --- Generic stuff --- 12 | pip install omegaconf==2.3.0 tqdm==4.66.4 13 | 14 | # --- For metadata proc --- 15 | pip install joblib==1.4.2 16 | 17 | # --- For my audio_utils --- 18 | pip install soundfile==0.12.1 soxr==0.3.7 nnAudio==0.3.3 19 | # Need to downgrade numpy for soxr 20 | pip install numpy==1.26.4 21 | # ffmpeg needs to be installed. Can do it through conda if you do not have sudo privileges: 22 | conda install -c conda-forge 'ffmpeg<7' 23 | # Alternatively, use: 24 | #sudo apt install ffmpeg 25 | 26 | # --- For data augmentation (may be imported but unused) --- 27 | pip install julius==0.2.7 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Sony Research Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import csv 3 | import json 4 | 5 | 6 | def load_txt(fn): 7 | with open(fn, "r") as fh: 8 | lines = fh.readlines() 9 | for i in range(len(lines)): 10 | lines[i] = lines[i].replace("\n", "") 11 | lines[i] = lines[i].replace("\r", "") 12 | return lines 13 | 14 | 15 | def load_csv(fn, sep=",", header=0, quotechar='"'): 16 | lines = load_txt(fn) 17 | csv_reader = csv.reader( 18 | lines, 19 | quotechar=quotechar, 20 | quoting=csv.QUOTE_ALL, 21 | delimiter=sep, 22 | ) 23 | data = [] 24 | for i, l in enumerate(csv_reader): 25 | if i == 0: 26 | desc = l[:] 27 | for _ in l: 28 | data.append([]) 29 | elif len(l) != len(data): 30 | print("Error reading " + fn) 31 | sys.exit() 32 | if i < header: 33 | continue 34 | for j, item in enumerate(l): 35 | data[j].append(item) 36 | return desc, data, len(data[0]) 37 | 38 | 39 | def load_json(fn): 40 | with open(fn, "r") as fh: 41 | d = json.load(fh) 42 | return d 43 | 44 | 45 | def load_jsons(fn, limit_lines=None): 46 | with open(fn, "r") as fh: 47 | d = [] 48 | for line in fh: 49 | aux = json.loads(line) 50 | d.append(aux) 51 | if limit_lines is not None and len(d) == limit_lines: 52 | break 53 | return d 54 | -------------------------------------------------------------------------------- /config/dvi-cqtnet.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/DiscogsVI/audio/" 10 | meta: "cache/metadata-dvi.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.4 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.3 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.5 66 | r: [-12,12] 67 | 68 | model: 69 | name: "cqtnet" 70 | shingling: 71 | len: 150 # in secs 72 | hop: 150 73 | cqt: 74 | hoplen: 0.02 # in seconds 75 | noctaves: 7 # 8 76 | nbinsoct: 12 77 | pool: 78 | len: 17 # in hops 79 | hop: 15 80 | zdim: 300 81 | maxcliques: 100000 82 | 83 | training: 84 | batchsize: 25 # using 2 GPUs 85 | numepochs: 1000 86 | save_freq: null # in epochs 87 | optim: 88 | name: "adam" 89 | lr: 3e-4 90 | wd: 0 91 | sched: "plateau_10" 92 | min_lr: 1e-6 93 | monitor: 94 | quantity: "m_MAP" 95 | mode: "max" 96 | -------------------------------------------------------------------------------- /config/shs-cqtnet.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/SHS100K/audio/" 10 | meta: "cache/metadata-shs.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.1 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.1 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.1 66 | r: [-12,12] 67 | 68 | model: 69 | name: "cqtnet" 70 | shingling: 71 | len: 150 # in secs 72 | hop: 150 73 | cqt: 74 | hoplen: 0.02 # in seconds 75 | noctaves: 7 # 8 76 | nbinsoct: 12 77 | pool: 78 | len: 17 # in hops 79 | hop: 15 80 | zdim: 300 81 | maxcliques: 100000 82 | 83 | training: 84 | batchsize: 25 # using 2 GPUs 85 | numepochs: 1000 86 | save_freq: null # in epochs 87 | optim: 88 | name: "adam" 89 | lr: 3e-4 90 | wd: 0 91 | sched: "plateau_10" 92 | min_lr: 1e-6 93 | monitor: 94 | quantity: "m_COMP" 95 | mode: "min" 96 | -------------------------------------------------------------------------------- /config/shs-dvinetp.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/SHS100K/audio/" 10 | meta: "cache/metadata-shs.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.4 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.3 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.5 66 | r: [-12,12] 67 | 68 | model: 69 | name: "dvinetp" 70 | shingling: 71 | len: 150 # in secs 72 | hop: 150 73 | cqt: 74 | hoplen: 0.02 # in seconds 75 | noctaves: 7 # 8 76 | nbinsoct: 12 77 | pool: 78 | len: 17 # in hops 79 | hop: 15 80 | ncha_in: 48 81 | zdim: 512 82 | margin: 0.3 83 | lamb: 0.02 84 | 85 | training: 86 | batchsize: 25 # using 2 GPUs 87 | numepochs: 1000 88 | save_freq: null # in epochs 89 | optim: 90 | name: "adam" 91 | lr: 3e-4 92 | wd: 0 93 | sched: "plateau_10" 94 | min_lr: 1e-6 95 | monitor: 96 | quantity: "m_MAP" 97 | mode: "max" 98 | -------------------------------------------------------------------------------- /config/dvi-dvinetp.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/DiscogsVI/audio/" 10 | meta: "cache/metadata-dvi.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.1 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.1 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.1 66 | r: [-12,12] 67 | 68 | model: 69 | name: "dvinetp" 70 | shingling: 71 | len: 150 # in secs 72 | hop: 150 73 | cqt: 74 | hoplen: 0.02 # in seconds 75 | noctaves: 7 # 8 76 | nbinsoct: 12 77 | pool: 78 | len: 17 # in hops 79 | hop: 15 80 | ncha_in: 48 81 | zdim: 512 82 | margin: 0.3 83 | lamb: 0.02 84 | 85 | training: 86 | batchsize: 25 # using 2 GPUs 87 | numepochs: 1000 88 | save_freq: null # in epochs 89 | optim: 90 | name: "adam" 91 | lr: 3e-4 92 | wd: 0 93 | sched: "plateau_10" 94 | min_lr: 1e-6 95 | monitor: 96 | quantity: "m_COMP" 97 | mode: "min" 98 | -------------------------------------------------------------------------------- /config/shs-coverhunterc.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/SHS100K/audio/" 10 | meta: "cache/metadata-shs.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.4 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.3 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.5 66 | r: [-12,12] 67 | 68 | model: 69 | name: "coverhunterc" 70 | shingling: 71 | len: 20 72 | hop: 20 73 | cqt: 74 | hoplen: 0.04 # in seconds 75 | noctaves: 8 # 8 76 | nbinsoct: 12 77 | pool: 78 | len: 1 79 | hop: 1 80 | ncha: 128 81 | ncha_attn: 256 82 | nblocks: 6 83 | maxcliques: 100000 84 | gamma: 2 85 | margin: 0.3 86 | 87 | training: 88 | batchsize: 25 # using 2 GPUs 89 | numepochs: 1000 90 | save_freq: null # in epochs 91 | optim: 92 | name: "adam" 93 | lr: 3e-4 94 | wd: 0 95 | sched: "plateau_10" 96 | min_lr: 1e-6 97 | monitor: 98 | quantity: "m_MAP" 99 | mode: "max" 100 | -------------------------------------------------------------------------------- /config/dvi-coverhunterc.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/DiscogsVI/audio/" 10 | meta: "cache/metadata-dvi.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.1 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.1 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.1 66 | r: [-12,12] 67 | 68 | model: 69 | name: "coverhunterc" 70 | shingling: 71 | len: 20 72 | hop: 20 73 | cqt: 74 | hoplen: 0.04 # in seconds 75 | noctaves: 8 # 8 76 | nbinsoct: 12 77 | pool: 78 | len: 1 79 | hop: 1 80 | ncha: 128 81 | ncha_attn: 256 82 | nblocks: 6 83 | maxcliques: 100000 84 | gamma: 2 85 | margin: 0.3 86 | 87 | training: 88 | batchsize: 25 # using 2 GPUs 89 | numepochs: 1000 90 | save_freq: null # in epochs 91 | optim: 92 | name: "adam" 93 | lr: 3e-4 94 | wd: 0 95 | sched: "plateau_10" 96 | min_lr: 1e-6 97 | monitor: 98 | quantity: "m_COMP" 99 | mode: "min" 100 | -------------------------------------------------------------------------------- /config/dvi-bytecover2x.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/DiscogsVI/audio/" 10 | meta: "cache/metadata-dvi.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.1 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.1 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.1 66 | r: [-12,12] 67 | 68 | model: 69 | name: "bytecover2x" 70 | shingling: 71 | len: 150 # in secs 72 | hop: 150 73 | cqt: 74 | hoplen: 0.023 # in secs 75 | noctaves: 7 76 | nbinsoct: 12 77 | pool: 78 | len: 5 # in hops 79 | hop: 5 80 | ncha: 2048 81 | zdim: 1536 82 | maxcliques: 100000 83 | smooth: 0.1 84 | margin: 0.3 85 | lamb: 0 86 | 87 | training: 88 | batchsize: 25 # using 2 GPUs 89 | numepochs: 1000 90 | save_freq: null # in epochs 91 | optim: 92 | name: "adam" 93 | lr: 3e-4 94 | wd: 0 95 | sched: "plateau_10" 96 | min_lr: 1e-6 97 | monitor: 98 | quantity: "m_COMP" 99 | mode: "min" 100 | -------------------------------------------------------------------------------- /config/shs-bytecover2x.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/SHS100K/audio/" 10 | meta: "cache/metadata-shs.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.4 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.3 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.5 66 | r: [-12,12] 67 | 68 | model: 69 | name: "bytecover2x" 70 | shingling: 71 | len: 150 # in secs 72 | hop: 150 73 | cqt: 74 | hoplen: 0.023 # in secs 75 | noctaves: 7 76 | nbinsoct: 12 77 | pool: 78 | len: 5 # in hops 79 | hop: 5 80 | ncha: 2048 81 | zdim: 1536 82 | maxcliques: 100000 83 | smooth: 0.1 84 | margin: 0.3 85 | lamb: 0 86 | 87 | training: 88 | batchsize: 25 # using 2 GPUs 89 | numepochs: 1000 90 | save_freq: null # in epochs 91 | optim: 92 | name: "adam" 93 | lr: 3e-4 94 | wd: 0 95 | sched: "plateau_10" 96 | min_lr: 1e-6 97 | monitor: 98 | quantity: "m_MAP" 99 | mode: "max" 100 | -------------------------------------------------------------------------------- /config/shs-bytecover3x.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/SHS100K/audio/" 10 | meta: "cache/metadata-shs.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.4 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.3 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.5 66 | r: [-12,12] 67 | 68 | model: 69 | name: "bytecover3x" 70 | shingling: 71 | len: 20 # in secs 72 | hop: 20 73 | cqt: 74 | hoplen: 0.023 # in secs 75 | noctaves: 7 76 | nbinsoct: 12 77 | pool: 78 | len: 6 # in hops 79 | hop: 6 80 | ncha: 2048 81 | zdim: 1024 82 | maxcliques: 100000 83 | nsub: 9 84 | smooth: 0.1 85 | margin: 0.1 86 | lamb: 0.1 87 | 88 | training: 89 | batchsize: 25 # using 2 GPUs 90 | numepochs: 1000 91 | save_freq: null # in epochs 92 | optim: 93 | name: "adam" 94 | lr: 3e-4 95 | wd: 0 96 | sched: "plateau_10" 97 | min_lr: 1e-6 98 | monitor: 99 | quantity: "m_MAP" 100 | mode: "max" 101 | -------------------------------------------------------------------------------- /config/dvi-bytecover3x.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/DiscogsVI/audio/" 10 | meta: "cache/metadata-dvi.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.1 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.1 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.1 66 | r: [-12,12] 67 | 68 | model: 69 | name: "bytecover3x" 70 | shingling: 71 | len: 20 # in secs 72 | hop: 20 73 | cqt: 74 | hoplen: 0.023 # in secs 75 | noctaves: 7 76 | nbinsoct: 12 77 | pool: 78 | len: 6 # in hops 79 | hop: 6 80 | ncha: 2048 81 | zdim: 1024 82 | maxcliques: 100000 83 | nsub: 9 84 | smooth: 0.1 85 | margin: 0.1 86 | lamb: 0.1 87 | 88 | training: 89 | batchsize: 18 # using 2 GPUs 90 | numepochs: 1000 91 | save_freq: null # in epochs 92 | optim: 93 | name: "adam" 94 | lr: 3e-4 95 | wd: 0 96 | sched: "plateau_10" 97 | min_lr: 1e-6 98 | monitor: 99 | quantity: "m_COMP" 100 | mode: "min" 101 | -------------------------------------------------------------------------------- /config/shs-clews.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/SHS100K/audio/" 10 | meta: "cache/metadata-shs.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.4 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.3 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.5 66 | r: [-12,12] 67 | 68 | model: 69 | name: "clews" 70 | shingling: 71 | len: 20 # in secs 72 | hop: 20 73 | cqt: 74 | hoplen: 0.02 # in secs 75 | noctaves: 7 76 | nbinsoct: 12 77 | fscale: 1 78 | pool: 5 # in hops 79 | frontend: 80 | cqtpow: 0.5 81 | channels: [128,256] 82 | backbone: 83 | blocks: [3,4,6,3] 84 | channels: [256,512,1024,2048] 85 | down: [1,2,2,1] 86 | zdim: 2048 87 | loss: 88 | redux: 89 | pos: "bpwr-5" 90 | neg: "min" 91 | gamma: 5 92 | epsilon: 1e-6 93 | 94 | training: 95 | batchsize: 25 # using 2 GPUs 96 | numepochs: 1000 97 | save_freq: null # in epochs 98 | optim: 99 | name: "adam" 100 | lr: 2e-4 101 | wd: 0 102 | sched: "plateau_10" 103 | min_lr: 1e-6 104 | monitor: 105 | quantity: "m_MAP" 106 | mode: "max" 107 | -------------------------------------------------------------------------------- /config/dvi-clews.yaml: -------------------------------------------------------------------------------- 1 | jobname: null # training script automatically inserts it here 2 | seed: 43 # must be positive 3 | checkpoint: null # if we want to load 4 | limit_batches: null 5 | 6 | path: 7 | cache: "cache/" 8 | logs: "logs/" 9 | audio: "data/DiscogsVI/audio/" 10 | meta: "cache/metadata-dvi.pt" 11 | 12 | fabric: 13 | nnodes: 1 14 | ngpus: 1 15 | precision: "32" 16 | 17 | data: 18 | nworkers: 16 # 16, w/ 4 gpus 19 | samplerate: 16000 20 | audiolen: 150 # in seconds 21 | maxlen: null 22 | pad_mode: "repeat" 23 | n_per_class: 4 24 | p_samesong: 0 25 | 26 | augmentations: 27 | # -- Time domain -- 28 | # length: 29 | # p: 1.0 30 | # rmin: 0.6 31 | # polarity: 32 | # p: 0.5 33 | # compexp: 34 | # p: 0.02 35 | # r: [0.6,1.4] 36 | # reqtime: 37 | # p: 0.5 38 | # nfreqs: [1,3] 39 | # gains: [-8,8] 40 | # qrange: [0.5,10.0] 41 | # gain: 42 | # p: 0.9 43 | # r: [0.02,1] # in absolute amplitude 44 | # clipping: 45 | # p: 0.01 46 | # max_qtl: 0.3 47 | # p_soft: 0.75 48 | # -- CQT domain -- 49 | specaugment: 50 | p: 0.1 51 | n: 1 52 | full: true 53 | f_pc: 0.15 54 | t_pc: 0.15 55 | timestretch: 56 | p: 0.1 57 | r: [0.6,1.8] 58 | pad_mode: "repeat" 59 | cut_mode: "random" 60 | # reqcqt: 61 | # p: 0.1 62 | # lpf: 0.02 63 | # r: [-1,1] 64 | pitchtranspose: 65 | p: 0.1 66 | r: [-12,12] 67 | 68 | model: 69 | name: "clews" 70 | shingling: 71 | len: 20 # in secs 72 | hop: 20 73 | cqt: 74 | hoplen: 0.02 # in secs 75 | noctaves: 7 76 | nbinsoct: 12 77 | fscale: 1 78 | pool: 5 # in hops 79 | frontend: 80 | cqtpow: 0.5 81 | channels: [128,256] 82 | backbone: 83 | blocks: [3,4,6,3] 84 | channels: [256,512,1024,2048] 85 | down: [1,2,2,1] 86 | zdim: 1024 87 | loss: 88 | redux: 89 | pos: "bpwr-5" 90 | neg: "min" 91 | gamma: 5 92 | epsilon: 1e-6 93 | 94 | training: 95 | batchsize: 25 # using 2 GPUs 96 | numepochs: 1000 97 | save_freq: null # in epochs 98 | optim: 99 | name: "adam" 100 | lr: 2e-4 101 | wd: 0 102 | sched: "plateau_10" 103 | min_lr: 1e-6 104 | monitor: 105 | quantity: "m_COMP" 106 | mode: "min" 107 | -------------------------------------------------------------------------------- /utils/print_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from tqdm import tqdm 4 | 5 | ################################################################################################### 6 | 7 | 8 | def myprint(s, end="\n", doit=True, flush=True): 9 | if doit: 10 | print(s, end=end, flush=flush) 11 | 12 | 13 | def myprogbar(iterator, desc=None, doit=True, ncols=80, ascii=True, leave=True): 14 | return tqdm( 15 | iterator, 16 | desc=desc, 17 | ascii=ascii, 18 | ncols=ncols, 19 | disable=not doit, 20 | leave=leave, 21 | file=sys.stdout, 22 | mininterval=0.2, 23 | maxinterval=2, 24 | ) 25 | 26 | 27 | def flush(doit=True): 28 | if doit: 29 | sys.stdout.flush() 30 | 31 | 32 | ################################################################################################### 33 | 34 | 35 | def report( 36 | dict, 37 | desc=None, 38 | ncols=120, 39 | fmt=None, 40 | fmt_default={ 41 | "loss": ".3f", 42 | "l_main": ".3f", 43 | "MAP": "5.3f", 44 | "m_MAP": "5.3f", 45 | "MR1": "7.1f", 46 | "m_MR1": "7.1f", 47 | "ARP": "5.2f", 48 | "m_ARP": "5.2f", 49 | }, 50 | fmt_base=".3f", 51 | clean_line=True, 52 | ): 53 | if clean_line: 54 | s = "\r" + " " * ncols + "\r" 55 | else: 56 | s = "" 57 | if desc is not None: 58 | s += desc + ": " 59 | keys = list(dict.keys()) 60 | keys.sort() 61 | for i, key in enumerate(keys): 62 | value = dict[key] 63 | if i > 0: 64 | s += ", " 65 | s += key + " = " 66 | if type(value) == str: 67 | s += value 68 | else: 69 | if fmt is not None and key in fmt: 70 | ff = fmt[key] 71 | elif key in fmt_default: 72 | ff = fmt_default[key] 73 | else: 74 | ff = fmt_base 75 | aux = "{:" + ff + "}" 76 | s += aux.format(value) 77 | return s 78 | 79 | 80 | ################################################################################################### 81 | 82 | 83 | class Timer: 84 | def __init__(self, use_milliseconds=False): 85 | self.use_milliseconds = use_milliseconds 86 | self.reset() 87 | 88 | def reset(self): 89 | self.tstart = time.time() 90 | 91 | def time(self): 92 | elapsed = time.time() - self.tstart 93 | msecs = elapsed % 60 94 | secs = int(elapsed) % 60 95 | mins = (int(elapsed) // 60) % 60 96 | hours = (int(elapsed) // (60 * 60)) % 24 97 | days = int(elapsed) // (60 * 60 * 24) 98 | if self.use_milliseconds: 99 | s = f"{msecs:04.1f}" 100 | else: 101 | s = f"{secs:02d}" 102 | s = f"{hours:02d}:{mins:02d}:" + s 103 | if days > 0: 104 | s = f"{days:02d}:" + s 105 | return s 106 | 107 | 108 | ################################################################################################### 109 | -------------------------------------------------------------------------------- /utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import torch 4 | import torchaudio 5 | import soxr 6 | import warnings 7 | 8 | ################################################################################################### 9 | 10 | 11 | def get_backend(filename): 12 | if filename.lower().endswith(".mp3"): 13 | backend = "ffmpeg" 14 | else: 15 | backend = "soundfile" 16 | return backend 17 | 18 | 19 | def get_info(filename, backend=None): 20 | if backend is None: 21 | backend = get_backend(filename) 22 | info = argparse.Namespace() 23 | with warnings.catch_warnings(): 24 | warnings.simplefilter("ignore") 25 | aux = torchaudio.info(filename) 26 | info.samplerate = aux.sample_rate 27 | info.length = aux.num_frames / aux.sample_rate 28 | info.channels = aux.num_channels 29 | return info 30 | 31 | 32 | def load_audio( 33 | filename, 34 | sample_rate=None, 35 | n_channels=None, 36 | start=0, # in samples 37 | length=None, # in samples 38 | backend=None, 39 | resample_method="soxr", 40 | pad_till_length=False, 41 | pad_mode="zeros", 42 | safe_load=True, 43 | return_numpy=False, 44 | ): 45 | # Load 46 | def load(): 47 | return torchaudio.load( 48 | filename, 49 | frame_offset=start, 50 | num_frames=length if length is not None else -1, 51 | normalize=True, # to float32 52 | channels_first=True, 53 | backend=get_backend(filename) if backend is None else backend, 54 | ) 55 | 56 | if safe_load: 57 | try: 58 | x, sr = load() 59 | except: 60 | print("\nWARNING: Could not load " + filename, flush=True) 61 | print(start, length, flush=True) 62 | return None 63 | else: 64 | try: 65 | x, sr = load() 66 | except: 67 | print("\nERROR: Could not load " + filename, flush=True) 68 | print(start, length, flush=True) 69 | x, sr = load() 70 | # Adjust channels 71 | if n_channels is None or n_channels == x.size(0): 72 | pass 73 | elif n_channels == 1: 74 | x = x.mean(0, keepdim=True) 75 | elif n_channels == 2: 76 | if x.size(0) == 1: 77 | x = torch.cat([x, x], dim=0) 78 | else: 79 | raise NotImplementedError 80 | else: 81 | raise NotImplementedError 82 | # Adjust sample rate 83 | if sample_rate is None: 84 | sample_rate = sr 85 | elif sr != sample_rate: 86 | x = resample(x, sr, sample_rate, method=resample_method) 87 | # Pad length 88 | if pad_till_length and length > x.size(1): 89 | if pad_mode == "zeros": 90 | x = torch.nn.functional.pad( 91 | x, (0, length - x.size(1)), mode="constant", value=0 92 | ) 93 | elif pad_mode == "repeat": 94 | aux = torch.cat([x, x], dim=-1) 95 | while aux.size(-1) < length: 96 | aux = torch.cat([aux, x], dim=-1) 97 | x = aux[:, :length] 98 | else: 99 | raise NotImplementedError 100 | # Done 101 | if return_numpy: 102 | return x.numpy() 103 | return x 104 | 105 | 106 | ################################################################################################### 107 | 108 | 109 | def resample(audio, in_sr, out_sr, method="soxr", prevent_clip=True): 110 | # audio is (C,T) or (B,T) 111 | audio *= 0.5 112 | if method == "soxr": 113 | audio = ( 114 | torch.FloatTensor(soxr.resample(audio.T.numpy(), in_sr, out_sr)) 115 | .to(audio.device) 116 | .T 117 | ) 118 | elif method == "torchaudio": 119 | audio = torchaudio.functional.resample(audio, orig_freg=in_sr, new_freq=out_sr) 120 | else: 121 | raise NotImplementedError 122 | audio *= 2 123 | if prevent_clip: 124 | mx = audio.abs().max(-1, keepdim=True)[0] 125 | audio /= torch.clamp(mx, min=1) 126 | else: 127 | audio = torch.clamp(audio, -1, 1) 128 | return audio 129 | 130 | 131 | ################################################################################################### 132 | 133 | 134 | def get_frames(x, win=10, hop=1, dimstack=1): 135 | frames = [] 136 | for i in range(0, x.size(-1) - win + 1, hop): 137 | frames.append(x[..., i : i + win]) 138 | return torch.stack(frames, dim=dimstack) 139 | 140 | 141 | ################################################################################################### 142 | -------------------------------------------------------------------------------- /lib/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | ################################################################################################### 5 | 6 | 7 | @torch.inference_mode() 8 | def compute( 9 | model, 10 | queries_c, # clique index (B) 11 | queries_i, # song index (B) 12 | queries_z, # embedding (B,S,C) 13 | candidates_c, # clique index (B') 14 | candidates_i, # song index (B') 15 | candidates_z, # embedding (B',S,C) 16 | queries_m=None, 17 | candidates_m=None, 18 | redux_strategy=None, 19 | batch_size_candidates=None, 20 | ): 21 | # Prepare 22 | aps = [] 23 | r1s = [] 24 | rpcs = [] 25 | model.eval() 26 | for n in range(len(queries_i)): 27 | # Compute distance between query and everything 28 | if batch_size_candidates is None or batch_size_candidates >= len(candidates_i): 29 | dist = model.distances( 30 | queries_z[n : n + 1].float(), 31 | candidates_z.float(), 32 | qmask=queries_m[n : n + 1] if queries_m is not None else None, 33 | cmask=candidates_m, 34 | redux_strategy=redux_strategy, 35 | ).squeeze(0) 36 | else: 37 | dist = [] 38 | for mstart in range(0, len(candidates_i), batch_size_candidates): 39 | mend = min(mstart + batch_size_candidates, len(candidates_i)) 40 | ddd = model.distances( 41 | queries_z[n : n + 1].float(), 42 | candidates_z[mstart:mend].float(), 43 | qmask=queries_m[n : n + 1] if queries_m is not None else None, 44 | cmask=( 45 | candidates_m[mstart:mend] if candidates_m is not None else None 46 | ), 47 | redux_strategy=redux_strategy, 48 | ).squeeze(0) 49 | dist.append(ddd) 50 | dist = torch.cat(dist, dim=-1) 51 | # Get ground truth 52 | match_clique = candidates_c == queries_c[n] 53 | # Remove query from candidates if present 54 | match_query = candidates_i == queries_i[n] 55 | dist = torch.where(match_query, torch.inf, dist) 56 | match_clique = torch.where(match_query, False, match_clique) 57 | # Compute AP and R1 58 | aps.append(average_precision(dist, match_clique)) 59 | r1s.append(rank_of_first_correct(dist, match_clique)) 60 | rpcs.append(rank_percentile(dist, match_clique)) 61 | # Return as vector 62 | aps = torch.stack(aps) 63 | r1s = torch.stack(r1s) 64 | rpcs = torch.stack(rpcs) 65 | return aps, r1s, rpcs 66 | 67 | 68 | ################################################################################################### 69 | 70 | 71 | @torch.inference_mode() 72 | def average_precision(distances, ismatch): 73 | assert distances.ndim == 1 and ismatch.ndim == 1 and len(distances) == len(ismatch) 74 | rel = ismatch.type_as(distances) 75 | assert rel.sum() >= 1, "There should be at least 1 relevant item" 76 | rel = rel[torch.argsort(distances)] 77 | rank = torch.arange(len(rel), device=distances.device) + 1 78 | prec = torch.cumsum(rel, 0) / rank 79 | ap = torch.sum(prec * rel) / torch.sum(rel) 80 | return ap 81 | 82 | 83 | @torch.inference_mode() 84 | def rank_of_first_correct(distances, ismatch): 85 | assert distances.ndim == 1 and ismatch.ndim == 1 and len(distances) == len(ismatch) 86 | rel = ismatch.type_as(distances) 87 | assert rel.sum() >= 1, "There should be at least 1 relevant item" 88 | rel = rel[torch.argsort(distances)] 89 | # argmax returns index of first occurrence 90 | r1 = (torch.argmax(rel) + 1).type_as(distances) 91 | return r1 92 | 93 | 94 | @torch.inference_mode() 95 | def rank_percentile(distances, ismatch, biased=False): 96 | # https://publications.hevs.ch/index.php/publications/show/125 97 | assert distances.ndim == 1 and ismatch.ndim == 1 and len(distances) == len(ismatch) 98 | rel = ismatch.type_as(distances) 99 | assert rel.sum() >= 1, "There should be at least 1 relevant item" 100 | rel = rel[torch.argsort(distances)] 101 | if biased: 102 | # Size of the clique affects the measure, that is, you do not get a 103 | # perfect 0 score if clique size>1 104 | normrank = torch.linspace(0, 1, len(rel), device=distances.device) 105 | else: 106 | # counting number of zeros preceding rels allows to get perfect 0 score 107 | normrank = torch.cumsum(1 - rel, 0) / torch.sum(1 - rel) 108 | rpc = torch.sum(rel * normrank) / torch.sum(rel) 109 | return 100 * rpc 110 | 111 | 112 | ################################################################################################### 113 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import sys, os, argparse 2 | import importlib 3 | from omegaconf import OmegaConf 4 | import torch 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import DDPStrategy 7 | from tqdm import tqdm 8 | 9 | from utils import pytorch_utils, audio_utils 10 | 11 | ACCEPTED_AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac", ".ogg") 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--checkpoint", type=str, required=True) 15 | parser.add_argument("--path_in", type=str, default=None) 16 | parser.add_argument("--path_out", type=str, default=None) 17 | parser.add_argument("--fn_in", type=str, default=None) 18 | parser.add_argument("--fn_out", type=str, default=None) 19 | parser.add_argument("--device", type=str, default="cuda") 20 | parser.add_argument("--hop_size", type=float, default=5.0) 21 | parser.add_argument("--win_len", type=float, default=-1) # will use model's default 22 | args = parser.parse_args() 23 | if args.win_len <= 0: 24 | args.win_len = None 25 | using_paths = args.path_in is not None and args.path_out is not None 26 | using_filenames = args.fn_in is not None and args.fn_out is not None 27 | conflicting = (using_paths and (args.fn_in is not None or args.fn_out is not None)) or ( 28 | using_filenames and (args.path_in is not None or args.path_out is not None) 29 | ) 30 | if (not (using_paths or using_filenames)) and conflicting: 31 | print( 32 | "ERROR: You should provide either path_in/path_out or fn_in/fn_out (and only these combinations)." 33 | ) 34 | print( 35 | ' Use either "--path_in=xxx --path_out=yyy" or "--fn_in=xxx.wav --fn_out=yyy.pt".' 36 | ) 37 | sys.exit() 38 | print("=" * 100) 39 | print(args) 40 | print("=" * 100) 41 | 42 | ############################################################################### 43 | 44 | # Init output path 45 | if using_paths and os.path.exists(args.path_out): 46 | print("*** Output path exists (" + args.path_out + ") ***") 47 | print("By hitting enter it will be erased and the script will continue. ") 48 | input("[Enter to continue/CTRL-C to exit]") 49 | os.system("rm -rf " + args.path_out) 50 | 51 | # Init pytorch/Fabric 52 | torch.backends.cudnn.benchmark = True 53 | torch.backends.cudnn.deterministic = False 54 | torch.set_float32_matmul_precision("medium") 55 | torch.autograd.set_detect_anomaly(False) 56 | fabric = Fabric( 57 | accelerator=args.device, 58 | devices=1, 59 | num_nodes=1, 60 | strategy=DDPStrategy(broadcast_buffers=False), 61 | precision="32", 62 | ) 63 | fabric.launch() 64 | 65 | # Load conf 66 | print("Load model conf...") 67 | path_checkpoint, _ = os.path.split(args.checkpoint) 68 | conf = OmegaConf.load(os.path.join(path_checkpoint, "configuration.yaml")) 69 | 70 | # Init model 71 | print("Init model...") 72 | module = importlib.import_module("models." + conf.model.name) 73 | with fabric.init_module(): 74 | model = module.Model(conf.model, sr=conf.data.samplerate) 75 | model = fabric.setup(model) 76 | 77 | # Load model 78 | print("Load checkpoint...") 79 | state = pytorch_utils.get_state(model, None, None, conf, None, None, None) 80 | fabric.load(args.checkpoint, state) 81 | model, _, _, conf, _, _, _ = pytorch_utils.set_state(state) 82 | model.eval() 83 | 84 | ############################################################################### 85 | 86 | # Get all files 87 | print("Get filenames...") 88 | if using_paths: 89 | filenames = [] 90 | for path, dirs, files in os.walk(args.path_in): 91 | for file in files: 92 | # Filter audio files 93 | _, ext = os.path.splitext(file) 94 | if ext.lower() not in ACCEPTED_AUDIO_EXTENSIONS: 95 | continue 96 | # Get full filename 97 | fn_in = os.path.join(path, file) 98 | fn_out = os.path.join(args.path_out, os.path.relpath(fn_in, args.path_in)) 99 | fn_out = os.path.splitext(fn_out)[0] + ".pt" 100 | path_out, _ = os.path.split(fn_out) 101 | filenames.append([fn_in, path_out, fn_out]) 102 | else: 103 | path_out, _ = os.path.split(args.fn_out) 104 | filenames = [[args.fn_in, path_out, args.fn_out]] 105 | 106 | # Extract 107 | with torch.inference_mode(): 108 | for fn_in, path_out, fn_out in tqdm( 109 | filenames, ascii=True, ncols=100, desc="Extract embeddings" 110 | ): 111 | # Load mono audio 112 | x = audio_utils.load_audio(fn_in, sample_rate=model.sr, n_channels=1) 113 | if x is None: 114 | continue 115 | # Compute embeddings 116 | z = model(x, shingle_hop=args.hop_size, shingle_len=args.win_len) 117 | z = z.squeeze(0).cpu() 118 | # Save 119 | os.makedirs(path_out, exist_ok=True) 120 | torch.save(z, fn_out) 121 | 122 | ############################################################################### 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Supervised Contrastive Learning from Weakly-Labeled Audio Segments for Musical Version Matching 2 | 3 | _This is the repository for the CLEWS paper. It includes the code to train and evaluate the main models we consider (including baselines), checkpoints for DVI and SHS data sets, and a basic inference script. We do not include the ablation experiments._ 4 | 5 | ### Abstract 6 | 7 | Detecting musical versions (different renditions of the same piece) is a challenging task with important applications. Because of the ground truth nature, existing approaches match musical versions at the track level (e.g., whole song). However, most applications require to match them at the segment level (e.g., 20s chunks). In addition, existing approaches resort to classification and triplet losses, disregarding more recent losses that could bring meaningful improvements. In this paper, we propose a method to learn from weakly annotated segments, together with a contrastive loss variant that outperforms well-studied alternatives. The former is based on pairwise segment distance reductions, while the latter modifies an existing loss following decoupling, hyper-parameter, and geometric considerations. With these two elements, we do not only achieve state-of-the-art results in the standard track-level evaluation, but we also obtain a breakthrough performance in a segment-level evaluation. We believe that, due to the generality of the challenges addressed here, the proposed methods may find utility in domains beyond audio or musical version matching. 8 | 9 | ### Authors 10 | 11 | Joan Serrà, R. Oguz Araz, Dmitry Bogdanov, & Yuki Mitsufuji. 12 | 13 | ### Reference and links 14 | 15 | J. Serrà, R. O. Araz, D. Bogdanov, & Y. Mitsufuji (2025). Supervised Contrastive Learning from Weakly-Labeled Audio Segments for Musical Version Matching. ArXiv: 2502.16936. 16 | 17 | [[`arxiv`](https://arxiv.org/abs/2502.16936)] [[`checkpoints`](https://zenodo.org/records/15045900)] 18 | 19 | ## Preparation 20 | 21 | ### Environment 22 | 23 | CLEWS requires python>=3.10. We used python 3.10.13. 24 | 25 | You should be able to create the environment by running [install_requirements.sh](install_requirements.sh). However, we recommend to just check inside that file and do it step by step. 26 | 27 | ## Operation 28 | 29 | ### Inference 30 | 31 | We provide a basic inference script to extract embeddings using a pre-trained checkpoint: 32 | 33 | ```bash 34 | OMP_NUM_THREADS=1 python inference.py --checkpoint=logs/model/checkpoint_best.ckpt --path_in=data/audio_files/ --path_out=cache/extracted_embeddings/ 35 | ``` 36 | 37 | It will go through all audio files in the folder and subfolders (recursive) and create the same structure in the output folder. Alternatively, you can use the following arguments for processing just a single file: 38 | 39 | ```bash 40 | OMP_NUM_THREADS=1 python inference.py --checkpoint=logs/model/checkpoint_best.ckpt --fn_in=data/audio_files/filename.mp3 --fn_out=cache/extracted_embeddings/filename.pt 41 | ``` 42 | 43 | ## Training and testing 44 | 45 | Note: Training and testing assume you have at least one GPU. 46 | 47 | ### Folder structure 48 | 49 | Apart from the structure of this repo, we used the following folders: 50 | * `data`: folder pointing to original audio and metadata files (can be a symbolic link). 51 | * `cache`: folder where to store preprocessed metadata files. 52 | * `logs`: folder where to output checkpoints and tensorboard files. 53 | 54 | You should create/organize those folders prior to running any training/testing script. The folders are not necessary for regular operation/inference. 55 | 56 | ### Preprocessing 57 | 58 | To launch the data preprocessing script, you can run, for instance: 59 | 60 | ```bash 61 | OMP_NUM_THREADS=1 python data_preproc.py --njobs=16 --dataset=SHS100K --path_meta=data/SHS100K/meta/ --path_audio=data/SHS100K/audio/ --ext_in=mp3 --fn_out=cache/metadata-shs.pt 62 | OMP_NUM_THREADS=1 python data_preproc.py --njobs=16 --dataset=DiscogsVI --path_meta=data/DiscogsVI/meta/ --path_audio=data/DiscogsVI/audio/ --ext_in=mp3 --fn_out=cache/metadata-dvi.pt 63 | ``` 64 | 65 | This script takes time as it reads/checks every audio file (so that you do not need to run checks while training or in your dataloader). You just do this once and save the corresponding metadata file. Depending on the path names/organization of your data set it is possible that you have to modify some minor portions of the `data_preproc.py` script. 66 | 67 | ### Training 68 | 69 | Before every training run, you need to clean the logs path and copy the configuration file (with the specific name `configuration.yaml`): 70 | ```bash 71 | rm -rf logs/shs-clews/ ; mkdir logs/shs-clews/ ; cp config/shs-clews.yaml logs/shs-clews/configuration.yaml 72 | rm -rf logs/dvi-clews/ ; mkdir logs/dvi-clews/ ; cp config/dvi-clews.yaml logs/dvi-clews/configuration.yaml 73 | ``` 74 | 75 | Next, launch the training script using, for instance: 76 | 77 | ```bash 78 | OMP_NUM_THREADS=1 python train.py jobname=shs-clews conf=config/shs-clews.yaml fabric.nnodes=1 fabric.ngpus=2 79 | OMP_NUM_THREADS=1 python train.py jobname=dvi-clews conf=config/dvi-clews.yaml fabric.nnodes=1 fabric.ngpus=2 80 | ``` 81 | 82 | ### Testing 83 | 84 | To launch the testing script, you can run, for instance: 85 | 86 | ```bash 87 | OMP_NUM_THREADS=1 python test.py jobname=test-script checkpoint=logs/shs-clews/checkpoint_best.ckpt nnodes=1 ngpus=4 redux=bpwr-10 88 | OMP_NUM_THREADS=1 python test.py jobname=test-script checkpoint=logs/dvi-clews/checkpoint_best.ckpt nnodes=1 ngpus=4 redux=bpwr-10 maxlen=300 89 | ``` 90 | 91 | ## License 92 | 93 | The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). 94 | 95 | ## Notes 96 | 97 | * If using this code, parts of it, or developments from it, please cite the reference above. 98 | * We do not provide any support or assistance for the supplied code nor we offer any other compilation/variant of it. 99 | * We assume no responsibility regarding the provided code. 100 | -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | 5 | from utils import audio_utils 6 | from lib import tensor_ops as tops 7 | 8 | LIMIT_CLIQUES = None 9 | 10 | 11 | class Dataset(torch.utils.data.Dataset): 12 | 13 | def __init__( 14 | self, 15 | conf, 16 | split, 17 | augment=False, 18 | fullsongs=False, 19 | checks=True, 20 | verbose=False, 21 | ): 22 | assert split in ("train", "valid", "test") 23 | # Params 24 | self.augment = augment 25 | self.samplerate = conf.samplerate 26 | self.fullsongs = fullsongs 27 | self.audiolen = conf.audiolen if not self.fullsongs else None 28 | self.maxlen = conf.maxlen if not self.fullsongs else None 29 | self.pad_mode = conf.pad_mode 30 | self.n_per_class = conf.n_per_class 31 | self.p_samesong = conf.p_samesong 32 | self.verbose = verbose 33 | # Load metadata 34 | self.info, splitdict = torch.load(conf.path.meta) 35 | if LIMIT_CLIQUES is None: 36 | self.clique = splitdict[split] 37 | else: 38 | if self.verbose: 39 | print(f"[Limiting cliques to {LIMIT_CLIQUES}]") 40 | self.clique = {} 41 | for key, item in splitdict[split].items(): 42 | self.clique[key] = item 43 | if len(self.clique) == LIMIT_CLIQUES: 44 | break 45 | # Update filename with audio_path 46 | for ver in self.info.keys(): 47 | self.info[ver]["filename"] = os.path.join( 48 | conf.path.audio, self.info[ver]["filename"] 49 | ) 50 | # Checks 51 | if checks: 52 | self.perform_checks(splitdict, split) 53 | # Get clique id 54 | self.clique2id = {} 55 | if split == "train": 56 | offset = 0 57 | elif split == "valid": 58 | offset = len(splitdict["train"]) 59 | else: 60 | offset = len(splitdict["train"]) + len(splitdict["valid"]) 61 | for i, cl in enumerate(self.clique.keys()): 62 | self.clique2id[cl] = offset + i 63 | # Get idx2version 64 | self.versions = [] 65 | for vers in self.clique.values(): 66 | self.versions += vers 67 | # Prints 68 | if self.verbose: 69 | print( 70 | f" {split}: --- Found {len(self.clique)} cliques, {len(self.versions)} songs ---" 71 | ) 72 | 73 | ########################################################################### 74 | 75 | def __len__(self): 76 | return len(self.versions) 77 | 78 | def __getitem__(self, idx): 79 | # Get v1 (anchor) and clique 80 | v1 = self.versions[idx] 81 | i1 = self.info[v1]["id"] 82 | cl = self.info[v1]["clique"] 83 | icl = self.clique2id[cl] 84 | # Get other versions from same clique 85 | otherversions = [] 86 | for v in self.clique[cl]: 87 | if v != v1 or torch.rand(1).item() < self.p_samesong: 88 | otherversions.append(v) 89 | if self.augment: 90 | new_vers = [] 91 | for k in torch.randperm(len(otherversions)).tolist(): 92 | new_vers.append(otherversions[k]) 93 | otherversions = new_vers 94 | # Construct v1..vn array (n_per_class) 95 | v_n = [v1] 96 | i_n = [i1] 97 | for k in range(self.n_per_class - 1): 98 | v = otherversions[k % len(otherversions)] 99 | i = self.info[v]["id"] 100 | v_n.append(v) 101 | i_n.append(i) 102 | # Time augment? 103 | s_n = [] 104 | for v in v_n: 105 | if self.augment: 106 | dur = self.info[v]["length"] 107 | if self.maxlen is not None: 108 | dur = min(self.maxlen, dur) 109 | start = max(0, torch.rand(1).item() * (dur - self.audiolen)) 110 | else: 111 | start = 0 112 | s_n.append(start) 113 | # Load audio and create output 114 | output = [icl] 115 | for i, v, s in zip(i_n, v_n, s_n): 116 | fn = self.info[v]["filename"] 117 | x = self.get_audio(fn, start=s, length=self.audiolen) 118 | output += [i, x] 119 | if self.fullsongs: 120 | return output 121 | return output 122 | 123 | ########################################################################### 124 | 125 | def get_audio(self, fn, start=0, length=None): 126 | start = int(start * self.samplerate) 127 | length = None if length is None else int(length * self.samplerate) 128 | # Load 129 | x = audio_utils.load_audio( 130 | fn, 131 | self.samplerate, 132 | n_channels=1, 133 | start=start, 134 | length=length, 135 | pad_till_length=False, # will pad below 136 | backend="ffmpeg", 137 | safe_load=False, 138 | ).squeeze(0) 139 | if length is not None and length > 0: 140 | x = tops.force_length( 141 | x, 142 | length, 143 | dim=-1, 144 | pad_mode=self.pad_mode, 145 | cut_mode="random" if self.augment else "start", 146 | ) 147 | return x 148 | 149 | ########################################################################### 150 | 151 | def perform_checks(self, splitdict, split): 152 | msg = "" 153 | errors = False 154 | # # Max id 155 | # mx = -1 156 | # for v in self.info.keys(): 157 | # if self.info[v]["id"] > mx: 158 | # mx = self.info[v]["id"] 159 | # msg += f"\n {split}: Max ID = {mx}" 160 | # Cliques have at least 2 versions 161 | for cl in self.clique.keys(): 162 | if len(self.clique[cl]) < 2: 163 | msg += f"\n {split}: Clique {cl} has < 2 versions" 164 | errors = True 165 | # No overlap between partitions 166 | for cl in splitdict[split].keys(): 167 | for partition in ["train", "valid", "test"]: 168 | if split == partition: 169 | continue 170 | if cl in splitdict[partition]: 171 | msg += ( 172 | f"\n {split}: Clique {cl} is both in {split} and {partition}" 173 | ) 174 | # errors=True 175 | if self.verbose and len(msg) > 1: 176 | print(msg[1:]) 177 | if errors: 178 | sys.exit() 179 | -------------------------------------------------------------------------------- /models/coverhunterc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from nnAudio import features # type: ignore 4 | from einops import rearrange, repeat 5 | 6 | from lib.coverhunter import ch_conformer, ch_layers, ch_losses 7 | from lib import tensor_ops as tops 8 | 9 | 10 | class Model(torch.nn.Module): 11 | 12 | def __init__(self, conf, sr=16000, eps=1e-6): 13 | super().__init__() 14 | self.conf = conf 15 | self.sr = sr 16 | self.eps = eps 17 | self.minlen = conf.shingling.len 18 | # CQT 19 | self.cqtbins = self.conf.cqt.noctaves * self.conf.cqt.nbinsoct 20 | self.cqt = features.CQT1992v2( 21 | sr=self.sr, 22 | hop_length=int(self.conf.cqt.hoplen * sr), 23 | n_bins=self.cqtbins, 24 | bins_per_octave=self.conf.cqt.nbinsoct, 25 | trainable=False, 26 | verbose=False, 27 | ) 28 | self.cqtpool = torch.nn.AvgPool1d( 29 | self.conf.cqt.pool.len, stride=self.conf.cqt.pool.hop 30 | ) 31 | # Model 32 | self.preproc = torch.nn.BatchNorm1d(self.cqtbins) 33 | self.backbone = ch_conformer.ConformerEncoder( 34 | input_size=self.cqtbins, 35 | output_size=self.conf.ncha, 36 | linear_units=self.conf.ncha_attn, 37 | num_blocks=self.conf.nblocks, 38 | ) 39 | self.pool_layer = ch_layers.AttentiveStatisticsPooling( 40 | self.conf.ncha, output_channels=self.conf.ncha 41 | ) 42 | self.bottleneck = torch.nn.BatchNorm1d(self.conf.ncha) 43 | self.bottleneck.bias.requires_grad_(False) # no shift 44 | # Loss 45 | self.ce_layer = torch.nn.Linear( 46 | self.conf.ncha, self.conf.maxcliques, bias=False 47 | ) 48 | self.ce_loss = ch_losses.FocalLoss(alpha=None, gamma=self.conf.gamma) 49 | # self.ce_loss = torch.nn.CrossEntropyLoss() 50 | self.center_loss = ch_losses.CenterLoss( 51 | num_classes=self.conf.maxcliques, feat_dim=self.conf.ncha 52 | ) 53 | self.triplet_loss = ch_losses.HardTripletLoss(margin=self.conf.margin) 54 | 55 | def get_shingle_params(self): 56 | return self.conf.shingling.len, self.conf.shingling.hop 57 | 58 | ########################################################################### 59 | 60 | def forward( 61 | self, 62 | h, # (B,T) 63 | shingle_len=None, 64 | shingle_hop=None, 65 | ): 66 | with torch.inference_mode(): 67 | h = self.prepare(h, shingle_len=shingle_len, shingle_hop=shingle_hop) 68 | h = h.clone() 69 | h, _ = self.embed(h) 70 | return h # (B,C) 71 | 72 | def prepare( 73 | self, 74 | h, # (B,T) 75 | shingle_len=None, 76 | shingle_hop=None, 77 | ): 78 | assert h.ndim == 2 79 | assert shingle_len is None or shingle_len > 0 80 | assert shingle_hop is None or shingle_hop > 0 81 | slen = self.conf.shingling.len if shingle_len is None else shingle_len 82 | shop = self.conf.shingling.hop / 2 if shingle_hop is None else shingle_hop 83 | # Shingle 84 | h = tops.get_frames( 85 | h, int(self.sr * slen), int(self.sr * shop), pad_mode="repeat" 86 | ) 87 | # Check audio length 88 | h = tops.force_length( 89 | h, int(self.sr * self.minlen), dim=-1, pad_mode="repeat", allow_longer=True 90 | ) 91 | # CQT 92 | s = h.size(1) 93 | h = rearrange(h, "b s t -> (b s) t") 94 | h = self.cqt(h) 95 | h = self.cqtpool(h) 96 | h = rearrange(h, "(b s) c t -> b s c t", s=s) 97 | return h # (B,S,C,T) 98 | 99 | def embed( 100 | self, 101 | h, # (B,S,C,T) 102 | ): 103 | assert h.ndim == 4 104 | s = h.size(1) 105 | h = rearrange(h, "b s c t -> (b s) c t") 106 | h = self.preproc(h).transpose(1, 2) 107 | lens = torch.full( 108 | [h.size(0)], fill_value=h.size(1), dtype=torch.long, device=h.device 109 | ) 110 | h, _ = self.backbone(h, xs_lens=lens, decoding_chunk_size=-1) 111 | f_t = self.pool_layer(h) 112 | f_i = self.bottleneck(f_t) 113 | f_t = rearrange(f_t, "(b s) c -> b s c", s=s) 114 | f_i = rearrange(f_i, "(b s) c -> b s c", s=s) 115 | return f_i, f_t # (B,S,C) both 116 | 117 | ########################################################################### 118 | 119 | def loss( 120 | self, 121 | label, # (B) 122 | idx, # (B) 123 | f_i, # (B,S,C) 124 | extra=None, 125 | ): 126 | f_t = extra 127 | assert len(label) == len(idx) and len(label) == len(f_t) 128 | s = f_t.size(1) 129 | f_t = rearrange(f_t, "b s c -> (b s) c") 130 | f_i = rearrange(f_i, "b s c -> (b s) c") 131 | label = rearrange(label.unsqueeze(-1).expand(-1, s), "b s -> (b s)") 132 | idx = rearrange(idx.unsqueeze(-1).expand(-1, s), "b s -> (b s)") 133 | loss_focal = self.ce_loss(self.ce_layer(f_i), label) 134 | loss_center = self.center_loss(f_t, label) 135 | loss_triplet = self.triplet_loss(f_t, label, ids=idx) 136 | 137 | loss = loss_focal + 0.01 * loss_center + 0.1 * loss_triplet 138 | logdict = { 139 | "l_main": loss, 140 | "l_cent": loss_focal, 141 | "l_cont": loss_triplet, 142 | } 143 | return loss, logdict 144 | 145 | ########################################################################### 146 | 147 | def distances( 148 | self, 149 | q, # (B,C) 150 | c, # (B',C) 151 | qmask=None, 152 | cmask=None, 153 | redux_strategy=None, 154 | ): 155 | assert q.ndim == 3 and c.ndim == 3 and q.size(-1) == c.size(-1) 156 | if redux_strategy is None: 157 | redux_strategy = "min" 158 | s1, s2 = q.size(1), c.size(1) 159 | q = rearrange(q, "b s c -> (b s) c") 160 | c = rearrange(c, "b s c -> (b s) c") 161 | dist = tops.pairwise_distance_matrix(q, c, mode="cos") 162 | dist = rearrange(dist, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=s1, s2=s2) 163 | if qmask is not None and cmask is not None: 164 | qmask = rearrange(qmask, "b s -> (b s)") 165 | cmask = rearrange(cmask, "b s -> (b s)") 166 | mask = qmask.view(-1, 1) | cmask.view(1, -1) 167 | mask = rearrange(mask, "(bq sq) (bc sc) -> bq bc sq sc", sq=s1, sc=s2) 168 | else: 169 | mask = None 170 | dist = tops.distance_tensor_redux(dist, redux_strategy, mask=mask) 171 | return dist 172 | -------------------------------------------------------------------------------- /utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from lightning.fabric.utilities import AttributeDict 4 | from lightning.fabric.loggers import TensorBoardLogger 5 | 6 | ################################################################################################### 7 | 8 | 9 | def get_optimizer(conf, model): 10 | if conf.name.lower() == "adam": 11 | optim = torch.optim.Adam(model.parameters(), lr=conf.lr) 12 | elif conf.name.lower() == "adamw": 13 | optim = torch.optim.AdamW(model.parameters(), lr=conf.lr, weight_decay=conf.wd) 14 | elif conf.name.lower() == "sgd": 15 | optim = torch.optim.SGD(model.parameters(), lr=conf.lr) 16 | else: 17 | raise NotImplementedError 18 | return optim 19 | 20 | 21 | def get_scheduler( 22 | conf, 23 | optim, 24 | epochs=None, 25 | mode="min", 26 | warm_factor=0.005, 27 | plateau_factor=0.2, 28 | ): 29 | name = conf.sched.lower() if conf.sched is not None else "flat" 30 | sched_on_epoch = True 31 | if name == "flat": 32 | sched = torch.optim.lr_scheduler.LambdaLR( 33 | optim, 34 | lr_lambda=lambda epoch: 1.0, 35 | ) 36 | elif name.startswith("plateau"): 37 | _, patience = name.split("_") 38 | patience = max(0, int(patience) - 1) 39 | sched = torch.optim.lr_scheduler.ReduceLROnPlateau( 40 | optim, 41 | mode=mode, 42 | factor=plateau_factor, 43 | patience=patience, 44 | ) 45 | elif name.startswith("poly"): 46 | _, power = name.split("_") 47 | power = float(power) 48 | sched = torch.optim.lr_scheduler.PolynomialLR( 49 | optim, total_iters=epochs, power=power 50 | ) 51 | elif name.startswith("warmpoly"): 52 | _, nwarm, power = name.split("_") 53 | nwarm = max(1, int(nwarm)) 54 | power = float(power) 55 | assert epochs > nwarm 56 | s1 = torch.optim.lr_scheduler.LinearLR( 57 | optim, start_factor=warm_factor, end_factor=1.0, total_iters=nwarm 58 | ) 59 | s2 = torch.optim.lr_scheduler.PolynomialLR( 60 | optim, total_iters=epochs - nwarm, power=power 61 | ) 62 | sched = torch.optim.lr_scheduler.SequentialLR(optim, [s1, s2], [nwarm]) 63 | elif name.startswith("sd"): 64 | _, ndec = name.split("_") 65 | ndec = max(1, int(ndec)) + 1 66 | assert epochs > ndec 67 | s1 = torch.optim.lr_scheduler.ConstantLR( 68 | optim, factor=1.0, total_iters=epochs - ndec 69 | ) 70 | s2 = torch.optim.lr_scheduler.PolynomialLR(optim, power=2, total_iters=ndec) 71 | sched = torch.optim.lr_scheduler.SequentialLR(optim, [s1, s2], [epochs - ndec]) 72 | elif name.startswith("wsd"): 73 | _, nwarm, ndec = name.split("_") 74 | nwarm = max(1, int(nwarm)) 75 | ndec = max(1, int(ndec)) + 1 76 | assert epochs > nwarm + ndec 77 | s1 = torch.optim.lr_scheduler.LinearLR( 78 | optim, start_factor=warm_factor, end_factor=1.0, total_iters=nwarm 79 | ) 80 | s2 = torch.optim.lr_scheduler.ConstantLR( 81 | optim, factor=1.0, total_iters=epochs - nwarm - ndec 82 | ) 83 | s3 = torch.optim.lr_scheduler.PolynomialLR(optim, power=2, total_iters=ndec) 84 | sched = torch.optim.lr_scheduler.SequentialLR( 85 | optim, [s1, s2, s3], [nwarm, epochs - ndec] 86 | ) 87 | else: 88 | raise NotImplementedError 89 | return sched, sched_on_epoch 90 | 91 | 92 | ################################################################################################### 93 | 94 | 95 | def weight_decay( 96 | model, 97 | lamb, 98 | optim_name, 99 | form="l1", 100 | excluded_optimizers=("adamw", "soap"), 101 | considered_layers=( 102 | torch.nn.Linear, 103 | torch.nn.Conv1d, 104 | torch.nn.Conv2d, 105 | torch.nn.ConvTranspose1d, 106 | torch.nn.ConvTranspose2d, 107 | ), 108 | ): 109 | assert form in ("l1", "l2") 110 | if optim_name in excluded_optimizers: 111 | lamb = 0 112 | num = torch.zeros(1, device=model.device) 113 | den = 0 114 | for m in model.modules(): 115 | if isinstance(m, considered_layers): 116 | w = m.weight 117 | n = m.weight.numel() 118 | if form == "l1": 119 | w = w.abs() 120 | elif form == "l2": 121 | w = w.pow(2) 122 | num += w.sum() 123 | den += n 124 | wd = num / den 125 | return lamb * wd, wd 126 | 127 | 128 | ################################################################################################### 129 | 130 | 131 | def get_logger(path): 132 | return TensorBoardLogger( 133 | root_dir=path, 134 | name="", 135 | version="", 136 | default_hp_metric=False, 137 | ) 138 | 139 | 140 | ################################################################################################### 141 | 142 | 143 | def set_state(state): 144 | return ( 145 | state.model, 146 | state.optim, 147 | state.sched, 148 | state.conf, 149 | state.epoch, 150 | state.lr, 151 | state.cost_best, 152 | ) 153 | 154 | 155 | def get_state(model, optim, sched, conf, epoch, lr, cost_best): 156 | return AttributeDict( 157 | model=model, 158 | optim=optim, 159 | sched=sched, 160 | conf=conf, 161 | epoch=epoch, 162 | lr=lr, 163 | cost_best=cost_best, 164 | ) 165 | 166 | 167 | ################################################################################################### 168 | 169 | 170 | class LogDict: 171 | 172 | def __init__(self, d=None): 173 | self.reset() 174 | if d is not None: 175 | self.append(d) 176 | 177 | def reset(self): 178 | self.d = {} 179 | 180 | def get(self, keys=None, prefix="", suffix=""): 181 | if keys is None: 182 | keys = list(self.d.keys()) 183 | elif type(keys) != list: 184 | return self.d[keys] 185 | d = {} 186 | for key in keys: 187 | new_key = prefix + key + suffix 188 | d[new_key] = self.d[key] 189 | return d 190 | 191 | def append(self, newd): 192 | assert type(newd) == dict 193 | for key, value in newd.items(): 194 | value = value.cpu() 195 | if value.ndim == 0: 196 | value = torch.FloatTensor([value]) 197 | if key not in self.d: 198 | self.d[key] = value 199 | else: 200 | self.d[key] = torch.cat([self.d[key], value], dim=0) 201 | 202 | def sync_and_mean(self, fabric): 203 | fabric.barrier() 204 | for key in self.d.keys(): 205 | self.d[key] = fabric.all_gather(self.d[key]).mean().item() 206 | 207 | 208 | ################################################################################################### 209 | -------------------------------------------------------------------------------- /models/cqtnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from nnAudio import features # type: ignore 4 | from einops import rearrange, repeat 5 | 6 | from lib import tensor_ops as tops 7 | 8 | 9 | class Model(torch.nn.Module): 10 | 11 | def __init__(self, conf, sr=16000, eps=1e-6): 12 | super().__init__() 13 | self.conf = conf 14 | self.sr = sr 15 | self.eps = eps 16 | self.minlen = conf.shingling.len 17 | # CQT 18 | self.cqtbins = self.conf.cqt.noctaves * self.conf.cqt.nbinsoct 19 | self.cqt = features.CQT1992v2( 20 | sr=self.sr, 21 | hop_length=int(self.conf.cqt.hoplen * sr), 22 | n_bins=self.cqtbins, 23 | bins_per_octave=self.conf.cqt.nbinsoct, 24 | trainable=False, 25 | verbose=False, 26 | ) 27 | self.cqtpool = torch.nn.AvgPool1d( 28 | self.conf.cqt.pool.len, stride=self.conf.cqt.pool.hop 29 | ) 30 | # Model 31 | self.block1 = torch.nn.Sequential( 32 | torch.nn.Conv2d(1, 32, (12, 3), padding=(6, 0), bias=False), 33 | torch.nn.BatchNorm2d(32), 34 | torch.nn.ReLU(), 35 | torch.nn.Conv2d(32, 64, (13, 3), dilation=(1, 2), bias=False), 36 | torch.nn.BatchNorm2d(64), 37 | torch.nn.ReLU(), 38 | torch.nn.MaxPool2d((1, 2), (1, 2)), 39 | ) 40 | self.block2 = torch.nn.Sequential( 41 | torch.nn.Conv2d(64, 64, (13, 3), bias=False), 42 | torch.nn.BatchNorm2d(64), 43 | torch.nn.ReLU(), 44 | torch.nn.Conv2d(64, 64, (3, 3), dilation=(1, 2), bias=False), 45 | torch.nn.BatchNorm2d(64), 46 | torch.nn.ReLU(), 47 | torch.nn.MaxPool2d((1, 2), (1, 2)), 48 | ) 49 | self.block3 = torch.nn.Sequential( 50 | torch.nn.Conv2d(64, 128, (3, 3), bias=False), 51 | torch.nn.BatchNorm2d(128), 52 | torch.nn.ReLU(), 53 | torch.nn.Conv2d(128, 128, (3, 3), dilation=(1, 2), bias=False), 54 | torch.nn.BatchNorm2d(128), 55 | torch.nn.ReLU(), 56 | torch.nn.MaxPool2d((1, 2), (1, 2)), 57 | ) 58 | self.block4 = torch.nn.Sequential( 59 | torch.nn.Conv2d(128, 256, (3, 3), bias=False), 60 | torch.nn.BatchNorm2d(256), 61 | torch.nn.ReLU(), 62 | torch.nn.Conv2d(256, 256, (3, 3), dilation=(1, 2), bias=False), 63 | torch.nn.BatchNorm2d(256), 64 | torch.nn.ReLU(), 65 | torch.nn.MaxPool2d((1, 2), (1, 2)), 66 | ) 67 | self.block5 = torch.nn.Sequential( 68 | torch.nn.Conv2d(256, 512, (3, 3), bias=False), 69 | torch.nn.BatchNorm2d(512), 70 | torch.nn.ReLU(), 71 | torch.nn.Conv2d(512, 512, (3, 3), dilation=(1, 2), bias=False), 72 | torch.nn.BatchNorm2d(512), 73 | torch.nn.ReLU(), 74 | torch.nn.AdaptiveMaxPool2d((1, 1)), 75 | ) 76 | self.fc0 = torch.nn.Linear(512, self.conf.zdim) 77 | # Loss 78 | self.fc1 = torch.nn.Linear(self.conf.zdim, self.conf.maxcliques) 79 | 80 | def get_shingle_params(self): 81 | return self.conf.shingling.len, self.conf.shingling.hop 82 | 83 | ########################################################################### 84 | 85 | def forward( 86 | self, 87 | h, # (B,T) 88 | shingle_len=None, 89 | shingle_hop=None, 90 | ): 91 | with torch.inference_mode(): 92 | h = self.prepare(h, shingle_len=shingle_len, shingle_hop=shingle_hop) 93 | h = h.clone() 94 | h, _ = self.embed(h) 95 | return h # (B,C) 96 | 97 | def prepare( 98 | self, 99 | h, # (B,T) 100 | shingle_len=None, 101 | shingle_hop=None, 102 | ): 103 | assert h.ndim == 2 104 | assert shingle_len is None or shingle_len > 0 105 | assert shingle_hop is None or shingle_hop > 0 106 | slen = self.conf.shingling.len if shingle_len is None else shingle_len 107 | shop = self.conf.shingling.hop if shingle_hop is None else shingle_hop 108 | # Shingle 109 | h = tops.get_frames( 110 | h, int(self.sr * slen), int(self.sr * shop), pad_mode="zeros" 111 | ) 112 | # Check audio length 113 | h = tops.force_length( 114 | h, int(self.sr * self.minlen), dim=-1, pad_mode="zeros", allow_longer=True 115 | ) 116 | # CQT 117 | s = h.size(1) 118 | h = rearrange(h, "b s t -> (b s) t") 119 | h = self.cqt(h) 120 | h = self.cqtpool(h) 121 | h = rearrange(h, "(b s) c t -> b s c t", s=s) 122 | return h 123 | 124 | def embed( 125 | self, 126 | h, 127 | ): 128 | assert h.ndim == 4 129 | s = h.size(1) 130 | h = rearrange(h, "b s c t -> (b s) c t") 131 | h = h / (h.abs().max(1, keepdim=True)[0].max(2, keepdim=True)[0] + self.eps) 132 | h = h.unsqueeze(1) 133 | h = self.block1(h) 134 | h = self.block2(h) 135 | h = self.block3(h) 136 | h = self.block4(h) 137 | h = self.block5(h) 138 | h = h.squeeze(-1).squeeze(-1) 139 | h = self.fc0(h) 140 | h = rearrange(h, "(b s) c -> b s c", s=s) 141 | return h, None # (B,C) 142 | 143 | ########################################################################### 144 | 145 | def loss( 146 | self, 147 | label, # (B) 148 | idx, # (B) 149 | z, # (B,S,C) 150 | extra=None, 151 | ): 152 | assert len(label) == len(idx) and len(label) == len(z) 153 | z = rearrange(z, "b s t -> (b s) t") 154 | logits = self.fc1(z) 155 | loss = torch.nn.functional.cross_entropy(logits, label) 156 | logd = { 157 | "l_main": loss, 158 | "l_cent": loss, 159 | } 160 | return loss, logd 161 | 162 | ########################################################################### 163 | 164 | def distances( 165 | self, 166 | q, # (B,C) 167 | c, # (B',C) 168 | qmask=None, 169 | cmask=None, 170 | redux_strategy=None, 171 | ): 172 | assert q.ndim == 3 and c.ndim == 3 and q.size(-1) == c.size(-1) 173 | if redux_strategy is None: 174 | redux_strategy = "min" 175 | s1, s2 = q.size(1), c.size(1) 176 | q = rearrange(q, "b s c -> (b s) c") 177 | c = rearrange(c, "b s c -> (b s) c") 178 | dist = tops.pairwise_distance_matrix(q, c, mode="cos") 179 | dist = rearrange(dist, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=s1, s2=s2) 180 | if qmask is not None and cmask is not None: 181 | qmask = rearrange(qmask, "b s -> (b s)") 182 | cmask = rearrange(cmask, "b s -> (b s)") 183 | mask = qmask.view(-1, 1) | cmask.view(1, -1) 184 | mask = rearrange(mask, "(bq sq) (bc sc) -> bq bc sq sc", sq=s1, sc=s2) 185 | else: 186 | mask = None 187 | dist = tops.distance_tensor_redux(dist, redux_strategy, mask=mask) 188 | return dist 189 | -------------------------------------------------------------------------------- /lib/coverhunter/ch_layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author:liufeng 4 | # datetime:2022/7/18 8:00 PM 5 | # software: PyCharm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from typing import Tuple, Dict, Optional 10 | 11 | 12 | class Linear(torch.nn.Module): 13 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): 14 | super(Linear, self).__init__() 15 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 16 | 17 | torch.nn.init.xavier_uniform_( 18 | self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain) 19 | ) 20 | return 21 | 22 | def forward(self, x): 23 | return self.linear_layer(x) 24 | 25 | 26 | class Conv1d(torch.nn.Module): 27 | def __init__( 28 | self, 29 | in_channels, 30 | out_channels, 31 | kernel_size=1, 32 | stride=1, 33 | padding=None, 34 | dilation=1, 35 | bias=True, 36 | w_init_gain="linear", 37 | ): 38 | super(Conv1d, self).__init__() 39 | if padding is None: 40 | assert kernel_size % 2 == 1 41 | padding = int(dilation * (kernel_size - 1) / 2) 42 | 43 | self.conv = torch.nn.Conv1d( 44 | in_channels, 45 | out_channels, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | dilation=dilation, 50 | bias=bias, 51 | ) 52 | 53 | torch.nn.init.xavier_uniform_( 54 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain) 55 | ) 56 | return 57 | 58 | def forward(self, signal): 59 | conv_signal = self.conv(signal) 60 | return conv_signal 61 | 62 | 63 | class AttentiveStatisticsPooling(torch.nn.Module): 64 | """This class implements an attentive statistic pooling layer for each channel. 65 | It returns the concatenated mean and std of the input tensor. 66 | 67 | Arguments 68 | --------- 69 | channels: int 70 | The number of input channels. 71 | output_channels: int 72 | The number of output channels. 73 | """ 74 | 75 | def __init__(self, channels, output_channels): 76 | super().__init__() 77 | 78 | self._eps = 1e-12 79 | self._linear = Linear(channels * 3, channels) 80 | self._tanh = torch.nn.Tanh() 81 | self._conv = Conv1d(in_channels=channels, out_channels=channels, kernel_size=1) 82 | self._final_layer = torch.nn.Linear(channels * 2, output_channels, bias=False) 83 | return 84 | 85 | @staticmethod 86 | def _compute_statistics(x: torch.Tensor, m: torch.Tensor, eps: float, dim: int = 2): 87 | mean = (m * x).sum(dim) 88 | std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) 89 | return mean, std 90 | 91 | def forward(self, x: torch.Tensor): 92 | """Calculates mean and std for a batch (input tensor). 93 | 94 | Args: 95 | x : torch.Tensor 96 | Tensor of shape [N, L, C]. 97 | """ 98 | 99 | x = x.transpose(1, 2) 100 | L = x.shape[-1] 101 | lengths = torch.ones(x.shape[0], device=x.device) 102 | mask = self.length_to_mask(lengths * L, max_len=L, device=x.device) 103 | mask = mask.unsqueeze(1) 104 | total = mask.sum(dim=2, keepdim=True).float() 105 | 106 | mean, std = self._compute_statistics(x, mask / total, self._eps) 107 | mean = mean.unsqueeze(2).repeat(1, 1, L) 108 | std = std.unsqueeze(2).repeat(1, 1, L) 109 | attn = torch.cat([x, mean, std], dim=1) 110 | attn = self._conv( 111 | self._tanh(self._linear(attn.transpose(1, 2)).transpose(1, 2)) 112 | ) 113 | 114 | attn = attn.masked_fill(mask == 0, float("-inf")) # Filter out zero-padding 115 | attn = F.softmax(attn, dim=2) 116 | mean, std = self._compute_statistics(x, attn, self._eps) 117 | pooled_stats = self._final_layer(torch.cat((mean, std), dim=1)) 118 | return pooled_stats 119 | 120 | def forward_with_mask( 121 | self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None 122 | ): 123 | """Calculates mean and std for a batch (input tensor). 124 | 125 | Args: 126 | x : torch.Tensor 127 | Tensor of shape [N, C, L]. 128 | lengths: 129 | """ 130 | L = x.shape[-1] 131 | 132 | if lengths is None: 133 | lengths = torch.ones(x.shape[0], device=x.device) 134 | 135 | # Make binary mask of shape [N, 1, L] 136 | mask = self.length_to_mask(lengths * L, max_len=L, device=x.device) 137 | mask = mask.unsqueeze(1) 138 | 139 | # Expand the temporal context of the pooling layer by allowing the 140 | # self-attention to look at global properties of the utterance. 141 | 142 | # torch.std is unstable for backward computation 143 | # https://github.com/pytorch/pytorch/issues/4320 144 | total = mask.sum(dim=2, keepdim=True).float() 145 | mean, std = self._compute_statistics(x, mask / total, self._eps) 146 | 147 | mean = mean.unsqueeze(2).repeat(1, 1, L) 148 | std = std.unsqueeze(2).repeat(1, 1, L) 149 | attn = torch.cat([x, mean, std], dim=1) 150 | 151 | # Apply layers 152 | attn = self.conv(self._tanh(self._linear(attn, lengths))) 153 | 154 | # Filter out zero-paddings 155 | attn = attn.masked_fill(mask == 0, float("-inf")) 156 | 157 | attn = F.softmax(attn, dim=2) 158 | mean, std = self._compute_statistics(x, attn, self._eps) 159 | # Append mean and std of the batch 160 | pooled_stats = torch.cat((mean, std), dim=1) 161 | pooled_stats = pooled_stats.unsqueeze(2) 162 | return pooled_stats 163 | 164 | @staticmethod 165 | def length_to_mask( 166 | length: torch.Tensor, 167 | max_len: Optional[int] = None, 168 | dtype: Optional[torch.dtype] = None, 169 | device: Optional[torch.device] = None, 170 | ): 171 | """Creates a binary mask for each sequence. 172 | 173 | Arguments 174 | --------- 175 | length : torch.LongTensor 176 | Containing the length of each sequence in the batch. Must be 1D. 177 | max_len : int 178 | Max length for the mask, also the size of the second dimension. 179 | dtype : torch.dtype, default: None 180 | The dtype of the generated mask. 181 | device: torch.device, default: None 182 | The device to put the mask variable. 183 | 184 | Returns 185 | ------- 186 | mask : tensor 187 | The binary mask. 188 | 189 | Example 190 | ------- 191 | """ 192 | assert len(length.shape) == 1 193 | 194 | if max_len is None: 195 | max_len = length.max().long().item() # using arange to generate mask 196 | mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( 197 | len(length), max_len 198 | ) < length.unsqueeze(1) 199 | 200 | if dtype is None: 201 | dtype = length.dtype 202 | 203 | if device is None: 204 | device = length.device 205 | 206 | mask = torch.as_tensor(mask, dtype=dtype, device=device) 207 | return mask 208 | -------------------------------------------------------------------------------- /models/bytecover2x.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch, math 3 | from nnAudio import features # type: ignore 4 | from einops import rearrange, repeat 5 | 6 | from lib import layers 7 | from lib import tensor_ops as tops 8 | 9 | 10 | class Model(torch.nn.Module): 11 | 12 | def __init__(self, conf, sr=16000, eps=1e-6): 13 | super().__init__() 14 | self.conf = conf 15 | self.sr = sr 16 | self.eps = eps 17 | self.minlen = conf.shingling.len 18 | # CQT 19 | self.cqtbins = self.conf.cqt.noctaves * self.conf.cqt.nbinsoct 20 | self.cqt = features.CQT1992v2( 21 | sr=self.sr, 22 | hop_length=int(self.conf.cqt.hoplen * sr), 23 | n_bins=self.cqtbins, 24 | bins_per_octave=self.conf.cqt.nbinsoct, 25 | trainable=False, 26 | verbose=False, 27 | ) 28 | self.cqtpool = torch.nn.AvgPool1d( 29 | self.conf.cqt.pool.len, stride=self.conf.cqt.pool.hop 30 | ) 31 | # Model 32 | nc1 = conf.ncha // 8 33 | nc2 = conf.ncha // 4 34 | nc3 = conf.ncha // 2 35 | nc4 = conf.ncha # 2048 36 | self.frontend = torch.nn.Sequential( 37 | layers.Unsqueeze(1), 38 | torch.nn.Conv2d(1, nc1, 7, stride=(1, 2), bias=False), 39 | torch.nn.BatchNorm2d(nc1), 40 | torch.nn.ReLU(inplace=True), 41 | torch.nn.MaxPool2d(3, 2), 42 | ) 43 | aux = [layers.ResNet50BottBlock(nc1, nc1, ibn=True)] 44 | for _ in range(2): 45 | aux += [layers.ResNet50BottBlock(nc1, nc1, ibn=True)] 46 | aux += [layers.ResNet50BottBlock(nc1, nc2, ibn=True, stride=2)] 47 | for _ in range(3): 48 | aux += [layers.ResNet50BottBlock(nc2, nc2, ibn=True)] 49 | aux += [layers.ResNet50BottBlock(nc2, nc3, ibn=True, stride=2)] 50 | for _ in range(5): 51 | aux += [layers.ResNet50BottBlock(nc3, nc3, ibn=True)] 52 | aux += [layers.ResNet50BottBlock(nc3, nc4)] 53 | for _ in range(2): 54 | aux += [layers.ResNet50BottBlock(nc4, nc4)] 55 | self.resblocks = torch.nn.Sequential(*aux) 56 | self.embpool = torch.nn.Sequential( 57 | layers.GeMPool(), 58 | torch.nn.Linear(conf.ncha, conf.zdim), 59 | ) 60 | self.bn = torch.nn.BatchNorm1d(conf.zdim) 61 | # Loss 62 | self.fc = torch.nn.Linear(conf.zdim, conf.maxcliques, bias=False) 63 | self.smooth = conf.smooth 64 | self.margin = conf.margin 65 | self.lamb = conf.lamb 66 | 67 | def get_shingle_params(self): 68 | return self.conf.shingling.len, self.conf.shingling.hop 69 | 70 | ########################################################################### 71 | 72 | def forward( 73 | self, 74 | h, # (B,T) 75 | shingle_len=None, 76 | shingle_hop=None, 77 | ): 78 | with torch.inference_mode(): 79 | h = self.prepare(h, shingle_len=shingle_len, shingle_hop=shingle_hop) 80 | h = h.clone() 81 | h, _ = self.embed(h) 82 | return h # (B,C) 83 | 84 | def prepare( 85 | self, 86 | h, # (B,T) 87 | shingle_len=None, 88 | shingle_hop=None, 89 | ): 90 | assert h.ndim == 2 91 | assert shingle_len is None or shingle_len > 0 92 | assert shingle_hop is None or shingle_hop > 0 93 | slen = self.conf.shingling.len if shingle_len is None else shingle_len 94 | shop = self.conf.shingling.hop if shingle_hop is None else shingle_hop 95 | # Shingle 96 | h = tops.get_frames( 97 | h, int(self.sr * slen), int(self.sr * shop), pad_mode="zeros" 98 | ) 99 | # Check audio length 100 | h = tops.force_length( 101 | h, int(self.sr * self.minlen), dim=-1, pad_mode="repeat", allow_longer=True 102 | ) 103 | # CQT 104 | s = h.size(1) 105 | h = rearrange(h, "b s t -> (b s) t") 106 | h = self.cqt(h) 107 | h = self.cqtpool(h) 108 | h = rearrange(h, "(b s) c t -> b s c t", s=s) 109 | return h # (B,S,C,T) 110 | 111 | def embed( 112 | self, 113 | h, # (B,S,C,T) 114 | ): 115 | assert h.ndim == 4 116 | s = h.size(1) 117 | h = rearrange(h, "b s c t -> (b s) c t") 118 | h = h / (h.abs().max(1, keepdim=True)[0].max(2, keepdim=True)[0] + self.eps) 119 | h = self.frontend(h) 120 | h = self.resblocks(h) 121 | ft = self.embpool(h) 122 | fc = self.bn(ft) 123 | ft = rearrange(ft, "(b s) c -> b s c", s=s) 124 | fc = rearrange(fc, "(b s) c -> b s c", s=s) 125 | return fc, ft # (B,C) 126 | 127 | ########################################################################### 128 | 129 | def loss( 130 | self, 131 | label, # (B) 132 | idx, # (B) 133 | fc, # (B,S,C) 134 | extra=None, 135 | ): 136 | assert len(label) == len(idx) and len(label) == len(fc) 137 | fc = fc[:, 0, :] 138 | ft = extra[:, 0, :] 139 | 140 | # Logits ByteCover 141 | logits = self.fc(fc) 142 | loss_cla = torch.nn.functional.cross_entropy( 143 | logits, label, label_smoothing=self.smooth 144 | ) 145 | 146 | # Triplet ByteCover 147 | dist = tops.pairwise_distance_matrix(ft, ft, mode="euc") 148 | samecla = label.view(-1, 1) == label.view(1, -1) 149 | diffid = idx.view(-1, 1) != idx.view(1, -1) 150 | pos = samecla & diffid 151 | neg = ~samecla 152 | posdist = torch.where(pos, dist, -torch.inf).max(1)[0] 153 | negdist = torch.where(neg, dist, torch.inf).min(1)[0] 154 | loss_dist = (posdist - negdist + self.conf.margin).clamp(min=0).mean() 155 | 156 | # Reg 157 | loss_reg = (pos.type_as(fc) * dist).sum() / (pos.type_as(fc).sum() + self.eps) 158 | 159 | loss = loss_cla + loss_dist + self.lamb * loss_reg 160 | logdict = { 161 | "l_main": loss, 162 | "l_cent": loss_cla, 163 | "l_cont": loss_dist, 164 | "v_dpos": (pos.type_as(fc) * dist).sum() 165 | / (pos.type_as(fc).sum() + self.eps), 166 | "v_dneg": (neg.type_as(fc) * dist).sum() 167 | / (neg.type_as(fc).sum() + self.eps), 168 | } 169 | return loss, logdict 170 | 171 | ########################################################################### 172 | 173 | def distances( 174 | self, 175 | q, # (B,S,C) 176 | c, # (B',S',C) 177 | qmask=None, 178 | cmask=None, 179 | redux_strategy=None, 180 | ): 181 | assert q.ndim == 3 and c.ndim == 3 and q.size(-1) == c.size(-1) 182 | if redux_strategy is None: 183 | redux_strategy = "min" 184 | s1, s2 = q.size(1), c.size(1) 185 | q = rearrange(q, "b s c -> (b s) c") 186 | c = rearrange(c, "b s c -> (b s) c") 187 | dist = tops.pairwise_distance_matrix(q, c, mode="cos") 188 | dist = rearrange(dist, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=s1, s2=s2) 189 | if qmask is not None and cmask is not None: 190 | qmask = rearrange(qmask, "b s -> (b s)") 191 | cmask = rearrange(cmask, "b s -> (b s)") 192 | mask = qmask.view(-1, 1) | cmask.view(1, -1) 193 | mask = rearrange(mask, "(bq sq) (bc sc) -> bq bc sq sc", sq=s1, sc=s2) 194 | else: 195 | mask = None 196 | dist = tops.distance_tensor_redux(dist, redux_strategy, mask=mask) 197 | return dist 198 | -------------------------------------------------------------------------------- /models/bytecover3x.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch, math 3 | from nnAudio import features # type: ignore 4 | from einops import rearrange, repeat 5 | 6 | from lib import layers 7 | from lib import tensor_ops as tops 8 | 9 | 10 | class Model(torch.nn.Module): 11 | 12 | def __init__(self, conf, sr=16000, eps=1e-6): 13 | super().__init__() 14 | self.conf = conf 15 | self.sr = sr 16 | self.eps = eps 17 | self.minlen = conf.shingling.len 18 | # CQT 19 | self.cqtbins = self.conf.cqt.noctaves * self.conf.cqt.nbinsoct 20 | self.cqt = features.CQT1992v2( 21 | sr=self.sr, 22 | hop_length=int(self.conf.cqt.hoplen * sr), 23 | n_bins=self.cqtbins, 24 | bins_per_octave=self.conf.cqt.nbinsoct, 25 | trainable=False, 26 | verbose=False, 27 | ) 28 | self.cqtpool = torch.nn.AvgPool1d( 29 | self.conf.cqt.pool.len, stride=self.conf.cqt.pool.hop 30 | ) 31 | # Model 32 | nc1 = conf.ncha // 8 33 | nc2 = conf.ncha // 4 34 | nc3 = conf.ncha // 2 35 | nc4 = conf.ncha # 2048 36 | self.frontend = torch.nn.Sequential( 37 | layers.Unsqueeze(1), 38 | torch.nn.Conv2d(1, nc1, 7, stride=(1, 2), bias=False), 39 | torch.nn.BatchNorm2d(nc1), 40 | torch.nn.ReLU(inplace=True), 41 | torch.nn.MaxPool2d(3, 2), 42 | ) 43 | aux = [layers.ResNet50BottBlock(nc1, nc1, ibn=True)] 44 | for _ in range(2): 45 | aux += [layers.ResNet50BottBlock(nc1, nc1, ibn=True)] 46 | aux += [layers.ResNet50BottBlock(nc1, nc2, ibn=True, stride=2)] 47 | for _ in range(3): 48 | aux += [layers.ResNet50BottBlock(nc2, nc2, ibn=True)] 49 | aux += [layers.ResNet50BottBlock(nc2, nc3, ibn=True, stride=2)] 50 | for _ in range(5): 51 | aux += [layers.ResNet50BottBlock(nc3, nc3, ibn=True)] 52 | aux += [layers.ResNet50BottBlock(nc3, nc4)] 53 | for _ in range(2): 54 | aux += [layers.ResNet50BottBlock(nc4, nc4)] 55 | self.resblocks = torch.nn.Sequential(*aux) 56 | self.embpool = layers.GeMPool() 57 | self.proj = torch.nn.Sequential( 58 | torch.nn.BatchNorm1d(conf.ncha), 59 | torch.nn.Linear(conf.ncha, conf.zdim), 60 | ) 61 | # Loss 62 | self.w = torch.nn.Parameter( 63 | 0.076 * (2 * torch.rand(conf.maxcliques, conf.nsub, conf.zdim) - 1) 64 | ) 65 | self.tau = torch.nn.Parameter(torch.ones(1)) 66 | self.smooth = conf.smooth 67 | self.relu = torch.nn.ReLU() 68 | self.margin = conf.margin 69 | self.lamb = conf.lamb 70 | 71 | def get_shingle_params(self): 72 | return self.conf.shingling.len, self.conf.shingling.hop 73 | 74 | ########################################################################### 75 | 76 | def forward( 77 | self, 78 | h, # (B,T) 79 | shingle_len=None, 80 | shingle_hop=None, 81 | ): 82 | with torch.inference_mode(): 83 | h = self.prepare(h, shingle_len=shingle_len, shingle_hop=shingle_hop) 84 | h = h.clone() 85 | h, _ = self.embed(h) 86 | return h # (B,C) 87 | 88 | def prepare( 89 | self, 90 | h, # (B,T) 91 | shingle_len=None, 92 | shingle_hop=None, 93 | ): 94 | assert h.ndim == 2 95 | assert shingle_len is None or shingle_len > 0 96 | assert shingle_hop is None or shingle_hop > 0 97 | slen = self.conf.shingling.len if shingle_len is None else shingle_len 98 | shop = self.conf.shingling.hop if shingle_hop is None else shingle_hop 99 | # Shingle 100 | h = tops.get_frames( 101 | h, int(self.sr * slen), int(self.sr * shop), pad_mode="zeros" 102 | ) 103 | # Check audio length 104 | h = tops.force_length( 105 | h, int(self.sr * self.minlen), dim=-1, pad_mode="repeat", allow_longer=True 106 | ) 107 | # CQT 108 | s = h.size(1) 109 | h = rearrange(h, "b s t -> (b s) t") 110 | h = self.cqt(h) 111 | h = self.cqtpool(h) 112 | h = rearrange(h, "(b s) c t -> b s c t", s=s) 113 | return h # (B,S,C,T) 114 | 115 | def embed( 116 | self, 117 | h, # (B,S,C,T) 118 | ): 119 | assert h.ndim == 4 120 | s = h.size(1) 121 | h = rearrange(h, "b s c t -> (b s) c t") 122 | h = h / (h.abs().max(1, keepdim=True)[0].max(2, keepdim=True)[0] + self.eps) 123 | h = self.frontend(h) 124 | h = self.resblocks(h) 125 | h = self.embpool(h) 126 | h = self.proj(h) 127 | h = rearrange(h, "(b s) c -> b s c", s=s) 128 | return h, None # (B,S,C) 129 | 130 | ########################################################################### 131 | 132 | def loss( 133 | self, 134 | label, # (B) 135 | idx, # (B) 136 | z, # (B,S,C) 137 | extra=None, 138 | ): 139 | assert len(label) == len(idx) and len(label) == len(z) 140 | 141 | # Logits ByteCover 142 | logits = self.tau.exp() * self.maxmean(z, self.w) 143 | loss_cla = torch.nn.functional.cross_entropy( 144 | logits, label, label_smoothing=self.smooth 145 | ) 146 | 147 | # Triplet ByteCover 148 | sim = self.maxmean(z, z) 149 | samecla = label.view(-1, 1) == label.view(1, -1) 150 | diffid = idx.view(-1, 1) != idx.view(1, -1) 151 | pos = samecla & diffid 152 | neg = ~samecla 153 | possim = torch.where(pos, sim, 2).min(1)[0] 154 | negsim = torch.where(neg, sim, -2).max(1)[0] 155 | loss_dist = self.relu((1 + self.margin) * negsim - possim).mean() 156 | 157 | # Reg 158 | pos, neg = pos.type_as(z), neg.type_as(z) 159 | loss_reg = (pos * (1 - sim)).sum() / (pos.sum() + self.eps) 160 | 161 | loss = loss_cla + loss_dist + self.lamb * loss_reg 162 | logdict = { 163 | "l_main": loss, 164 | "l_cent": loss_cla, 165 | "l_cont": loss_dist, 166 | "v_dpos": (pos * sim).sum() / (pos.sum() + self.eps), 167 | "v_dneg": (neg * sim).sum() / (neg.sum() + self.eps), 168 | } 169 | return loss, logdict 170 | 171 | def maxmean(self, x, y): 172 | assert x.ndim == 3 and y.ndim == 3 and x.size(-1) == y.size(-1) 173 | s, l = x.size(1), y.size(1) 174 | x = rearrange(x, "b s c -> (b s) c") 175 | y = rearrange(y, "k l c -> (k l) c") 176 | sim = tops.pairwise_distance_matrix(x, y, mode="cossim") 177 | sim = rearrange(sim, "(b s) (k l) -> b k s l", s=s, l=l) 178 | res = sim.max(-1)[0].mean(-1) 179 | return res 180 | 181 | ########################################################################### 182 | 183 | def distances( 184 | self, 185 | q, # (B,S,C) 186 | c, # (B',S',C) 187 | qmask=None, 188 | cmask=None, 189 | redux_strategy=None, 190 | ): 191 | assert q.ndim == 3 and c.ndim == 3 and q.size(-1) == c.size(-1) 192 | if redux_strategy is None: 193 | redux_strategy = "smeanmin" 194 | s1, s2 = q.size(1), c.size(1) 195 | q = rearrange(q, "b s c -> (b s) c") 196 | c = rearrange(c, "b s c -> (b s) c") 197 | dist = tops.pairwise_distance_matrix(q, c, mode="cos") 198 | dist = rearrange(dist, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=s1, s2=s2) 199 | if qmask is not None and cmask is not None: 200 | qmask = rearrange(qmask, "b s -> (b s)") 201 | cmask = rearrange(cmask, "b s -> (b s)") 202 | mask = qmask.view(-1, 1) | cmask.view(1, -1) 203 | mask = rearrange(mask, "(bq sq) (bc sc) -> bq bc sq sc", sq=s1, sc=s2) 204 | else: 205 | mask = None 206 | dist = tops.distance_tensor_redux(dist, redux_strategy, mask=mask) 207 | return dist 208 | -------------------------------------------------------------------------------- /models/clews.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch, math 3 | from nnAudio import features # type: ignore 4 | from einops import rearrange 5 | 6 | from lib import layers 7 | from lib import tensor_ops as tops 8 | 9 | 10 | class Model(torch.nn.Module): 11 | 12 | def __init__(self, conf, sr=16000, eps=1e-6, max_exp=10): 13 | super().__init__() 14 | self.sr = sr 15 | self.eps = eps 16 | # Shingling 17 | self.shingling_len = conf.shingling.len 18 | self.shingling_hop = conf.shingling.hop 19 | self.minlen = self.shingling_len # set minlen to training shinglen 20 | # CQT 21 | self.cqt = torch.nn.Sequential( 22 | features.CQT1992v2( 23 | sr=self.sr, 24 | hop_length=int(conf.cqt.hoplen * sr), 25 | n_bins=conf.cqt.noctaves * conf.cqt.nbinsoct, 26 | bins_per_octave=conf.cqt.nbinsoct, 27 | filter_scale=conf.cqt.fscale, 28 | trainable=False, 29 | verbose=False, 30 | ), 31 | torch.nn.AvgPool1d(conf.cqt.pool, stride=conf.cqt.pool), 32 | ) 33 | # Model - Frontend 34 | ncha0, ncha = conf.frontend.channels 35 | self.frontend = torch.nn.Sequential( 36 | layers.CQTPrepare(pow=conf.frontend.cqtpow), 37 | torch.nn.Conv2d(1, ncha0, (12, 3), stride=(1, 2), bias=False), 38 | torch.nn.BatchNorm2d(ncha0), 39 | torch.nn.ReLU(inplace=True), 40 | torch.nn.Conv2d(ncha0, ncha, (12, 3), stride=2, bias=False), 41 | ) 42 | # Model - Backbone 43 | aux = [] 44 | for nb, nc, st in zip( 45 | conf.backbone.blocks, conf.backbone.channels, conf.backbone.down 46 | ): 47 | aux += [layers.MyIBNResBlock(ncha, nc, stride=st)] 48 | for _ in range(nb - 1): 49 | aux += [layers.MyIBNResBlock(nc, nc)] 50 | ncha = nc 51 | self.backbone = torch.nn.Sequential(*aux) 52 | # Pooling & projection 53 | self.pool = layers.GeMPool() 54 | self.proj = torch.nn.Sequential( 55 | torch.nn.BatchNorm1d(ncha), 56 | torch.nn.Linear(ncha, conf.zdim, bias=False), 57 | ) 58 | # Loss 59 | self.redux = conf.loss.redux 60 | self.gamma = conf.loss.gamma 61 | self.epsilon = conf.loss.epsilon 62 | self.b = max_exp 63 | self.beta = 1 / (self.epsilon * math.exp(self.b)) 64 | 65 | def get_shingle_params(self): 66 | return self.shingling_len, self.shingling_hop 67 | 68 | ########################################################################### 69 | 70 | def forward( 71 | self, 72 | h, # (B,T) 73 | shingle_len=None, 74 | shingle_hop=None, 75 | ): 76 | with torch.inference_mode(): 77 | h = self.prepare(h, shingle_len=shingle_len, shingle_hop=shingle_hop) 78 | h = h.clone() 79 | h, _ = self.embed(h) 80 | return h # (B,C) 81 | 82 | def prepare( 83 | self, 84 | h, # (B,T) 85 | shingle_len=None, 86 | shingle_hop=None, 87 | ): 88 | assert h.ndim == 2 89 | assert shingle_len is None or shingle_len > 0 90 | assert shingle_hop is None or shingle_hop > 0 91 | # Shingle 92 | slen = self.shingling_len if shingle_len is None else shingle_len 93 | shop = self.shingling_hop if shingle_hop is None else shingle_hop 94 | h = tops.get_frames( 95 | h, int(self.sr * slen), int(self.sr * shop), pad_mode="repeat" 96 | ) 97 | # Check min shingle length 98 | h = tops.force_length( 99 | h, int(self.sr * self.minlen), dim=-1, pad_mode="repeat", allow_longer=True 100 | ) 101 | # CQT 102 | s = h.size(1) 103 | h = rearrange(h, "b s t -> (b s) t") 104 | h = self.cqt(h) 105 | h = rearrange(h, "(b s) c t -> b s c t", s=s) 106 | return h # (B,S,C,T) 107 | 108 | def embed( 109 | self, 110 | h, # (B,S,C,T) 111 | ): 112 | assert h.ndim == 4 113 | # Prepare 114 | s = h.size(1) 115 | h = rearrange(h, "b s c t -> (b s) 1 c t") 116 | # Feedforward 117 | h = self.frontend(h) 118 | h = self.backbone(h) 119 | # Pool and project 120 | h = self.pool(h) 121 | z = self.proj(h) 122 | # Out 123 | z = rearrange(z, "(b s) c -> b s c", s=s) 124 | return z, None # (B,S,C) 125 | 126 | ########################################################################### 127 | 128 | def loss( 129 | self, 130 | z_label, # (B) 131 | z_idx, # (B) 132 | z, # (B,S,C) 133 | extra=None, 134 | numerically_friendly=True, 135 | ): 136 | assert len(z_label) == len(z_idx) and len(z_label) == len(z) 137 | assert len(z) >= 4 138 | # If no negatives, add label noise for loss stability 139 | # (we assume positives exist due to batch construction) 140 | if len(z_label.unique()) == 1: 141 | z_label[: max(2, int(len(z_label) * 0.01))] = -1 142 | 143 | # Prepare 144 | sz = z.size(1) 145 | z = rearrange(z, "b s c -> (b s) c") 146 | same_label = z_label.view(-1, 1) == z_label.view(1, -1) 147 | same_idx = z_idx.view(-1, 1) == z_idx.view(1, -1) 148 | mask_pos = (~same_label) | same_idx 149 | mask_neg = same_label 150 | 151 | # Distances 152 | dist = tops.pairwise_distance_matrix(z, z, mode="nsqeuc") 153 | dist = rearrange(dist, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=sz, s2=sz) 154 | dpos = tops.distance_tensor_redux(dist, self.redux.pos) 155 | dneg = tops.distance_tensor_redux(dist, self.redux.neg) 156 | 157 | # Losses 158 | loss_align = tops.mmean(dpos, mask=mask_pos, eps=self.eps) 159 | if numerically_friendly: 160 | loss_uniform = ( 161 | self.beta 162 | * tops.mmean( 163 | (self.b - self.gamma * dneg).exp(), mask=mask_neg, eps=self.eps 164 | ) 165 | ).log1p() 166 | else: 167 | loss_uniform = ( 168 | tops.mmean((-self.gamma * dneg).exp(), mask=mask_neg, eps=self.eps) 169 | + self.epsilon 170 | ).log() 171 | 172 | # Output 173 | loss = loss_align + loss_uniform 174 | logdict = { 175 | "l_main": loss, 176 | "l_cent": loss_align, 177 | "l_cont": loss_uniform, 178 | "v_dpos": tops.mmean(dpos, mask=mask_pos), 179 | "v_dneg": tops.mmean(dneg, mask=mask_neg), 180 | "v_zmax": z.abs().max(), 181 | "v_zmean": z.mean(), 182 | "v_zstd": z.std(), 183 | } 184 | return loss, logdict 185 | 186 | ########################################################################### 187 | 188 | def distances( 189 | self, 190 | q, # (B,S,C) 191 | c, # (B',S',C) 192 | qmask=None, 193 | cmask=None, 194 | redux_strategy=None, 195 | ): 196 | assert q.ndim == 3 and c.ndim == 3 and q.size(-1) == c.size(-1) 197 | if redux_strategy is None: 198 | redux_strategy = self.redux.pos 199 | # Reshape and compute 200 | sq = q.size(1) 201 | sc = c.size(1) 202 | q = rearrange(q, "b s c -> (b s) c") 203 | c = rearrange(c, "b s c -> (b s) c") 204 | dist = tops.pairwise_distance_matrix(q, c, mode="nsqeuc") 205 | dist = rearrange(dist, "(bq sq) (bc sc) -> bq bc sq sc", sq=sq, sc=sc) 206 | if qmask is not None and cmask is not None: 207 | qmask = rearrange(qmask, "b s -> (b s)") 208 | cmask = rearrange(cmask, "b s -> (b s)") 209 | mask = qmask.view(-1, 1) | cmask.view(1, -1) 210 | mask = rearrange(mask, "(bq sq) (bc sc) -> bq bc sq sc", sq=sq, sc=sc) 211 | else: 212 | mask = None 213 | # Redux 214 | dist = tops.distance_tensor_redux(dist, redux_strategy, mask=mask) 215 | return dist 216 | -------------------------------------------------------------------------------- /models/dvinetp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch, math 3 | from nnAudio import features # type: ignore 4 | from einops import rearrange, repeat 5 | 6 | from lib import tensor_ops as tops 7 | 8 | 9 | class Model(torch.nn.Module): 10 | 11 | def __init__(self, conf, sr=16000, eps=1e-6): 12 | super().__init__() 13 | self.conf = conf 14 | self.sr = sr 15 | self.eps = eps 16 | self.minlen = conf.shingling.len 17 | # CQT 18 | self.cqtbins = self.conf.cqt.noctaves * self.conf.cqt.nbinsoct 19 | self.cqt = features.CQT1992v2( 20 | sr=self.sr, 21 | hop_length=int(self.conf.cqt.hoplen * sr), 22 | n_bins=self.cqtbins, 23 | bins_per_octave=self.conf.cqt.nbinsoct, 24 | trainable=False, 25 | verbose=False, 26 | ) 27 | self.cqtpool = torch.nn.AvgPool1d( 28 | self.conf.cqt.pool.len, stride=self.conf.cqt.pool.hop 29 | ) 30 | # Model 31 | ncha = conf.ncha_in 32 | self.block1 = torch.nn.Sequential( 33 | torch.nn.Conv2d(1, ncha, (12, 3), padding=(6, 0), bias=False), 34 | torch.nn.BatchNorm2d(ncha), 35 | torch.nn.ReLU(), 36 | torch.nn.Conv2d(ncha, 2 * ncha, (13, 3), dilation=(1, 2), bias=False), 37 | torch.nn.BatchNorm2d(2 * ncha), 38 | torch.nn.ReLU(), 39 | torch.nn.MaxPool2d((1, 2), (1, 2)), 40 | ) 41 | self.block2 = torch.nn.Sequential( 42 | torch.nn.Conv2d(2 * ncha, 2 * ncha, (13, 3), bias=False), 43 | torch.nn.BatchNorm2d(2 * ncha), 44 | torch.nn.ReLU(), 45 | torch.nn.Conv2d(2 * ncha, 2 * ncha, (3, 3), dilation=(1, 2), bias=False), 46 | torch.nn.BatchNorm2d(2 * ncha), 47 | torch.nn.ReLU(), 48 | torch.nn.MaxPool2d((1, 2), (1, 2)), 49 | ) 50 | self.block3 = torch.nn.Sequential( 51 | torch.nn.Conv2d(2 * ncha, 4 * ncha, (3, 3), bias=False), 52 | torch.nn.BatchNorm2d(4 * ncha), 53 | torch.nn.ReLU(), 54 | torch.nn.Conv2d(4 * ncha, 4 * ncha, (3, 3), dilation=(1, 2), bias=False), 55 | torch.nn.BatchNorm2d(4 * ncha), 56 | torch.nn.ReLU(), 57 | torch.nn.MaxPool2d((1, 2), (1, 2)), 58 | ) 59 | self.block4 = torch.nn.Sequential( 60 | torch.nn.Conv2d(4 * ncha, 8 * ncha, (3, 3), bias=False), 61 | torch.nn.BatchNorm2d(8 * ncha), 62 | torch.nn.ReLU(), 63 | torch.nn.Conv2d(8 * ncha, 8 * ncha, (3, 3), dilation=(1, 2), bias=False), 64 | torch.nn.BatchNorm2d(8 * ncha), 65 | torch.nn.ReLU(), 66 | torch.nn.MaxPool2d((1, 2), (1, 2)), 67 | ) 68 | self.block5 = torch.nn.Sequential( 69 | torch.nn.Conv2d(8 * ncha, 16 * ncha, (3, 3), bias=False), 70 | torch.nn.BatchNorm2d(16 * ncha), 71 | torch.nn.ReLU(), 72 | torch.nn.Conv2d(16 * ncha, 16 * ncha, (3, 3), dilation=(1, 2), bias=False), 73 | torch.nn.BatchNorm2d(16 * ncha), 74 | torch.nn.ReLU(), 75 | torch.nn.AdaptiveMaxPool2d((1, 1)), 76 | ) 77 | self.fc = torch.nn.Linear(16 * ncha, self.conf.zdim) 78 | self.margin = conf.margin 79 | self.relu = torch.nn.ReLU() 80 | self.lamb = conf.lamb 81 | 82 | def get_shingle_params(self): 83 | return self.conf.shingling.len, self.conf.shingling.hop 84 | 85 | ########################################################################### 86 | 87 | def forward( 88 | self, 89 | h, # (B,T) 90 | shingle_len=None, 91 | shingle_hop=None, 92 | ): 93 | with torch.inference_mode(): 94 | h = self.prepare(h, shingle_len=shingle_len, shingle_hop=shingle_hop) 95 | h = h.clone() 96 | h, _ = self.embed(h) 97 | return h # (B,C) 98 | 99 | def prepare( 100 | self, 101 | h, # (B,T) 102 | shingle_len=None, 103 | shingle_hop=None, 104 | ): 105 | assert h.ndim == 2 106 | assert shingle_len is None or shingle_len > 0 107 | assert shingle_hop is None or shingle_hop > 0 108 | slen = self.conf.shingling.len if shingle_len is None else shingle_len 109 | shop = self.conf.shingling.hop if shingle_hop is None else shingle_hop 110 | # Shingle 111 | h = tops.get_frames( 112 | h, int(self.sr * slen), int(self.sr * shop), pad_mode="zeros" 113 | ) 114 | # Check audio length 115 | h = tops.force_length( 116 | h, int(self.sr * self.minlen), dim=-1, pad_mode="zeros", allow_longer=True 117 | ) 118 | # CQT 119 | s = h.size(1) 120 | h = rearrange(h, "b s t -> (b s) t") 121 | h = self.cqt(h) 122 | h = self.cqtpool(h) 123 | h = rearrange(h, "(b s) c t -> b s c t", s=s) 124 | return h # (B,C,T) 125 | 126 | def embed( 127 | self, 128 | h, # (B,C,T) 129 | ): 130 | assert h.ndim == 4 131 | s = h.size(1) 132 | h = rearrange(h, "b s c t -> (b s) c t") 133 | h = h / (h.abs().max(1, keepdim=True)[0].max(2, keepdim=True)[0] + self.eps) 134 | h = h.unsqueeze(1) 135 | h = self.block1(h) 136 | h = self.block2(h) 137 | h = self.block3(h) 138 | h = self.block4(h) 139 | h = self.block5(h) 140 | h = h.squeeze(-1).squeeze(-1) 141 | h = self.fc(h) 142 | h = torch.nn.functional.normalize(h, dim=-1) 143 | h = rearrange(h, "(b s) c -> b s c", s=s) 144 | return h, None # (B,C) 145 | 146 | ########################################################################### 147 | 148 | def loss( 149 | self, 150 | label, # (B) 151 | idx, # (B) 152 | z, # (B,S,C) 153 | extra=None, 154 | ): 155 | assert len(label) == len(idx) and len(label) == len(z) 156 | 157 | sz = z.size(1) 158 | z = rearrange(z, "b s c -> (b s) c") 159 | same_label = label.view(-1, 1) == label.view(1, -1) 160 | diagonal = torch.eye(len(label), device=z.device) > 0.5 161 | positive = same_label & (~diagonal) 162 | negative = (~same_label) & (~diagonal) 163 | dlimit = 1.1 * math.sqrt(2 * z.size(1)) # because euc dist on l2-normed z 164 | 165 | dist = tops.pairwise_distance_matrix(z, z, mode="euc") 166 | dist = rearrange(dist, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=sz, s2=sz) 167 | dist = tops.distance_tensor_redux(dist, "mean") 168 | dpos = torch.where(positive, dist, -dlimit).max(1)[0] 169 | dneg = torch.where(negative, dist, dlimit).min(1)[0] 170 | loss_trip = self.relu(dpos - dneg + self.margin).mean() 171 | 172 | # Inspired by https://arxiv.org/pdf/1708.06320 173 | sim = tops.pairwise_distance_matrix(z, z, mode="dotsim") 174 | sim = rearrange(sim, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=sz, s2=sz) 175 | sim = tops.distance_tensor_redux(sim, "mean") 176 | loss_reg = (negative * sim.pow(2)).sum() / (negative.sum() + self.eps) 177 | 178 | loss = loss_trip + self.lamb * loss_reg 179 | logd = { 180 | "l_main": loss, 181 | "l_cont": loss_trip, 182 | "l_cent": loss_reg, 183 | "v_dneg": dneg.mean(), 184 | "v_dpos": dpos.mean(), 185 | } 186 | return loss, logd 187 | 188 | ########################################################################### 189 | 190 | def distances( 191 | self, 192 | q, # (B,C) 193 | c, # (B',C) 194 | qmask=None, 195 | cmask=None, 196 | redux_strategy=None, 197 | ): 198 | assert q.ndim == 3 and c.ndim == 3 and q.size(-1) == c.size(-1) 199 | if redux_strategy is None: 200 | redux_strategy = "min" 201 | s1, s2 = q.size(1), c.size(1) 202 | q = rearrange(q, "b s c -> (b s) c") 203 | c = rearrange(c, "b s c -> (b s) c") 204 | dist = tops.pairwise_distance_matrix(q, c, mode="cos") 205 | dist = rearrange(dist, "(b1 s1) (b2 s2) -> b1 b2 s1 s2", s1=s1, s2=s2) 206 | if qmask is not None and cmask is not None: 207 | qmask = rearrange(qmask, "b s -> (b s)") 208 | cmask = rearrange(cmask, "b s -> (b s)") 209 | mask = qmask.view(-1, 1) | cmask.view(1, -1) 210 | mask = rearrange(mask, "(bq sq) (bc sc) -> bq bc sq sc", sq=s1, sc=s2) 211 | else: 212 | mask = None 213 | dist = tops.distance_tensor_redux(dist, redux_strategy, mask=mask) 214 | return dist 215 | -------------------------------------------------------------------------------- /lib/coverhunter/ch_losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author:liufeng 4 | # datetime:2022/7/15 12:36 PM 5 | # software: PyCharm 6 | 7 | from typing import List 8 | import os 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class CenterLoss(nn.Module): 16 | """Center loss. 17 | 18 | Reference: 19 | A Discriminative Feature Learning Approach for Deep Face Recognition. 20 | 21 | Args: 22 | num_classes (int): number of classes. 23 | feat_dim (int): feature dimension. 24 | """ 25 | 26 | def __init__(self, num_classes=10, feat_dim=2): 27 | super(CenterLoss, self).__init__() 28 | self.num_classes = num_classes 29 | self.feat_dim = feat_dim 30 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 31 | 32 | def forward(self, x, labels): 33 | """ 34 | Args: 35 | x: feature matrix with shape (batch_size, feat_dim). 36 | labels: ground truth labels with shape (batch_size). 37 | """ 38 | batch_size = x.size(0) 39 | distmat = ( 40 | torch.pow(x, 2) 41 | .sum(dim=1, keepdim=True) 42 | .expand(batch_size, self.num_classes) 43 | + torch.pow(self.centers, 2) 44 | .sum(dim=1, keepdim=True) 45 | .expand(self.num_classes, batch_size) 46 | .t() 47 | ) 48 | # distmat.addmm_(1, -2, x, self.centers.t()) 49 | distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2) 50 | 51 | classes = torch.arange(self.num_classes, device=x.device).long() 52 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 53 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 54 | 55 | dist = distmat * mask.float() 56 | loss = dist.clamp(min=1e-12, max=1e12).sum() / batch_size 57 | 58 | return loss 59 | 60 | 61 | class FocalLoss(nn.Module): 62 | """Focal Loss implement for https://arxiv.org/abs/1708.02002""" 63 | 64 | def __init__( 65 | self, 66 | gamma: float = 2.0, 67 | alpha: List = None, 68 | num_cls: int = -1, 69 | reduction: str = "mean", 70 | ): 71 | 72 | super(FocalLoss, self).__init__() 73 | if reduction not in ["mean", "sum"]: 74 | raise NotImplementedError("Reduction {} not implemented.".format(reduction)) 75 | self._reduction = reduction 76 | self._alpha = alpha 77 | self._gamma = gamma 78 | if alpha is not None: 79 | assert len(alpha) <= num_cls, "{} != {}".format(len(alpha), num_cls) 80 | self._alpha = torch.tensor(self._alpha) 81 | self._eps = np.finfo(float).eps 82 | return 83 | 84 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 85 | """compute focal loss for pred and label 86 | 87 | Args: 88 | y_pred: [batch_size, num_cls] 89 | y_true: [batch_size] 90 | 91 | Returns: 92 | loss 93 | """ 94 | b = y_pred.size(0) 95 | y_pred_softmax = torch.nn.Softmax(dim=1)(y_pred) + self._eps 96 | ce = -torch.log(y_pred_softmax) 97 | ce = ce.gather(1, y_true.view(-1, 1)) 98 | 99 | y_pred_softmax = y_pred_softmax.gather(1, y_true.view(-1, 1)) 100 | weight = torch.pow(torch.sub(1.0, y_pred_softmax), self._gamma) 101 | 102 | if self._alpha is not None: 103 | self._alpha = self._alpha.to(y_pred.device) 104 | alpha = self._alpha.gather(0, y_true.view(-1)) 105 | alpha = alpha.unsqueeze(1) 106 | alpha = alpha / torch.sum(alpha) * b 107 | weight = torch.mul(alpha, weight) 108 | fl_loss = torch.mul(weight, ce).squeeze(1) 109 | return self._reduce(fl_loss) 110 | 111 | def forward_with_onehot( 112 | self, y_pred: torch.Tensor, y_true: torch.Tensor 113 | ) -> torch.Tensor: 114 | """Another implement for "forward" with onehot label matrix. 115 | 116 | It is not good because when ce_embed size is large, more memory will used. 117 | 118 | Args: 119 | y_pred: [batch_size, num_cls] 120 | y_true: [batch_size] 121 | 122 | Returns: 123 | loss 124 | 125 | """ 126 | y_pred = torch.nn.Softmax(dim=1)(y_pred) 127 | y_true = torch.nn.functional.one_hot(y_true, self.num_cls) 128 | 129 | eps = np.finfo(float).eps 130 | y_pred = y_pred + eps 131 | ce = torch.mul(y_true, -torch.log(y_pred)) 132 | weight = torch.mul(y_true, torch.pow(torch.sub(1.0, y_pred), self._gamma)) 133 | fl_loss = torch.mul(self.alpha, torch.mul(weight, ce)) 134 | 135 | fl_loss, _ = torch.max(fl_loss, dim=1) 136 | return self._reduce(fl_loss) 137 | 138 | def _reduce(self, x): 139 | if self._reduction == "mean": 140 | return torch.mean(x) 141 | else: 142 | return torch.sum(x) 143 | 144 | 145 | class HardTripletLoss(nn.Module): 146 | """Hard/Hardest Triplet Loss 147 | (pytorch implementation of https://omoindrot.github.io/triplet-loss) 148 | For each anchor, we get the hardest positive and hardest negative to form a triplet. 149 | """ 150 | 151 | def __init__(self, margin=0.1): 152 | """Args: 153 | margin: margin for triplet loss 154 | """ 155 | super(HardTripletLoss, self).__init__() 156 | self._margin = margin 157 | return 158 | 159 | def forward(self, embeddings, labels, ids=None): 160 | """ 161 | Args: 162 | labels: labels of the batch, of size (batch_size,) 163 | embeddings: tensor of shape (batch_size, embed_dim) 164 | 165 | Returns: 166 | triplet_loss: scalar tensor containing the triplet loss 167 | """ 168 | pairwise_dist = self._pairwise_distance(embeddings, squared=False) 169 | 170 | mask_anchor_positive = self._get_anchor_positive_triplet_mask( 171 | labels, ids 172 | ).float() 173 | valid_positive_dist = pairwise_dist * mask_anchor_positive 174 | hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True) 175 | 176 | # Get the hardest negative pairs 177 | mask_negative = self._get_anchor_negative_triplet_mask(labels).float() 178 | max_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True) 179 | negative_dist = pairwise_dist + max_negative_dist * (1.0 - mask_negative) 180 | hardest_negative_dist, _ = torch.min(negative_dist, dim=1, keepdim=True) 181 | 182 | # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss 183 | triplet_loss = F.relu( 184 | hardest_positive_dist - hardest_negative_dist + self._margin 185 | ) 186 | triplet_loss = torch.mean(triplet_loss) 187 | return triplet_loss 188 | 189 | @staticmethod 190 | def _pairwise_distance(x, squared=False, eps=1e-16): 191 | # Compute the 2D matrix of distances between all the embeddings. 192 | 193 | cor_mat = torch.matmul(x, x.t()) 194 | norm_mat = cor_mat.diag() 195 | distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0) 196 | distances = F.relu(distances) 197 | 198 | if not squared: 199 | mask = torch.eq(distances, 0.0).float() 200 | distances = distances + mask * eps 201 | distances = torch.sqrt(distances) 202 | distances = distances * (1.0 - mask) 203 | return distances 204 | 205 | @staticmethod 206 | def _get_anchor_positive_triplet_mask(labels, ids): 207 | """Return a 2D mask where mask[a, p] is True, if a and p are distinct and 208 | have same label. 209 | 210 | """ 211 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 212 | device = labels.device 213 | if ids is None: 214 | indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1 215 | else: 216 | indices_not_equal = ( 217 | torch.unsqueeze(ids, 0) == torch.unsqueeze(ids, 1) 218 | ).byte() ^ 1 219 | labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1) 220 | mask = indices_not_equal * labels_equal 221 | return mask 222 | 223 | @staticmethod 224 | def _get_anchor_negative_triplet_mask(labels): 225 | """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.""" 226 | labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1) 227 | mask = labels_equal ^ 1 228 | return mask 229 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import importlib 4 | from omegaconf import OmegaConf 5 | import torch, math 6 | from lightning import Fabric 7 | from lightning.fabric.strategies import DDPStrategy 8 | 9 | from lib import eval, dataset 10 | from lib import tensor_ops as tops 11 | from utils import pytorch_utils, print_utils 12 | 13 | # --- Get arguments (and set defaults) --- Basic --- 14 | args = OmegaConf.from_cli() 15 | assert "checkpoint" in args 16 | log_path, _ = os.path.split(args.checkpoint) 17 | if "ngpus" not in args: 18 | args.ngpus = 1 19 | if "nnodes" not in args: 20 | args.nnodes = 1 21 | args.precision = "32" 22 | if "path_audio" not in args: 23 | args.path_audio = None 24 | if "path_meta" not in args: 25 | args.path_meta = None 26 | if "partition" not in args: 27 | args.partition = "test" 28 | if "limit_num" not in args: 29 | args.limit_num = None 30 | 31 | # --- Get arguments (and set defaults) --- Tunable --- 32 | if "maxlen" not in args: # maximum audio length 33 | args.maxlen = 10 * 60 # in seconds 34 | if "redux" not in args: # distance reduction 35 | args.redux = None 36 | if "qslen" not in args: # query shingle len 37 | args.qslen = None 38 | if "qshop" not in args: # query shingle hop (default = every 5 sec) 39 | args.qshop = 5 40 | if "cslen" not in args: # candidate shingle len 41 | args.cslen = None 42 | if "cshop" not in args: # candidate shingle hop (default = every 5 sec) 43 | args.cshop = 5 44 | 45 | ############################################################################### 46 | 47 | # Init pytorch/Fabric 48 | torch.backends.cudnn.benchmark = False 49 | torch.backends.cudnn.deterministic = False 50 | torch.set_float32_matmul_precision("medium") 51 | torch.autograd.set_detect_anomaly(False) 52 | fabric = Fabric( 53 | accelerator="cuda", 54 | devices=args.ngpus, 55 | num_nodes=args.nnodes, 56 | strategy=DDPStrategy(broadcast_buffers=False), 57 | precision=args.precision, 58 | ) 59 | fabric.launch() 60 | 61 | # Seed (random segment needs a seed) 62 | fabric.barrier() 63 | fabric.seed_everything(44 + fabric.global_rank, workers=True) 64 | 65 | # Init my utils 66 | myprint = lambda s, end="\n": print_utils.myprint( 67 | s, end=end, doit=fabric.is_global_zero 68 | ) 69 | myprogbar = lambda it, desc=None, leave=False: print_utils.myprogbar( 70 | it, desc=desc, leave=leave, doit=fabric.is_global_zero 71 | ) 72 | timer = print_utils.Timer() 73 | fabric.barrier() 74 | 75 | # Load conf 76 | myprint(OmegaConf.to_yaml(args)) 77 | myprint("Load model conf...") 78 | conf = OmegaConf.load(os.path.join(log_path, "configuration.yaml")) 79 | 80 | # Init model 81 | myprint("Init model...") 82 | module = importlib.import_module("models." + conf.model.name) 83 | with fabric.init_module(): 84 | model = module.Model(conf.model, sr=conf.data.samplerate) 85 | model = fabric.setup(model) 86 | 87 | # Load model 88 | myprint(" Load checkpoint") 89 | state = pytorch_utils.get_state(model, None, None, conf, None, None, None) 90 | fabric.load(args.checkpoint, state) 91 | model, _, _, conf, epoch, _, best = pytorch_utils.set_state(state) 92 | myprint(f" ({epoch} epochs; best was {best:.3f})") 93 | model.eval() 94 | if args.path_audio is not None: 95 | conf.path.audio = args.path_audio 96 | if args.path_meta is not None: 97 | conf.path.meta = args.path_meta 98 | conf.data.path = conf.path 99 | 100 | # Get dataset 101 | myprint("Dataset...") 102 | dset = dataset.Dataset( 103 | conf.data, 104 | args.partition, 105 | augment=False, 106 | fullsongs=True, 107 | verbose=fabric.is_global_zero, 108 | ) 109 | dloader = torch.utils.data.DataLoader( 110 | dset, 111 | batch_size=1, 112 | shuffle=False, 113 | num_workers=8, 114 | drop_last=False, 115 | pin_memory=False, 116 | ) 117 | dloader = fabric.setup_dataloaders(dloader) 118 | 119 | ############################################################################### 120 | 121 | 122 | @torch.inference_mode() 123 | def extract_embeddings(shingle_len, shingle_hop, desc="Embed", eps=1e-6): 124 | # Check shingle args 125 | shinglen, shinghop = model.get_shingle_params() 126 | if shingle_len is not None: 127 | shinglen = shingle_len 128 | if shingle_hop is not None: 129 | shinghop = shingle_hop 130 | mxlen = int(args.maxlen * model.sr) 131 | numshingles = int((mxlen - int(shinglen * model.sr)) / int(shinghop * model.sr)) 132 | # Extract embeddings 133 | all_c = [] 134 | all_i = [] 135 | all_z = [] 136 | all_m = [] 137 | for batch in myprogbar(dloader, desc=desc, leave=True): 138 | # Get info & audio 139 | c, i, x = batch[:3] 140 | if x.size(1) > mxlen: 141 | x = x[:, :mxlen] 142 | # Get embedding (B=1,S,C) 143 | z = model( 144 | x, 145 | shingle_len=int(x.size(1) / model.sr) if shinglen <= 0 else shinglen, 146 | shingle_hop=int(0.99 * x.size(1) / model.sr) if shinghop <= 0 else shinghop, 147 | ) 148 | # Make embedding shingles same size 149 | z = tops.force_length( 150 | z, 151 | 1 if shinglen <= 0 else numshingles, 152 | dim=1, 153 | pad_mode="zeros", 154 | cut_mode="start", 155 | ) 156 | m = z.abs().max(-1)[0] < eps 157 | # Append 158 | all_c.append(c) 159 | all_i.append(i) 160 | all_z.append(z) 161 | all_m.append(m) 162 | # Limit number of queries/candidates? 163 | if args.limit_num is not None and len(all_z) >= args.limit_num / args.ngpus: 164 | myprint("") 165 | myprint(" [Max num reached]") 166 | break 167 | # Concat single-song batches 168 | all_c = torch.cat(all_c, dim=0) 169 | all_i = torch.cat(all_i, dim=0) 170 | all_z = torch.cat(all_z, dim=0) 171 | all_m = torch.cat(all_m, dim=0) 172 | # Return 173 | return all_c, all_i, all_z, all_m 174 | 175 | 176 | ############################################################################### 177 | 178 | # Let's go 179 | with torch.inference_mode(): 180 | 181 | # Extract embeddings 182 | query_c, query_i, query_z, query_m = extract_embeddings( 183 | args.qslen, args.qshop, desc="Query emb" 184 | ) 185 | query_c = query_c.int() 186 | query_i = query_i.int() 187 | query_z = query_z.half() 188 | if args.cslen == args.qslen and args.cshop == args.qshop: 189 | myprint("Cand emb: (copy)") 190 | cand_c, cand_i, cand_z, cand_m = ( 191 | query_c.clone(), 192 | query_i.clone(), 193 | query_z.clone(), 194 | query_m.clone(), 195 | ) 196 | else: 197 | cand_c, cand_i, cand_z, cand_m = extract_embeddings( 198 | args.cslen, args.cshop, desc="Cand emb" 199 | ) 200 | cand_c = cand_c.int() 201 | cand_i = cand_i.int() 202 | cand_z = cand_z.half() 203 | 204 | # Collect candidates from all GPUs + collapse to batch dim 205 | fabric.barrier() 206 | cand_c = fabric.all_gather(cand_c) 207 | cand_i = fabric.all_gather(cand_i) 208 | cand_z = fabric.all_gather(cand_z) 209 | cand_m = fabric.all_gather(cand_m) 210 | cand_c = torch.cat(torch.unbind(cand_c, dim=0), dim=0) 211 | cand_i = torch.cat(torch.unbind(cand_i, dim=0), dim=0) 212 | cand_z = torch.cat(torch.unbind(cand_z, dim=0), dim=0) 213 | cand_m = torch.cat(torch.unbind(cand_m, dim=0), dim=0) 214 | 215 | # Evaluate 216 | aps = [] 217 | r1s = [] 218 | rpcs = [] 219 | for n in myprogbar(range(len(query_z)), desc="Retrieve", leave=True): 220 | ap, r1, rpc = eval.compute( 221 | model, 222 | query_c[n : n + 1], 223 | query_i[n : n + 1], 224 | query_z[n : n + 1], 225 | cand_c, 226 | cand_i, 227 | cand_z, 228 | queries_m=query_m[n : n + 1], 229 | candidates_m=cand_m, 230 | redux_strategy=args.redux, 231 | batch_size_candidates=2**15, 232 | ) 233 | aps.append(ap) 234 | r1s.append(r1) 235 | rpcs.append(rpc) 236 | aps = torch.stack(aps) 237 | r1s = torch.stack(r1s) 238 | rpcs = torch.stack(rpcs) 239 | 240 | # Collect measures from all GPUs + collapse to batch dim 241 | fabric.barrier() 242 | aps = fabric.all_gather(aps) 243 | r1s = fabric.all_gather(r1s) 244 | rpcs = fabric.all_gather(rpcs) 245 | aps = torch.cat(torch.unbind(aps, dim=0), dim=0) 246 | r1s = torch.cat(torch.unbind(r1s, dim=0), dim=0) 247 | rpcs = torch.cat(torch.unbind(rpcs, dim=0), dim=0) 248 | 249 | ############################################################################### 250 | 251 | # Print 252 | logdict_mean = { 253 | "MAP": aps.mean(), 254 | "MR1": r1s.mean(), 255 | "ARP": rpcs.mean(), 256 | } 257 | logdict_ci = { 258 | "MAP": 1.96 * aps.std() / math.sqrt(len(aps)), 259 | "MR1": 1.96 * r1s.std() / math.sqrt(len(r1s)), 260 | "ARP": 1.96 * rpcs.std() / math.sqrt(len(rpcs)), 261 | } 262 | myprint("=" * 100) 263 | myprint("Result:") 264 | myprint(" Avg --> " + print_utils.report(logdict_mean, clean_line=False)) 265 | myprint(" c.i. -> " + print_utils.report(logdict_ci, clean_line=False)) 266 | myprint("=" * 100) 267 | -------------------------------------------------------------------------------- /data_preproc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | import torch 6 | from joblib import Parallel, delayed 7 | 8 | from utils import file_utils, audio_utils, print_utils 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--dataset", type=str, choices=["SHS100K", "covers80", "DiscogsVI"], required=True 13 | ) 14 | parser.add_argument("--path_meta", type=str, default="data/xxx", required=True) 15 | parser.add_argument("--path_audio", type=str, default="data/yyy", required=True) 16 | parser.add_argument("--ext_in", type=str, default="mp3", required=True) 17 | parser.add_argument( 18 | "--fn_out", type=str, default="cache/metadata-dataset-specs.pt", required=True 19 | ) 20 | parser.add_argument("--njobs", type=int, default=-1) 21 | args = parser.parse_args() 22 | while args.ext_in[0] == ".": 23 | args.ext_in = args.ext_in[1:] 24 | print("=" * 100) 25 | print(args) 26 | print("=" * 100) 27 | 28 | ############################################################################### 29 | 30 | 31 | def load_cliques_shs100k(fn): 32 | cliques = {} 33 | _, data, _ = file_utils.load_csv(fn, sep="\t") 34 | for c, n in zip(data[0], data[1]): 35 | if c not in cliques: 36 | cliques[c] = [] 37 | cliques[c].append(n) 38 | for c in cliques.keys(): 39 | cliques[c] = list(set(cliques[c])) 40 | cliques[c].sort() 41 | for i in range(len(cliques[c])): 42 | cliques[c][i] = c + "-" + cliques[c][i] 43 | return cliques 44 | 45 | 46 | def load_cliques_discogsvi(fn, i=0): 47 | jsoncliques = file_utils.load_json(fn) 48 | cliques = {} 49 | cliqueinfo = {} 50 | notfound = 0 51 | # istart = i 52 | for c, versions in jsoncliques.items(): 53 | clique = [] 54 | for ver in versions: 55 | print(f"\r Version {i+1}", end=" ") 56 | sys.stdout.flush() 57 | v = ver["version_id"] 58 | ytid = ver["youtube_id"] 59 | idx = c + ":" + v 60 | # search basename 61 | basename = None 62 | for pref in [ 63 | ytid[:2], 64 | ytid[0].upper() + ytid[1].upper(), 65 | ytid[0].upper() + ytid[1].lower(), 66 | ytid[0].lower() + ytid[1].upper(), 67 | ytid[0].lower() + ytid[1].lower(), 68 | ]: 69 | fn_meta = os.path.join(args.path_audio, pref, ytid + ".meta") 70 | if os.path.exists(fn_meta): 71 | basename = os.path.join(pref, ytid) 72 | break 73 | if basename is None: 74 | notfound += 1 75 | continue 76 | # load metadata 77 | try: 78 | data = file_utils.load_json(fn_meta) 79 | except: 80 | data = {} 81 | # fill in now 82 | clique.append(idx) 83 | cliqueinfo[idx] = { 84 | "id": i, 85 | "clique": c, 86 | "version": v, 87 | "artist": data["artist"] if "artist" in data else "?", 88 | "title": data["title"] if "title" in data else "?", 89 | "filename": os.path.join(basename + "." + args.ext_in), 90 | } 91 | i += 1 92 | # if i == istart + 1000: 93 | # break 94 | cliques[c] = clique 95 | # if i == istart + 1000: 96 | # break 97 | print() 98 | return cliques, cliqueinfo, i, notfound 99 | 100 | 101 | ############################################################################### 102 | 103 | timer = print_utils.Timer() 104 | 105 | # Load cliques and splits 106 | print(f"Load {args.dataset}") 107 | if args.dataset == "SHS100K": 108 | 109 | # ********** SHS100K ********** 110 | # Info 111 | fn = os.path.join(args.path_meta, "list") 112 | _, data, numrecords = file_utils.load_csv(fn, sep="\t") 113 | info = {} 114 | for i in range(numrecords): 115 | c, n = data[0][i], data[1][i] 116 | idx = c + "-" + n 117 | info[idx] = { 118 | "id": i, 119 | "clique": c, 120 | "version": n, 121 | "artist": data[3][i], 122 | "title": data[2][i], 123 | "filename": os.path.join(idx[:2], idx + "." + args.ext_in), 124 | } 125 | # Splits 126 | splits = {} 127 | for sp, suff in zip(["train", "valid", "test"], ["TRAIN", "VAL", "TEST"]): 128 | fn = os.path.join(args.path_meta, "SHS100K-" + suff) 129 | splits[sp] = load_cliques_shs100k(fn) 130 | # ***************************** 131 | 132 | elif args.dataset == "DiscogsVI": 133 | 134 | # ********* DiscogsVI ********** 135 | # Splits + Info 136 | splits = {} 137 | info = {} 138 | i = 0 139 | for sp, suff in zip(["train", "valid", "test"], [".train", ".val", ".test"]): 140 | fn = os.path.join(args.path_meta, "DiscogsVI-YT-20240701-light.json" + suff) 141 | cliques, infosp, i, notfound = load_cliques_discogsvi(fn, i=i) 142 | splits[sp] = cliques 143 | info.update(infosp) 144 | if notfound > 0: 145 | print(f"({sp}: Could not find {notfound} songs)") 146 | else: 147 | print(f"({sp}: Found all songs)") 148 | 149 | elif args.dataset == "covers80": 150 | 151 | # ********* covers80 ********** 152 | # Info 153 | info = {} 154 | for prefix in ["list1", "list2"]: 155 | fn = os.path.join(args.path_meta, prefix + ".list") 156 | with open(fn, "r") as fh: 157 | lines = fh.readlines() 158 | for line in lines: 159 | c, n = line[:-1].split(os.sep) 160 | idx = line[:-1] 161 | info[idx] = { 162 | "id": len(info), 163 | "clique": c, 164 | "version": n, 165 | "artist": n.split("+")[0], 166 | "title": n.split("-")[-1], 167 | "filename": os.path.join(c, n + "." + args.ext_in), 168 | } 169 | # Splits 170 | cliques = {} 171 | for idx, ifo in info.items(): 172 | c = ifo["clique"] 173 | if c not in cliques: 174 | cliques[c] = [] 175 | cliques[c].append(idx) 176 | splits = { 177 | "train": cliques, 178 | "valid": cliques, 179 | "test": cliques, 180 | } 181 | # ***************************** 182 | 183 | else: 184 | raise NotImplementedError 185 | print(f" Found {len(info)} songs") 186 | nsongs = {} 187 | for sp in splits.keys(): 188 | nsongs[sp] = 0 189 | for cl, items in splits[sp].items(): 190 | nsongs[sp] += len(items) 191 | print(" Contains:", nsongs) 192 | 193 | ############################################################################### 194 | 195 | 196 | def get_file_info(idx, info): 197 | fn = os.path.join(args.path_audio, info["filename"]) 198 | print(f"\r[{timer.time()}] " + fn, end=" ", flush=True) 199 | # Check if exists (safe load) 200 | x = audio_utils.load_audio( 201 | fn, sample_rate=16000, n_channels=1, start=16000, length=16000 202 | ) 203 | if x is None or x.size(-1) < 16000: 204 | return None, None 205 | # Get info 206 | try: 207 | audio_info = audio_utils.get_info(fn) 208 | except: 209 | return None, None 210 | info["samplerate"] = audio_info.samplerate 211 | info["length"] = audio_info.length 212 | info["channels"] = audio_info.channels 213 | return idx, info 214 | 215 | 216 | ############################################################################### 217 | 218 | # Filter existing ones 219 | print(f"Filter existing") 220 | keys = list(info.keys()) 221 | keys.sort() 222 | if args.njobs == 1: 223 | done = [] 224 | for idx in keys: 225 | fn = os.path.join(args.path_audio, info[idx]["filename"]) 226 | if os.path.exists(fn): 227 | done.append(get_file_info(idx, info[idx])) 228 | else: 229 | todo = [] 230 | for idx in tqdm(keys, ncols=100, ascii=True): 231 | fn = os.path.join(args.path_audio, info[idx]["filename"]) 232 | if os.path.exists(fn): 233 | job = delayed(get_file_info)(idx, info[idx]) 234 | todo.append(job) 235 | done = Parallel(n_jobs=args.njobs)(todo) 236 | print() 237 | new_info = {} 238 | for idx, inf in done: 239 | if idx is not None and inf is not None: 240 | new_info[idx] = inf 241 | print(f" Found {len(new_info)} songs ({100*len(new_info)/len(info):.1f}%)") 242 | info = new_info 243 | 244 | # Redo splits 245 | print(f"Filter split") 246 | new_split = {} 247 | new_nsongs = {} 248 | for sp in splits.keys(): 249 | new_split[sp] = {} 250 | new_nsongs[sp] = 0 251 | for cl, items in splits[sp].items(): 252 | new_items = [] 253 | for idx in items: 254 | if idx in info: 255 | new_items.append(idx) 256 | if len(new_items) > 1: 257 | new_split[sp][cl] = new_items[:] 258 | new_nsongs[sp] += len(new_items) 259 | print(" Contains:", new_nsongs) 260 | for sp in nsongs.keys(): 261 | nsongs[sp] = 100 * new_nsongs[sp] / nsongs[sp] 262 | print(" Percent:", nsongs) 263 | splits = new_split 264 | 265 | # Save 266 | print(f"Save {args.fn_out}") 267 | torch.save([info, splits], args.fn_out) 268 | 269 | # Done 270 | print(f"Done!") 271 | -------------------------------------------------------------------------------- /lib/layers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import math 4 | 5 | ################################################################################################### 6 | 7 | 8 | class CQTPrepare(torch.nn.Module): 9 | 10 | def __init__(self, pow=0.5, norm="max2d", noise=True, affine=True, eps=1e-6): 11 | super().__init__() 12 | assert norm in ("max1d", "max2d", "mean2d") 13 | self.pow = pow 14 | self.norm = norm 15 | self.noise = noise 16 | self.affine = affine 17 | if self.affine: 18 | self.gain = torch.nn.Parameter(torch.ones(1)) 19 | self.bias = torch.nn.Parameter(torch.zeros(1)) 20 | self.eps = eps 21 | 22 | def forward(self, h): 23 | h = h.clamp(min=0).pow(self.pow) 24 | h = self.normalize(h) 25 | if self.noise: 26 | h = h + self.eps * torch.rand_like(h) 27 | h = self.normalize(h) 28 | if self.affine: 29 | h = self.gain * h + self.bias 30 | return h 31 | 32 | def normalize(self, h): 33 | h = h - h.min(2, keepdim=True)[0].min(3, keepdim=True)[0] 34 | if self.norm == "max2d": 35 | h = h / (h.max(2, keepdim=True)[0].max(3, keepdim=True)[0] + self.eps) 36 | elif self.norm == "max1d": 37 | h = h / (h.max(2, keepdim=True)[0] + self.eps) 38 | elif self.norm == "mean2d": 39 | h = h / (h.mean((2, 3), keepdim=True) + self.eps) 40 | return h 41 | 42 | 43 | ################################################################################################### 44 | 45 | 46 | class Linear(torch.nn.Module): 47 | 48 | def __init__(self, nin, nout, dim=1, bias=True): 49 | super().__init__() 50 | self.lin = torch.nn.Linear(nin, nout, bias=bias) 51 | self.dim = dim 52 | 53 | def forward(self, h): 54 | if self.dim != -1: 55 | h = h.transpose(self.dim, -1) 56 | h = self.lin(h) 57 | if self.dim != -1: 58 | h = h.transpose(self.dim, -1) 59 | return h 60 | 61 | 62 | class PadConv2d(torch.nn.Module): 63 | 64 | def __init__(self, nin, nout, kern, stride=1, bias=True): 65 | super().__init__() 66 | assert kern % 2 == 1 67 | pad = kern // 2 68 | self.conv = torch.nn.Conv2d( 69 | nin, nout, kern, stride=stride, padding=pad, bias=bias 70 | ) 71 | 72 | def forward(self, h): 73 | return self.conv(h) 74 | 75 | 76 | ################################################################################################### 77 | 78 | 79 | class Squeeze(torch.nn.Module): 80 | 81 | def __init__(self, dim=-1): 82 | super().__init__() 83 | assert type(dim) == int or type(dim) == tuple 84 | self.dim = dim 85 | 86 | def forward(self, h): 87 | return torch.squeeze(h, dim=self.dim) 88 | 89 | 90 | class Unsqueeze(torch.nn.Module): 91 | 92 | def __init__(self, dim=-1): 93 | super().__init__() 94 | assert type(dim) == int 95 | self.dim = dim 96 | 97 | def forward(self, h): 98 | return torch.unsqueeze(h, dim=self.dim) 99 | 100 | 101 | ################################################################################################### 102 | 103 | 104 | class InstanceBatchNorm1d(torch.nn.Module): 105 | 106 | def __init__(self, ncha, affine=True): 107 | super().__init__() 108 | assert ncha % 2 == 0 109 | self.bn = torch.nn.BatchNorm1d(ncha // 2, affine=affine) 110 | self.inst = torch.nn.InstanceNorm1d(ncha // 2, affine=affine) 111 | 112 | def forward(self, h): 113 | h1, h2 = torch.chunk(h, 2, dim=1) 114 | h1 = self.bn(h1) 115 | h2 = self.inst(h2) 116 | h = torch.cat([h1, h2], dim=1) 117 | return h 118 | 119 | 120 | class InstanceBatchNorm2d(torch.nn.Module): 121 | 122 | def __init__(self, ncha, affine=True): 123 | super().__init__() 124 | assert ncha % 2 == 0 125 | self.bn = torch.nn.BatchNorm2d(ncha // 2, affine=affine) 126 | self.inst = torch.nn.InstanceNorm2d(ncha // 2, affine=affine) 127 | 128 | def forward(self, h): 129 | h1, h2 = torch.chunk(h, 2, dim=1) 130 | h1 = self.bn(h1) 131 | h2 = self.inst(h2) 132 | h = torch.cat([h1, h2], dim=1) 133 | return h 134 | 135 | 136 | ################################################################################################### 137 | 138 | 139 | class GeMPool(torch.nn.Module): 140 | 141 | def __init__(self, ncha=1, init=3, eps=1e-6): 142 | super().__init__() 143 | self.flatten = torch.nn.Flatten(start_dim=2, end_dim=-1) 144 | self.softplus = torch.nn.Softplus() 145 | pinit = math.log(math.exp(init - 1) - 1) 146 | self.p = torch.nn.Parameter(pinit * torch.ones(1, ncha, 1)) 147 | self.eps = eps 148 | 149 | def forward(self, h): 150 | h = self.flatten(h) 151 | pow = 1 + self.softplus(self.p) 152 | h = h.clamp(min=self.eps).pow(pow) 153 | h = h.mean(-1).pow(1 / pow.squeeze(-1)) 154 | return h 155 | 156 | 157 | class AutoPool(torch.nn.Module): 158 | 159 | def __init__(self, ncha=1, p_init=1): 160 | super().__init__() 161 | self.flatten = torch.nn.Flatten(start_dim=2, end_dim=-1) 162 | self.p = torch.nn.Parameter(p_init * torch.ones(1, ncha, 1)) 163 | 164 | def forward(self, h): 165 | h = self.flatten(h) 166 | a = torch.softmax(self.p * h, -1) 167 | return (h * a).sum(dim=-1) 168 | 169 | 170 | class SoftPool(torch.nn.Module): 171 | 172 | def __init__(self, ncha): 173 | super().__init__() 174 | self.flatten = torch.nn.Flatten(start_dim=2, end_dim=-1) 175 | self.lin = Linear(ncha, 2 * ncha, dim=1, bias=False) 176 | self.norm = torch.nn.InstanceNorm1d(ncha, affine=True) 177 | 178 | def forward(self, h): 179 | h = self.flatten(h) 180 | h = self.lin(h) 181 | h, a = torch.chunk(h, 2, dim=1) 182 | a = torch.softmax(self.norm(a), dim=-1) 183 | return (h * a).sum(dim=-1) 184 | 185 | 186 | ################################################################################################### 187 | 188 | 189 | class ResNet50BottBlock(torch.nn.Module): 190 | 191 | def __init__( 192 | self, 193 | ncin, 194 | ncout, 195 | ncfactor=0.25, 196 | kern=3, 197 | stride=1, 198 | ibn=False, 199 | se=False, 200 | ): 201 | super().__init__() 202 | assert kern % 2 == 1 203 | pad = kern // 2 204 | ncmid = int(max(ncin, ncout) * ncfactor) 205 | if ncmid % 2 != 0: 206 | ncmid += 1 207 | tmp = [torch.nn.Conv2d(ncin, ncmid, 1, bias=False)] 208 | if ibn: 209 | tmp += [InstanceBatchNorm2d(ncmid)] 210 | else: 211 | tmp += [torch.nn.BatchNorm2d(ncmid)] 212 | tmp += [ 213 | torch.nn.ReLU(inplace=True), 214 | torch.nn.Conv2d(ncmid, ncmid, kern, stride=stride, padding=pad, bias=False), 215 | torch.nn.BatchNorm2d(ncmid), 216 | torch.nn.ReLU(inplace=True), 217 | torch.nn.Conv2d(ncmid, ncout, 1, bias=False), 218 | torch.nn.BatchNorm2d(ncout), 219 | ] 220 | if se: 221 | tmp += [SqueezeExcitation2d(ncout)] 222 | self.convs = torch.nn.Sequential(*tmp) 223 | if ncin != ncout or stride != 1: 224 | self.residual = torch.nn.Sequential( 225 | torch.nn.Conv2d( 226 | ncin, ncout, kern, stride=stride, padding=pad, bias=False 227 | ), 228 | torch.nn.BatchNorm2d(ncout), 229 | ) 230 | else: 231 | self.residual = torch.nn.Identity() 232 | self.relu = torch.nn.ReLU(inplace=True) 233 | 234 | def forward(self, h): 235 | return self.relu(self.convs(h) + self.residual(h)) 236 | 237 | 238 | ################################################################################################### 239 | 240 | 241 | class MyIBNResBlock(torch.nn.Module): 242 | 243 | def __init__( 244 | self, 245 | ncin, 246 | ncout, 247 | factor=0.5, 248 | kern=3, 249 | stride=1, 250 | ibn="pre", 251 | se="none", 252 | ): 253 | super().__init__() 254 | ncmid = max(1, int(max(ncin, ncout) * factor)) 255 | ncmid += ncmid % 2 256 | tmp = [] 257 | if ibn == "pre": 258 | tmp += [InstanceBatchNorm2d(ncin)] 259 | else: 260 | tmp += [torch.nn.BatchNorm2d(ncin)] 261 | if se == "pre": 262 | tmp += [SqueezeExcitation2d(ncin)] 263 | tmp += [ 264 | torch.nn.ReLU(inplace=True), 265 | PadConv2d(ncin, ncmid, kern, stride=stride, bias=False), 266 | ] 267 | if ibn == "post": 268 | tmp += [InstanceBatchNorm2d(ncmid)] 269 | else: 270 | tmp += [torch.nn.BatchNorm2d(ncmid)] 271 | tmp += [ 272 | torch.nn.ReLU(inplace=True), 273 | PadConv2d(ncmid, ncout, kern, bias=False), 274 | ] 275 | if se == "post": 276 | tmp += [SqueezeExcitation2d(ncout)] 277 | self.convs = torch.nn.Sequential(*tmp) 278 | if ncin != ncout or stride != 1: 279 | self.skip = torch.nn.Sequential( 280 | torch.nn.BatchNorm2d(ncin), 281 | torch.nn.ReLU(inplace=True), 282 | PadConv2d(ncin, ncout, kern, stride=stride, bias=False), 283 | ) 284 | else: 285 | self.skip = torch.nn.Identity() 286 | self.gain = torch.nn.Parameter(torch.zeros(1)) 287 | 288 | def forward(self, h): 289 | return self.gain * self.convs(h) + self.skip(h) 290 | 291 | 292 | ################################################################################################### 293 | 294 | 295 | class SqueezeExcitation2d(torch.nn.Module): 296 | 297 | def __init__(self, ncha, r=2): 298 | super().__init__() 299 | self.pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) 300 | nmid = max(1, int(ncha / r)) 301 | self.mlp = torch.nn.Sequential( 302 | torch.nn.Linear(ncha, nmid, bias=False), 303 | torch.nn.ReLU(inplace=True), 304 | torch.nn.Linear(nmid, ncha, bias=False), 305 | torch.nn.Sigmoid(), 306 | ) 307 | 308 | def forward(self, h): 309 | s = self.pooling(h).transpose(1, -1) 310 | s = self.mlp(s).transpose(-1, 1) 311 | return h * s 312 | 313 | 314 | ################################################################################################### 315 | -------------------------------------------------------------------------------- /lib/augmentations.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch, math 3 | import torchaudio, julius 4 | from torchvision import transforms 5 | from einops import rearrange 6 | 7 | from lib import tensor_ops as tops 8 | 9 | 10 | class Augment: 11 | 12 | def __init__(self, conf, sr=22050, random_order=True, eps=1e-6): 13 | self.conf = conf 14 | self.sr = sr 15 | self.random_order = random_order 16 | self.eps = eps 17 | 18 | def waveform( 19 | self, 20 | x, # (B,T) 21 | noise=None, 22 | ): 23 | assert x.ndim == 2 24 | # Randomize augmentations 25 | augs = list(self.conf.keys()) 26 | if self.random_order: 27 | ids = torch.randperm(len(augs)).tolist() 28 | else: 29 | ids = list(range(len(augs))) 30 | # Waveform domain augmentations 31 | for i in ids: 32 | if augs[i] == "polarity" and self.conf.polarity.p > 0: 33 | 34 | # --- Polarity augmentation --- 35 | mask = torch.rand(x.size(0), 1, device=x.device) < self.conf.polarity.p 36 | x = torch.where(mask, -x, x) 37 | 38 | elif augs[i] == "gain" and self.conf.gain.p > 0: 39 | 40 | # --- Gain augmentation --- 41 | rmin, rmax = self.conf.gain.r 42 | r = rmin + (rmax - rmin) * torch.rand(x.size(0), 1, device=x.device) 43 | r /= x.abs().max(-1, keepdim=True)[0] + self.eps 44 | mask = torch.rand(x.size(0), 1, device=x.device) < self.conf.gain.p 45 | x = torch.where(mask, (x * r).clamp(min=-1, max=1), x) 46 | del r 47 | 48 | elif augs[i] == "noise" and self.conf.noise.p > 0 and noise is not None: 49 | 50 | # --- Noise augmentation --- 51 | rmin, rmax = self.conf.noise.snr 52 | r = rmin + (rmax - rmin) * torch.rand(x.size(0), 1, device=x.device) 53 | r = 10 ** (r / 20) 54 | xnorm = x / (x.pow(2).mean(1, keepdim=True) + self.eps) 55 | nnorm = noise / (noise.pow(2).mean(1, keepdim=True) + self.eps) 56 | xnew = r * xnorm + nnorm 57 | xnew /= xnew.abs().max(1, keepdim=True)[0] + self.eps 58 | xnew *= x.abs().max(1, keepdim=True)[0] 59 | mask = torch.rand(x.size(0), 1, device=x.device) < self.conf.noise.p 60 | x = torch.where(mask, xnew, x) 61 | del r, xnorm, nnorm, xnew 62 | 63 | elif augs[i] == "reqtime" and self.conf.reqtime.p > 0: 64 | 65 | # --- Random EQ (time) --- 66 | # TODO: Really slow. Also unchecked. 67 | nfmin, nfmax = self.conf.reqtime.nfreqs 68 | nf = torch.randint(nfmin, nfmax + 1, (1,)).item() 69 | gmin, gmax = self.conf.reqtime.gains 70 | fmin, fmax = math.log(20), math.log(self.sr * 0.5 * 0.98) 71 | qmin, qmax = self.conf.reqtime.qrange 72 | qmin, qmax = math.log(qmin), math.log(qmax) 73 | xeq = x.clone() 74 | for _ in range(nf): 75 | gain = gmin + (gmax - gmin) * torch.rand(1).item() 76 | freq = math.exp(fmin + (fmax - fmin) * torch.rand(1).item()) 77 | q = math.exp(qmin + (qmax - qmin) * torch.rand(1).item()) 78 | xeq = torchaudio.functional.equalizer_biquad( 79 | xeq, self.sr, freq, gain, Q=q 80 | ) 81 | mask = torch.rand(x.size(0), 1, device=x.device) < self.conf.reqtime.p 82 | x = torch.where(mask, xeq.clamp(min=-1, max=1), x) 83 | del xeq 84 | 85 | elif augs[i] == "clipping" and self.conf.clipping.p > 0: 86 | 87 | # --- Clipping augmentation --- 88 | qtl = ( 89 | 1 90 | - torch.rand(x.size(0), 1, device=x.device) 91 | * self.conf.clipping.max_qtl 92 | ) 93 | thres = tops.tensor_quantile(x.abs(), qtl, dim=-1, keepdim=True) 94 | mask = ( 95 | torch.rand(x.size(0), 1, device=x.device) 96 | < self.conf.clipping.p_soft 97 | ) 98 | xclip = torch.tanh(x * 2 / (thres + self.eps)) * thres 99 | xclip = torch.where(mask, xclip, x.clamp(min=-thres, max=thres)) 100 | mask = torch.rand(x.size(0), 1, device=x.device) < self.conf.clipping.p 101 | x = torch.where(mask, xclip, x) 102 | del qtl, thres, xclip 103 | 104 | elif augs[i] == "length" and self.conf.length.p > 0: 105 | 106 | # --- Length augmentation --- 107 | if torch.rand(1).item() < self.conf.length.p: 108 | rmin = self.conf.length.rmin 109 | r = rmin + (1 - rmin) * torch.rand(1).item() 110 | x = x[:, : int(r * x.size(-1))] 111 | 112 | elif augs[i] == "compexp" and self.conf.compexp.p > 0: 113 | 114 | # --- Basic compression/expansion augmentation --- 115 | rmin, rmax = self.conf.compexp.r 116 | r = rmin + (rmax - rmin) * torch.rand(x.size(0), 1, device=x.device) 117 | mask = torch.rand(x.size(0), 1, device=x.device) < self.conf.compexp.p 118 | x = torch.where(mask, x.sign() * (x.abs() ** r), x) 119 | 120 | return x # (B,T) 121 | 122 | def cqgram( 123 | self, 124 | y, # (B,C,T) or (B,N,C,T) 125 | ): 126 | assert y.ndim == 3 or y.ndim == 4 127 | # Reshape? 128 | if y.ndim == 4: 129 | nnn = y.size(1) 130 | y = rearrange(y, "b n c t -> (b n) c t") 131 | else: 132 | nnn = None 133 | # Randomize augmentations 134 | augs = list(self.conf.keys()) 135 | if self.random_order: 136 | ids = torch.randperm(len(augs)).tolist() 137 | else: 138 | ids = list(range(len(augs))) 139 | # CQT domain augmentations 140 | for i in ids: 141 | 142 | if augs[i] == "specaugment" and self.conf.specaugment.p > 0: 143 | 144 | # --- Specaugment --- 145 | ydrop = y.clone() 146 | n = torch.randint(1, self.conf.specaugment.n + 1, (1,)).item() 147 | for _ in range(n): 148 | fpc = ( 149 | torch.rand(y.size(0), 1, 1, device=y.device) 150 | * self.conf.specaugment.f_pc 151 | ) 152 | flen = (fpc * y.size(1)).clamp(min=1).long() 153 | fmax = y.size(1) - flen 154 | f0 = (torch.rand_like(fpc) * fmax).long() 155 | tpc = ( 156 | torch.rand(y.size(0), 1, 1, device=y.device) 157 | * self.conf.specaugment.t_pc 158 | ) 159 | tlen = (tpc * y.size(2)).clamp(min=1).long() 160 | tmax = y.size(2) - tlen 161 | t0 = (torch.rand_like(tpc) * tmax).long() 162 | fids = torch.arange(0, y.size(1), device=y.device).view(1, -1, 1) 163 | tids = torch.arange(0, y.size(2), device=y.device).view(1, 1, -1) 164 | if self.conf.specaugment.full: 165 | cond = ((fids >= f0) & (fids < f0 + flen)) | ( 166 | (tids >= t0) & (tids < t0 + tlen) 167 | ) 168 | else: 169 | cond = ((fids >= f0) & (fids < f0 + flen)) & ( 170 | (tids >= t0) & (tids < t0 + tlen) 171 | ) 172 | ydrop = torch.where(cond, 0, ydrop) 173 | mask = ( 174 | torch.rand(y.size(0), 1, 1, device=y.device) 175 | < self.conf.specaugment.p 176 | ) 177 | y = torch.where(mask, ydrop, y) 178 | del ydrop 179 | 180 | elif augs[i] == "timestretch" and self.conf.timestretch.p > 0: 181 | 182 | # --- Time stretch --- 183 | rmin, rmax = self.conf.timestretch.r 184 | ys = y.clone() 185 | for j in range(ys.size(0)): 186 | r = rmin + (rmax - rmin) * torch.rand(1).item() 187 | length = int(ys.size(2) * r) 188 | if ys.size(2) != length: 189 | resize = transforms.Resize((ys.size(1), length)) 190 | aux = resize(ys[j : j + 1].unsqueeze(1)).squeeze(1) 191 | ys[j : j + 1] = tops.force_length( 192 | aux, 193 | ys.size(2), 194 | dim=2, 195 | pad_mode=self.conf.timestretch.pad_mode, 196 | cut_mode=self.conf.timestretch.cut_mode, 197 | ) 198 | mask = ( 199 | torch.rand(y.size(0), 1, 1, device=y.device) 200 | < self.conf.timestretch.p 201 | ) 202 | y = torch.where(mask, ys, y) 203 | del ys 204 | 205 | elif augs[i] == "pitchstretch" and self.conf.pitchstretch.p > 0: 206 | 207 | # --- Pitch stretch --- 208 | rmin, rmax = self.conf.pitchstretch.r 209 | ys = y.clone() 210 | for j in range(ys.size(0)): 211 | r = rmin + (rmax - rmin) * torch.rand(1).item() 212 | length = int(ys.size(1) * r) 213 | if ys.size(1) != length: 214 | resize = transforms.Resize((length, ys.size(2))) 215 | aux = resize(ys[j : j + 1].unsqueeze(1)).squeeze(1) 216 | ys = tops.force_length( 217 | aux, 218 | ys.size(1), 219 | dim=1, 220 | pad_mode=self.conf.pitchstretch.pad_mode, 221 | cut_mode=self.conf.pitchstretch.cut_mode, 222 | ) 223 | mask = ( 224 | torch.rand(y.size(0), 1, 1, device=y.device) 225 | < self.conf.pitchstretch.p 226 | ) 227 | y = torch.where(mask, ys, y) 228 | del ys 229 | 230 | elif augs[i] == "pitchtranspose" and self.conf.pitchtranspose.p > 0: 231 | 232 | # --- Pitch transposition --- 233 | rmin, rmax = self.conf.pitchtranspose.r 234 | yt = torch.zeros_like(y) 235 | for j in range(yt.size(0)): 236 | r = torch.randint(rmin, rmax + 1, (1,)).item() 237 | if r == 0: 238 | yt[j, :, :] = y[j, :, :] 239 | elif r > 0: 240 | yt[j, r:, :] = y[j, :-r, :] 241 | else: 242 | yt[j, : -abs(r), :] = y[j, abs(r) :, :] 243 | mask = ( 244 | torch.rand(y.size(0), 1, 1, device=y.device) 245 | < self.conf.pitchtranspose.p 246 | ) 247 | y = torch.where(mask, yt, y) 248 | del yt 249 | 250 | elif augs[i] == "reqcqt" and self.conf.reqcqt.p > 0: 251 | 252 | # --- Random EQ (CQT) --- 253 | rmin, rmax = self.conf.reqcqt.r 254 | r = torch.cumsum(torch.randn(y.size(0), y.size(1), device=y.device), 1) 255 | r = julius.lowpass.lowpass_filter( 256 | r, self.conf.reqcqt.lpf, zeros=4 257 | ).unsqueeze(-1) 258 | r -= r.min(1, keepdim=True)[0] 259 | r /= r.max(1, keepdim=True)[0] + self.eps 260 | r *= rmin + (rmax - rmin) * torch.rand(r.size(0), 1, 1, device=y.device) 261 | r = 10 ** (r / 10) 262 | mask = torch.rand(y.size(0), 1, 1, device=y.device) < self.conf.reqcqt.p 263 | y = torch.where(mask, y * r, y) 264 | 265 | # Reshape 266 | if nnn is not None: 267 | y = rearrange(y, "(b n) c t -> b n c t", n=nnn) 268 | return y # (B,C,T) or (B,N,C,T), same as input 269 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os, warnings 3 | import math 4 | import importlib 5 | from omegaconf import OmegaConf 6 | import torch 7 | from lightning import Fabric 8 | from lightning.fabric.strategies import DDPStrategy 9 | import torchinfo 10 | 11 | from lib import augmentations, eval, dataset 12 | from utils import print_utils, pytorch_utils 13 | 14 | 15 | ######################################################################################### 16 | # Inits 17 | ######################################################################################### 18 | 19 | # Load config 20 | args = OmegaConf.from_cli() 21 | assert "jobname" in args 22 | assert "conf" in args 23 | conf = OmegaConf.merge(OmegaConf.load(args.conf), args) 24 | conf.jobname = args.jobname 25 | conf.data.path = conf.path 26 | conf.path.logs = os.path.join(conf.path.logs, conf.jobname) 27 | fn_ckpt_last = os.path.join(conf.path.logs, "checkpoint_last.ckpt") 28 | fn_ckpt_best = os.path.join(conf.path.logs, "checkpoint_best.ckpt") 29 | fn_ckpt_epoch = os.path.join(conf.path.logs, "checkpoint_$epoch$.ckpt") 30 | 31 | # Init pytorch/Fabric 32 | torch.backends.cudnn.benchmark = True # seems it is same speed as False? 33 | torch.backends.cudnn.deterministic = False 34 | torch.set_float32_matmul_precision("medium") 35 | torch.autograd.set_detect_anomaly(False) 36 | fabric = Fabric( 37 | accelerator="cuda", 38 | devices=conf.fabric.ngpus, 39 | num_nodes=conf.fabric.nnodes, 40 | strategy=DDPStrategy(broadcast_buffers=False), 41 | precision=conf.fabric.precision, 42 | loggers=pytorch_utils.get_logger(conf.path.logs), 43 | ) 44 | fabric.launch() 45 | 46 | # Common seed to have same model everywhere 47 | # (for different rands per GPU see re-seed below) 48 | fabric.barrier() 49 | fabric.seed_everything(conf.seed, workers=True) 50 | 51 | # Init my utils 52 | myprint = lambda s, end="\n": print_utils.myprint( 53 | s, end=end, doit=fabric.is_global_zero 54 | ) 55 | myprogbar = lambda it, desc=None, leave=False: print_utils.myprogbar( 56 | it, desc=desc, leave=leave, doit=fabric.is_global_zero 57 | ) 58 | timer = print_utils.Timer() 59 | 60 | # Print config 61 | myprint("-" * 65) 62 | myprint(OmegaConf.to_yaml(conf)[:-1]) 63 | myprint("-" * 65) 64 | 65 | 66 | ######################################################################################### 67 | # Model, optim, scheduler, load checkpoint... 68 | ######################################################################################### 69 | 70 | # Init model 71 | myprint("Init model...") 72 | module = importlib.import_module("models." + conf.model.name) 73 | with fabric.init_module(): 74 | model = module.Model(conf.model, sr=conf.data.samplerate) 75 | if fabric.is_global_zero: 76 | torchinfo.summary(model, depth=1) 77 | model = fabric.setup(model) 78 | model.mark_forward_method("prepare") 79 | model.mark_forward_method("embed") 80 | model.mark_forward_method("loss") 81 | 82 | # Init optimizer & scheduler 83 | myprint("Init optimizer...") 84 | optim = pytorch_utils.get_optimizer(conf.training.optim, model) 85 | optim = fabric.setup_optimizers(optim) 86 | sched, sched_on_epoch = pytorch_utils.get_scheduler( 87 | conf.training.optim, 88 | optim, 89 | epochs=conf.training.numepochs, 90 | mode=conf.training.monitor.mode, 91 | ) 92 | 93 | # Init local variables 94 | myprint("Init variables...") 95 | epoch = 0 96 | cost_best = torch.inf if conf.training.monitor.mode == "min" else -torch.inf 97 | if conf.training.optim.sched.startswith("plateau"): 98 | sched.step(cost_best) 99 | lr = sched.get_last_lr()[0] 100 | 101 | # Restore from previous checkpoint? 102 | fn_ckpt = None 103 | if conf.checkpoint is not None: 104 | fn_ckpt = conf.checkpoint 105 | elif os.path.exists(fn_ckpt_last): 106 | fn_ckpt = fn_ckpt_last 107 | if fn_ckpt is not None: 108 | myprint("Loading checkpoint...") 109 | state = pytorch_utils.get_state(model, optim, sched, conf, epoch, lr, cost_best) 110 | fabric.load(fn_ckpt, state) 111 | model, optim, sched, conf, epoch, lr, cost_best = pytorch_utils.set_state(state) 112 | myprint(" Loaded " + fn_ckpt) 113 | 114 | # Re-seed with global_rank to have truly different augmentations 115 | myprint("Re-seed...") 116 | fabric.barrier() 117 | fabric.seed_everything((epoch + 1) * (conf.seed + fabric.global_rank), workers=True) 118 | 119 | 120 | ######################################################################################### 121 | # Data 122 | ######################################################################################### 123 | 124 | # Dataset & augmentations 125 | myprint("Load data...") 126 | ds_train = dataset.Dataset( 127 | conf.data, 128 | "train", 129 | augment=True, 130 | verbose=fabric.is_global_zero, 131 | ) 132 | ds_valid = dataset.Dataset( 133 | conf.data, 134 | "valid", 135 | augment=False, 136 | verbose=fabric.is_global_zero, 137 | ) 138 | assert conf.training.batchsize > 1 139 | dl_train = torch.utils.data.DataLoader( 140 | ds_train, 141 | batch_size=conf.training.batchsize, 142 | shuffle=True, 143 | num_workers=conf.data.nworkers, 144 | drop_last=True, 145 | persistent_workers=False, 146 | pin_memory=True, 147 | ) 148 | dl_valid = torch.utils.data.DataLoader( 149 | ds_valid, 150 | batch_size=conf.training.batchsize, 151 | shuffle=False, 152 | num_workers=conf.data.nworkers, 153 | drop_last=False, 154 | persistent_workers=False, 155 | pin_memory=True, 156 | ) 157 | dl_train, dl_valid = fabric.setup_dataloaders(dl_train, dl_valid) 158 | augment = augmentations.Augment(conf.augmentations, sr=conf.data.samplerate) 159 | 160 | ######################################################################################### 161 | # Main loss function 162 | ######################################################################################### 163 | 164 | 165 | def main_loss_func(batch, logdict, training=False): 166 | # Prepare data 167 | with torch.inference_mode(): 168 | n_per_class = (len(batch) - 1) // 2 169 | cc = [batch[0]] * n_per_class 170 | cc = torch.cat(cc, dim=0) 171 | ii = batch[1::2] 172 | ii = torch.cat(ii, dim=0) 173 | xx = batch[2::2] 174 | xx = torch.cat(xx, dim=0) 175 | # Export audio? 176 | # torch.save([cc, ii, xx], "explore.pt") 177 | # sys.exit() 178 | # Augmentations - Waveform domain 179 | if training: 180 | xx = augment.waveform(xx) 181 | # Model - Shingle and CQT 182 | xx = model.prepare( 183 | xx, 184 | shingle_hop=None if training else model.get_shingle_params()[-1] / 2, 185 | ) 186 | # Augmentations - CQ domain 187 | if training: 188 | xx = augment.cqgram(xx) 189 | cc, ii, xx = cc.clone(), ii.clone(), xx.clone() 190 | # Train procedure 191 | if training: 192 | optim.zero_grad(set_to_none=True) 193 | zz, extra = model.embed(xx) 194 | loss, logdct = model.loss(cc, ii, zz, extra=extra) 195 | if training: 196 | fabric.backward(loss) 197 | optim.step() 198 | if not sched_on_epoch: 199 | sched.step() 200 | # Outputs and logdict 201 | with torch.inference_mode(): 202 | clist = torch.chunk(cc, n_per_class, dim=0) 203 | ilist = torch.chunk(ii, n_per_class, dim=0) 204 | zlist = torch.chunk(zz, n_per_class, dim=0) 205 | outputs = [clist[0]] + [None] * (2 * n_per_class) 206 | outputs[1::2] = ilist 207 | outputs[2::2] = zlist 208 | logdict.append(logdct) 209 | return outputs, logdict 210 | 211 | 212 | ######################################################################################### 213 | # Train/valid loops 214 | ######################################################################################### 215 | 216 | 217 | def train_loop(desc=None): 218 | # Init 219 | model.train() 220 | logdict = pytorch_utils.LogDict() 221 | # Loop 222 | fabric.barrier() 223 | for n, batch in enumerate(myprogbar(dl_train, desc=desc)): 224 | if conf.limit_batches is not None and n >= conf.limit_batches: 225 | break 226 | # Regular loss calc 227 | _, logdict = main_loss_func(batch, logdict, training=True) 228 | losses = logdict.get("l_main") 229 | myprint(f" [L*={losses[-1]:.3f}, L={losses.mean():.3f}]", end="") 230 | return logdict 231 | 232 | 233 | @torch.inference_mode() 234 | def valid_loop(desc=None): 235 | # Init 236 | model.eval() 237 | logdict = pytorch_utils.LogDict() 238 | queries_c = [] 239 | queries_i = [] 240 | queries_z = [] 241 | # Loop 242 | fabric.barrier() 243 | for n, batch in enumerate(myprogbar(dl_valid, desc=desc)): 244 | # if conf.limit_batches is not None and n >= conf.limit_batches: 245 | # break 246 | # Regular loss calc 247 | outputs, logdict = main_loss_func(batch, logdict, training=False) 248 | losses = logdict.get("l_main") 249 | myprint(f" [L*={losses[-1]:.3f}, L={losses.mean():.3f}]", end="") 250 | # Keep z for evaluating MAP 251 | cl, i1, z1 = outputs[:3] 252 | queries_c.append(cl) 253 | queries_i.append(i1) 254 | queries_z.append(z1) 255 | queries_c = torch.cat(queries_c, dim=0) # (B) 256 | queries_i = torch.cat(queries_i, dim=0) # (B) 257 | queries_z = torch.cat(queries_z, dim=0) # (B,C) or (B,S,C) 258 | # Gather all multi-gpu tensors 259 | fabric.barrier() 260 | all_c = fabric.all_gather(queries_c) # (N,B) 261 | all_i = fabric.all_gather(queries_i) # (N,B) 262 | all_z = fabric.all_gather(queries_z) # (N,B,C) or (N,B,S,C) 263 | all_c = torch.cat(torch.unbind(all_c, dim=0), dim=0) 264 | all_i = torch.cat(torch.unbind(all_i, dim=0), dim=0) 265 | all_z = torch.cat(torch.unbind(all_z, dim=0), dim=0) 266 | # Eval kNN 267 | myprint("Eval... ", end="") 268 | aps, r1s, rpcs = eval.compute( 269 | model, 270 | queries_c, 271 | queries_i, 272 | queries_z, 273 | all_c, 274 | all_i, 275 | all_z, 276 | ) 277 | comp = (rpcs * (1 - aps)) ** 0.5 278 | logdict.append({"m_MAP": aps, "m_MR1": r1s, "m_ARP": rpcs, "m_COMP": comp}) 279 | return logdict 280 | 281 | 282 | ######################################################################################### 283 | # Main loop 284 | ######################################################################################### 285 | 286 | # Main loop (epoch) 287 | myprint("Training...") 288 | stop = None 289 | start_epoch = epoch 290 | for epoch in range(start_epoch, conf.training.numepochs): 291 | desc = f"{epoch+1:{len(str(conf.training.numepochs))}d}/{conf.training.numepochs}" 292 | fabric.log("hpar/epoch", epoch + 1, step=epoch + 1) 293 | # Train 294 | logdict_train = train_loop(desc="Train " + desc) 295 | logdict_train.sync_and_mean(fabric) 296 | fabric.log_dict(logdict_train.get(prefix="train/"), step=epoch + 1) 297 | # Valid 298 | logdict_valid = valid_loop(desc="Valid " + desc) 299 | logdict_valid.sync_and_mean(fabric) 300 | fabric.log_dict(logdict_valid.get(prefix="valid/"), step=epoch + 1) 301 | # Report & check NaN/inf 302 | tmp = logdict_valid.get(keys=["l_main", "m_MAP", "m_ARP", "m_COMP"]) 303 | tmp["l_main_t"] = logdict_train.get("l_main") 304 | report = print_utils.report(tmp, desc=f"[{timer.time()}] Epoch {desc}") 305 | for aux in tmp.values(): 306 | if math.isnan(aux) or math.isinf(aux): 307 | stop = "NaN or inf reached!" 308 | break 309 | # Get current cost 310 | cost_current = logdict_valid.get(conf.training.monitor.quantity) 311 | # Optimizer schedule? 312 | fabric.log("hpar/lr", lr, step=epoch + 1) 313 | if sched_on_epoch: 314 | if conf.training.optim.sched.startswith("plateau"): 315 | sched.step(cost_current) 316 | else: 317 | with warnings.catch_warnings(): 318 | # otherwise it warns about passing the epoch number (?) 319 | warnings.simplefilter("ignore") 320 | sched.step() 321 | new_lr = sched.get_last_lr()[0] 322 | if new_lr != lr: 323 | if conf.training.optim.sched.startswith("plateau"): 324 | report += f" (lr={new_lr:.1e})" 325 | lr = new_lr 326 | if "min_lr" in conf.training.optim and lr < conf.training.optim.min_lr: 327 | stop = "Min lr reached." 328 | # Checkpoint & best 329 | if ( 330 | conf.training.save_freq is not None 331 | and (epoch + 1) % conf.training.save_freq == 0 332 | ): 333 | fn = fn_ckpt_epoch.replace("$epoch$", "epoch" + str(epoch + 1)) 334 | state = pytorch_utils.get_state( 335 | model, optim, sched, conf, epoch + 1, lr, cost_best 336 | ) 337 | fabric.save(fn, state) 338 | if (conf.training.monitor.mode == "max" and cost_current > cost_best) or ( 339 | conf.training.monitor.mode == "min" and cost_current < cost_best 340 | ): 341 | cost_best = cost_current 342 | state = pytorch_utils.get_state( 343 | model, optim, sched, conf, epoch + 1, lr, cost_best 344 | ) 345 | fabric.save(fn_ckpt_best, state) 346 | report += " *" 347 | state = pytorch_utils.get_state(model, optim, sched, conf, epoch + 1, lr, cost_best) 348 | fabric.save(fn_ckpt_last, state) 349 | # Done 350 | myprint(report) 351 | if stop is not None: 352 | myprint(stop + " Stop.") 353 | break 354 | 355 | ######################################################################################### 356 | -------------------------------------------------------------------------------- /lib/tensor_ops.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch, math 3 | from einops import rearrange 4 | 5 | ################################################################################################### 6 | 7 | 8 | def tensor_quantile(x, q, dim=-1, keepdim=False): 9 | assert x.ndim == q.ndim 10 | qn = (q.clamp(min=0, max=1) * (x.size(dim) - 1)).round().long() 11 | sx = x.sort(dim=dim)[0] 12 | xq = torch.gather(sx, dim, qn) 13 | if keepdim: 14 | return xq 15 | return xq.squeeze(dim) 16 | 17 | 18 | ################################################################################################### 19 | 20 | 21 | def debug_inf_nan(ten, txt): 22 | if torch.isnan(ten).float().sum() > 0: 23 | print() 24 | print("nan " + txt) 25 | sys.exit() 26 | if torch.isinf(ten).float().sum() > 0: 27 | print() 28 | print("inf " + txt) 29 | sys.exit() 30 | 31 | 32 | ################################################################################################### 33 | 34 | 35 | def force_length( 36 | x, length, dim=-1, pad_mode="repeat", cut_mode="start", allow_longer=False 37 | ): 38 | assert pad_mode in ("repeat", "zeros", "crazy") 39 | assert cut_mode in ("start", "end", "random") 40 | # fast bypass 41 | if x.size(dim) == length or (x.size(dim) > length and allow_longer): 42 | return x 43 | # do otherwise 44 | aux = x.clone() 45 | while aux.size(dim) < length: 46 | if pad_mode == "repeat": 47 | aux = torch.cat([aux, x], dim=dim) 48 | elif pad_mode == "zeros": 49 | aux = torch.cat([aux, torch.zeros_like(x)], dim=dim) 50 | elif pad_mode == "crazy": 51 | r = torch.randint(0, 4, (1,)).item() 52 | if r == 0: 53 | aux = torch.cat([aux, x], dim=dim) 54 | elif r == 1: 55 | aux = torch.cat([x, aux], dim=dim) 56 | elif r == 2: 57 | aux = torch.cat([aux, torch.zeros_like(x)], dim=dim) 58 | elif r == 3: 59 | aux = torch.cat([torch.zeros_like(x), aux], dim=dim) 60 | if not allow_longer and aux.size(-1) > length: 61 | if dim != -1: 62 | aux = aux.transpose(dim, -1) 63 | if cut_mode == "start": 64 | aux = aux[..., :length] 65 | elif cut_mode == "end": 66 | aux = aux[..., -length:] 67 | elif cut_mode == "random": 68 | r = torch.randint(0, aux.size(-1) - length + 1, (1,)).item() 69 | aux = aux[..., r : r + length] 70 | if dim != -1: 71 | aux = aux.transpose(-1, dim) 72 | return aux 73 | 74 | 75 | ################################################################################################### 76 | 77 | 78 | def frames(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): 79 | if pad_end: 80 | signal_length = signal.shape[axis] 81 | frames_overlap = frame_length - frame_step 82 | rest_samples = abs(signal_length - frames_overlap) % abs(frame_step) 83 | if rest_samples != 0: 84 | pad_size = int(frame_length - rest_samples) 85 | pad_axis = [0] * signal.ndim 86 | pad_axis[axis] = pad_size 87 | signal = torch.nn.functional.pad(signal, pad_axis, "constant", pad_value) 88 | frames = signal.unfold(axis, frame_length, frame_step) 89 | return frames 90 | 91 | 92 | def get_frames( 93 | x, length, step, dim=-1, pad_end=True, pad_mode="zeros", cut_mode="start" 94 | ): 95 | if pad_end: 96 | newlength = ( 97 | max(int(math.ceil((x.size(dim) - length) / step)), 0) * step + length 98 | ) 99 | x = force_length( 100 | x, 101 | newlength, 102 | dim=dim, 103 | pad_mode=pad_mode, 104 | cut_mode=cut_mode, 105 | allow_longer=False, 106 | ) 107 | return x.unfold(dim, length, step) 108 | 109 | 110 | ################################################################################################### 111 | 112 | 113 | def covariance(x, eps=1e-6): 114 | xx = x - x.mean(0, keepdim=True) 115 | cov = torch.matmul(xx.T, xx) / (len(xx) - 1) 116 | weight = torch.triu(torch.ones_like(cov), diagonal=1) 117 | cov = (weight * cov.pow(2)).sum() / (weight.sum() + eps) 118 | return cov 119 | 120 | 121 | ################################################################################################### 122 | 123 | 124 | def roughly_equal(x, y, tol=1e-6): 125 | return (x - y).abs() < tol 126 | 127 | 128 | ################################################################################################### 129 | 130 | 131 | def pairwise_euclidean_distance_matrix(x, y, squared=False, eps=1e-6): 132 | squared_x = x.pow(2).sum(1).view(-1, 1) 133 | squared_y = y.pow(2).sum(1).view(1, -1) 134 | dot_product = torch.mm(x, y.t()) 135 | distance_matrix = squared_x - 2 * dot_product + squared_y 136 | # get rid of negative distances due to numerical instabilities 137 | distance_matrix[distance_matrix <= 0] = 0 138 | if not squared: 139 | # handle numerical stability 140 | # derivative of the square root operation applied to 0 is infinite 141 | # we need to handle by setting any 0 to eps 142 | mask = (distance_matrix == 0.0).type_as(distance_matrix) 143 | # use this mask to set indices with a value of 0 to eps 144 | distance_matrix += mask * eps 145 | # now it is safe to get the square root 146 | distance_matrix = torch.sqrt(distance_matrix) 147 | # undo the trick for numerical stability 148 | distance_matrix *= 1.0 - mask 149 | return distance_matrix 150 | 151 | 152 | def pairwise_distance_matrix(x, y, mode="fro", p=2, eps=1e-6): 153 | assert x.ndim == y.ndim and x.ndim <= 2 154 | if x.ndim == 1: 155 | x = x.unsqueeze(-1) 156 | y = y.unsqueeze(-1) 157 | if mode == "euc" or mode == "neuc": 158 | p = 2 159 | if mode in ("fro", "nfro", "euc", "neuc"): 160 | dist = torch.cdist(x.unsqueeze(0), y.unsqueeze(0), p=p).squeeze(0) 161 | if mode == "nfro" or mode == "neuc": 162 | dist = dist / (x.size(-1) ** (1 / p)) 163 | elif mode in ("sqeuc", "nsqeuc"): 164 | dist = pairwise_euclidean_distance_matrix(x, y, squared=True) 165 | if mode == "nsqeuc": 166 | dist = dist / x.size(-1) 167 | elif mode in ("cos", "cossim", "dot", "dotsim"): 168 | if mode == "cos" or mode == "cossim": 169 | x = x / (torch.norm(x, dim=-1, keepdim=True) + eps) 170 | y = y / (torch.norm(y, dim=-1, keepdim=True) + eps) 171 | dist = torch.matmul(x, y.T) 172 | if mode == "cos" or mode == "dot": 173 | dist = 1 - dist 174 | else: 175 | raise NotImplementedError 176 | return dist 177 | 178 | 179 | ################################################################################################### 180 | 181 | 182 | def msum(x, mask=None, dim=None, keepdim=False): 183 | if mask is None: 184 | included = torch.ones_like(x) 185 | else: 186 | included = (~mask).type_as(x) 187 | if dim is None: 188 | sum = (included * x).sum() 189 | if keepdim: 190 | while sum.ndim < x.ndim: 191 | sum = sum.unsqueeze(0) 192 | else: 193 | sum = (included * x).sum(dim=dim, keepdim=keepdim) 194 | return sum 195 | 196 | 197 | def mmean(x, mask=None, dim=None, keepdim=False, eps=1e-7): 198 | if mask is None: 199 | included = torch.ones_like(x) 200 | else: 201 | included = (~mask).type_as(x) 202 | if dim is None: 203 | num = (included * x).sum() 204 | den = included.sum() 205 | if keepdim: 206 | while num.ndim < x.ndim: 207 | num = num.unsqueeze(0) 208 | den = den.unsqueeze(0) 209 | else: 210 | num = (included * x).sum(dim=dim, keepdim=keepdim) 211 | den = included.sum(dim=dim, keepdim=keepdim) 212 | return num / den.clamp(min=eps) 213 | 214 | 215 | def mmin(x, mask=None, dim=None, keepdim=False, ctt=torch.inf): 216 | if mask is None: 217 | tmp = x 218 | else: 219 | tmp = torch.where(mask, ctt, x) 220 | if dim is None: 221 | tmp = tmp.min() 222 | if keepdim: 223 | while tmp.ndim < x.ndim: 224 | tmp = tmp.unsqueeze(0) 225 | else: 226 | if type(dim) == int: 227 | dim = [dim] 228 | else: 229 | dim = list(dim) 230 | for d in dim: 231 | tmp = tmp.min(d, keepdim=True)[0] 232 | if not keepdim: 233 | for d in dim: 234 | tmp = tmp.squeeze(d) 235 | return tmp 236 | 237 | 238 | def mmax(x, mask=None, dim=None, keepdim=False, ctt=-torch.inf): 239 | if mask is None: 240 | tmp = x 241 | else: 242 | tmp = torch.where(mask, ctt, x) 243 | if dim is None: 244 | tmp = tmp.max() 245 | if keepdim: 246 | while tmp.ndim < x.ndim: 247 | tmp = tmp.unsqueeze(0) 248 | else: 249 | if type(dim) == int: 250 | dim = [dim] 251 | else: 252 | dim = list(dim) 253 | for d in dim: 254 | tmp = tmp.max(d, keepdim=True)[0] 255 | if not keepdim: 256 | for d in dim: 257 | tmp = tmp.squeeze(d) 258 | return tmp 259 | 260 | 261 | def mrand(x, mask=None, dim=None, keepdim=False, ctt=torch.inf, eps=1e-7): 262 | r = torch.rand_like(x) 263 | if mask is not None: 264 | r = torch.where(mask, ctt, r) 265 | mr = r > mmin(r, mask=mask, dim=dim, keepdim=True, ctt=ctt) 266 | return mmean(x, mask=mr, dim=dim, keepdim=keepdim, eps=eps) 267 | 268 | 269 | def mbest(x, k, mask=None, dim=None, keepdim=False, ctt=torch.inf, eps=1e-7): 270 | assert type(dim) == int 271 | if mask is not None: 272 | x = torch.where(mask, ctt, x) 273 | x = x.topk(k, dim=dim, largest=False)[0] 274 | return mmean(x, mask=x >= ctt, dim=dim, keepdim=keepdim, eps=eps) 275 | 276 | 277 | def mworst(x, k, mask=None, dim=None, keepdim=False, ctt=-torch.inf, eps=1e-7): 278 | assert type(dim) == int 279 | if mask is not None: 280 | x = torch.where(mask, ctt, x) 281 | x = x.topk(k, dim=dim, largest=True)[0] 282 | return mmean(x, mask=x >= ctt, dim=dim, keepdim=keepdim, eps=eps) 283 | 284 | 285 | ################################################################################################### 286 | 287 | 288 | def distance_tensor_redux(dist, redux, mask=None, squeeze=True, eps=1e-7, inf=1e12): 289 | # Expects dist shape to be (b1,b2,s1,s2) 290 | # Reduces last two dims 291 | if redux == "min": 292 | dist = mmin(dist, mask=mask, dim=(-1, -2), keepdim=True, ctt=inf) 293 | elif redux == "max": 294 | dist = mmax(dist, mask=mask, dim=(-1, -2), keepdim=True, ctt=-inf) 295 | elif redux == "mean": 296 | dist = mmean(dist, mask=mask, dim=(-1, -2), keepdim=True, eps=eps) 297 | elif redux == "minmean": 298 | dist = mmean(dist, mask=mask, dim=-1, keepdim=True, eps=eps) 299 | dist = mmin(dist, mask=mask, dim=(-1, -2), keepdim=True, ctt=inf) 300 | elif redux == "meanmin": 301 | dist = mmin(dist, mask=mask, dim=-1, keepdim=True, ctt=inf) 302 | dist = mmean(dist, mask=mask, dim=(-1, -2), keepdim=True, eps=eps) 303 | elif redux == "randmin": 304 | dist = mmin(dist, mask=mask, dim=-1, keepdim=True, ctt=inf) 305 | dist = mrand(dist, mask=mask, dim=(-1, -2), keepdim=True, ctt=inf, eps=eps) 306 | elif redux.startswith("bpwr"): # best pairs without replacement 307 | # transpose if smaller 308 | if dist.size(3) < dist.size(2): 309 | dist = dist.transpose(2, 3) 310 | if mask is not None: 311 | mask = mask.transpose(2, 3) 312 | # set max iters 313 | if "-" not in redux: 314 | n = dist.size(2) 315 | else: 316 | n = max(1, min(int(redux.split("-")[-1]), dist.size(2))) 317 | # try to avoid ties 318 | dist = dist + eps * torch.rand_like(dist) 319 | # init 320 | if mask is None: 321 | mask = dist > inf 322 | all_sel = dist > inf 323 | # loop 324 | for i in range(n): 325 | mn = mmin(dist, mask=mask, dim=(-1, -2), keepdim=True, ctt=inf) 326 | sel = (dist <= mn) & (~mask) 327 | all_sel = all_sel | sel 328 | if i < n - 1: 329 | mask = ( 330 | mask 331 | | (mmin(dist, mask=mask, dim=-1, keepdim=True, ctt=inf) <= mn) 332 | | (mmin(dist, mask=mask, dim=-2, keepdim=True, ctt=inf) <= mn) 333 | ) 334 | # average 335 | dist = mmean(dist, mask=(~all_sel), dim=(-1, -2), keepdim=True, eps=eps) 336 | elif redux.startswith("best"): 337 | if "-" not in redux: 338 | k = 1 339 | else: 340 | k = max(1, min(int(redux.split("-")[-1]), dist.size(2) * dist.size(3))) 341 | dist = rearrange(dist, "b1 b2 s1 s2 -> b1 b2 1 (s1 s2)") 342 | if mask is not None: 343 | mask = rearrange(mask, "b1 b2 s1 s2 -> b1 b2 1 (s1 s2)") 344 | dist = mbest(dist, k, mask=mask, dim=-1, keepdim=True, ctt=inf, eps=eps) 345 | elif redux.startswith("worst"): 346 | if "-" not in redux: 347 | k = 1 348 | else: 349 | k = max(1, min(int(redux.split("-")[-1]), dist.size(2) * dist.size(3))) 350 | dist = rearrange(dist, "b1 b2 s1 s2 -> b1 b2 1 (s1 s2)") 351 | if mask is not None: 352 | mask = rearrange(mask, "b1 b2 s1 s2 -> b1 b2 1 (s1 s2)") 353 | dist = mworst(dist, k, mask=mask, dim=-1, keepdim=True, ctt=-inf, eps=eps) 354 | elif redux.startswith("bestmin"): 355 | if "-" not in redux: 356 | k = 1 357 | else: 358 | k = max(1, min(int(redux.split("-")[-1]), dist.size(2))) 359 | dist = mmin(dist, mask=mask, dim=-1, keepdim=True, ctt=inf) 360 | dist = mbest(dist, k, mask=mask, dim=(-1, -2), keepdim=True, ctt=inf, eps=eps) 361 | elif redux[0] == "s": 362 | aux1 = distance_tensor_redux(dist, redux[1:], mask=mask, squeeze=False) 363 | dist = dist.transpose(2, 3) 364 | if mask is not None: 365 | mask = mask.transpose(2, 3) 366 | aux2 = distance_tensor_redux(dist, redux[1:], mask=mask, squeeze=False) 367 | aux2 = aux2.transpose(2, 3) 368 | dist = 0.5 * (aux1 + aux2) 369 | else: 370 | raise NotImplementedError 371 | if squeeze: 372 | dist = dist.squeeze((-1, -2)) 373 | return dist 374 | 375 | 376 | ################################################################################################### 377 | --------------------------------------------------------------------------------