├── .gitignore ├── builder.py ├── configs ├── default.json ├── gentest.json ├── gentest_snr-2.json ├── gentest_snr-4.json ├── gentest_snr-6.json ├── gentest_snr0.json ├── gentest_snr2.json ├── gentest_snr4.json ├── gentest_snr6.json ├── gentest_snr8.json ├── ir_no.json ├── ir_no_snr-10_10.json ├── ir_pad1s.json ├── ir_pad1s_snr-10_10.json ├── n640d64.json ├── noise_no.json ├── noise_snr-10_10.json ├── noise_snr-5_10.json ├── seg.json ├── shuffle_1.json ├── shuffle_10.json ├── shuffle_100.json ├── shuffle_1000.json ├── snr.json ├── timeshift_250ms.json └── timeshift_no.json ├── cpp ├── faisscputest.cpp └── seqscore.cpp ├── cppmatcher.py ├── database.py ├── datautil ├── __init__.py ├── audio.py ├── dataset.py ├── dataset_v2.py ├── ir.py ├── melspec.py ├── mock_data.py ├── musicdata.py ├── noise.py ├── preprocess.py └── specaug.py ├── denoise └── createdataset.py ├── ensemble ├── drawheatmap.py ├── drawheatmap2.py ├── extractscore.py ├── lmscore.py ├── svmdraw.py ├── svmheatmap.py ├── svmheatmap2.py ├── svmtrain.py └── svmval.py ├── extractemb.py ├── genall.sh ├── genquery.py ├── genquery_naf.py ├── lists └── readme.md ├── matchemb.py ├── matcher.py ├── matchfromgt.py ├── model.py ├── preview.py ├── readme.md ├── rebuild.py ├── simpleutils.py ├── testall.sh ├── thesis.pdf ├── tools ├── accuracy.py ├── audioset.py ├── audioset2.py ├── convert_naf_to_pfann.py ├── cosinedecay.py ├── csv2txt.py ├── filterduration.py ├── fit.py ├── fma_full.py ├── fma_large.py ├── listaudio.py ├── mirexacc.py ├── stat.py ├── traintestsplit.py └── wham.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # cache files 2 | __pycache__ 3 | caches 4 | 5 | # output files 6 | *.mp3 7 | *.wav 8 | 9 | # log files 10 | dlyt.txt 11 | runs 12 | 13 | # the list is too long 14 | configs/*.csv 15 | lists/* 16 | !lists/readme.md 17 | 18 | # model files 19 | *.pt 20 | out/models 21 | 22 | # database files 23 | db/ 24 | out/dbs 25 | fastdb 26 | slowdb 27 | 28 | # generated queries 29 | out/queries 30 | 31 | # matcher results 32 | out/results 33 | 34 | # visualize 35 | *.wav.png 36 | *.mp3.png 37 | 38 | # compile file 39 | faisscputest 40 | seqscore 41 | *.exe 42 | *.dll 43 | *.exp 44 | *.lib 45 | *.obj 46 | 47 | logs/ 48 | .DS_Store 49 | -------------------------------------------------------------------------------- /builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | import time 5 | import warnings 6 | 7 | import faiss 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | import torch.multiprocessing as mp 13 | import tqdm 14 | 15 | # torchaudio currently (0.7) will throw warning that cannot be disabled 16 | with warnings.catch_warnings(): 17 | warnings.simplefilter("ignore") 18 | import torchaudio 19 | 20 | import simpleutils 21 | from model import FpNetwork 22 | from datautil.melspec import build_mel_spec_layer 23 | from datautil.musicdata import MusicDataset 24 | 25 | if __name__ == "__main__": 26 | logger_init = simpleutils.MultiProcessInitLogger('builder') 27 | logger_init() 28 | 29 | mp.set_start_method('spawn') 30 | if len(sys.argv) < 3: 31 | print('Usage: python %s ' % sys.argv[0]) 32 | sys.exit() 33 | file_list_for_db = sys.argv[1] 34 | dir_for_db = sys.argv[2] 35 | configs = 'configs/default.json' 36 | if len(sys.argv) >= 4: 37 | configs = sys.argv[3] 38 | if os.path.isdir(configs): 39 | configs_path = os.path.join(configs, 'configs.json') 40 | params = simpleutils.read_config(configs_path) 41 | params['model_dir'] = configs 42 | configs = configs_path 43 | else: 44 | params = simpleutils.read_config(configs) 45 | 46 | d = params['model']['d'] 47 | h = params['model']['h'] 48 | u = params['model']['u'] 49 | F_bin = params['n_mels'] 50 | segn = int(params['segment_size'] * params['sample_rate']) 51 | T = (segn + params['stft_hop'] - 1) // params['stft_hop'] 52 | 53 | print('loading model...') 54 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 55 | model = FpNetwork(d, h, u, F_bin, T, params['model']).to(device) 56 | model.load_state_dict(torch.load(os.path.join(params['model_dir'], 'model.pt'), map_location=device)) 57 | print('model loaded') 58 | 59 | # doing inference, turn off gradient 60 | model.eval() 61 | for param in model.parameters(): 62 | param.requires_grad = False 63 | 64 | params['indexer']['frame_shift_mul'] = 1 65 | dataset = MusicDataset(file_list_for_db, params) 66 | loader = DataLoader(dataset, num_workers=4, batch_size=None, worker_init_fn=logger_init) 67 | 68 | mel = build_mel_spec_layer(params).to(device) 69 | 70 | os.makedirs(dir_for_db, exist_ok=True) 71 | embeddings_file = open(os.path.join(dir_for_db, 'embeddings'), 'wb') 72 | lbl = [] 73 | landmarkKey = [] 74 | embeddings = 0 75 | for dat in tqdm.tqdm(loader): 76 | logger = mp.get_logger() 77 | i, name, wav = dat 78 | logger.info('get music %s', name) 79 | tm_0 = time.time() 80 | i = int(i) # i is leaking file handles! 81 | 82 | if wav.shape[0] == 0: 83 | # load file error! 84 | print('load %s error!' % name) 85 | landmarkKey.append(0) 86 | continue 87 | 88 | for batch in torch.split(wav, 32): 89 | g = batch.to(device) 90 | 91 | # Mel spectrogram 92 | with warnings.catch_warnings(): 93 | # torchaudio is still using deprecated function torch.rfft 94 | warnings.simplefilter("ignore") 95 | g = mel(g) 96 | z = model(g).cpu() 97 | for _ in z: 98 | lbl.append(i) 99 | embeddings_file.write(z.numpy().tobytes()) 100 | embeddings += z.shape[0] 101 | landmarkKey.append(int(wav.shape[0])) 102 | tm_1 = time.time() 103 | logger.info('compute embedding %.6fs', tm_1 - tm_0) 104 | embeddings_file.flush() 105 | print('total', embeddings, 'embeddings') 106 | if embeddings == 0: 107 | print('The database is empty!') 108 | #writer = tensorboardX.SummaryWriter() 109 | #writer.add_embedding(embeddings, lbl) 110 | 111 | # train indexer 112 | print('training indexer') 113 | try: 114 | index = faiss.index_factory(d, params['indexer']['index_factory'], faiss.METRIC_INNER_PRODUCT) 115 | except RuntimeError as x: 116 | if 'not implemented for inner prod search' in str(x) or "Error: 'metric == METRIC_L2' failed" in str(x): 117 | print(x) 118 | index = faiss.index_factory(d, params['indexer']['index_factory'], faiss.METRIC_L2) 119 | else: 120 | raise 121 | 122 | embeddings = np.fromfile(os.path.join(dir_for_db, 'embeddings'), dtype=np.float32).reshape([-1, d]) 123 | if not index.is_trained: 124 | index.verbose = True 125 | try: 126 | index.train(embeddings) 127 | except RuntimeError as x: 128 | print(x) 129 | if "Error: 'nx >= k' failed" in str(x): 130 | index = faiss.IndexFlatIP(d) 131 | #index = faiss.IndexFlatIP(d) 132 | 133 | # write database 134 | print('writing database') 135 | index.add(embeddings) 136 | faiss.write_index(index, os.path.join(dir_for_db, 'landmarkValue')) 137 | 138 | landmarkKey = np.array(landmarkKey, dtype=np.int32) 139 | landmarkKey.tofile(os.path.join(dir_for_db, 'landmarkKey')) 140 | 141 | shutil.copyfile(file_list_for_db, os.path.join(dir_for_db, 'songList.txt')) 142 | 143 | # write settings 144 | shutil.copyfile(configs, os.path.join(dir_for_db, 'configs.json')) 145 | 146 | # write model 147 | shutil.copyfile(os.path.join(params['model_dir'], 'model.pt'), 148 | os.path.join(dir_for_db, 'model.pt')) 149 | else: 150 | torch.set_num_threads(1) 151 | -------------------------------------------------------------------------------- /configs/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "baseline_model", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": null, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.2, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": 0 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/gentest.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": 10, 20 | "snr_min": -10 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr-2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": -2, 20 | "snr_min": -2 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr-4.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": -4, 20 | "snr_min": -4 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": -6, 20 | "snr_min": -6 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr0.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": 0, 20 | "snr_min": 0 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": 2, 20 | "snr_min": 2 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr4.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": 4, 20 | "snr_min": 4 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr6.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": 6, 20 | "snr_min": 6 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gentest_snr8.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_inside_test.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "cache_dir": "caches", 7 | "fftconv_n": 32768, 8 | "sample_rate": 8000, 9 | "f_min": 300, 10 | "f_max": 4000, 11 | "segment_size": 1, 12 | "hop_size": 0.5, 13 | "time_offset": 1.2, 14 | "pad_start": 0, 15 | "noise": { 16 | "train": "lists/noise_train.csv", 17 | "validate": "lists/noise_val.csv", 18 | "dir": "../pfann_dataset/audioset", 19 | "snr_max": 8, 20 | "snr_min": 8 21 | }, 22 | "micirp": { 23 | "train": "lists/micirp_train.csv", 24 | "validate": "lists/micirp_val.csv", 25 | "dir": "../pfann_dataset/micirp", 26 | "length": 0.5 27 | }, 28 | "air": { 29 | "train": "lists/air_train.csv", 30 | "validate": "lists/air_val.csv", 31 | "dir": "../pfann_dataset/AIR_1_4", 32 | "length": 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/ir_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/ir_no", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": -5 30 | }, 31 | "micirp": { 32 | "train": "", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/ir_no_snr-10_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/ir_no_snr-10_10", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": -10 30 | }, 31 | "micirp": { 32 | "train": "", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/ir_pad1s.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/ir_pad1s", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 1, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": -5 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/ir_pad1s_snr-10_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/ir_pad1s_snr-10_10", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 1, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": -10 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/n640d64.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "configs/train.csv", 3 | "validate_csv": "configs/validate.csv", 4 | "test_csv": "configs/test.csv", 5 | "model_dir": "n640d64", 6 | "cache_dir": "caches", 7 | "batch_size": 640, 8 | "shuffle_size": 20000, 9 | "clips_per_song": 60, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "dynamic_range": 80, 16 | "f_min": 300, 17 | "f_max": 4000, 18 | "segment_size": 1, 19 | "hop_size": 0.5, 20 | "time_offset": 1.2, 21 | "pad_start": 1, 22 | "noise": { 23 | "train": "configs/noise_train.csv", 24 | "validate": "configs/noise_val.csv", 25 | "snr_max": 10, 26 | "snr_min": 0 27 | }, 28 | "micirp": { 29 | "train": "configs/micirp_train.csv", 30 | "validate": "configs/micirp_val.csv", 31 | "length": 0.5 32 | }, 33 | "air": { 34 | "train": "configs/air_train.csv", 35 | "validate": "configs/air_val.csv", 36 | "length": 1 37 | }, 38 | "cutout_min": 0.1, 39 | "cutout_max": 0.5, 40 | "model": { 41 | "d": 64, 42 | "h": 1024, 43 | "u": 32, 44 | "fuller": false 45 | }, 46 | "indexer": { 47 | "index_factory": "IVF200,PQ64x8np", 48 | "top_k": 100 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /configs/noise_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/noise_no", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10000, 29 | "snr_min": 10000 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/noise_snr-10_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/noise_snr-10_10", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": -10 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/noise_snr-5_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/noise_snr-5_10", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": -5 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/seg.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "configs/train.csv", 3 | "validate_csv": "configs/validate.csv", 4 | "test_csv": "configs/test.csv", 5 | "model_dir": "n640d128seg", 6 | "cache_dir": "caches", 7 | "batch_size": 640, 8 | "shuffle_size": 20000, 9 | "clips_per_song": 60, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "dynamic_range": 80, 16 | "f_min": 300, 17 | "f_max": 4000, 18 | "segment_size": 1, 19 | "hop_size": 0.5, 20 | "time_offset": 1.25, 21 | "time_shift_type": "uniform", 22 | "pad_start": 1, 23 | "noise": { 24 | "train": "configs/noise_train.csv", 25 | "validate": "configs/noise_val.csv", 26 | "snr_max": 10, 27 | "snr_min": 0 28 | }, 29 | "micirp": { 30 | "train": "configs/micirp_train.csv", 31 | "validate": "configs/micirp_val.csv", 32 | "length": 0.5 33 | }, 34 | "air": { 35 | "train": "configs/air_train.csv", 36 | "validate": "configs/air_val.csv", 37 | "length": 1 38 | }, 39 | "cutout_min": 0.1, 40 | "cutout_max": 0.5, 41 | "model": { 42 | "d": 128, 43 | "h": 1024, 44 | "u": 32 45 | }, 46 | "indexer": { 47 | "index_factory": "Flat", 48 | "top_k": 100 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /configs/shuffle_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/shuffle_1", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 1, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.2, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": 0 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/shuffle_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/shuffle_10", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 10, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.2, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": 0 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/shuffle_100.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/shuffle_100", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.2, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": 0 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/shuffle_1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/shuffle_1000", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 1000, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.2, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": 0 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/snr.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "configs/train.csv", 3 | "validate_csv": "configs/validate.csv", 4 | "test_csv": "configs/test.csv", 5 | "model_dir": "snr128", 6 | "cache_dir": "caches", 7 | "batch_size": 640, 8 | "shuffle_size": 20000, 9 | "clips_per_song": 60, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "dynamic_range": 80, 16 | "f_min": 300, 17 | "f_max": 4000, 18 | "segment_size": 1, 19 | "hop_size": 0.5, 20 | "time_offset": 1.2, 21 | "pad_start": 1, 22 | "noise": { 23 | "train": "configs/noise_train.csv", 24 | "validate": "configs/noise_val.csv", 25 | "snr_max": 10, 26 | "snr_min": 0, 27 | "snr_only_in_f_range": true 28 | }, 29 | "micirp": { 30 | "train": "configs/micirp_train.csv", 31 | "validate": "configs/micirp_val.csv", 32 | "length": 0.5 33 | }, 34 | "air": { 35 | "train": "configs/air_train.csv", 36 | "validate": "configs/air_val.csv", 37 | "length": 1 38 | }, 39 | "cutout_min": 0.1, 40 | "cutout_max": 0.5, 41 | "model": { 42 | "d": 128, 43 | "h": 1024, 44 | "u": 32 45 | }, 46 | "indexer": { 47 | "index_factory": "IVF200,PQ64x8np", 48 | "top_k": 100 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /configs/timeshift_250ms.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/timeshift_250ms", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1.25, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": 0 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/timeshift_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_csv": "lists/fma_medium_train.csv", 3 | "validate_csv": "lists/fma_medium_val.csv", 4 | "test_csv": "lists/fma_medium_test.csv", 5 | "music_dir": "../pfann_dataset/fma_medium", 6 | "model_dir": "out/models/timeshift_no", 7 | "cache_dir": "caches", 8 | "batch_size": 640, 9 | "shuffle_size": 100, 10 | "fftconv_n": 32768, 11 | "sample_rate": 8000, 12 | "stft_n": 1024, 13 | "stft_hop": 256, 14 | "n_mels": 256, 15 | "f_min": 300, 16 | "f_max": 4000, 17 | "segment_size": 1, 18 | "hop_size": 0.5, 19 | "time_offset": 1, 20 | "pad_start": 0, 21 | "epoch": 100, 22 | "lr": 1e-4, 23 | "tau": 0.05, 24 | "noise": { 25 | "train": "lists/noise_train.csv", 26 | "validate": "lists/noise_val.csv", 27 | "dir": "../pfann_dataset/audioset", 28 | "snr_max": 10, 29 | "snr_min": 0 30 | }, 31 | "micirp": { 32 | "train": "lists/micirp_train.csv", 33 | "validate": "lists/micirp_val.csv", 34 | "dir": "../pfann_dataset/micirp", 35 | "length": 0.5 36 | }, 37 | "air": { 38 | "train": "lists/air_train.csv", 39 | "validate": "lists/air_val.csv", 40 | "dir": "../pfann_dataset/AIR_1_4", 41 | "length": 1 42 | }, 43 | "cutout_min": 0.1, 44 | "cutout_max": 0.5, 45 | "model": { 46 | "d": 128, 47 | "h": 1024, 48 | "u": 32, 49 | "fuller": true, 50 | "conv_activation": "ReLU" 51 | }, 52 | "indexer": { 53 | "index_factory": "IVF200,PQ64x8np", 54 | "top_k": 100, 55 | "frame_shift_mul": 1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /cpp/faisscputest.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Install conda and download faiss source code, then 3 | On Linux: 4 | g++ -O3 -I ../../faiss faisscputest.cpp ~/miniconda3/lib/libfaiss_avx2.so -o faisscputest 5 | or nvcc for GPU acceleration 6 | 7 | On Windows: 8 | cl /O2 /I ../../faiss /EHsc faisscputest.cpp %HomePath%\Miniconda3\Library\lib\faiss_avx2.lib /Fefaisscputest 9 | */ 10 | #ifdef __NVCC__ 11 | #include 12 | #include 13 | #include 14 | #endif 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #ifdef _WIN32 27 | #include 28 | #include 29 | #endif 30 | 31 | using faiss::Index; 32 | 33 | int idx_to_song_id(const int64_t *song_pos, int n_songs, int64_t idx) { 34 | return std::upper_bound(song_pos, song_pos + n_songs, idx) - song_pos - 1; 35 | } 36 | 37 | void my_search(const Index *idx, const int64_t *song_pos, int n_songs, int len, const float *query, const Index *idx2) { 38 | if (len > 100) len = 100; 39 | 40 | int k = 100; 41 | int d = idx->d; 42 | 43 | std::vector labels(len * k); 44 | std::vector distances(len * k); 45 | std::vector > candidates; 46 | std::vector vec(d); 47 | 48 | idx->search(len, query, k, distances.data(), labels.data()); 49 | 50 | for (int t = 0; t < len; t++) { 51 | for (int i = 0; i < k; i++) { 52 | if (labels[t*k+i] < 0) continue; 53 | 54 | int song_id = idx_to_song_id(song_pos, n_songs, labels[t*k+i]); 55 | candidates.emplace_back(song_id, int(labels[t*k+i] - song_pos[song_id] - t)); 56 | } 57 | } 58 | std::sort(candidates.begin(), candidates.end()); 59 | candidates.resize(std::unique(candidates.begin(), candidates.end()) - candidates.begin()); 60 | 61 | float best = -len - 1; 62 | int best_song = -1; 63 | for (auto c : candidates) { 64 | int song_id = c.first; 65 | int song_len = song_pos[song_id+1] - song_pos[song_id]; 66 | int64_t song_start = song_pos[song_id]; 67 | int t = c.second; 68 | 69 | float sco = 0; 70 | for (int i = 0; i < len; i++) { 71 | if (t+i < 0 || t+i >= song_len) continue; 72 | idx2->reconstruct(song_start + t+i, vec.data()); 73 | for (int j = 0; j < d; j++) { 74 | sco += vec[j] * query[i*d + j]; 75 | } 76 | } 77 | if (sco > best) { 78 | best = sco; 79 | best_song = song_id; 80 | } 81 | } 82 | fwrite(&best_song, 4, 1, stdout); 83 | fflush(stdout); 84 | } 85 | 86 | int main(int argc, char *argv[]) { 87 | #ifdef _WIN32 88 | _setmode(_fileno(stdin), _O_BINARY); 89 | _setmode(_fileno(stdout), _O_BINARY); 90 | #endif 91 | if (argc < 2) return 1; 92 | 93 | faiss::Index *index = NULL; 94 | std::string filename; 95 | filename = std::string(argv[1]) + "/landmarkValue"; 96 | 97 | #ifdef __NVCC__ 98 | faiss::gpu::StandardGpuResources res; 99 | faiss::gpu::GpuClonerOptions opt; 100 | opt.useFloat16 = true; 101 | #endif 102 | 103 | faiss::Index *index2 = NULL; 104 | try { 105 | index2 = index = faiss::read_index(filename.c_str()); 106 | #ifdef __NVCC__ 107 | index2 = faiss::gpu::index_cpu_to_gpu(&res, 0, index, &opt); 108 | #endif 109 | } 110 | catch (faiss::FaissException x) { 111 | puts(x.what()); 112 | return 1; 113 | } 114 | 115 | filename = std::string(argv[1]) + "/landmarkKey"; 116 | std::vector song_pos(1); 117 | FILE *fin = fopen(filename.c_str(), "rb"); 118 | if (!fin) { 119 | printf("database corrupt!\n"); 120 | return 1; 121 | } 122 | int32_t tmp; 123 | while (fread(&tmp, 4, 1, fin) == 1) { 124 | song_pos.push_back(song_pos.back() + tmp); 125 | } 126 | int n_songs = song_pos.size() - 1; 127 | fclose(fin); 128 | 129 | //printf("I read %lld data!\n", index->ntotal); 130 | if (faiss::IndexIVF *ivf = dynamic_cast(index)) { 131 | ivf->make_direct_map(); 132 | //ivf->nprobe = ivf->invlists->nlist; 133 | ivf->nprobe = 50; 134 | } 135 | #ifdef __NVCC__ 136 | if (faiss::gpu::GpuIndexIVF *ivf = dynamic_cast(index2)) { 137 | ivf->setNumProbes(50); 138 | } 139 | #endif 140 | int d = index->d; 141 | uint32_t len; 142 | while (fread(&len, 4, 1, stdin) == 1) { 143 | if (len % d != 0) return 1; 144 | std::vector query(len); 145 | uint32_t actual = fread(query.data(), 4, len, stdin); 146 | //if (len > d * 100) len = d * 100; 147 | //printf("cpu: "); 148 | //my_search(index, song_pos.data(), n_songs, len/d, query.data(), index); 149 | //printf("gpu: "); 150 | my_search(index2, song_pos.data(), n_songs, len/d, query.data(), index); 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /cpp/seqscore.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Install conda and download faiss source code, then 3 | On Linux: 4 | g++ -O3 -I ../../faiss -shared -fPIC -fopenmp seqscore.cpp ~/miniconda3/lib/libfaiss_avx2.so -o seqscore 5 | 6 | On Windows: 7 | cl /O2 /I ../../faiss /EHsc /LD /openmp seqscore.cpp %HomePath%\Miniconda3\Library\lib\faiss_avx2.lib /Feseqscore 8 | 9 | On Mac miniforge: please use conda compilers to compile faiss 10 | clang++ -O3 -I ../../faiss -shared -fPIC -Xclang -fopenmp seqscore.cpp ../../faiss/build/faiss/python/_swigfaiss.so -L $CONDA_PREFIX/lib -l omp -std=c++11 -o seqscore 11 | */ 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #ifndef _WIN32 20 | #define __declspec(x) 21 | #endif 22 | 23 | int idx_to_song_id(const int64_t *song_pos, int n_songs, int64_t idx) { 24 | return std::upper_bound(song_pos, song_pos + n_songs, idx) - song_pos - 1; 25 | } 26 | 27 | extern "C" __declspec(dllexport) 28 | long long version() { 29 | return 20220625002LL; 30 | } 31 | 32 | extern "C" __declspec(dllexport) 33 | int seq_score( 34 | void *index, 35 | const int64_t *song_pos, 36 | int n_songs, 37 | const float *query, 38 | int query_len, 39 | const int64_t *labels, 40 | int top_k, 41 | float *song_scores, 42 | int frame_shift_mul, 43 | float score_alpha) 44 | { 45 | const faiss::Index *idx = (faiss::Index *) index; 46 | const int d = idx->d; 47 | std::vector > candidates; 48 | 49 | for (int t = 0; t < query_len; t++) { 50 | int tim = t / frame_shift_mul; 51 | int shift = t % frame_shift_mul; 52 | for (int i = 0; i < top_k; i++) { 53 | if (labels[t*top_k+i] < 0) continue; 54 | 55 | int song_id = idx_to_song_id(song_pos, n_songs, labels[t*top_k+i]); 56 | candidates.emplace_back(song_id, int(labels[t*top_k+i] - song_pos[song_id] - tim), shift); 57 | } 58 | } 59 | std::sort(candidates.begin(), candidates.end()); 60 | candidates.resize(std::unique(candidates.begin(), candidates.end()) - candidates.begin()); 61 | 62 | float best = -INFINITY; 63 | int best_song = -1; 64 | std::vector tmp_score(candidates.size()); 65 | std::vector tmp_t(candidates.size()); 66 | 67 | int64_t mod = 1; 68 | while (mod < query_len) { 69 | mod *= 2; 70 | } 71 | 72 | #pragma omp parallel 73 | { 74 | std::vector cache(mod, -1); 75 | std::vector vec(d * mod); 76 | float my_best = -INFINITY; 77 | int my_best_song = -1; 78 | #pragma omp for 79 | for (int i = 0; i < candidates.size(); i++) { 80 | int song_id = std::get<0>(candidates[i]); 81 | if (song_id >= n_songs || song_id < 0) continue; 82 | int song_len = song_pos[song_id+1] - song_pos[song_id]; 83 | int64_t song_start = song_pos[song_id]; 84 | int t = std::get<1>(candidates[i]); 85 | int shift = std::get<2>(candidates[i]); 86 | 87 | float sco = 0; 88 | int my_query_len = (query_len - shift + frame_shift_mul - 1) / frame_shift_mul; 89 | for (int j = 0; j < my_query_len; j++) { 90 | int query_idx = j * frame_shift_mul + shift; 91 | if (t+j < 0 || t+j >= song_len) continue; 92 | int64_t song_at = song_start + t+j; 93 | int64_t song_at_hash = song_at & (mod - 1); 94 | float *my_vec = &vec[song_at_hash * d]; 95 | if (cache[song_at_hash] != song_at) { 96 | idx->reconstruct(song_at, my_vec); 97 | cache[song_at_hash] = song_at; 98 | } 99 | float innerprod = 0; 100 | for (int k = 0; k < d; k++) { 101 | innerprod += my_vec[k] * query[query_idx*d + k]; 102 | } 103 | // reference paper: Query adaptive similarity for large scale object retrieval 104 | // by D. Qin, C. Wengert, and L. V. Gool. 105 | float l2norm = 1.0f - 1.0f * innerprod; 106 | if (score_alpha == 0.0f) { 107 | sco += innerprod; 108 | } else if (score_alpha > 0.0f) { 109 | sco += expf(-score_alpha * l2norm * l2norm); 110 | } 111 | } 112 | sco /= std::max(my_query_len, 1); 113 | tmp_score[i] = sco; 114 | tmp_t[i] = t * frame_shift_mul - shift; 115 | if (sco > my_best) { 116 | my_best = sco; 117 | my_best_song = song_id; 118 | } 119 | } 120 | #pragma omp critical 121 | if (my_best > best || my_best == best && my_best_song < best_song) { 122 | best = my_best; 123 | best_song = my_best_song; 124 | } 125 | } 126 | for (int i = 0; i < candidates.size(); i++) { 127 | int song_id = std::get<0>(candidates[i]); 128 | if (song_id >= n_songs || song_id < 0) continue; 129 | if (tmp_score[i] > song_scores[song_id*2]) { 130 | song_scores[song_id*2] = tmp_score[i]; 131 | song_scores[song_id*2+1] = tmp_t[i]; 132 | } 133 | } 134 | //printf("hit %lld miss %lld\n", hit, miss); 135 | return best_song; 136 | } 137 | -------------------------------------------------------------------------------- /cppmatcher.py: -------------------------------------------------------------------------------- 1 | # need conda to run this program 2 | import csv 3 | import math 4 | import os 5 | import sys 6 | import warnings 7 | 8 | import faiss 9 | import julius 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import DataLoader 13 | import torch.nn.functional as F 14 | import torch.multiprocessing as mp 15 | import tqdm 16 | import subprocess 17 | if os.name == 'nt': 18 | print(os.name) 19 | import msvcrt 20 | 21 | # torchaudio currently (0.7) will throw warning that cannot be disabled 22 | with warnings.catch_warnings(): 23 | warnings.simplefilter("ignore") 24 | import torchaudio 25 | 26 | import simpleutils 27 | from model import FpNetwork 28 | from datautil.melspec import build_mel_spec_layer 29 | from datautil.musicdata import MusicDataset 30 | 31 | if __name__ == "__main__": 32 | mp.set_start_method('spawn') 33 | if len(sys.argv) < 4: 34 | print('Usage: python %s ' % sys.argv[0]) 35 | sys.exit() 36 | file_list_for_query = sys.argv[1] 37 | dir_for_db = sys.argv[2] 38 | result_file = sys.argv[3] 39 | result_file2 = os.path.splitext(result_file) # for more detailed output 40 | result_file2 = result_file2[0] + '_detail.csv' 41 | result_file_score = result_file + '.bin' 42 | configs = os.path.join(dir_for_db, 'configs.json') 43 | params = simpleutils.read_config(configs) 44 | 45 | visualize = False 46 | 47 | d = params['model']['d'] 48 | h = params['model']['h'] 49 | u = params['model']['u'] 50 | F_bin = params['n_mels'] 51 | segn = int(params['segment_size'] * params['sample_rate']) 52 | T = (segn + params['stft_hop'] - 1) // params['stft_hop'] 53 | 54 | top_k = params['indexer']['top_k'] 55 | frame_shift_mul = params['indexer'].get('frame_shift_mul', 1) 56 | 57 | print('loading model...') 58 | device = torch.device('cuda') 59 | model = FpNetwork(d, h, u, F_bin, T, params['model']).to(device) 60 | model.load_state_dict(torch.load(os.path.join(dir_for_db, 'model.pt'))) 61 | print('model loaded') 62 | 63 | print('loading database...') 64 | with open(os.path.join(dir_for_db, 'songList.txt'), 'r', encoding='utf8') as fin: 65 | songList = [] 66 | for line in fin: 67 | if line.endswith('\n'): line = line[:-1] 68 | songList.append(line) 69 | 70 | # doing inference, turn off gradient 71 | model.eval() 72 | for param in model.parameters(): 73 | param.requires_grad = False 74 | 75 | dataset = MusicDataset(file_list_for_query, params) 76 | # no task parallelism 77 | loader = DataLoader(dataset, num_workers=0) 78 | 79 | # open my c++ program 80 | env = {**os.environ} 81 | env['LD_LIBRARY_PATH'] = os.environ['CONDA_PREFIX'] + '/lib' 82 | query_proc = subprocess.Popen(['cpp/faisscputest', dir_for_db] 83 | , stdin=subprocess.PIPE, stdout=subprocess.PIPE 84 | , universal_newlines=False, env=env) 85 | if os.name == 'nt': 86 | # only Windows needs this! 87 | print('nt') 88 | msvcrt.setmode(query_proc.stdin.fileno(), os.O_BINARY) 89 | msvcrt.setmode(query_proc.stdout.fileno(), os.O_BINARY) 90 | 91 | mel = build_mel_spec_layer(params).to(device) 92 | 93 | fout = open(result_file, 'w', encoding='utf8', newline='\n') 94 | fout2 = open(result_file2, 'w', encoding='utf8', newline='\n') 95 | fout_score = open(result_file_score, 'wb') 96 | detail_writer = csv.writer(fout2) 97 | detail_writer.writerow(['query', 'answer', 'score', 'time', 'part_scores']) 98 | 99 | torch.set_num_threads(1) 100 | for dat in tqdm.tqdm(loader): 101 | embeddings = [] 102 | grads = [] 103 | specs = [] 104 | i, name, wav = dat 105 | i = int(i) # i is leaking file handles! 106 | # batch size should be less than 20 because query contains at most 19 segments 107 | for batch in torch.split(wav.squeeze(0), 16): 108 | g = batch.to(device) 109 | 110 | # Mel spectrogram 111 | with warnings.catch_warnings(): 112 | # torchaudio is still using deprecated function torch.rfft 113 | warnings.simplefilter("ignore") 114 | g = mel(g) 115 | z = model.forward(g, norm=False).cpu() 116 | z = torch.nn.functional.normalize(z, p=2) 117 | embeddings.append(z) 118 | embeddings = torch.cat(embeddings) 119 | song_score = np.zeros(len(songList), dtype=np.float32) 120 | 121 | query_proc.stdin.write(embeddings[0::frame_shift_mul].numpy().size.to_bytes(4, 'little')) 122 | query_proc.stdin.write(embeddings[0::frame_shift_mul].numpy().tobytes()) 123 | query_proc.stdin.flush() 124 | ans = int.from_bytes(query_proc.stdout.read(4), 'little') 125 | 126 | ans = songList[ans] 127 | sco = 0 128 | tim = 0 129 | upsco = [] 130 | tim /= frame_shift_mul 131 | tim *= params['hop_size'] 132 | fout.write('%s\t%s\n' % (name[0], ans)) 133 | fout.flush() 134 | detail_writer.writerow([name[0], ans, sco, tim] + upsco) 135 | fout2.flush() 136 | 137 | #fout_score.write(song_score.tobytes()) 138 | fout.close() 139 | fout2.close() 140 | else: 141 | torch.set_num_threads(1) 142 | -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from ctypes import cdll, c_float, c_int, c_int64, c_void_p, POINTER 3 | import os 4 | import time 5 | 6 | import faiss 7 | import numpy as np 8 | import torch.multiprocessing as mp 9 | 10 | import simpleutils 11 | 12 | cpp_accelerate = False 13 | gpu_accelerate = False 14 | if cpp_accelerate: 15 | mydll = cdll.LoadLibrary('cpp/seqscore') 16 | mydll.seq_score.argtypes = [ 17 | c_void_p, 18 | POINTER(c_int64), 19 | c_int, 20 | POINTER(c_float), 21 | c_int, 22 | POINTER(c_int64), 23 | c_int, 24 | POINTER(c_float), 25 | c_int, 26 | c_float 27 | ] 28 | mydll.seq_score.restype = c_int 29 | mydll.version.restype = c_int64 30 | if mydll.version() != 20220625002: 31 | print('seqscore.cpp Wrong version! Please recompile') 32 | exit(1) 33 | 34 | 35 | def make_direct_map(index): 36 | if isinstance(index, faiss.Index): 37 | index = faiss.downcast_index(index) 38 | elif isinstance(index, faiss.IndexBinary): 39 | index = faiss.downcast_IndexBinary(index) 40 | if hasattr(index, 'make_direct_map'): 41 | index.make_direct_map() 42 | return True 43 | elif isinstance(index, faiss.IndexPreTransform): 44 | return make_direct_map(index.index) 45 | elif isinstance(index, faiss.IndexFlat): 46 | return True 47 | else: 48 | print(type(index), 'does not support direct map yet!') 49 | return False 50 | 51 | def set_search_params(index, params): 52 | def helper(subindex, subparam): 53 | for name in subparam: 54 | value = subparam[name] 55 | if hasattr(subindex, name): 56 | if isinstance(value, dict): 57 | helper(getattr(subindex, name), value) 58 | else: 59 | setattr(subindex, name, value) 60 | else: 61 | print(subindex, 'has no attribute', name) 62 | if 'search_params' in params: 63 | helper(index, params['search_params']) 64 | 65 | # set nprobes 66 | myindex = index 67 | if isinstance(myindex, faiss.IndexPreTransform): 68 | myindex = faiss.downcast_index(myindex.index) 69 | if isinstance(myindex, faiss.IndexIVF): 70 | print('inverse list count:', myindex.nlist) 71 | myindex.nprobe = params.get('nprobe', 50) 72 | print('num probes:', myindex.nprobe) 73 | 74 | class Database: 75 | def __init__(self, dir_for_db, indexer_params, hop_size): 76 | self.dir_for_db = dir_for_db 77 | self.params = indexer_params 78 | self.top_k = self.params['top_k'] 79 | self.frame_shift_mul = self.params.get('frame_shift_mul', 1) 80 | self.hop_size = hop_size 81 | 82 | self.songList = simpleutils.read_file_list(os.path.join(dir_for_db, 'songList.txt')) 83 | 84 | self.song_pos = np.fromfile(os.path.join(dir_for_db, 'landmarkKey'), dtype=np.int32) 85 | assert len(self.songList) == self.song_pos.shape[0] 86 | self.song_pos = np.pad(np.cumsum(self.song_pos, dtype=np.int64), (1, 0)) 87 | 88 | self.index = faiss.read_index(os.path.join(dir_for_db, 'landmarkValue')) 89 | try: 90 | self.embedding = None 91 | if self.index.ntotal > 0: 92 | self.index.reconstruct(0) 93 | except RuntimeError: 94 | if not make_direct_map(self.index): 95 | print('This index cannot recover vector') 96 | self.embedding = np.fromfile(os.path.join(dir_for_db, 'embeddings'), dtype=np.float32) 97 | self.embedding = self.embedding.reshape([-1, self.index.d]) 98 | 99 | set_search_params(self.index, self.params) 100 | 101 | if gpu_accelerate and self.params.get('gpu', False): 102 | co = faiss.GpuMultipleClonerOptions() 103 | co.useFloat16 = True 104 | self.gpu_index = faiss.index_cpu_to_all_gpus(self.index, co, 1) 105 | else: 106 | self.gpu_index = self.index 107 | logger = mp.get_logger() 108 | self.score_alpha = self.params.get('score_alpha', 0) 109 | logger.info('score alpha: %d', self.score_alpha) 110 | 111 | def query_embeddings(self, query): 112 | if cpp_accelerate: 113 | return self.query_embeddings_cpp(query) 114 | else: 115 | return self.query_embeddings_base(query) 116 | 117 | def query_embeddings_base(self, query): 118 | logger = mp.get_logger() 119 | tm_1 = time.time() 120 | d = self.index.d 121 | distances, labels = self.gpu_index.search(query, self.top_k) 122 | tm_2 = time.time() 123 | best = -1e999 124 | best_song_t = -1, 0 125 | song_score = np.zeros([len(self.songList), 2], dtype=np.float32) 126 | if self.gpu_index.ntotal == 0: 127 | return best, best_song_t, song_score 128 | 129 | for shift in range(self.frame_shift_mul): 130 | candidates = [] 131 | subquery = query[shift::self.frame_shift_mul] 132 | sub_len = subquery.shape[0] 133 | for t in range(sub_len): 134 | lab = labels[t * self.frame_shift_mul + shift] 135 | lab = lab[lab != -1] 136 | song_id = np.searchsorted(self.song_pos, lab, side='right') - 1 137 | song_t = lab - self.song_pos[song_id] - t 138 | candidates.append(np.stack([song_id, song_t], axis=1)) 139 | # according to NumPy, np.unique returns sorted array 140 | candidates = np.unique(np.concatenate(candidates), axis=0) 141 | 142 | vec = np.zeros_like(subquery) 143 | for c in candidates: 144 | song_id = c[0].item() 145 | song_start = self.song_pos[song_id].item() 146 | song_len = self.song_pos[song_id+1].item() - song_start 147 | t = c[1].item() 148 | real_time = (t - shift / self.frame_shift_mul) * self.hop_size 149 | 150 | # get corresponding embeddings from db 151 | for i in range(sub_len): 152 | if t+i < 0 or t+i >= song_len: 153 | vec[i] = 0.0 154 | else: 155 | self.index.reconstruct(song_start + t+i, vec[i]) 156 | # compute average score 157 | sco = np.dot(vec.flatten(), subquery.flatten()).item() / sub_len 158 | if sco > song_score[song_id, 0]: 159 | song_score[song_id, 0] = sco 160 | song_score[song_id, 1] = real_time 161 | if sco > best: 162 | best = sco 163 | best_song_t = song_id, real_time 164 | tm_3 = time.time() 165 | logger.info('search %.6fs rerank %.6fs', tm_2-tm_1, tm_3-tm_2) 166 | return best, best_song_t, song_score 167 | 168 | def query_embeddings_cpp(self, query): 169 | logger = mp.get_logger() 170 | tm_1 = time.time() 171 | d = self.index.d 172 | distances, labels = self.gpu_index.search(query, self.top_k) 173 | tm_2 = time.time() 174 | best = -1e999 175 | best_song_t = -1, 0 176 | song_score = np.zeros([self.song_pos.shape[0] - 1, 2], dtype=np.float32) 177 | 178 | song_id = mydll.seq_score( 179 | int(self.index.this), 180 | self.song_pos.ctypes.data_as(POINTER(c_int64)), 181 | self.song_pos.shape[0]-1, 182 | query.ctypes.data_as(POINTER(c_float)), 183 | query.shape[0], 184 | labels.ctypes.data_as(POINTER(c_int64)), 185 | self.top_k, 186 | song_score.ctypes.data_as(POINTER(c_float)), 187 | self.frame_shift_mul, 188 | self.score_alpha 189 | ) 190 | best = song_score[song_id, 0].item() 191 | best_song_t = song_id, song_score[song_id, 1].item() * self.hop_size / self.frame_shift_mul 192 | tm_3 = time.time() 193 | song_score[:, 1] *= self.hop_size / self.frame_shift_mul 194 | logger.info('search %.6fs rerank %.6fs', tm_2-tm_1, tm_3-tm_2) 195 | return best, best_song_t, song_score 196 | -------------------------------------------------------------------------------- /datautil/__init__.py: -------------------------------------------------------------------------------- 1 | """HELP ME!""" -------------------------------------------------------------------------------- /datautil/audio.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import wave 8 | import io 9 | 10 | import simpleutils 11 | 12 | # because builtin wave won't read wav files with more than 2 channels 13 | class HackExtensibleWave: 14 | def __init__(self, stream): 15 | self.stream = stream 16 | self.pos = 0 17 | def read(self, n): 18 | r = self.stream.read(n) 19 | new_pos = self.pos + len(r) 20 | if self.pos < 20 and self.pos + n >= 20: 21 | r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:] 22 | elif 20 <= self.pos < 22: 23 | r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:] 24 | self.pos = new_pos 25 | return r 26 | 27 | def ffmpeg_get_audio(filename): 28 | error_log = open(os.devnull, 'w') 29 | proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'], 30 | stderr=error_log, 31 | stdin=open(os.devnull), 32 | stdout=subprocess.PIPE, 33 | bufsize=1000000) 34 | try: 35 | dat = proc.stdout.read() 36 | wav = wave.open(HackExtensibleWave(io.BytesIO(dat))) 37 | ch = wav.getnchannels() 38 | rate = wav.getframerate() 39 | n = wav.getnframes() 40 | dat = wav.readframes(n) 41 | del wav 42 | samples = np.frombuffer(dat, dtype=np.int16) / 32768 43 | samples = samples.reshape([-1, ch]).T 44 | return samples, rate 45 | except (wave.Error, EOFError): 46 | print('failed to decode %s. maybe the file is broken!' % filename) 47 | return np.zeros([1, 0]), 44100 48 | 49 | def wave_get_audio(filename): 50 | with open(filename, 'rb') as fin: 51 | wav = wave.open(HackExtensibleWave(fin)) 52 | smpwidth = wav.getsampwidth() 53 | if smpwidth not in {1, 2, 3}: 54 | return None 55 | n = wav.getnframes() 56 | if smpwidth == 1: 57 | samples = np.frombuffer(wav.readframes(n), dtype=np.uint8) / 128 - 1 58 | elif smpwidth == 2: 59 | samples = np.frombuffer(wav.readframes(n), dtype=np.int16) / 32768 60 | elif smpwidth == 3: 61 | a = np.frombuffer(wav.readframes(n), dtype=np.uint8) 62 | samples = np.stack([a[0::3], a[1::3], a[2::3], -(a[2::3]>>7)], axis=1).view(np.int32).squeeze(1) 63 | del a 64 | samples = samples / 8388608 65 | samples = samples.reshape([-1, wav.getnchannels()]).T 66 | return samples, wav.getframerate() 67 | 68 | def get_audio(filename): 69 | if str(filename).endswith('.wav'): 70 | try: 71 | a = wave_get_audio(filename) 72 | if a: return a 73 | except Exception: 74 | pass 75 | return ffmpeg_get_audio(filename) 76 | 77 | class FfmpegStream: 78 | def __init__(self, proc, sample_rate, nchannels, tmpfile): 79 | self.proc = proc 80 | self.sample_rate = sample_rate 81 | self.nchannels = nchannels 82 | self.tmpfile = None 83 | self.stream = self.gen_stream() 84 | self.tmpfile = tmpfile 85 | def __del__(self): 86 | self.proc.terminate() 87 | self.proc.communicate() 88 | del self.proc 89 | if self.tmpfile: 90 | os.unlink(self.tmpfile) 91 | def gen_stream(self): 92 | num = yield np.array([], dtype=np.int16) 93 | if not num: num = 1024 94 | while True: 95 | to_read = num * self.nchannels * 2 96 | dat = self.proc.stdout.read(to_read) 97 | num = yield np.frombuffer(dat, dtype=np.int16) 98 | if not num: num = 1024 99 | if len(dat) < to_read: 100 | break 101 | 102 | def ffmpeg_stream_audio(filename, is_tmp=False): 103 | while 1: 104 | try: 105 | stderr=open(os.devnull, 'w') 106 | stdin=open(os.devnull) 107 | break 108 | except PermissionError: 109 | print('PermissionError occured, try again') 110 | proc = subprocess.Popen(['ffprobe', '-i', filename, '-show_streams', 111 | '-select_streams', 'a', '-print_format', 'json'], 112 | stderr=stderr, 113 | stdin=stdin, 114 | stdout=subprocess.PIPE) 115 | prop = json.loads(proc.stdout.read()) 116 | if 'streams' not in prop: 117 | raise RuntimeError('FFmpeg cannot decode audio') 118 | sample_rate = int(prop['streams'][0]['sample_rate']) 119 | nchannels = prop['streams'][0]['channels'] 120 | proc = subprocess.Popen(['ffmpeg', '-i', filename, 121 | '-f', 's16le', '-acodec', 'pcm_s16le', 'pipe:1'], 122 | stderr=stderr, 123 | stdin=stdin, 124 | stdout=subprocess.PIPE) 125 | tmpfile = None 126 | if is_tmp: 127 | tmpfile = filename 128 | return FfmpegStream(proc, sample_rate, nchannels, tmpfile=tmpfile) 129 | 130 | class WaveStream: 131 | def __init__(self, filename, is_tmp=False): 132 | self.is_tmp = None 133 | self.file = open(filename, 'rb') 134 | self.wave = wave.open(HackExtensibleWave(self.file)) 135 | self.smpsize = self.wave.getnchannels() * self.wave.getsampwidth() 136 | self.sample_rate = self.wave.getframerate() 137 | self.nchannels = self.wave.getnchannels() 138 | if self.wave.getsampwidth() != 2: 139 | raise NotImplementedError('wave stream currently only supports 16bit wav') 140 | self.stream = self.gen_stream() 141 | self.is_tmp = filename if is_tmp else None 142 | def gen_stream(self): 143 | num = yield np.array([], dtype=np.int16) 144 | if not num: num = 1024 145 | while True: 146 | dat = self.wave.readframes(num) 147 | num = yield np.frombuffer(dat, dtype=np.int16) 148 | if not num: num = 1024 149 | if len(dat) < num * self.smpsize: 150 | break 151 | def __del__(self): 152 | if self.is_tmp: 153 | os.unlink(self.is_tmp) 154 | 155 | def stream_audio(filename): 156 | is_tmp = False 157 | if filename.startswith('s3://'): 158 | tmpname = simpleutils.download_tmp_from_s3(filename) 159 | is_tmp = True 160 | filename = tmpname 161 | try: 162 | return WaveStream(filename, is_tmp=is_tmp) 163 | except: 164 | pass 165 | try: 166 | return ffmpeg_stream_audio(filename, is_tmp=is_tmp) 167 | except: 168 | if is_tmp: 169 | os.unlink(tmpname) 170 | raise 171 | -------------------------------------------------------------------------------- /datautil/ir.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import warnings 5 | 6 | import scipy.io 7 | import numpy as np 8 | import torch 9 | import torch.fft 10 | with warnings.catch_warnings(): 11 | warnings.simplefilter("ignore") 12 | import torchaudio 13 | 14 | from datautil.audio import get_audio 15 | 16 | class AIR: 17 | def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000): 18 | print('loading Aachen IR dataset') 19 | with open(list_csv, 'r') as fin: 20 | reader = csv.reader(fin) 21 | airs = [] 22 | firstrow = next(reader) 23 | for row in reader: 24 | airs.append(row[0]) 25 | data = [] 26 | to_len = int(length * sample_rate) 27 | self.names = [] 28 | for name in airs: 29 | mat = scipy.io.loadmat(os.path.join(air_dir, name)) 30 | h_air = torch.tensor(mat['h_air'].astype(np.float32)) 31 | assert h_air.shape[0] == 1 32 | h_air = h_air[0] 33 | air_info = mat['air_info'] 34 | fs = int(air_info['fs'][0][0][0][0]) 35 | self.names.append(str(air_info['room'][0][0][0])) 36 | resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air) 37 | truncated = resampled[0:to_len] 38 | freqd = torch.fft.rfft(truncated, fftconv_n) 39 | data.append(freqd) 40 | self.data = torch.stack(data) 41 | 42 | def random_choose(self, num): 43 | indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) 44 | return self.data[indices] 45 | 46 | def random_choose_name(self): 47 | index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item() 48 | return self.data[index], self.names[index] 49 | 50 | class MicIRP: 51 | def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000): 52 | print('loading microphone IR dataset') 53 | with open(list_csv, 'r') as fin: 54 | reader = csv.reader(fin) 55 | mics = [] 56 | firstrow = next(reader) 57 | for row in reader: 58 | mics.append(row[0]) 59 | data = [] 60 | to_len = int(length * sample_rate) 61 | for name in mics: 62 | smp, smprate = get_audio(os.path.join(mic_dir, name)) 63 | smp = torch.FloatTensor(smp).mean(dim=0) 64 | resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) 65 | truncated = resampled[0:to_len] 66 | freqd = torch.fft.rfft(truncated, fftconv_n) 67 | data.append(freqd) 68 | self.data = torch.stack(data) 69 | 70 | def random_choose(self, num): 71 | indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) 72 | return self.data[indices] 73 | 74 | if __name__ == '__main__': 75 | args = argparse.ArgumentParser() 76 | args.add_argument('air') 77 | args.add_argument('out') 78 | args = args.parse_args() 79 | 80 | with open(args.out, 'w', encoding='utf8', newline='\n') as fout: 81 | writer = csv.writer(fout) 82 | writer.writerow(['file']) 83 | files = [] 84 | for name in os.listdir(args.air): 85 | if name.endswith('.mat'): 86 | files.append(name) 87 | files.sort() 88 | for name in files: 89 | writer.writerow([name]) 90 | -------------------------------------------------------------------------------- /datautil/melspec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | class MelSpec(torch.nn.Module): 5 | def __init__(self, 6 | sample_rate=8000, 7 | stft_n=1024, 8 | stft_hop=256, 9 | f_min=300, 10 | f_max=4000, 11 | n_mels=256, 12 | naf_mode=False, 13 | mel_log='log', 14 | spec_norm='l2'): 15 | super(MelSpec, self).__init__() 16 | self.naf_mode = naf_mode 17 | self.mel_log = mel_log 18 | self.spec_norm = spec_norm 19 | self.mel = torchaudio.transforms.MelSpectrogram( 20 | sample_rate=sample_rate, 21 | n_fft=stft_n, 22 | hop_length=stft_hop, 23 | f_min=f_min, 24 | f_max=f_max, 25 | n_mels=n_mels, 26 | window_fn=torch.hann_window, 27 | power = 1 if naf_mode else 2, 28 | pad_mode = 'constant' if naf_mode else 'reflect', 29 | norm = 'slaney' if naf_mode else None, 30 | mel_scale = 'slaney' if naf_mode else 'htk' 31 | ) 32 | 33 | def forward(self, x): 34 | # normalize volume 35 | p = 1e999 if self.spec_norm == 'max' else 2 36 | x = torch.nn.functional.normalize(x, p=p, dim=-1) 37 | 38 | if self.naf_mode: 39 | x = self.mel(x) + 0.06 40 | else: 41 | x = self.mel(x) + 1e-8 42 | 43 | if self.mel_log == 'log10': 44 | x = torch.log10(x) 45 | elif self.mel_log == 'log': 46 | x = torch.log(x) 47 | 48 | if self.spec_norm == 'max': 49 | x = x - torch.amax(x, dim=(-2,-1), keepdim=True) 50 | return x 51 | 52 | def build_mel_spec_layer(params): 53 | return MelSpec( 54 | sample_rate = params['sample_rate'], 55 | stft_n = params['stft_n'], 56 | stft_hop = params['stft_hop'], 57 | f_min = params['f_min'], 58 | f_max = params['f_max'], 59 | n_mels = params['n_mels'], 60 | naf_mode = params.get('naf_mode', False), 61 | mel_log = params.get('mel_log', 'log'), 62 | spec_norm = params.get('spec_norm', 'l2') 63 | ) 64 | 65 | if __name__ == '__main__': 66 | import simpleutils 67 | params = simpleutils.read_config('configs/default.json') 68 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 69 | mel = build_mel_spec_layer(params).to(device) 70 | x = torch.rand(2, 8000).to(device) - 0.5 71 | y = mel(x) 72 | print(y) 73 | print(y.shape) 74 | -------------------------------------------------------------------------------- /datautil/mock_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, BatchSampler 3 | from datautil.dataset_v2 import TwoStageShuffler 4 | 5 | def make_false_data(N, F_bin, T): 6 | mock = torch.rand([N, F_bin, T], dtype=torch.float32) 7 | mock2 = mock + torch.rand([N, F_bin, T], dtype=torch.float32) * 1 - 0.5 8 | mock = torch.stack([mock, mock2], dim=1) 9 | return mock 10 | 11 | # use this when you don't have training data 12 | class MockedDataLoader: 13 | def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2): 14 | assert train_val in {'train', 'validate'} 15 | F_bin = configs['n_mels'] 16 | segn = int(configs['segment_size'] * configs['sample_rate']) 17 | T = (segn + configs['stft_hop'] - 1) // configs['stft_hop'] 18 | # 1/50 of real training data 19 | num_fake_data = 0 20 | if train_val == 'train': 21 | num_fake_data = 584183//50 22 | else: 23 | num_fake_data = 29215//50 24 | self.dataset = make_false_data(num_fake_data, F_bin, T) 25 | assert configs['batch_size'] % 2 == 0 26 | self.batch_size = configs['batch_size'] 27 | self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size']) 28 | self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False) 29 | self.num_workers = num_workers 30 | self.configs = configs 31 | self.pin_memory = pin_memory 32 | self.prefetch_factor = prefetch_factor 33 | 34 | # you can change shuffle to True/False 35 | self.shuffle = True 36 | # you can change augmented to True/False 37 | self.augmented = True 38 | # you can change eval time shift to True/False 39 | self.eval_time_shift = False 40 | 41 | self.loader = DataLoader( 42 | self.dataset, 43 | sampler=self.sampler, 44 | batch_size=None, 45 | num_workers=self.num_workers, 46 | pin_memory=self.pin_memory, 47 | prefetch_factor=self.prefetch_factor 48 | ) 49 | 50 | def set_epoch(self, epoch): 51 | self.shuffler.set_epoch(epoch) 52 | 53 | def __iter__(self): 54 | #self.dataset.augmented = self.augmented 55 | #self.dataset.eval_time_shift = self.eval_time_shift 56 | self.shuffler.shuffle = self.shuffle 57 | return iter(self.loader) 58 | 59 | def __len__(self): 60 | return len(self.loader) 61 | -------------------------------------------------------------------------------- /datautil/musicdata.py: -------------------------------------------------------------------------------- 1 | import julius 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import multiprocessing as mp 6 | import time 7 | 8 | from datautil.audio import stream_audio 9 | 10 | import simpleutils 11 | 12 | class MusicDataset(torch.utils.data.Dataset): 13 | def __init__(self, file_list, params): 14 | self.params = params 15 | self.sample_rate = self.params['sample_rate'] 16 | self.segment_size = int(self.sample_rate * self.params['segment_size']) 17 | self.hop_size = int(self.sample_rate * self.params['hop_size']) 18 | self.frame_shift_mul = self.params['indexer'].get('frame_shift_mul', 1) 19 | self.files = simpleutils.read_file_list(file_list) 20 | 21 | def unsafe_getitem(self, index): 22 | logger = mp.get_logger() 23 | logger.info('MusicDataset getitem %s', self.files[index]) 24 | tm_0 = time.time() 25 | smprate = self.sample_rate 26 | 27 | # resample 28 | stm = stream_audio(self.files[index]) 29 | resampler = julius.ResampleFrac(stm.sample_rate, smprate) 30 | arr = [] 31 | n = 0 32 | total = 0 33 | minute = stm.sample_rate * 60 34 | second = stm.sample_rate 35 | new_min = smprate * 60 36 | new_sec = smprate 37 | strip_head = 0 38 | wav = [] 39 | 40 | tm_1 = time.time() 41 | tm_resample = 0.0 42 | tm_load = tm_1 - tm_0 43 | 44 | for b in stm.stream: 45 | tm_2 = time.time() 46 | tm_load += tm_2 - tm_1 47 | b = np.array(b).reshape([-1, stm.nchannels]) 48 | b = np.multiply(b, 1/32768, dtype=np.float32) 49 | arr.append(b) 50 | n += b.shape[0] 51 | total += b.shape[0] 52 | if n >= minute: 53 | arr = np.concatenate(arr) 54 | b = arr[:minute] 55 | out = torch.from_numpy(b.T) 56 | wav.append(resampler(out)[:, strip_head : new_min-new_sec//2]) 57 | arr = [arr[minute-second:].copy()] 58 | strip_head = new_sec//2 59 | n -= minute-second 60 | tm_1 = time.time() 61 | tm_resample += tm_1 - tm_2 62 | # resample tail part 63 | arr = np.concatenate(arr) 64 | out = torch.from_numpy(arr.T) 65 | wav.append(resampler(out)[:, strip_head : ]) 66 | wav = torch.cat(wav, dim=1) 67 | 68 | tm_2 = time.time() 69 | tm_resample += tm_2 - tm_1 70 | logger.info('load %.6fs resample %.6fs', tm_load, tm_resample) 71 | 72 | # stereo to mono 73 | # check if it is fake stereo 74 | if wav.shape[0] == 2: 75 | pow1 = ((wav[0] - wav[1])**2).mean() 76 | pow2 = ((wav[0] + wav[1])**2).mean() 77 | if pow1 > pow2 * 1000: 78 | logger.warning('fake stereo with opposite phase detected: %s', self.files[index]) 79 | wav[1] *= -1 80 | wav = wav.mean(dim=0) 81 | 82 | if wav.shape[0] < self.segment_size: 83 | # this "music" is too short and need to be extended 84 | wav = F.pad(wav, (0, self.segment_size - wav.shape[0])) 85 | 86 | # slice overlapping segments 87 | wav = wav.unfold(0, self.segment_size, self.hop_size//self.frame_shift_mul) 88 | wav = wav - wav.mean(dim=1).unsqueeze(1) 89 | 90 | tm_3 = time.time() 91 | logger.info('stereo to mono %.6fs', tm_3 - tm_2) 92 | 93 | return index, self.files[index], wav 94 | 95 | def __getitem__(self, index): 96 | try: 97 | return self.unsafe_getitem(index) 98 | except Exception as x: 99 | logger = mp.get_logger() 100 | logger.exception(x) 101 | return index, self.files[index], torch.zeros(0, self.segment_size) 102 | 103 | def __len__(self): 104 | return len(self.files) 105 | -------------------------------------------------------------------------------- /datautil/noise.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import warnings 4 | 5 | import tqdm 6 | import numpy as np 7 | import torch 8 | with warnings.catch_warnings(): 9 | warnings.simplefilter("ignore") 10 | import torchaudio 11 | 12 | from simpleutils import get_hash 13 | from datautil.audio import get_audio 14 | 15 | class NoiseData: 16 | def __init__(self, noise_dir, list_csv, sample_rate, cache_dir): 17 | print('loading noise dataset') 18 | hashes = [] 19 | with open(list_csv, 'r') as fin: 20 | reader = csv.reader(fin) 21 | noises = [] 22 | firstrow = next(reader) 23 | for row in reader: 24 | noises.append(row[0]) 25 | hashes.append(get_hash(row[0])) 26 | hash = get_hash(''.join(hashes)) 27 | #self.data = self.load_from_cache(list_csv, cache_dir, hash) 28 | #if self.data is not None: 29 | # print(self.data.shape) 30 | # return 31 | data = [] 32 | silence_threshold = 0 33 | self.names = [] 34 | for name in tqdm.tqdm(noises): 35 | smp, smprate = get_audio(os.path.join(noise_dir, name)) 36 | smp = torch.from_numpy(smp.astype(np.float32)) 37 | 38 | # convert to mono 39 | smp = smp.mean(dim=0) 40 | 41 | # strip silence start/end 42 | abs_smp = torch.abs(smp) 43 | if torch.max(abs_smp) <= silence_threshold: 44 | print('%s too silent' % name) 45 | continue 46 | has_sound = (abs_smp > silence_threshold).to(torch.int) 47 | start = int(torch.argmax(has_sound)) 48 | end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0))) 49 | smp = smp[max(start, 0) : end] 50 | 51 | resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) 52 | resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999) 53 | data.append(resampled) 54 | self.names.append(name) 55 | self.data = torch.cat(data) 56 | self.boundary = [0] + [x.shape[0] for x in data] 57 | self.boundary = torch.LongTensor(self.boundary).cumsum(0) 58 | del data 59 | #self.save_to_cache(list_csv, cache_dir, hash, self.data) 60 | print(self.data.shape) 61 | 62 | def load_from_cache(self, list_csv, cache_dir, hash): 63 | loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') 64 | loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') 65 | if os.path.exists(loc) and os.path.exists(loc2): 66 | with open(loc2, 'r') as fin: 67 | read_hash = fin.read().strip() 68 | if read_hash != hash: 69 | return None 70 | print('cache hit!') 71 | return torch.from_numpy(np.fromfile(loc, dtype=np.float32)) 72 | return None 73 | 74 | def save_to_cache(self, list_csv, cache_dir, hash, audio): 75 | os.makedirs(cache_dir, exist_ok=True) 76 | loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') 77 | loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') 78 | with open(loc2, 'w') as fout: 79 | fout.write(hash) 80 | print('save to cache') 81 | audio.numpy().tofile(loc) 82 | 83 | def random_choose(self, num, duration, out_name=False): 84 | indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long) 85 | out = torch.zeros([num, duration], dtype=torch.float32) 86 | for i in range(num): 87 | start = int(indices[i]) 88 | end = start + duration 89 | out[i] = self.data[start:end] 90 | name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1 91 | if out_name: 92 | return out, [self.names[x] for x in name_lookup] 93 | return out 94 | 95 | # x is a 2d array 96 | def add_noises(self, x, snr_min, snr_max, out_name=False): 97 | eps = 1e-12 98 | noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name) 99 | if out_name: 100 | noise, noise_name = noise 101 | vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt() 102 | vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt() 103 | snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max) 104 | ratio = vol_x / vol_noise 105 | ratio *= 10 ** -(snr / 20) 106 | x_aug = x + ratio.unsqueeze(1) * noise 107 | if out_name: 108 | return x_aug, noise_name, snr 109 | return x_aug 110 | -------------------------------------------------------------------------------- /datautil/preprocess.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.multiprocessing as mp 7 | from torch.utils.data import DataLoader, Dataset 8 | import torchaudio 9 | import tqdm 10 | 11 | from datautil.audio import get_audio 12 | 13 | class Preprocessor(Dataset): 14 | def __init__(self, files, dir, sample_rate): 15 | self.files = files 16 | self.dir = dir 17 | self.resampler = {} 18 | self.sample_rate = sample_rate 19 | 20 | def __getitem__(self, n): 21 | dat = get_audio(os.path.join(self.dir, self.files[n])) 22 | wav, smprate = dat 23 | if smprate not in self.resampler: 24 | self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate) 25 | wav = torch.Tensor(wav) 26 | wav = wav.mean(dim=0) 27 | wav = self.resampler[smprate](torch.Tensor(wav)) 28 | 29 | # quantize to 16 bit again 30 | wav *= 32768 31 | torch.clamp(wav, -32768, 32767, out=wav) 32 | wav = wav.to(torch.int16) 33 | return wav 34 | 35 | def __len__(self): 36 | return len(self.files) 37 | 38 | def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out): 39 | print('converting music to wav') 40 | with open(music_csv) as fin: 41 | reader = csv.reader(fin) 42 | next(reader) 43 | files = [row[0] for row in reader] 44 | 45 | preprocessor = Preprocessor(files, music_dir, sample_rate) 46 | loader = DataLoader(preprocessor, num_workers=4, batch_size=None) 47 | out_file = open(preprocess_out + '.bin', 'wb') 48 | song_lens = [] 49 | for wav in tqdm.tqdm(loader): 50 | # torch.set_num_threads(1) # default multithreading causes cpu contention 51 | 52 | wav = wav.numpy() 53 | out_file.write(wav.tobytes()) 54 | song_lens.append(wav.shape[0]) 55 | out_file.close() 56 | np.save(preprocess_out, np.array(song_lens, dtype=np.int64)) 57 | -------------------------------------------------------------------------------- /datautil/specaug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SpecAugment: 4 | def __init__(self, params): 5 | self.freq_min = params.get('cutout_min', 0.1) # 5 6 | self.freq_max = params.get('cutout_max', 0.5) # 20 7 | self.time_min = params.get('cutout_min', 0.1) # 5 8 | self.time_max = params.get('cutout_max', 0.5) # 16 9 | 10 | self.cutout_min = params.get('cutout_min', 0.1) # 0.1 11 | self.cutout_max = params.get('cutout_max', 0.5) # 0.4 12 | 13 | def get_mask(self, F, T): 14 | mask = torch.zeros(F, T) 15 | # cutout 16 | cutout_max = self.cutout_max 17 | cutout_min = self.cutout_min 18 | f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) 19 | f = int(f) 20 | f0 = torch.randint(0, F - f + 1, (1,)) 21 | t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) 22 | t = int(t) 23 | t0 = torch.randint(0, T - t + 1, (1,)) 24 | mask[f0:f0+f, t0:t0+t] = 1 25 | 26 | # frequency masking 27 | f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min)) 28 | f = int(f) 29 | f0 = torch.randint(0, F - f + 1, (1,)) 30 | mask[f0:f0+f, :] = 1 31 | 32 | # time masking 33 | t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min)) 34 | t = int(t) 35 | t0 = torch.randint(0, T - t + 1, (1,)) 36 | mask[:, t0:t0+t] = 1 37 | return mask 38 | 39 | def augment(self, x): 40 | mask = self.get_mask(x.shape[-2], x.shape[-1]) 41 | x = x * (1 - mask) 42 | return x 43 | -------------------------------------------------------------------------------- /denoise/createdataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import warnings 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | with warnings.catch_warnings(): 10 | warnings.simplefilter("ignore") 11 | import torchaudio 12 | import tqdm 13 | import scipy.signal 14 | 15 | import simpleutils 16 | from datautil.audio import get_audio 17 | from datautil.ir import AIR, MicIRP 18 | from datautil.noise import NoiseData 19 | 20 | def biquad_faster(waveform, b0, b1, b2, a0, a1, a2): 21 | waveform = waveform.numpy() 22 | b = np.array([b0, b1, b2], dtype=waveform.dtype) 23 | a = np.array([a0, a1, a2], dtype=waveform.dtype) 24 | return torch.from_numpy(scipy.signal.lfilter(b, a, waveform)) 25 | torchaudio.functional.biquad = biquad_faster 26 | 27 | class QueryGen(torch.utils.data.Dataset): 28 | def __init__(self, music_dir, music_list, noise, air, micirp, query_len, params): 29 | self.music_dir = music_dir 30 | self.music_list = music_list 31 | self.noise = noise 32 | self.air = air 33 | self.micirp = micirp 34 | self.query_len = query_len 35 | self.params = params 36 | self.pad_start = params['pad_start'] 37 | self.sample_rate = params['sample_rate'] 38 | 39 | def __getitem__(self, index): 40 | # load music 41 | name = self.music_list[index % len(self.music_list)] 42 | music, smprate = get_audio(os.path.join(self.music_dir, name)) 43 | 44 | # crop a music clip 45 | sel_smp = int(smprate * self.query_len) 46 | pad_smp = int(smprate * self.pad_start) 47 | hop_smp = int(smprate * self.params['hop_size']) 48 | if music.shape[1] > sel_smp: 49 | time_offset = torch.randint(low=0, high=music.shape[1]-sel_smp, size=(1,)) 50 | music = music[:, max(0,time_offset-pad_smp):time_offset+sel_smp] 51 | music = np.pad(music, ((0,0), (max(pad_smp-time_offset,0),0))) 52 | else: 53 | time_offset = 0 54 | music = np.pad(music, ((0,0), (pad_smp, sel_smp-music.shape[1]))) 55 | music = torch.from_numpy(music) 56 | 57 | # stereo to mono and resample 58 | music = music.mean(dim=0) 59 | music = torchaudio.transforms.Resample(smprate, self.sample_rate)(music) 60 | 61 | # fix size 62 | sel_smp = int(self.sample_rate * self.query_len) 63 | pad_smp = int(self.sample_rate * self.pad_start) 64 | if music.shape[0] > sel_smp+pad_smp: 65 | music = music[:sel_smp+pad_smp] 66 | else: 67 | music = F.pad(music, (0, sel_smp+pad_smp-music.shape[0])) 68 | 69 | # background mixing 70 | music -= music.mean() 71 | amp = torch.sqrt((music**2).mean()) 72 | snr_max = self.params['noise']['snr_max'] 73 | snr_min = self.params['noise']['snr_min'] 74 | snr = snr_min + torch.rand(1) * (snr_max - snr_min) 75 | if self.noise: 76 | noise = self.noise.random_choose(1, music.shape[0])[0] 77 | noise_amp = torch.sqrt((noise**2).mean()) 78 | noise = noise * (amp / noise_amp * torch.pow(10, -0.05*snr)) 79 | else: 80 | noise = torch.normal(mean=torch.zeros_like(music), std=(amp*torch.pow(10, -0.05*snr))) 81 | 82 | # IR filters 83 | music_freq = torch.fft.rfft(music, self.params['fftconv_n']) 84 | noise_freq = torch.fft.rfft(noise, self.params['fftconv_n']) 85 | if self.air: 86 | aira, reverb = self.air.random_choose_name() 87 | music_freq *= aira 88 | noise_freq *= aira 89 | if self.micirp: 90 | micirp = self.micirp.random_choose(1)[0] 91 | music_freq *= micirp 92 | noise_freq *= micirp 93 | music = torch.fft.irfft(music_freq, self.params['fftconv_n']) 94 | music = music[pad_smp:pad_smp+sel_smp] 95 | noise = torch.fft.irfft(noise_freq, self.params['fftconv_n']) 96 | noise = noise[pad_smp:pad_smp+sel_smp] 97 | mix = music + noise 98 | 99 | # normalize volume 100 | vol = max(torch.max(torch.abs(mix)), torch.max(torch.abs(music)), torch.max(torch.abs(noise))) 101 | music /= vol 102 | noise /= vol 103 | mix /= vol 104 | 105 | return name, music, noise, mix 106 | 107 | def __len__(self): 108 | return len(self.music_list) 109 | 110 | def gen_for(train_val, args, params): 111 | sample_rate = params['sample_rate'] 112 | 113 | if args.noise: 114 | noise = NoiseData(noise_dir=args.noise, 115 | list_csv=params['noise'][train_val], 116 | sample_rate=sample_rate, cache_dir=params['cache_dir']) 117 | else: 118 | noise = None 119 | 120 | if args.air: 121 | air = AIR(air_dir=args.air, 122 | list_csv=params['air'][train_val], 123 | length=params['air']['length'], 124 | fftconv_n=params['fftconv_n'], sample_rate=sample_rate) 125 | else: 126 | air = None 127 | 128 | if args.micirp: 129 | micirp = MicIRP(mic_dir=args.micirp, 130 | list_csv=params['micirp'][train_val], 131 | length=params['micirp']['length'], 132 | fftconv_n=params['fftconv_n'], sample_rate=sample_rate) 133 | else: 134 | micirp = None 135 | 136 | with open(params[train_val+'_csv'], 'r') as fin: 137 | music_list = [] 138 | reader = csv.reader(fin) 139 | next(reader) 140 | for line in reader: 141 | music_list.append(line[0]) 142 | 143 | gen = QueryGen(args.data, music_list, noise, air, micirp, args.length, params) 144 | runall = torch.utils.data.DataLoader( 145 | dataset=gen, 146 | num_workers=3 147 | ) 148 | os.makedirs(args.out, exist_ok=True) 149 | fout = open(os.path.join(args.out, 'denoise_'+train_val+'.csv'), 'w', encoding='utf8', newline='\n') 150 | writer = csv.writer(fout) 151 | writer.writerow(['mix_path', 'music_path', 'noise_path', 'duration']) 152 | os.makedirs(os.path.join(args.out, 'music'), exist_ok=True) 153 | os.makedirs(os.path.join(args.out, 'mix'), exist_ok=True) 154 | os.makedirs(os.path.join(args.out, 'noise'), exist_ok=True) 155 | for i, (name,music,noise,mix) in enumerate(tqdm.tqdm(runall)): 156 | name = os.path.split(name[0])[1] 157 | name = os.path.splitext(name)[0] + '.wav' 158 | writer.writerow(['music/'+name, 'mix/'+name, 'noise/'+name, float(args.length)]) 159 | 160 | torchaudio.save(os.path.join(args.out, 'music', name), music, gen.sample_rate) 161 | torchaudio.save(os.path.join(args.out, 'mix', name), mix, gen.sample_rate) 162 | torchaudio.save(os.path.join(args.out, 'noise', name), noise, gen.sample_rate) 163 | fout.close() 164 | 165 | if __name__ == '__main__': 166 | # don't delete this line, because my data loader uses queues 167 | torch.multiprocessing.set_start_method('spawn') 168 | args = argparse.ArgumentParser() 169 | args.add_argument('-d', '--data', required=True) 170 | args.add_argument('--noise') 171 | args.add_argument('--air') 172 | args.add_argument('--micirp') 173 | args.add_argument('-p', '--params', default='configs/default.json') 174 | args.add_argument('-l', '--length', type=float, default=30) 175 | args.add_argument('-o', '--out', required=True) 176 | args = args.parse_args() 177 | 178 | params = simpleutils.read_config(args.params) 179 | sample_rate = params['sample_rate'] 180 | win = (params['pad_start'] + args.length + params['air']['length'] + params['micirp']['length']) * sample_rate 181 | train_val = 'validate' 182 | fftconv_n = 2048 183 | while fftconv_n < win: 184 | fftconv_n *= 2 185 | params['fftconv_n'] = fftconv_n 186 | gen_for('train', args, params) 187 | gen_for('validate', args, params) 188 | -------------------------------------------------------------------------------- /ensemble/drawheatmap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import math 4 | 5 | import matplotlib.pyplot as plt 6 | import seaborn 7 | 8 | args = argparse.ArgumentParser() 9 | args.add_argument('file') 10 | args = args.parse_args() 11 | 12 | with open(args.file) as fin: 13 | reader = csv.reader(fin) 14 | col_names = [float(x) for x in next(reader)[1:]] 15 | row_names = [] 16 | data = [] 17 | for row in reader: 18 | row_names.append(float(row[0])) 19 | data.append([float(x) for x in row[1:]]) 20 | 21 | col_names = ['$10^{%d}$' % math.log10(x) for x in col_names] 22 | row_names = ['$10^{%d}$' % math.log10(x) for x in row_names] 23 | 24 | seaborn.set(font_scale=0.5) 25 | 26 | seaborn.heatmap(data, annot=True, xticklabels=col_names, yticklabels=row_names, fmt='.4f', cmap='viridis') 27 | plt.xlabel('gamma') 28 | plt.ylabel('C') 29 | plt.savefig(args.file + '.pdf') 30 | plt.show() 31 | -------------------------------------------------------------------------------- /ensemble/drawheatmap2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import math 4 | 5 | import matplotlib 6 | matplotlib.rcParams['font.family'] = ['Heiti TC'] 7 | 8 | import matplotlib.pyplot as plt 9 | import seaborn 10 | 11 | args = argparse.ArgumentParser() 12 | args.add_argument('file') 13 | args = args.parse_args() 14 | 15 | with open(args.file) as fin: 16 | reader = csv.reader(fin) 17 | col_names = [x for x in next(reader)[1:]] 18 | row_names = [] 19 | data = [] 20 | for row in reader: 21 | row_names.append(float(row[0])) 22 | data.append([float(x) for x in row[1:]]) 23 | 24 | col_names = ['FMA\n-6dB', 'FMA\n-4dB', 'FMA\n-2dB', 'FMA\n0dB', 'FMA\n2dB', 'FMA\n4dB', 'FMA\n6dB', 'FMA\n8dB', 'FMA\n-10~10dB', 'MIREX'] 25 | row_names = ['$10^{%d}$' % math.log10(x) for x in row_names] 26 | 27 | seaborn.set(font="Heiti TC", font_scale=0.5) 28 | 29 | seaborn.heatmap(data, annot=True, xticklabels=col_names, yticklabels=row_names, fmt='.4f', cmap='viridis') 30 | plt.xlabel('驗證資料集') 31 | plt.ylabel('C') 32 | plt.savefig(args.file + '.pdf') 33 | plt.show() 34 | -------------------------------------------------------------------------------- /ensemble/extractscore.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import argparse 3 | import os 4 | import numpy as np 5 | 6 | args = argparse.ArgumentParser() 7 | args.add_argument('songlist') 8 | args.add_argument('groundtruth') 9 | args.add_argument('predict') 10 | args.add_argument('out') 11 | args = args.parse_args() 12 | 13 | def extract_ans_txt(file): 14 | with open(file, 'r') as fin: 15 | out = [] 16 | for line in fin: 17 | if line.endswith('\n'): line = line[:-1] 18 | query, ans = line.split('\t') 19 | my_query = os.path.splitext(os.path.split(query)[1])[0] 20 | my_ans = os.path.splitext(os.path.split(ans)[1])[0] 21 | out.append((my_query, my_ans)) 22 | return out 23 | 24 | def extract_ans_csv(file): 25 | with open(file, 'r') as fin: 26 | out = [] 27 | reader = csv.reader(fin) 28 | next(reader) 29 | for line in reader: 30 | query, ans = line[:2] 31 | my_query = os.path.splitext(os.path.split(query)[1])[0] 32 | my_ans = os.path.splitext(os.path.split(ans)[1])[0] 33 | if my_query in out: 34 | print('Warning! query %s occured twice' % query) 35 | out.append((my_query, my_ans)) 36 | return out 37 | 38 | def extract_ans(file): 39 | if file.endswith('.csv'): 40 | return extract_ans_csv(file) 41 | return extract_ans_txt(file) 42 | 43 | GT = dict(extract_ans(args.groundtruth)) 44 | PR = extract_ans(args.predict) 45 | 46 | with open(args.songlist) as fin: 47 | song_list = [] 48 | song_ids = {} 49 | for i, line in enumerate(fin): 50 | if line.endswith('\n'): line = line[:-1] 51 | line = os.path.splitext(os.path.split(line)[1])[0] 52 | song_list.append(line) 53 | song_ids[line] = i 54 | 55 | sco_bin = np.fromfile(args.predict+'.bin', dtype=np.float32) 56 | sco_bin = sco_bin.reshape([-1, len(song_list), 2]) 57 | 58 | scores = [] 59 | for i in range(len(PR)): 60 | query, ans = PR[i] 61 | if query in GT: 62 | real_ans = GT[query] 63 | sco = sco_bin[i, song_ids[ans], 0] 64 | scores.append((sco, ans==real_ans)) 65 | else: 66 | print('query %s in prediction file not found!!' % query) 67 | print('ARE YOU KIDDING ME?') 68 | exit(1) 69 | scores = np.array(scores, dtype=np.float32) 70 | np.save(args.out, scores) 71 | -------------------------------------------------------------------------------- /ensemble/lmscore.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import argparse 3 | import os 4 | import numpy as np 5 | 6 | args = argparse.ArgumentParser() 7 | args.add_argument('songlist') 8 | args.add_argument('groundtruth') 9 | args.add_argument('predict') 10 | args.add_argument('out') 11 | args = args.parse_args() 12 | 13 | def extract_ans_txt(file): 14 | with open(file, 'r') as fin: 15 | out = [] 16 | for line in fin: 17 | if line.endswith('\n'): line = line[:-1] 18 | query, ans = line.split('\t') 19 | my_query = os.path.splitext(os.path.split(query)[1])[0] 20 | my_ans = os.path.splitext(os.path.split(ans)[1])[0] 21 | out.append((my_query, my_ans)) 22 | return out 23 | 24 | def extract_ans_csv(file): 25 | with open(file, 'r') as fin: 26 | out = [] 27 | reader = csv.reader(fin) 28 | next(reader) 29 | for line in reader: 30 | query, ans = line[:2] 31 | my_query = os.path.splitext(os.path.split(query)[1])[0] 32 | my_ans = os.path.splitext(os.path.split(ans)[1])[0] 33 | if my_query in out: 34 | print('Warning! query %s occured twice' % query) 35 | out.append((my_query, my_ans)) 36 | return out 37 | 38 | def extract_ans(file): 39 | if file.endswith('.csv'): 40 | return extract_ans_csv(file) 41 | return extract_ans_txt(file) 42 | 43 | GT = dict(extract_ans(args.groundtruth)) 44 | PR = extract_ans(args.predict) 45 | 46 | with open(args.songlist) as fin: 47 | song_list = [] 48 | song_ids = {} 49 | for i, line in enumerate(fin): 50 | if line.endswith('\n'): line = line[:-1] 51 | line = os.path.splitext(os.path.split(line)[1])[0] 52 | song_list.append(line) 53 | song_ids[line] = i 54 | 55 | sco_bin = np.fromfile(args.predict+'.bin', dtype=np.int32) 56 | sco_bin = sco_bin.reshape([-1, len(song_list), 2]) 57 | 58 | scores = [] 59 | for i in range(len(PR)): 60 | query, ans = PR[i] 61 | if query in GT: 62 | real_ans = GT[query] 63 | sco = sco_bin[i, song_ids[ans], 1] 64 | scores.append((sco, ans==real_ans)) 65 | else: 66 | print('query %s in prediction file not found!!' % query) 67 | print('ARE YOU KIDDING ME?') 68 | exit(1) 69 | scores = np.array(scores, dtype=np.float32) 70 | np.save(args.out, scores) 71 | -------------------------------------------------------------------------------- /ensemble/svmdraw.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import numpy as np 5 | from sklearn.svm import SVC 6 | import matplotlib.pyplot as plt 7 | 8 | args = argparse.ArgumentParser() 9 | args.add_argument('lm_npy') 10 | args.add_argument('nn_npy') 11 | args.add_argument('--svm') 12 | args.add_argument('--out') 13 | args = args.parse_args() 14 | 15 | a = np.load(args.nn_npy) 16 | b = np.load(args.lm_npy) 17 | select = a[:,1] + b[:,1] == 1 18 | x = np.stack([a[select,0], b[select,0]], axis=1) 19 | y = a[select,1] 20 | x2 = np.stack([a[:,0], b[:,0]], axis=1) 21 | print('nn wins', np.sum(y==1)) 22 | print('landmark wins', np.sum(y==0)) 23 | 24 | xx = np.linspace(0, 0.8, 200) 25 | yy = np.linspace(0, 0.025, 200) 26 | xx, yy = np.meshgrid(xx, yy) 27 | if args.svm: 28 | with open(args.svm, 'rb') as fin: 29 | model = pickle.load(fin) 30 | pred = model.predict(x2) 31 | ok = np.where(pred, a[:,1], b[:,1]) 32 | acc = np.mean(ok) 33 | print('acc=%.4f' % acc) 34 | 35 | Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) 36 | plt.contourf(xx, yy, Z.reshape(200, 200), np.arange(101)*0.01) 37 | print('nn score too big', x[np.where(x[:,0] > 0.8)]) 38 | print('lm score too big', x[np.where(x[:,1] > 0.025)], y[np.where(x[:,1] > 0.025)]) 39 | plt.scatter(x[y==0, 0], x[y==0, 1]) 40 | plt.scatter(x[y==1, 0], x[y==1, 1]) 41 | plt.xlabel('neural network score') 42 | plt.xlim(0, 0.8) 43 | plt.ylabel('landmark score') 44 | plt.ylim(0, 0.025) 45 | if args.out: 46 | plt.savefig(args.out) 47 | plt.show() 48 | -------------------------------------------------------------------------------- /ensemble/svmheatmap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | from sklearn.svm import SVC 8 | 9 | args = argparse.ArgumentParser() 10 | args.add_argument('lm_npy') 11 | args.add_argument('nn_npy') 12 | args.add_argument('svms') 13 | args.add_argument('out') 14 | args = args.parse_args() 15 | 16 | a = np.load(args.nn_npy) 17 | b = np.load(args.lm_npy) 18 | select = a[:,1] + b[:,1] == 1 19 | x = np.stack([a[select,0], b[select,0]], axis=1) 20 | y = a[select,1] 21 | #print('nn wins', np.sum(y==1)) 22 | #print('landmark wins', np.sum(y==0)) 23 | 24 | x2 = np.stack([a[:,0], b[:,0]], axis=1) 25 | 26 | gammas = ['1e-09','1e-08','1e-07','1e-06','1e-05','0.0001','0.001','0.01','0.1','1','10','100','1000'] 27 | 28 | dats = [['']+gammas] 29 | for C in ['0.01','0.1'] + [str(10**x) for x in range(0,11)]: 30 | dats.append([C]) 31 | for gamma in gammas: 32 | svm = 'rbf_C' + C + '_gamma' + gamma + '.pkl' 33 | with open(os.path.join(args.svms, svm), 'rb') as fin: 34 | model = pickle.load(fin) 35 | pred = model.predict(x2) 36 | ok = np.where(pred, a[:,1], b[:,1]) 37 | acc = np.mean(ok) 38 | dats[-1].append(acc) 39 | #print('%s acc=%.4f' % (svm, acc)) 40 | with open(args.out, 'w', newline='\n') as fout: 41 | writer = csv.writer(fout) 42 | writer.writerows(dats) 43 | -------------------------------------------------------------------------------- /ensemble/svmheatmap2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | from sklearn.svm import SVC 8 | 9 | args = argparse.ArgumentParser() 10 | args.add_argument('lm_npy') 11 | args.add_argument('nn_npy') 12 | args.add_argument('svms') 13 | args.add_argument('out') 14 | args = args.parse_args() 15 | 16 | snrs = ['out2_snr-6', 'out2_snr-4', 'out2_snr-2', 'out2_snr0', 'out2_snr2', 'out2_snr4', 'out2_snr6', 'out2_snr8', 'out2', 'mirex'] 17 | 18 | dats = [['C']+snrs] 19 | for C in ['0.01','0.1'] + [str(10**x) for x in range(0,11)]: 20 | dats.append([C]) 21 | for snr in snrs: 22 | svm = 'lin_C' + C + '.pkl' 23 | a = np.load(args.nn_npy + snr + '.npy') 24 | b = np.load(args.lm_npy + snr + '.npy') 25 | select = a[:,1] + b[:,1] == 1 26 | x2 = np.stack([a[:,0], b[:,0]], axis=1) 27 | with open(os.path.join(args.svms, svm), 'rb') as fin: 28 | model = pickle.load(fin) 29 | pred = model.predict(x2) 30 | ok = np.where(pred, a[:,1], b[:,1])[select] 31 | acc = np.mean(ok) 32 | dats[-1].append(acc) 33 | #print('%s acc=%.4f' % (svm, acc)) 34 | with open(args.out, 'w', newline='\n') as fout: 35 | writer = csv.writer(fout) 36 | writer.writerows(dats) 37 | -------------------------------------------------------------------------------- /ensemble/svmtrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import numpy as np 6 | from sklearn.svm import SVC 7 | 8 | args = argparse.ArgumentParser() 9 | args.add_argument('lm_npy') 10 | args.add_argument('nn_npy') 11 | args.add_argument('out') 12 | args = args.parse_args() 13 | 14 | a = np.load(args.nn_npy) 15 | b = np.load(args.lm_npy) 16 | select = a[:,1] + b[:,1] == 1 17 | x = np.stack([a[select,0], b[select,0]], axis=1) 18 | y = a[select,1] 19 | print('nn wins', np.sum(y==1)) 20 | print('landmark wins', np.sum(y==0)) 21 | 22 | print('Linear SVM') 23 | for C in range(-2, 11): 24 | model = SVC(C=10**C, kernel='linear') 25 | model.fit(x, y) 26 | acc = np.mean(model.predict(x) == y) 27 | print('C={} train acc={:.4f}'.format(10**C, acc)) 28 | with open(os.path.join(args.out, 'lin_C{}.pkl'.format(10**C)), 'wb') as fout: 29 | pickle.dump(model, fout) 30 | 31 | print('RBF SVM') 32 | for C in range(-2, 11): 33 | for gamma in range(-9, 4): 34 | model = SVC(C=10**C, kernel='rbf', gamma=10**gamma) 35 | model.fit(x, y) 36 | acc = np.mean(model.predict(x) == y) 37 | print('C={} gamma={} train acc={:.4f}'.format(10**C, 10**gamma, acc)) 38 | with open(os.path.join(args.out, 'rbf_C{}_gamma{}.pkl'.format(10**C, 10**gamma)), 'wb') as fout: 39 | pickle.dump(model, fout) 40 | -------------------------------------------------------------------------------- /ensemble/svmval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import numpy as np 6 | from sklearn.svm import SVC 7 | 8 | args = argparse.ArgumentParser() 9 | args.add_argument('lm_npy') 10 | args.add_argument('nn_npy') 11 | args.add_argument('svms') 12 | args = args.parse_args() 13 | 14 | a = np.load(args.nn_npy) 15 | b = np.load(args.lm_npy) 16 | select = a[:,1] + b[:,1] == 1 17 | x = np.stack([a[select,0], b[select,0]], axis=1) 18 | y = a[select,1] 19 | print('nn wins', np.sum(y==1)) 20 | print('landmark wins', np.sum(y==0)) 21 | 22 | x2 = np.stack([a[:,0], b[:,0]], axis=1) 23 | 24 | for svm in sorted(os.listdir(args.svms)): 25 | if svm.endswith('.pkl'): 26 | with open(os.path.join(args.svms, svm), 'rb') as fin: 27 | model = pickle.load(fin) 28 | pred = model.predict(x2) 29 | ok = np.where(pred, a[:,1], b[:,1]) 30 | acc = np.mean(ok) 31 | print('%s acc=%.4f' % (svm, acc)) 32 | -------------------------------------------------------------------------------- /extractemb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.multiprocessing as mp 10 | import tqdm 11 | 12 | import simpleutils 13 | from model import FpNetwork 14 | from datautil.melspec import build_mel_spec_layer 15 | from datautil.musicdata import MusicDataset 16 | 17 | if __name__ == "__main__": 18 | logger_init = simpleutils.MultiProcessInitLogger('nnextract') 19 | logger_init() 20 | 21 | mp.set_start_method('spawn') 22 | if len(sys.argv) < 4: 23 | print('Usage: python %s ' % sys.argv[0]) 24 | sys.exit() 25 | file_list_for_query = sys.argv[1] 26 | dir_for_db = sys.argv[2] 27 | out_embed_dir = sys.argv[3] 28 | configs = os.path.join(dir_for_db, 'configs.json') 29 | params = simpleutils.read_config(configs) 30 | 31 | d = params['model']['d'] 32 | h = params['model']['h'] 33 | u = params['model']['u'] 34 | F_bin = params['n_mels'] 35 | segn = int(params['segment_size'] * params['sample_rate']) 36 | T = (segn + params['stft_hop'] - 1) // params['stft_hop'] 37 | 38 | frame_shift_mul = params['indexer'].get('frame_shift_mul', 1) 39 | 40 | print('loading model...') 41 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 42 | model = FpNetwork(d, h, u, F_bin, T, params['model']).to(device) 43 | model.load_state_dict(torch.load(os.path.join(dir_for_db, 'model.pt'), map_location=device)) 44 | print('model loaded') 45 | 46 | # doing inference, turn off gradient 47 | model.eval() 48 | for param in model.parameters(): 49 | param.requires_grad = False 50 | 51 | dataset = MusicDataset(file_list_for_query, params) 52 | loader = DataLoader(dataset, num_workers=4, batch_size=None, worker_init_fn=logger_init) 53 | 54 | mel = build_mel_spec_layer(params).to(device) 55 | 56 | os.makedirs(out_embed_dir, exist_ok=True) 57 | embeddings_file = open(os.path.join(out_embed_dir, 'query_embeddings'), 'wb') 58 | query_idx = open(os.path.join(out_embed_dir, 'query_index'), 'wb') 59 | tm_0 = time.time() 60 | idx_pos = 0 61 | for dat in tqdm.tqdm(loader): 62 | logger = mp.get_logger() 63 | i, name, wav = dat 64 | logger.info('get query %s', name) 65 | tm_1 = time.time() 66 | i = int(i) # i is leaking file handles! 67 | 68 | if wav.shape[0] == 0: 69 | # load file error! 70 | logger.error('load %s error!', name) 71 | 72 | query_idx.write(np.array([idx_pos, 0], dtype=np.int64)) 73 | continue 74 | 75 | idx_start = idx_pos 76 | # batch size should be less than 20 because query contains at most 19 segments 77 | for batch in torch.split(wav, 16): 78 | g = batch.to(device) 79 | 80 | # Mel spectrogram 81 | g = mel(g) 82 | z = model(g).cpu() 83 | embeddings_file.write(z.numpy().tobytes()) 84 | idx_pos += z.shape[0] 85 | query_idx.write(np.array([idx_start, idx_pos - idx_start], dtype=np.int64)) 86 | 87 | tm_2 = time.time() 88 | logger.info('compute embedding %.6fs', tm_2 - tm_1) 89 | embeddings_file.flush() 90 | print('total', idx_pos, 'embeddings') 91 | shutil.copyfile(file_list_for_query, os.path.join(out_embed_dir, 'queryList.txt')) 92 | 93 | # write settings 94 | shutil.copyfile(configs, os.path.join(out_embed_dir, 'configs.json')) 95 | 96 | logger.info('total extract time %.6fs', time.time() - tm_0) 97 | -------------------------------------------------------------------------------- /genall.sh: -------------------------------------------------------------------------------- 1 | for snr in -6 -4 -2 0 2 4 6 8 2 | do 3 | python genquery.py --params configs/gentest_snr$snr.json --len 10 --num 2000 --mode test --out out/queries/out2_snr$snr 4 | done 5 | -------------------------------------------------------------------------------- /genquery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import math 4 | import os 5 | import warnings 6 | import json 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | with warnings.catch_warnings(): 12 | warnings.simplefilter("ignore") 13 | import torchaudio 14 | import tqdm 15 | import scipy.signal 16 | 17 | import simpleutils 18 | from datautil.audio import get_audio 19 | from datautil.ir import AIR, MicIRP 20 | from datautil.noise import NoiseData 21 | 22 | def biquad_faster(waveform, b0, b1, b2, a0, a1, a2): 23 | waveform = waveform.numpy() 24 | b = np.array([b0, b1, b2], dtype=waveform.dtype) 25 | a = np.array([a0, a1, a2], dtype=waveform.dtype) 26 | return torch.from_numpy(scipy.signal.lfilter(b, a, waveform)) 27 | torchaudio.functional.biquad = biquad_faster 28 | 29 | class QueryGen(torch.utils.data.Dataset): 30 | def __init__(self, music_dir, music_list, noise, air, micirp, query_len, num_queries, params): 31 | self.music_dir = music_dir 32 | self.music_list = music_list 33 | self.noise = noise 34 | self.air = air 35 | self.micirp = micirp 36 | self.query_len = query_len 37 | self.num_queries = num_queries 38 | self.params = params 39 | self.pad_start = params['pad_start'] 40 | self.sample_rate = params['sample_rate'] 41 | 42 | def __getitem__(self, index): 43 | torch.manual_seed(9000 + index) 44 | # load music 45 | name = self.music_list[index % len(self.music_list)] 46 | audio, smprate = get_audio(os.path.join(self.music_dir, name)) 47 | 48 | # crop a music clip 49 | sel_smp = int(smprate * self.query_len) 50 | pad_smp = int(smprate * self.pad_start) 51 | hop_smp = int(smprate * self.params['hop_size']) 52 | if audio.shape[1] >= sel_smp: 53 | time_offset = torch.randint(low=0, high=audio.shape[1]-sel_smp, size=(1,)).item() 54 | audio = audio[:, max(0,time_offset-pad_smp):time_offset+sel_smp] 55 | audio = np.pad(audio, ((0,0), (max(pad_smp-time_offset,0),0))) 56 | else: 57 | time_offset = 0 58 | audio = np.pad(audio, ((0,0), (pad_smp, sel_smp-audio.shape[1]))) 59 | audio = torch.from_numpy(audio.astype(np.float32)) 60 | 61 | # stereo to mono and resample 62 | audio = audio.mean(dim=0) 63 | audio = torchaudio.transforms.Resample(smprate, self.sample_rate)(audio) 64 | 65 | # fix size 66 | sel_smp = int(self.sample_rate * self.query_len) 67 | pad_smp = int(self.sample_rate * self.pad_start) 68 | if audio.shape[0] > sel_smp+pad_smp: 69 | audio = audio[:sel_smp+pad_smp] 70 | else: 71 | audio = F.pad(audio, (0, sel_smp+pad_smp-audio.shape[0])) 72 | 73 | # background mixing 74 | snr_max = self.params['noise']['snr_max'] 75 | snr_min = self.params['noise']['snr_min'] 76 | if self.noise: 77 | audio, noise, snr = self.noise.add_noises(audio.unsqueeze(0), snr_min, snr_max, out_name=True) 78 | audio = audio[0] 79 | noise = noise[0] 80 | snr = snr.item() 81 | 82 | # IR filters 83 | audio_freq = torch.fft.rfft(audio, self.params['fftconv_n']) 84 | reverb = '' 85 | if self.air: 86 | aira, reverb = self.air.random_choose_name() 87 | audio_freq *= aira 88 | if self.micirp: 89 | audio_freq *= self.micirp.random_choose(1)[0] 90 | audio = torch.fft.irfft(audio_freq, self.params['fftconv_n']) 91 | audio = audio[pad_smp:pad_smp+sel_smp] 92 | 93 | # normalize volume 94 | audio = F.normalize(audio, p=np.inf, dim=0) 95 | 96 | return name, time_offset/smprate, audio, snr, reverb 97 | 98 | def __len__(self): 99 | return self.num_queries 100 | 101 | if __name__ == '__main__': 102 | # don't delete this line, because my data loader uses queues 103 | torch.multiprocessing.set_start_method('spawn') 104 | args = argparse.ArgumentParser() 105 | args.add_argument('-p', '--params', default='configs/default.json') 106 | args.add_argument('-l', '--length', type=float, default=1) 107 | args.add_argument('--num', type=int, default=10) 108 | args.add_argument('--mode', default='test', choices=['train', 'validate', 'test']) 109 | args.add_argument('-o', '--out', required=True) 110 | args = args.parse_args() 111 | 112 | # warn user (actually just me!) if query files exist 113 | if os.path.exists(args.out): 114 | yesno = input('Folder %s exists, overwrite anyway? (y/n) ' % args.out) 115 | while yesno not in {'y', 'n'}: 116 | yesno = input('Please enter y or n: ') 117 | if yesno == 'n': 118 | exit() 119 | 120 | params = simpleutils.read_config(args.params) 121 | train_val = 'validate' if args.mode == 'test' else args.mode 122 | train_val_test = args.mode 123 | sample_rate = params['sample_rate'] 124 | win = (params['pad_start'] + args.length + params['air']['length'] + params['micirp']['length']) * sample_rate 125 | fftconv_n = 2048 126 | while fftconv_n < win: 127 | fftconv_n *= 2 128 | params['fftconv_n'] = fftconv_n 129 | 130 | noise = NoiseData(noise_dir=params['noise']['dir'], 131 | list_csv=params['noise'][train_val], 132 | sample_rate=sample_rate, cache_dir=params['cache_dir']) 133 | 134 | air = AIR(air_dir=params['air']['dir'], 135 | list_csv=params['air'][train_val], 136 | length=params['air']['length'], 137 | fftconv_n=params['fftconv_n'], sample_rate=sample_rate) 138 | 139 | micirp = MicIRP(mic_dir=params['micirp']['dir'], 140 | list_csv=params['micirp'][train_val], 141 | length=params['micirp']['length'], 142 | fftconv_n=params['fftconv_n'], sample_rate=sample_rate) 143 | 144 | music_list = simpleutils.read_file_list(params[train_val_test + '_csv']) 145 | 146 | gen = QueryGen(params['music_dir'], music_list, noise, air, micirp, args.length, args.num, params) 147 | runall = torch.utils.data.DataLoader( 148 | dataset=gen, 149 | num_workers=3, 150 | batch_size=None 151 | ) 152 | os.makedirs(args.out, exist_ok=True) 153 | fout = open(os.path.join(args.out, 'expected.csv'), 'w', encoding='utf8', newline='\n') 154 | fout2 = open(os.path.join(args.out, 'list.txt'), 'w', encoding='utf8') 155 | writer = csv.writer(fout) 156 | writer.writerow(['query', 'answer', 'time', 'snr', 'reverb']) 157 | for i, (name,time_offset,sound,snr,reverb) in enumerate(tqdm.tqdm(runall)): 158 | safe_name = os.path.splitext(os.path.split(name)[1])[0] 159 | out_name = 'q%04d_%s_snr%d_%s.wav' % (i+1, safe_name, math.floor(snr), reverb) 160 | writer.writerow([out_name, name, time_offset, snr, reverb]) 161 | path = os.path.join(args.out, out_name) 162 | torchaudio.save(path, sound.unsqueeze(0), gen.sample_rate, encoding='PCM_S', bits_per_sample=16) 163 | fout2.write(path + '\n') 164 | fout.close() 165 | fout2.close() 166 | params['genquery'] = {'mode': train_val_test, 'length': args.length} 167 | with open(os.path.join(args.out, 'configs.json'), 'w') as fout: 168 | json.dump(params, fout, indent=2) 169 | -------------------------------------------------------------------------------- /genquery_naf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import warnings 5 | import json 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | with warnings.catch_warnings(): 11 | warnings.simplefilter("ignore") 12 | import torchaudio 13 | import tqdm 14 | import scipy.signal 15 | 16 | import simpleutils 17 | from datautil.audio import get_audio 18 | from datautil.ir import AIR, MicIRP 19 | from datautil.noise import NoiseData 20 | 21 | def biquad_faster(waveform, b0, b1, b2, a0, a1, a2): 22 | waveform = waveform.numpy() 23 | b = np.array([b0, b1, b2], dtype=waveform.dtype) 24 | a = np.array([a0, a1, a2], dtype=waveform.dtype) 25 | return torch.from_numpy(scipy.signal.lfilter(b, a, waveform)) 26 | torchaudio.functional.biquad = biquad_faster 27 | 28 | class QueryGen(torch.utils.data.Dataset): 29 | def __init__(self, music_dir, music_list, noise, air, micirp, query_len, num_queries, params): 30 | self.music_dir = music_dir 31 | self.music_list = music_list 32 | self.noise = noise 33 | self.air = air 34 | self.micirp = micirp 35 | self.query_len = query_len 36 | self.num_queries = num_queries 37 | self.params = params 38 | self.pad_start = params['pad_start'] 39 | self.sample_rate = params['sample_rate'] 40 | 41 | def __getitem__(self, index): 42 | torch.manual_seed(9000 + index) 43 | # load music 44 | name = self.music_list[index % len(self.music_list)] 45 | audio, smprate = get_audio(os.path.join(self.music_dir, name)) 46 | audio = torch.from_numpy(audio.astype(np.float32)) 47 | 48 | # stereo to mono and resample 49 | audio = audio.mean(dim=0) 50 | audio = torchaudio.transforms.Resample(smprate, self.sample_rate)(audio) 51 | 52 | # random crop 53 | sel_smp = int(self.sample_rate * self.params['segment_size']) 54 | total_segs = max(int(audio.shape[0] / sel_smp), 1) 55 | shift_smp = int(self.sample_rate * self.params['time_offset']) - sel_smp 56 | crop_pos = torch.randint(low=-shift_smp, high=shift_smp+1, size=[total_segs]) 57 | segs = [] 58 | for i in range(total_segs): 59 | offset = crop_pos[i] + sel_smp * i 60 | seg = audio[max(0,offset) : max(0,offset+sel_smp)] 61 | seg = F.pad(seg, ( 62 | max(0, -offset), 63 | max(0, (offset+sel_smp)-audio.shape[0]) 64 | )) 65 | segs.append(seg) 66 | audio = torch.stack(segs, 0) 67 | 68 | # background mixing 69 | audio -= audio.mean(1, keepdim=True) 70 | snr_max = self.params['noise']['snr_max'] 71 | snr_min = self.params['noise']['snr_min'] 72 | audio = self.noise.add_noises(audio, snr_min, snr_max) 73 | 74 | # IR filters 75 | audio_freq = torch.fft.rfft(audio, self.params['fftconv_n']) 76 | if self.air: 77 | audio_freq *= self.air.random_choose(audio_freq.shape[0]) 78 | if self.micirp: 79 | audio_freq *= self.micirp.random_choose(audio_freq.shape[0]) 80 | audio = torch.fft.irfft(audio_freq, self.params['fftconv_n']) 81 | audio = audio[..., 0:0+sel_smp] 82 | 83 | # normalize volume 84 | audio = F.normalize(audio, p=np.inf, dim=1) 85 | 86 | # random select part 87 | audio = audio.flatten() 88 | hop_size = int(self.params['hop_size'] * self.sample_rate) 89 | n_segs = int((audio.shape[0] - sel_smp) / hop_size) + 1 90 | q_len = int(self.query_len * self.sample_rate) 91 | need_segs = int((q_len - sel_smp) / hop_size) + 1 92 | r = torch.randint(0, n_segs - need_segs + 1, (1,)).item() 93 | time_offset = r * hop_size 94 | audio = audio[r*hop_size : r*hop_size+q_len] 95 | 96 | return name, time_offset / self.sample_rate, audio 97 | 98 | def __len__(self): 99 | return self.num_queries 100 | 101 | if __name__ == '__main__': 102 | # don't delete this line, because my data loader uses queues 103 | torch.multiprocessing.set_start_method('spawn') 104 | args = argparse.ArgumentParser() 105 | args.add_argument('-p', '--params', default='configs/default.json') 106 | args.add_argument('-l', '--length', type=float, default=1) 107 | args.add_argument('--num', type=int, default=10) 108 | args.add_argument('--mode', default='test', choices=['train', 'validate', 'test']) 109 | args.add_argument('-o', '--out', required=True) 110 | args = args.parse_args() 111 | 112 | # warn user (actually just me!) if query files exist 113 | if os.path.exists(args.out): 114 | yesno = input('Folder %s exists, overwrite anyway? (y/n) ' % args.out) 115 | while yesno not in {'y', 'n'}: 116 | yesno = input('Please enter y or n: ') 117 | if yesno == 'n': 118 | exit() 119 | 120 | params = simpleutils.read_config(args.params) 121 | train_val = 'validate' if args.mode == 'test' else args.mode 122 | train_val_test = args.mode 123 | sample_rate = params['sample_rate'] 124 | win = (params['pad_start'] + args.length + params['air']['length'] + params['micirp']['length']) * sample_rate 125 | fftconv_n = 2048 126 | while fftconv_n < win: 127 | fftconv_n *= 2 128 | params['fftconv_n'] = fftconv_n 129 | 130 | noise = NoiseData(noise_dir=params['noise']['dir'], 131 | list_csv=params['noise'][train_val], 132 | sample_rate=sample_rate, cache_dir=params['cache_dir']) 133 | 134 | air = AIR(air_dir=params['air']['dir'], 135 | list_csv=params['air'][train_val], 136 | length=params['air']['length'], 137 | fftconv_n=params['fftconv_n'], sample_rate=sample_rate) 138 | 139 | micirp = MicIRP(mic_dir=params['micirp']['dir'], 140 | list_csv=params['micirp'][train_val], 141 | length=params['micirp']['length'], 142 | fftconv_n=params['fftconv_n'], sample_rate=sample_rate) 143 | 144 | music_list = simpleutils.read_file_list(params[train_val_test + '_csv']) 145 | 146 | gen = QueryGen(params['music_dir'], music_list, noise, air, micirp, args.length, args.num, params) 147 | runall = torch.utils.data.DataLoader( 148 | dataset=gen, 149 | num_workers=3, 150 | batch_size=None 151 | ) 152 | os.makedirs(args.out, exist_ok=True) 153 | fout = open(os.path.join(args.out, 'expected.csv'), 'w', encoding='utf8', newline='\n') 154 | fout2 = open(os.path.join(args.out, 'list.txt'), 'w', encoding='utf8') 155 | writer = csv.writer(fout) 156 | writer.writerow(['query', 'answer', 'time']) 157 | for i, (name,time_offset,sound) in enumerate(tqdm.tqdm(runall)): 158 | safe_name = os.path.splitext(os.path.split(name)[1])[0] 159 | out_name = 'q%04d_%s_%.1f.wav' % (i+1, safe_name, time_offset) 160 | writer.writerow([out_name, name, time_offset]) 161 | path = os.path.join(args.out, out_name) 162 | torchaudio.save(path, sound.unsqueeze(0), gen.sample_rate, encoding='PCM_S', bits_per_sample=16) 163 | fout2.write(path + '\n') 164 | fout.close() 165 | fout2.close() 166 | params['genquery'] = {'mode': train_val_test, 'length': args.length} 167 | with open(os.path.join(args.out, 'configs.json'), 'w') as fout: 168 | json.dump(params, fout, indent=2) 169 | -------------------------------------------------------------------------------- /lists/readme.md: -------------------------------------------------------------------------------- 1 | dataset file lists are stored here 2 | -------------------------------------------------------------------------------- /matchemb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import math 3 | import os 4 | import sys 5 | import time 6 | import warnings 7 | 8 | import faiss 9 | import numpy as np 10 | import tqdm 11 | import torch.multiprocessing as mp 12 | 13 | import simpleutils 14 | from database import Database 15 | 16 | if __name__ == "__main__": 17 | logger_init = simpleutils.MultiProcessInitLogger('matchemb') 18 | logger_init() 19 | 20 | mp.set_start_method('spawn') 21 | logger = mp.get_logger() 22 | if len(sys.argv) < 4: 23 | print('Usage: python %s ' % sys.argv[0]) 24 | sys.exit() 25 | dir_for_query = sys.argv[1] 26 | dir_for_db = sys.argv[2] 27 | result_file = sys.argv[3] 28 | result_file2 = os.path.splitext(result_file) # for more detailed output 29 | result_file2 = result_file2[0] + '_detail.csv' 30 | result_file_score = result_file + '.bin' 31 | configs = os.path.join(dir_for_db, 'configs.json') 32 | params = simpleutils.read_config(configs) 33 | file_list = simpleutils.read_file_list(os.path.join(dir_for_query, 'queryList.txt')) 34 | logger.info('command args: %s', sys.argv) 35 | logger.info('params: %s', params) 36 | 37 | d = params['model']['d'] 38 | 39 | top_k = params['indexer']['top_k'] 40 | frame_shift_mul = params['indexer'].get('frame_shift_mul', 1) 41 | 42 | print('loading database...') 43 | db = Database(dir_for_db, params['indexer'], params['hop_size']) 44 | print('database loaded') 45 | 46 | print('loading queries') 47 | query_embeddings = np.fromfile(os.path.join(dir_for_query, 'query_embeddings'), dtype=np.float32) 48 | query_embeddings = query_embeddings.reshape([-1, d]) 49 | query_index = np.fromfile(os.path.join(dir_for_query, 'query_index'), dtype=np.int64) 50 | query_index = query_index.reshape([-1, 2]) 51 | print('queries loaded') 52 | 53 | tm_0 = time.time() 54 | fout = open(result_file, 'w', encoding='utf8', newline='\n') 55 | fout2 = open(result_file2, 'w', encoding='utf8', newline='\n') 56 | fout_score = open(result_file_score, 'wb') 57 | detail_writer = csv.writer(fout2) 58 | detail_writer.writerow(['query', 'answer', 'score', 'time', 'part_scores']) 59 | for i, name in enumerate(tqdm.tqdm(file_list)): 60 | logger.info('get query %s', name) 61 | tm_1 = time.time() 62 | 63 | my_idx = query_index[i, 0] 64 | my_len = query_index[i, 1] 65 | embeddings = query_embeddings[my_idx:my_idx+my_len] 66 | 67 | tm_2 = time.time() 68 | logger.info('compute embedding %.6fs', tm_2 - tm_1) 69 | 70 | sco, (ans, tim), song_score = db.query_embeddings(embeddings) 71 | upsco = [] 72 | ans = db.songList[ans] 73 | 74 | tm_1 = time.time() 75 | fout.write('%s\t%s\n' % (name, ans)) 76 | fout.flush() 77 | detail_writer.writerow([name, ans, sco, tim] + upsco) 78 | fout2.flush() 79 | 80 | fout_score.write(song_score.tobytes()) 81 | tm_2 = time.time() 82 | logger.info('output answer %.6fs', tm_2 - tm_1) 83 | fout.close() 84 | fout2.close() 85 | logger.info('total query time %.6fs', time.time() - tm_0) 86 | -------------------------------------------------------------------------------- /matcher.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import math 3 | import os 4 | import sys 5 | import time 6 | import warnings 7 | 8 | import faiss 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import torch.nn.functional as F 13 | import torch.multiprocessing as mp 14 | import torchvision 15 | import tqdm 16 | 17 | # torchaudio currently (0.7) will throw warning that cannot be disabled 18 | with warnings.catch_warnings(): 19 | warnings.simplefilter("ignore") 20 | import torchaudio 21 | 22 | import simpleutils 23 | from model import FpNetwork 24 | from datautil.melspec import build_mel_spec_layer 25 | from datautil.musicdata import MusicDataset 26 | from database import Database 27 | 28 | if __name__ == "__main__": 29 | logger_init = simpleutils.MultiProcessInitLogger('nnmatcher') 30 | logger_init() 31 | 32 | mp.set_start_method('spawn') 33 | logger = mp.get_logger() 34 | if len(sys.argv) < 4: 35 | print('Usage: python %s ' % sys.argv[0]) 36 | sys.exit() 37 | file_list_for_query = sys.argv[1] 38 | dir_for_db = sys.argv[2] 39 | result_file = sys.argv[3] 40 | result_file2 = os.path.splitext(result_file) # for more detailed output 41 | result_file2 = result_file2[0] + '_detail.csv' 42 | result_file_score = result_file + '.bin' 43 | configs = os.path.join(dir_for_db, 'configs.json') 44 | params = simpleutils.read_config(configs) 45 | 46 | visualize = False 47 | 48 | d = params['model']['d'] 49 | h = params['model']['h'] 50 | u = params['model']['u'] 51 | F_bin = params['n_mels'] 52 | segn = int(params['segment_size'] * params['sample_rate']) 53 | T = (segn + params['stft_hop'] - 1) // params['stft_hop'] 54 | 55 | top_k = params['indexer']['top_k'] 56 | frame_shift_mul = params['indexer'].get('frame_shift_mul', 1) 57 | 58 | print('loading model...') 59 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 60 | model = FpNetwork(d, h, u, F_bin, T, params['model']).to(device) 61 | model.load_state_dict(torch.load(os.path.join(dir_for_db, 'model.pt'), map_location=device)) 62 | print('model loaded') 63 | 64 | print('loading database...') 65 | db = Database(dir_for_db, params['indexer'], params['hop_size']) 66 | print('database loaded') 67 | 68 | # doing inference, turn off gradient 69 | model.eval() 70 | for param in model.parameters(): 71 | param.requires_grad = False 72 | 73 | dataset = MusicDataset(file_list_for_query, params) 74 | # no task parallelism 75 | loader = DataLoader(dataset, num_workers=0, batch_size=None) 76 | 77 | mel = build_mel_spec_layer(params).to(device) 78 | 79 | tm_0 = time.time() 80 | fout = open(result_file, 'w', encoding='utf8', newline='\n') 81 | fout2 = open(result_file2, 'w', encoding='utf8', newline='\n') 82 | fout_score = open(result_file_score, 'wb') 83 | detail_writer = csv.writer(fout2) 84 | detail_writer.writerow(['query', 'answer', 'score', 'time', 'part_scores']) 85 | for dat in tqdm.tqdm(loader): 86 | embeddings = [] 87 | grads = [] 88 | specs = [] 89 | i, name, wav = dat 90 | logger.info('get query %s', name) 91 | tm_1 = time.time() 92 | i = int(i) # i is leaking file handles! 93 | 94 | if wav.shape[0] == 0: 95 | # load file error! 96 | logger.error('load %s error!', name) 97 | ans = 'error' 98 | sco = -1e999 99 | tim = 0 100 | fout.write('%s\t%s\n' % (name, ans)) 101 | fout.flush() 102 | detail_writer.writerow([name, ans, sco, tim]) 103 | fout2.flush() 104 | 105 | song_score = np.zeros([len(db.songList), 2], dtype=np.float32) 106 | fout_score.write(song_score.tobytes()) 107 | continue 108 | 109 | # batch size should be less than 20 because query contains at most 19 segments 110 | for batch in torch.split(wav, 16): 111 | g = batch.to(device) 112 | 113 | # Mel spectrogram 114 | with warnings.catch_warnings(): 115 | # torchaudio is still using deprecated function torch.rfft 116 | warnings.simplefilter("ignore") 117 | g = mel(g) 118 | if visualize: 119 | g.requires_grad = True 120 | z = model.forward(g, norm=False).cpu() 121 | if visualize: 122 | z.backward(z) 123 | z.detach_() 124 | grads.append(g.grad.cpu()) 125 | specs.append(g.detach().cpu()) 126 | z = torch.nn.functional.normalize(z, p=2) 127 | embeddings.append(z) 128 | embeddings = torch.cat(embeddings) 129 | 130 | tm_2 = time.time() 131 | logger.info('compute embedding %.6fs', tm_2 - tm_1) 132 | 133 | if visualize: 134 | grads = torch.cat(grads) 135 | specs = torch.cat(specs) 136 | sco, (ans, tim), song_score = db.query_embeddings(embeddings.numpy()) 137 | upsco = [] 138 | ans = db.songList[ans] 139 | #tim /= frame_shift_mul 140 | #tim *= params['hop_size'] 141 | #song_score[:, 1] *= params['hop_size'] / frame_shift_mul 142 | if visualize: 143 | grads = torch.abs(grads) 144 | grads = torch.nn.functional.normalize(grads, p=np.inf) 145 | grads = grads.transpose(0, 1).flatten(1, 2) 146 | grads = grads.repeat(3, 1, 1) 147 | specs = specs.transpose(0, 1).flatten(1, 2) 148 | grads[1] = specs - math.log(1e-6) 149 | grads[1] /= torch.max(grads[1]) 150 | grads[0] = torch.nn.functional.relu(grads[0]) 151 | grads[1] *= 1 - grads[0] 152 | grads[2] = 0 153 | grads = torch.flip(grads, [1]) 154 | grads[:,:,::32] = 0 155 | torchvision.utils.save_image(grads, '%s.png' % os.path.basename(name[0])) 156 | 157 | tm_1 = time.time() 158 | fout.write('%s\t%s\n' % (name, ans)) 159 | fout.flush() 160 | detail_writer.writerow([name, ans, sco, tim] + upsco) 161 | fout2.flush() 162 | 163 | fout_score.write(song_score.tobytes()) 164 | tm_2 = time.time() 165 | logger.info('output answer %.6fs', tm_2 - tm_1) 166 | fout.close() 167 | fout2.close() 168 | logger.info('total query time %.6fs', time.time() - tm_0) 169 | else: 170 | torch.set_num_threads(1) 171 | -------------------------------------------------------------------------------- /matchfromgt.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import math 3 | import os 4 | import sys 5 | import argparse 6 | import warnings 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | import torch.multiprocessing as mp 13 | import torchvision 14 | import tqdm 15 | 16 | # torchaudio currently (0.7) will throw warning that cannot be disabled 17 | with warnings.catch_warnings(): 18 | warnings.simplefilter("ignore") 19 | import torchaudio 20 | 21 | import simpleutils 22 | from model import FpNetwork 23 | from datautil.melspec import build_mel_spec_layer 24 | from datautil.musicdata import MusicDataset 25 | 26 | if __name__ == "__main__": 27 | mp.set_start_method('spawn') 28 | args = argparse.ArgumentParser() 29 | args.add_argument('file_list') 30 | args.add_argument('gt') 31 | args.add_argument('db') 32 | args.add_argument('result') 33 | args = args.parse_args() 34 | 35 | file_list_for_query = args.file_list 36 | dir_for_db = args.db 37 | result_file = args.result 38 | configs = os.path.join(dir_for_db, 'configs.json') 39 | params = simpleutils.read_config(configs) 40 | 41 | d = params['model']['d'] 42 | h = params['model']['h'] 43 | u = params['model']['u'] 44 | F_bin = params['n_mels'] 45 | segn = int(params['segment_size'] * params['sample_rate']) 46 | T = (segn + params['stft_hop'] - 1) // params['stft_hop'] 47 | 48 | frame_shift_mul = params['indexer'].get('frame_shift_mul', 1) 49 | 50 | print('loading model...') 51 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 52 | model = FpNetwork(d, h, u, F_bin, T, params['model']).to(device) 53 | model.load_state_dict(torch.load(os.path.join(dir_for_db, 'model.pt'), map_location=device)) 54 | print('model loaded') 55 | 56 | print('loading database...') 57 | with open(os.path.join(dir_for_db, 'songList.txt'), 'r', encoding='utf8') as fin: 58 | songList = [] 59 | for line in fin: 60 | if line.endswith('\n'): line = line[:-1] 61 | songList.append(line) 62 | 63 | landmarkKey = np.fromfile(os.path.join(dir_for_db, 'landmarkKey'), dtype=np.int32) 64 | assert len(songList) == landmarkKey.shape[0] 65 | index2song = np.repeat(np.arange(len(songList)), landmarkKey) 66 | landmarkKey = np.pad(np.cumsum(landmarkKey, dtype=np.int64), (1,0)) 67 | songEmb = np.fromfile(os.path.join(dir_for_db, 'embeddings'), dtype=np.float32) 68 | songEmb = songEmb.reshape([-1, d]) 69 | songEmb = torch.from_numpy(songEmb) 70 | print('database loaded') 71 | 72 | print('loading ground truth...') 73 | songList_noext = [os.path.splitext(os.path.basename(x))[0] for x in songList] 74 | with open(args.gt, 'r', encoding='utf8') as fin: 75 | gt = {} 76 | for i in fin: 77 | query, ans = i.split('\t') 78 | ans = ans.rstrip() 79 | gt[query] = songList_noext.index(ans) 80 | print('ground truth loaded') 81 | 82 | # doing inference, turn off gradient 83 | model.eval() 84 | for param in model.parameters(): 85 | param.requires_grad = False 86 | 87 | dataset = MusicDataset(file_list_for_query, params) 88 | # no task parallelism 89 | loader = DataLoader(dataset, num_workers=0) 90 | 91 | mel = build_mel_spec_layer(params).to(device) 92 | 93 | fout = open(result_file, 'w', encoding='utf8', newline='\n') 94 | detail_writer = csv.writer(fout) 95 | detail_writer.writerow(['query', 'answer', 'score', 'time', 'part_scores']) 96 | for dat in tqdm.tqdm(loader): 97 | embeddings = [] 98 | grads = [] 99 | specs = [] 100 | i, name, wav = dat 101 | i = int(i) # i is leaking file handles! 102 | 103 | # get song name 104 | query = os.path.splitext(os.path.basename(name[0]))[0] 105 | 106 | 107 | if query not in gt: 108 | print('query %s does not have ground truth' % query) 109 | continue 110 | ansId = gt[query] 111 | ans = songList[ansId] 112 | 113 | # batch size should be less than 20 because query contains at most 19 segments 114 | for batch in DataLoader(wav.squeeze(0), batch_size=16): 115 | g = batch.to(device) 116 | 117 | # Mel spectrogram 118 | with warnings.catch_warnings(): 119 | # torchaudio is still using deprecated function torch.rfft 120 | warnings.simplefilter("ignore") 121 | g = mel(g) 122 | z = model.forward(g, norm=False).cpu() 123 | z = torch.nn.functional.normalize(z, p=2) 124 | embeddings.append(z) 125 | embeddings = torch.cat(embeddings) 126 | 127 | idx1 = landmarkKey[ansId] 128 | idx2 = landmarkKey[ansId+1] 129 | T = (embeddings.shape[0]-1) // frame_shift_mul + 1 130 | slen = idx2 - idx1 131 | # find alignment 132 | scos = embeddings @ songEmb[idx1:idx2].T 133 | accum_scos = torch.zeros([frame_shift_mul, slen + T]) 134 | for t in range(embeddings.shape[0]): 135 | t0 = T - t//frame_shift_mul 136 | accum_scos[t % frame_shift_mul, t0:t0+slen] += scos[t] 137 | # these are invalid time shifts 138 | accum_scos[:, 0] = -T*2 139 | accum_scos[(embeddings.shape[0]-1)%frame_shift_mul+1:, 1] = -T*2 140 | 141 | tim = torch.argmax(accum_scos).item() 142 | tim1, tim2 = divmod(tim, slen + T) 143 | tim = -tim1 + (tim2-T) * frame_shift_mul 144 | 145 | tim /= frame_shift_mul 146 | tim *= params['hop_size'] 147 | sco = accum_scos[tim1, tim2].item() 148 | myscos = [] 149 | myvecs = [] 150 | tidxs = [] 151 | for t in range(T): 152 | tidx = t*frame_shift_mul + tim1 153 | if 0 <= tidx < embeddings.shape[0] and 0 <= tim2-T+t < slen: 154 | mysco = scos[tidx, tim2-T + t].item() 155 | tidxs.append(tidx) 156 | myscos.append(mysco) 157 | myvecs.append(embeddings[tidx]) 158 | myvecs = torch.stack(myvecs) 159 | score_seg = myvecs @ songEmb.T 160 | 161 | upsco = [] 162 | for i in range(len(myscos)): 163 | score_seg[i, idx1 + (tim2-T) + i] = -10 164 | rank = (score_seg[i] >= myscos[i]).sum().item() + 1 165 | upsco += [myscos[i], tidxs[i], rank] 166 | 167 | detail_writer.writerow([name[0], ans, sco, tim] + upsco) 168 | fout.flush() 169 | del score_seg 170 | fout.close() 171 | else: 172 | torch.set_num_threads(1) 173 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | from torch.nn import Module, Conv2d, LayerNorm, ReLU, ModuleList, Conv1d, ELU, ZeroPad2d 6 | 7 | def get_activation(name): 8 | if name == 'ReLU': 9 | return ReLU() 10 | elif name == 'ELU': 11 | return ELU() 12 | raise KeyError(name) 13 | 14 | class SeparableConv2d(Module): 15 | def __init__(self, i, o, k, s, in_F, in_T, fuller=False, activation='ReLU', relu_after_bn=True): 16 | super(SeparableConv2d, self).__init__() 17 | # this is actually "same" padding, but PyTorch doesn't support that 18 | padding = (in_T-1)//s[0] * s[0] + k - in_T 19 | self.pad1 = ZeroPad2d((padding//2, padding - padding//2, 0, 0)) 20 | self.conv1 = Conv2d(i, o, kernel_size=(1, k), stride=(1, s[0])) 21 | self.ln1 = LayerNorm((o, in_F, (in_T-1)//s[0]+1)) 22 | self.relu1 = get_activation(activation) 23 | # this is actually "same" padding, but PyTorch doesn't support that 24 | padding = (in_F-1)//s[1] * s[1] + k - in_F 25 | self.pad2 = ZeroPad2d((0, 0, padding//2, padding - padding//2)) 26 | if fuller: 27 | self.conv2 = Conv2d(o, o, kernel_size=(k, 1), stride=(s[1], 1)) 28 | else: 29 | self.conv2 = Conv2d(o, o, kernel_size=(k, 1), stride=(s[1], 1), groups=o) 30 | self.ln2 = LayerNorm((o, (in_F-1)//s[1]+1, (in_T-1)//s[0]+1)) 31 | self.relu2 = get_activation(activation) 32 | 33 | self.relu_after_bn = relu_after_bn 34 | self.hacked = False 35 | 36 | # I found a way to do Keras same padding for stride=2 without zero padding layer/function 37 | # Just flip the image, then the builtin Conv2d padding will do the right thing 38 | def hack(self): 39 | self.hacked = not self.hacked 40 | with torch.no_grad(): 41 | self.conv1.weight.set_(self.conv1.weight.flip([2, 3])) 42 | self.ln1.weight.set_(self.ln1.weight.flip([1, 2])) 43 | self.ln1.bias.set_(self.ln1.bias.flip([1, 2])) 44 | self.conv2.weight.set_(self.conv2.weight.flip([2, 3])) 45 | self.ln2.weight.set_(self.ln2.weight.flip([1, 2])) 46 | self.ln2.bias.set_(self.ln2.bias.flip([1, 2])) 47 | if self.hacked: 48 | self.conv1.padding = self.pad1.padding[3::-2] 49 | self.conv2.padding = self.pad2.padding[3::-2] 50 | else: 51 | self.conv1.padding = (0,0) 52 | self.conv2.padding = (0,0) 53 | 54 | def forward(self, x): 55 | if not self.hacked: 56 | x = self.pad1(x) 57 | x = self.conv1(x) 58 | if self.relu_after_bn: 59 | x = self.ln1(x) 60 | x = self.relu1(x) 61 | else: 62 | x = self.relu1(x) 63 | x = self.ln1(x) 64 | if not self.hacked: 65 | x = self.pad2(x) 66 | x = self.conv2(x) 67 | if self.relu_after_bn: 68 | x = self.ln2(x) 69 | x = self.relu2(x) 70 | else: 71 | x = self.relu2(x) 72 | x = self.ln2(x) 73 | return x 74 | 75 | class MyF(Module): 76 | def __init__(self, d, h, u, in_F, in_T, fuller=False, activation='ReLU', 77 | strides=None, relu_after_bn=True): 78 | super(MyF, self).__init__() 79 | channels = [1, d, d, 2*d, 2*d, 4*d, 4*d, h, h] 80 | convs = [] 81 | for i in range(8): 82 | k = 3 83 | s = 2, 2 84 | if strides is not None: 85 | s = strides[i][0][1], strides[i][1][0] 86 | sepconv = SeparableConv2d(channels[i], channels[i+1], k, s, in_F, in_T, 87 | fuller=fuller, 88 | activation=activation, 89 | relu_after_bn=relu_after_bn 90 | ) 91 | convs.append(sepconv) 92 | in_F = (in_F-1)//s[1] + 1 93 | in_T = (in_T-1)//s[0] + 1 94 | assert in_F==in_T==1, 'output must be 1x1' 95 | self.convs = ModuleList(convs) 96 | 97 | def hack(self): 98 | for conv in self.convs: 99 | conv.hack() 100 | 101 | def forward(self, x): 102 | x = x.unsqueeze(1) 103 | for i, conv in enumerate(self.convs): 104 | x = conv(x) 105 | #assert x.shape[2]==x.shape[3]==1, 'output must be 1x1' 106 | return x 107 | 108 | class MyG(Module): 109 | __constants__ = ['d', 'h'] 110 | def __init__(self, d, h, u): 111 | super(MyG, self).__init__() 112 | assert h%d == 0, 'h must be divisible by d' 113 | v = h//d 114 | self.d = d 115 | self.h = h 116 | self.u = u 117 | self.v = v 118 | self.linear1 = Conv1d(d * v, d * u, kernel_size=(1,), groups=d) 119 | self.elu = ELU() 120 | self.linear2 = Conv1d(d * u, d, kernel_size=(1,), groups=d) 121 | 122 | def forward(self, x, norm=True): 123 | x = x.reshape([-1, self.h, 1]) 124 | x = self.linear1(x) 125 | x = self.elu(x) 126 | x = self.linear2(x) 127 | x = x.reshape([-1, self.d]) 128 | if norm: 129 | x = torch.nn.functional.normalize(x, p=2.0) 130 | return x 131 | 132 | class FpNetwork(Module): 133 | def __init__(self, d, h, u, F, T, params): 134 | super(FpNetwork, self).__init__() 135 | self.f = MyF(d, h, u, F, T, 136 | fuller=params.get('fuller', False), 137 | activation=params.get('conv_activation', 'ReLU'), 138 | strides=params.get('strides'), 139 | relu_after_bn=params.get('relu_after_bn', True) 140 | ) 141 | self.g = MyG(d, h, u) 142 | self.hacked = False 143 | 144 | def hack(self): 145 | self.hacked = not self.hacked 146 | self.f.hack() 147 | 148 | def forward(self, x, norm=True): 149 | if self.hacked: 150 | x = x.flip([1, 2]) 151 | x = self.f(x) 152 | x = self.g(x, norm=norm) 153 | return x 154 | -------------------------------------------------------------------------------- /preview.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | import argparse 4 | 5 | import torch 6 | with warnings.catch_warnings(): 7 | warnings.simplefilter("ignore") 8 | import torchaudio 9 | 10 | from datautil.dataset import build_data_loader 11 | import simpleutils 12 | 13 | if __name__ == '__main__': 14 | # don't delete this line, because my data loader uses queues 15 | torch.multiprocessing.set_start_method('spawn') 16 | args = argparse.ArgumentParser() 17 | args.add_argument('-d', '--data', required=True) 18 | args.add_argument('--noise') 19 | args.add_argument('--air') 20 | args.add_argument('--micirp') 21 | args.add_argument('-p', '--params', default='configs/default.json') 22 | args = args.parse_args() 23 | 24 | params = simpleutils.read_config(args.params) 25 | 26 | train_data = build_data_loader(params, args.data, args.noise, args.air, args.micirp) 27 | i = 0 28 | train_data.dataset.output_wav = True 29 | train_data.sampler.sampler.shuffle = False 30 | iterator = iter(train_data) 31 | for a in iterator: 32 | i += 1 33 | sound = a.transpose(0,1).flatten(1,2) 34 | sound *= 0.5 / torch.max(torch.abs(sound)) 35 | torchaudio.save('temp%d.wav' % i, sound, 8000) 36 | print(i) 37 | if i >= 3: 38 | iterator._shutdown_workers() 39 | # kill my preloader 40 | train_data.sampler.sampler.preloader.terminate() 41 | break 42 | print('stopping...') 43 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # pfann 2 | This is an unofficial reproduction of paper ["Neural Audio Fingerprint for High-specific Audio Retrieval based on Contrasive Learning."](https://arxiv.org/abs/2010.11910) 3 | 4 | Now I have a thesis that is a "trivial" improvement to the above paper: "Improvement of Neural Network- and Landmark-based Audio Fingerprinting" (in Traditional Chinese). [Link here](thesis.pdf) 5 | 6 | Note: I am now employed ~~and our company does not allow GitHub login during work~~. 7 | I have less time to work on my side project or maintain my thesis code, ~~and I do not have access to high performance GPU (currently), so I cannot solve compatibility issues or problems related to training.~~ 8 | Finally I bought a gaming computer in 2023, now I can help you solve training issues. 9 | 10 | ## Install 11 | 12 | ```sh 13 | conda install python=3.9 # python 3.10 doesn't work with faiss... 14 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia # I forget which version of PyTorch I used, but latest PyTorch seems to work 15 | conda install -c pytorch faiss-gpu # can also be faiss-cpu if you don't test gpu-accelerated search 16 | pip install tqdm 17 | pip install tensorboardX 18 | pip install torch_optimizer 19 | pip install scipy 20 | pip install julius 21 | pip install matplotlib # for visualization purpose, not needed for server 22 | pip install seaborn # for visualization purpose, not needed for server 23 | pip install scikit-learn 24 | ``` 25 | 26 | ## Prepare dataset 27 | 28 | ### FMA dataset 29 | 30 | Download fma_medium from https://github.com/mdeff/fma and unzip to 31 | `../pfann_dataset/fma_medium` . 32 | 33 | ``` 34 | python tools/listaudio.py --folder ../pfann_dataset/fma_medium --out lists/fma_medium.csv 35 | python tools/filterduration.py --csv lists/fma_medium.csv --min-len 29.9 --out lists/fma_medium_30s.csv 36 | python tools/traintestsplit.py --csv lists/fma_medium_30s.csv --train lists/fma_medium_train.csv --train-size 10000 --test lists/fma_medium_valtest.csv --test-size 1000 37 | python tools/traintestsplit.py --csv lists/fma_medium_valtest.csv --train lists/fma_medium_val.csv --train-size 500 --test lists/fma_medium_test.csv --test-size 500 38 | python tools/traintestsplit.py --csv lists/fma_medium_train.csv --train-size 2000 --train lists/fma_inside_test.csv 39 | rm test.csv 40 | python tools/listaudio.py --folder ../pfann_dataset/fma_large --out lists/fma_large.csv 41 | ``` 42 | 43 | ### AudioSet 44 | 45 | Download 3 csv files `unbalanced_train_segments.csv`, `balanced_train_segments.csv`, `eval_segments.csv`, and `ontology.json` from https://research.google.com/audioset/download.html . 46 | Then run these to list all the videos needed: 47 | 48 | ``` 49 | python tools/audioset.py /path/to/unbalanced_train_segments.csv lists/audioset1.csv --ontology /path/to/ontology.json 50 | python tools/audioset.py /path/to/balanced_train_segments.csv lists/audioset2.csv --ontology /path/to/ontology.json 51 | python tools/audioset.py /path/to/eval_segments.csv lists/audioset3.csv --ontology /path/to/ontology.json 52 | ``` 53 | 54 | Use these commands to crawl videos from youtube and convert to wav: 55 | 56 | ``` 57 | python tools/audioset2.py lists/audioset1.csv ../pfann_dataset/audioset 58 | python tools/audioset2.py lists/audioset2.csv ../pfann_dataset/audioset 59 | python tools/audioset2.py lists/audioset3.csv ../pfann_dataset/audioset 60 | ``` 61 | 62 | After downloading, run this command to list all successfully downloaded files: 63 | 64 | ``` 65 | python tools/listaudio.py --folder ../pfann_dataset/audioset --out lists/noise.csv 66 | ``` 67 | 68 | This command will show errors because some videos are unavailable. 69 | 70 | Finally run the command: 71 | 72 | ``` 73 | python tools/filterduration.py --csv lists/noise.csv --min-len 9.9 --out lists/noise_10s.csv 74 | python tools/traintestsplit.py --csv lists/noise_10s.csv --train lists/noise_train.csv --train-size 8 --test lists/noise_val.csv --test-size 2 -p 75 | ``` 76 | 77 | ### Microphone impulse response dataset 78 | 79 | Go to http://micirp.blogspot.com/ , and download files to `../pfann_dataset/micirp`. Then run the commands: 80 | 81 | ``` 82 | python tools/listaudio.py --folder ../pfann_dataset/micirp --out lists/micirp.csv 83 | python tools/traintestsplit.py --csv lists/micirp.csv --train lists/micirp_train.csv --train-size 8 --test lists/micirp_val.csv --test-size 2 -p 84 | ``` 85 | 86 | ### Aachen Impulse Response Database 87 | 88 | Download zip from https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/aachen-impulse-response-database/ 89 | and unzip to `../pfann_dataset/AIR_1_4`. 90 | 91 | ``` 92 | python -m datautil.ir ../pfann_dataset/AIR_1_4 lists/air.csv 93 | python tools/traintestsplit.py --csv lists/air.csv --train lists/air_train.csv --train-size 8 --test lists/air_val.csv --test-size 2 -p 94 | ``` 95 | 96 | ## Train 97 | 98 | ``` 99 | python train.py --param configs/default.json -w4 100 | ``` 101 | 102 | ## Generate query 103 | Inside test (not used in my thesis anymore): 104 | ``` 105 | python genquery.py --params configs/gentest.json --len 10 --num 2000 --mode train --out out/queries/inside 106 | ``` 107 | 108 | Assume that you have installed all the datasets, then just run this to generate all queries: 109 | ```sh 110 | ./genall.sh 111 | ``` 112 | 113 | Will output to folders `out/queries/out2_snr$snr`, where `$snr` is one of -6, -4, -2, 0, 2, 4, 6, 8. 114 | The query list (used by `matcher.py`) is `out/queries/out2_snr$snr/list.txt`, and the ground truth is `out/queries/out2_snr$snr/expected.csv`. 115 | 116 | ## Build a fingerprint database 117 | Inside test (not used in my thesis anymore): 118 | ``` 119 | python tools/csv2txt.py --dir ../pfann_dataset/fma_medium lists/fma_medium_train.csv --out lists/fma_medium_train.txt 120 | python builder.py lists/fma_medium_train.txt /path/to/db configs/default.json 121 | ``` 122 | 123 | Usage of `builder.py`: 124 | ``` 125 | python builder.py 126 | ``` 127 | Music list file is a file containing list of music file paths. 128 | File must be UTF-8 without BOM. For example: 129 | ``` 130 | /path/to/fma_medium/000/000002.mp3 131 | /path/to/fma_medium/000/000005.mp3 132 | /path/to/your/music/aaa.wav 133 | /path/to/your/music/bbb.wav 134 | ``` 135 | Model config is a JSON file like in `configs/` folder. 136 | It is used to load a trained model. 137 | If omitted, the model config is `configs/default.json` by default. 138 | 139 | This program supports both MP3 and WAV audio format. 140 | Relative paths are supported but not recommended. 141 | 142 | ## Recognize music 143 | Usage of `matcher.py`: 144 | ``` 145 | python matcher.py 146 | ``` 147 | 148 | Query list is a file containing list of query file paths. For example: 149 | ``` 150 | /path/to/queries/out2_snr2/000002.wav 151 | /path/to/queries/out2_snr2/000005.wav 152 | /path/to/song_recorded_on_street1.wav 153 | /path/to/song_recorded_on_street2.wav 154 | ``` 155 | Database location is the place where `builder.py` saves database. 156 | 157 | The result file will be a TSV file with 2 fields: query file path, and matched music path, but without header. 158 | It may look like this: 159 | ``` 160 | /path/to/queries/out2_snr2/000002.wav /path/to/fma_medium/000/000002.mp3 161 | /path/to/queries/out2_snr2/000005.wav /path/to/fma_medium/000/000005.mp3 162 | /path/to/song_recorded_on_street1.wav /path/to/your/music/aaa.wav 163 | /path/to/song_recorded_on_street2.wav /path/to/your/music/aaa.wav 164 | ``` 165 | 166 | Matcher will also generate a `_detail.csv` file and a `.bin` file. 167 | CSV file contains more information about the matches. 168 | It has 5 columns: query, answer, score, time, and part_scores. 169 | * query: Query file path 170 | * answer: Matched music path 171 | * score: Matching score, used in my thesis 172 | * time: The time when the query clip starts in the matched music, in seconds 173 | * part_scores: Mainly used for debugging, currently empty 174 | 175 | BIN file contains matching scores of every database music for each query. 176 | It is used in my ensemble experiments. 177 | The file format is a flattened 2D array of following structure, without header: 178 | ```c++ 179 | struct match_t { 180 | float score; // Matching score 181 | float offset; // The time when the query clip starts in the matched music, in seconds 182 | }; 183 | ``` 184 | The matching score of j-th database music in i-th query is at index `i * database size + j`. 185 | 186 | ## Evaluation 187 | ``` 188 | python tools/accuracy.py /path/to/query6s/expected.csv /path/to/result_detail.csv 189 | ``` 190 | 191 | ## Ensemble experiment 192 | ```bash 193 | python ensemble/svmheatmap2.py out/lm_ out/shift_4_ out/svm lin_acc.csv 194 | ``` 195 | More info TODO 196 | -------------------------------------------------------------------------------- /rebuild.py: -------------------------------------------------------------------------------- 1 | # rebuild is for reindexing embedding database using different index 2 | import os 3 | import shutil 4 | import sys 5 | import time 6 | import warnings 7 | 8 | import faiss 9 | import numpy as np 10 | 11 | import simpleutils 12 | 13 | def set_verbose(index): 14 | if isinstance(index, faiss.Index): 15 | index = faiss.downcast_index(index) 16 | elif isinstance(index, faiss.IndexBinary): 17 | index = faiss.downcast_IndexBinary(index) 18 | index.verbose = True 19 | if isinstance(index, faiss.IndexPreTransform): 20 | set_verbose(index.index) 21 | elif isinstance(index, faiss.IndexIVF): 22 | index.cp.verbose = True 23 | 24 | if __name__ == "__main__": 25 | if len(sys.argv) < 2: 26 | print('Usage: python %s ' % sys.argv[0]) 27 | sys.exit() 28 | dir_for_db = sys.argv[1] 29 | configs = os.path.join(dir_for_db, 'configs.json') 30 | params = simpleutils.read_config(configs) 31 | 32 | d = params['model']['d'] 33 | h = params['model']['h'] 34 | u = params['model']['u'] 35 | F_bin = params['n_mels'] 36 | segn = int(params['segment_size'] * params['sample_rate']) 37 | T = (segn + params['stft_hop'] - 1) // params['stft_hop'] 38 | 39 | print('loading embeddings') 40 | embeddings = np.fromfile(os.path.join(dir_for_db, 'embeddings'), dtype=np.float32).reshape([-1, d]) 41 | 42 | # train indexer 43 | print('training indexer') 44 | try: 45 | index = faiss.index_factory(d, params['indexer']['index_factory'], faiss.METRIC_INNER_PRODUCT) 46 | except RuntimeError as x: 47 | if 'not implemented for inner prod search' in str(x) or "Error: 'metric == METRIC_L2' failed" in str(x): 48 | print(x) 49 | index = faiss.index_factory(d, params['indexer']['index_factory'], faiss.METRIC_L2) 50 | else: 51 | raise 52 | 53 | set_verbose(index) 54 | if not index.is_trained: 55 | try: 56 | index.train(embeddings) 57 | except RuntimeError as x: 58 | print(x) 59 | if "Error: 'nx >= k' failed" in str(x): 60 | index = faiss.IndexFlatIP(d) 61 | #index = faiss.IndexFlatIP(d) 62 | 63 | # write database 64 | print('writing database') 65 | index.add(embeddings) 66 | emb_db_path = os.path.join(dir_for_db, 'landmarkValue') 67 | faiss.write_index(index, emb_db_path) 68 | print('embedding size:', os.stat(emb_db_path).st_size) 69 | -------------------------------------------------------------------------------- /simpleutils.py: -------------------------------------------------------------------------------- 1 | # every utils that don't use torch 2 | import csv 3 | import datetime 4 | import hashlib 5 | import json 6 | import logging 7 | import multiprocessing as mp 8 | import os 9 | import tempfile 10 | import time 11 | 12 | class Timing(): 13 | def __init__(self, name='run time'): 14 | self.name = name 15 | self.t = time.time() 16 | self.entered = False 17 | def __enter__(self): 18 | self.t = time.time() 19 | self.entered = True 20 | def __exit__(self, *ignored): 21 | self.showRunTime(self.name) 22 | def showRunTime(self, name): 23 | print(self.name, ':', time.time() - self.t, 's') 24 | 25 | def get_hash(s): 26 | m = hashlib.md5() 27 | m.update(s.encode('utf8')) 28 | return m.hexdigest() 29 | 30 | def read_config(path): 31 | with open(path, 'r') as fin: 32 | return json.load(fin) 33 | 34 | def read_file_list(list_file): 35 | files = [] 36 | if list_file.endswith('.csv'): 37 | with open(list_file, 'r') as fin: 38 | reader = csv.reader(fin) 39 | firstrow = next(reader) 40 | files = [row[0] for row in reader] 41 | else: 42 | with open(list_file, 'r', encoding='utf8') as fin: 43 | for line in fin: 44 | if line.endswith('\n'): 45 | line = line[:-1] 46 | files.append(line) 47 | return files 48 | 49 | s3_resource = None 50 | def get_s3_resource(): 51 | import boto3 52 | global s3_resource 53 | if s3_resource is None: 54 | s3_resource = boto3.resource('s3', endpoint_url='https://cos.twcc.ai') 55 | return s3_resource 56 | 57 | def download_tmp_from_s3(s3url): 58 | s3_res = get_s3_resource() 59 | d1 = s3url.find('/', 5) 60 | bucket_name = s3url[5:d1] 61 | object_name = s3url[d1+1:] 62 | ext = os.path.splitext(s3url)[1] 63 | obj = s3_res.Object(bucket_name, object_name) 64 | _, tmpname = tempfile.mkstemp(suffix=ext, prefix='pfann') 65 | try: 66 | obj.download_file(tmpname) 67 | return tmpname 68 | except Exception as x: 69 | os.unlink(tmpname) 70 | raise RuntimeError('Unable to download %s: %s' % (s3url, x)) 71 | 72 | def init_logger(app_name): 73 | os.makedirs('logs', exist_ok=True) 74 | logger = mp.get_logger() 75 | logger.setLevel(logging.INFO) 76 | handler = logging.FileHandler('logs/%s.log' % app_name, encoding="utf8") 77 | handler.setFormatter(logging.Formatter('[%(asctime)s] [%(processName)s/%(levelname)s] %(message)s')) 78 | logger.addHandler(handler) 79 | 80 | class MultiProcessInitLogger: 81 | def __init__(self, app_name): 82 | date_str = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') 83 | self.log_name = app_name + '-' + date_str 84 | def __call__(self, *args): 85 | init_logger(self.log_name) 86 | -------------------------------------------------------------------------------- /testall.sh: -------------------------------------------------------------------------------- 1 | # model is one of: baseline shuffle_1000 shuffle_100 shuffle_10 shuffle_1 2 | # dataset is one of: inside out1 out2 mirex 3 | builder() { 4 | # model dataset 5 | if [ $2 == inside ]; then 6 | list=lists/fma_medium_train.txt 7 | elif [ $2 == out1 ]; then 8 | list=lists/fma_out1.txt 9 | elif [ $2 == out2 ]; then 10 | list=lists/fma_out2.txt 11 | elif [ $2 == mirex ]; then 12 | list=lists/mirex-db.txt 13 | else 14 | echo $2 is not a supported dataset 15 | exit 2 16 | fi 17 | python builder.py $list out/dbs/$1_$2 out/models/$1 18 | } 19 | matcher() { 20 | # model dataset 21 | if [[ $1 =~ ^lm ]]; then 22 | prog=../pfa/matcher 23 | else 24 | prog="python matcher.py" 25 | fi 26 | if [ $2 == mirex ]; then 27 | $prog lists/mirex-query.txt out/dbs/$1_$2 out/results/$1_$2.txt 28 | else 29 | $prog out/queries/$2/list.txt out/dbs/$1_$2 out/results/$1_$2.txt 30 | fi 31 | } 32 | matcher_snr() { 33 | # model dataset snr 34 | if [[ $1 =~ ^lm ]]; then 35 | prog=../pfa/matcher 36 | else 37 | prog="python matcher.py" 38 | fi 39 | $prog out/queries/$2_snr$3/list.txt out/dbs/$1_$2 out/results/$1_$2_snr$3.txt 40 | } 41 | matcher_snr_full() { 42 | # model dataset snr 43 | if [[ $1 =~ ^lm ]]; then 44 | prog=../pfa/matcher 45 | else 46 | prog="python matcher.py" 47 | fi 48 | $prog out/queries/$2_snr$3/list.txt out/dbs/$1_full out/results/$1_$2_full_snr$3.txt 49 | } 50 | accuracy() { 51 | # model dataset 52 | if [ $2 == mirex ]; then 53 | python tools/mirexacc.py lists/mirex-answer.txt out/results/$1_$2.txt 54 | else 55 | python tools/accuracy.py out/queries/$2/expected.csv out/results/$1_$2_detail.csv 56 | fi 57 | } 58 | accuracy_snr() { 59 | echo snr=$3 60 | if [[ $1 =~ ^lm ]]; then 61 | python tools/accuracy.py out/queries/$2_snr$3/expected.csv out/results/$1_$2_snr$3.txt.csv 62 | else 63 | python tools/accuracy.py out/queries/$2_snr$3/expected.csv out/results/$1_$2_snr$3_detail.csv 64 | fi 65 | } 66 | accuracy_snr_full() { 67 | echo snr=$3 68 | if [[ $1 =~ ^lm ]]; then 69 | python tools/accuracy.py out/queries/$2_snr$3/expected.csv out/results/$1_$2_full_snr$3.txt.csv 70 | else 71 | python tools/accuracy.py out/queries/$2_snr$3/expected.csv out/results/$1_$2_full_snr$3_detail.csv 72 | fi 73 | } 74 | forall_snr() { 75 | # some_command model dataset 76 | for snr in -6 -4 -2 0 2 4 6 8 77 | do 78 | $1 $2 $3 $snr 79 | done 80 | } 81 | model="$1" 82 | dataset="$2" 83 | shift 2 84 | while [[ $# -gt 0 ]] 85 | do 86 | action="$1" 87 | shift 88 | case "$action" in 89 | "-build" ) 90 | builder $model $dataset || exit 1;; 91 | "-match_snr" ) 92 | forall_snr matcher_snr $model $dataset || exit 1;; 93 | "-accuracy_snr" ) 94 | forall_snr accuracy_snr $model $dataset || exit 1;; 95 | "-match_snr_full" ) 96 | forall_snr matcher_snr_full $model $dataset || exit 1;; 97 | "-accuracy_snr_full" ) 98 | forall_snr accuracy_snr_full $model $dataset || exit 1;; 99 | "-match" ) 100 | matcher $model $dataset || exit 1;; 101 | "-accuracy" ) 102 | accuracy $model $dataset || exit 1;; 103 | esac 104 | done 105 | -------------------------------------------------------------------------------- /thesis.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stdio2016/pfann/815b6e310d8055517fe59dd94712f42870b28101/thesis.pdf -------------------------------------------------------------------------------- /tools/accuracy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | 5 | args = argparse.ArgumentParser() 6 | args.add_argument('groundtruth') 7 | args.add_argument('predict') 8 | args = args.parse_args() 9 | 10 | with open(args.groundtruth, 'r') as fin: 11 | reader = csv.DictReader(fin) 12 | gt = {} 13 | for row in reader: 14 | name = os.path.basename(row['query']) 15 | gt[name] = row 16 | 17 | with open(args.predict, 'r') as fin: 18 | reader = csv.DictReader(fin) 19 | predict = list(reader) 20 | total = 0 21 | correct = 0 22 | correct_near = 0 23 | correct_exact = 0 24 | fail_time = [] 25 | all_time = [] 26 | for row in predict: 27 | name = os.path.basename(row['query']) 28 | ans = os.path.basename(row['answer']) 29 | actual = gt[name] 30 | actual_ans = os.path.basename(actual['answer']) 31 | total += 1 32 | tm = float(row['time']) 33 | actual_tm = float(actual['time']) 34 | if actual_ans == ans: 35 | correct += 1 36 | if abs(actual_tm - tm) <= 0.25: 37 | correct_exact += 1 38 | if abs(actual_tm - tm) <= 0.5: 39 | correct_near += 1 40 | else: 41 | fail_time.append(actual_tm % 0.5) 42 | all_time.append(actual_tm % 0.5) 43 | print("exact match correct %d acc %.2f" % (correct_exact, correct_exact/total*100)) 44 | print("near match correct %d acc %.2f" % (correct_near, correct_near/total*100)) 45 | print("song correct %d acc %.2f" % (correct, correct/total*100)) 46 | -------------------------------------------------------------------------------- /tools/audioset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | 5 | subway = '/m/0195fx' 6 | singing = '/m/015lz1' 7 | music = '/m/04rlf' 8 | music_related = set() 9 | 10 | def recursive_mark(ont, lbl): 11 | item = ont[lbl] 12 | if lbl not in music_related: 13 | music_related.add(lbl) 14 | for i in item['child_ids']: 15 | recursive_mark(ont, i) 16 | 17 | if __name__ == '__main__': 18 | args = argparse.ArgumentParser() 19 | args.add_argument('csv') 20 | args.add_argument('out') 21 | args.add_argument('--ontology') 22 | args = args.parse_args() 23 | 24 | if args.ontology: 25 | with open(args.ontology, 'r', encoding='utf8') as fin: 26 | ontology = json.load(fin) 27 | ontology = {o['id']: o for o in ontology} 28 | recursive_mark(ontology, singing) 29 | recursive_mark(ontology, music) 30 | 31 | with open(args.csv, 'r', encoding='utf8') as fin: 32 | reader = csv.reader(fin, skipinitialspace=True) 33 | segments = [] 34 | for item in reader: 35 | if item[0].startswith('#'): continue 36 | lbls = set(item[3].split(',')) 37 | if subway in lbls and len(music_related & lbls) == 0: 38 | segments.append(item) 39 | 40 | with open(args.out, 'w', encoding='utf8', newline='\n') as fout: 41 | writer = csv.writer(fout, lineterminator="\r\n") 42 | writer.writerow(['# YTID', 'start_seconds', 'end_seconds', 'positive_labels']) 43 | writer.writerows(segments) 44 | 45 | print(len(segments)) 46 | -------------------------------------------------------------------------------- /tools/audioset2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import subprocess 5 | import time 6 | from datetime import datetime 7 | 8 | def yt_rename(name): 9 | new = [] 10 | for ch in name: 11 | if ch.islower(): 12 | new.append('=') 13 | new.append(ch) 14 | return ''.join(new) 15 | 16 | def download(name, start, end, where, loger): 17 | out_name = '%s_%d.wav' % (name, start) 18 | out_path = os.path.join(where, out_name) 19 | if os.path.exists(out_path): 20 | return 21 | 22 | tmp_name = '%s_%d_tmp.wav' % (name, start) 23 | tmp_path = os.path.join(where, tmp_name) 24 | t1 = time.time() 25 | print('download %s from %d to %d' % (name, start, end)) 26 | loger.write('%s download %s from %d to %d\n' % (datetime.now(), name, start, end)) 27 | loger.flush() 28 | proc = subprocess.Popen(['youtube-dl', '-f', 'bestaudio', 29 | '--get-url', 'https://youtube.com/watch?v=%s' % name, 30 | ], stdout=subprocess.PIPE, stderr=loger) 31 | link = proc.stdout.read().strip() 32 | proc.wait() 33 | if proc.returncode == 0: 34 | proc = subprocess.Popen(['ffmpeg', '-loglevel', 'error', 35 | '-ss', str(start), '-i', link, '-t', str(end-start), 36 | '-y', out_path 37 | ], stdin=None, stderr=subprocess.PIPE) 38 | errs = proc.stderr.read().decode('utf8') 39 | print(errs, end='') 40 | loger.write(errs) 41 | if not os.path.exists(out_path): 42 | print('failed to download ;-(') 43 | loger.write('%s download %s error!\n' % (datetime.now(), name)) 44 | loger.flush() 45 | else: 46 | print('failed to download ;-(') 47 | with open(out_path, 'wb') as fout: 48 | pass 49 | t2 = time.time() 50 | print('stop for a moment~~~') 51 | time.sleep(max(2, 10 - (t2-t1))) 52 | 53 | if __name__ == '__main__': 54 | args = argparse.ArgumentParser() 55 | args.add_argument('csv') 56 | args.add_argument('folder') 57 | args = args.parse_args() 58 | 59 | os.makedirs(args.folder, exist_ok=True) 60 | 61 | with open(args.csv, 'r', encoding='utf8') as fin: 62 | reader = csv.reader(fin, skipinitialspace=True) 63 | segments = [] 64 | nameuu = set() 65 | for item in reader: 66 | if item[0].startswith('#'): continue 67 | 68 | name = item[0] 69 | start = float(item[1]) 70 | end = float(item[2]) 71 | segments.append([name, start, end]) 72 | nameuu.add(name.upper()) 73 | print(len(nameuu), len(segments)) 74 | 75 | loger = open('dlyt.txt', 'a') 76 | loger.write('%s start program...\n' % datetime.now()) 77 | loger.flush() 78 | for name, start, end in segments: 79 | download(name, start, end, args.folder, loger) 80 | loger.write('%s end program...\n' % datetime.now()) 81 | loger.close() 82 | -------------------------------------------------------------------------------- /tools/convert_naf_to_pfann.py: -------------------------------------------------------------------------------- 1 | # Copy this program to neural-audio-fp repo from https://github.com/mimbres/neural-audio-fp 2 | from collections import OrderedDict 3 | import json 4 | import os 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 6 | 7 | import torch 8 | import numpy as np 9 | import tensorflow as tf 10 | import argparse 11 | 12 | # from neural-audio-fp repo 13 | from model.utils.config_gpu_memory_lim import allow_gpu_memory_growth 14 | import run 15 | from model.generate import build_fp, load_checkpoint 16 | 17 | def convert_conv2d(conv, prefix, out): 18 | out[prefix + '.weight'] = conv.get_weights()[0].transpose([3, 2, 0, 1]) 19 | out[prefix + '.bias'] = conv.get_weights()[1] 20 | 21 | def convert_layernorm(ln, prefix, out): 22 | out[prefix + '.weight'] = ln.get_weights()[0].transpose([2, 0, 1]) 23 | out[prefix + '.bias'] = ln.get_weights()[1].transpose([2, 0, 1]) 24 | 25 | def convert_conv_layer(conv, prefix, out): 26 | convert_conv2d(conv.conv2d_1x3, prefix + '.conv1', out) 27 | convert_layernorm(conv.BN_1x3, prefix + '.ln1', out) 28 | convert_conv2d(conv.conv2d_3x1, prefix + '.conv2', out) 29 | convert_layernorm(conv.BN_3x1, prefix + '.ln2', out) 30 | return conv.conv2d_1x3.strides, conv.conv2d_3x1.strides 31 | 32 | if __name__ == '__main__': 33 | args = argparse.ArgumentParser() 34 | args.add_argument('checkpoint_name') 35 | args.add_argument('--checkpoint-index') 36 | args.add_argument('--config', default='default') 37 | args.add_argument('pfann') 38 | args = args.parse_args() 39 | 40 | cfg = run.load_config(args.config) 41 | 42 | # copied from https://github.com/mimbres/neural-audio-fp/blob/main/model/generate.py 43 | # load model from checkpoint 44 | m_pre, m_fp = build_fp(cfg) 45 | model = tf.train.Checkpoint(model=m_fp) 46 | checkpoint_root_dir = cfg['DIR']['LOG_ROOT_DIR'] + 'checkpoint/' 47 | checkpoint_index = load_checkpoint(checkpoint_root_dir, args.checkpoint_name, 48 | args.checkpoint_index, m_fp) 49 | 50 | n_frame = int(cfg['MODEL']['DUR'] * cfg['MODEL']['FS']) 51 | 52 | # initialize model 53 | x = np.zeros([1, 1, n_frame]) 54 | y = m_fp(m_pre(x)).numpy() 55 | 56 | # convert weight 57 | out = OrderedDict() 58 | strides = [] 59 | for lv, conv in enumerate(m_fp.front_conv.layers[:-1]): 60 | stride = convert_conv_layer(conv, 'f.convs.%d' % lv, out) 61 | strides.append(stride) 62 | h = list(out.items())[-1][1].shape[0] 63 | fc1w = [] 64 | fc1b = [] 65 | fc2w = [] 66 | fc2b = [] 67 | for seq in m_fp.div_enc.split_fc_layers: 68 | fc1w.append(seq.layers[0].weights[0]) 69 | fc1b.append(seq.layers[0].weights[1]) 70 | u = seq.layers[0].weights[1].shape[0] 71 | fc2w.append(seq.layers[1].weights[0]) 72 | fc2b.append(seq.layers[1].weights[1]) 73 | out['g.linear1.weight'] = np.expand_dims(np.concatenate(fc1w, axis=1).T, 2) 74 | out['g.linear1.bias'] = np.concatenate(fc1b) 75 | out['g.linear2.weight'] = np.expand_dims(np.concatenate(fc2w, axis=1).T, 2) 76 | out['g.linear2.bias'] = np.concatenate(fc2b) 77 | out = {x:torch.from_numpy(out[x]) for x in out} 78 | 79 | # save weight 80 | os.makedirs(args.pfann, exist_ok=True) 81 | torch.save(out, os.path.join(args.pfann, 'model.pt')) 82 | params = { 83 | "model_dir": args.pfann, 84 | "fftconv_n": 32768, 85 | "sample_rate": cfg['MODEL']['FS'], 86 | "stft_n": cfg['MODEL']['STFT_WIN'], 87 | "stft_hop": cfg['MODEL']['STFT_HOP'], 88 | "n_mels": cfg['MODEL']['N_MELS'], 89 | "dynamic_range": 80, 90 | "f_min": cfg['MODEL']['F_MIN'], 91 | "f_max": cfg['MODEL']['F_MAX'], 92 | "segment_size": cfg['MODEL']['DUR'], 93 | "hop_size": cfg['MODEL']['HOP'], 94 | "naf_mode": True, 95 | "mel_log": "log10", 96 | "spec_norm": "max", 97 | "model": { 98 | "d": cfg['MODEL']['EMB_SZ'], 99 | "h": h, 100 | "u": u, 101 | "fuller": True, 102 | "conv_activation": "ELU", 103 | "relu_after_bn": False, 104 | "strides": strides, 105 | }, 106 | "indexer": { 107 | "index_factory": "IVF200,PQ64x8np", 108 | "top_k": 100, 109 | } 110 | } 111 | with open(os.path.join(args.pfann, 'configs.json'), 'w') as fout: 112 | json.dump(params, fout, indent=2) 113 | -------------------------------------------------------------------------------- /tools/cosinedecay.py: -------------------------------------------------------------------------------- 1 | # draw cosine decay learning rate curve 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | A = torch.Tensor([1]) 6 | A.requires_grad=True 7 | 8 | optimizer = torch.optim.Adam([A], lr=1e-4) 9 | sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10 | T_0=100, eta_min=1e-7) 11 | lrs = [] 12 | for i in range(100): 13 | lrs.append(optimizer.param_groups[0]['lr']) 14 | sched.step() 15 | plt.xlabel('epoch') 16 | plt.ylabel('learning rate') 17 | plt.plot(lrs) 18 | plt.show() 19 | -------------------------------------------------------------------------------- /tools/csv2txt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | 5 | args = argparse.ArgumentParser() 6 | args.add_argument('csv') 7 | args.add_argument('--dir', required=True) 8 | args.add_argument('--out') 9 | args = args.parse_args() 10 | 11 | if not args.out: 12 | args.out = args.csv + '.txt' 13 | 14 | with open(args.csv, 'r', encoding='utf8') as fin, open(args.out, 'w', encoding='utf8') as fout: 15 | reader = csv.reader(fin) 16 | next(reader) 17 | for row in reader: 18 | file_path = os.path.abspath(os.path.join(args.dir, row[0])) 19 | fout.write(file_path + '\n') 20 | -------------------------------------------------------------------------------- /tools/filterduration.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | argp = argparse.ArgumentParser() 5 | argp.add_argument('--csv', required=True) 6 | argp.add_argument('--min-len', type=float, default=0) 7 | argp.add_argument('--max-len', type=float, default=1e999) 8 | argp.add_argument('--out', required=True) 9 | args = argp.parse_args() 10 | 11 | out = [] 12 | with open(args.csv) as fin: 13 | reader = csv.reader(fin) 14 | out.append(next(reader)) 15 | n = 0 16 | for row in reader: 17 | duration = float(row[1]) 18 | n += 1 19 | if args.min_len <= duration <= args.max_len: 20 | out.append(row) 21 | print('total %d sounds, filter remain %d sounds' % (n, len(out)-1)) 22 | 23 | with open(args.out, 'w', newline='\n') as fout: 24 | writer = csv.writer(fout) 25 | writer.writerows(out) -------------------------------------------------------------------------------- /tools/fit.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import curve_fit 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | lm_score = [71.10, 79.65, 86.85, 91.10, 93.30, 95.20, 96.60, 97.70] 6 | nn_score = [59.05, 75.20, 86.40, 92.55, 95.95, 97.30, 98.05, 99.00] 7 | nn_score2 = [81.70, 89.55, 93.30, 95.60, 97.30, 98.10, 98.60, 98.90] 8 | svm_score = [83.75, 90.30, 93.85, 96.05, 97.55, 98.40, 98.80, 99.05] 9 | snr = [-6, -4, -2, 0, 2, 4, 6, 8] 10 | lm_score = np.array(lm_score) * 0.01 11 | nn_score = np.array(nn_score) * 0.01 12 | nn_score2 = np.array(nn_score2) * 0.01 13 | svm_score = np.array(svm_score) * 0.01 14 | snr = np.array(snr) 15 | ali_snr = np.linspace(-7, 10, 100) 16 | 17 | # https://stackoverflow.com/questions/55725139/fit-sigmoid-function-s-shape-curve-to-data-using-python 18 | def sigmoid(x, L ,x0, k): 19 | y = L / (1 + np.exp(-k*(x-x0))) 20 | return (y) 21 | 22 | p0 = [max(lm_score), np.median(snr), 1] # this is an mandatory initial guess 23 | 24 | popt, pcov = curve_fit(sigmoid, snr, lm_score, p0, method='dogbox') 25 | print(popt) 26 | plt.plot(ali_snr, sigmoid(ali_snr, *popt)) 27 | plt.scatter(snr, lm_score) 28 | 29 | popt, pcov = curve_fit(sigmoid, snr, nn_score, p0, method='dogbox') 30 | print(popt) 31 | plt.plot(ali_snr, sigmoid(ali_snr, *popt)) 32 | plt.scatter(snr, nn_score) 33 | 34 | popt, pcov = curve_fit(sigmoid, snr, nn_score2, p0, method='dogbox') 35 | print(popt) 36 | plt.plot(ali_snr, sigmoid(ali_snr, *popt)) 37 | plt.scatter(snr, nn_score2) 38 | 39 | popt, pcov = curve_fit(sigmoid, snr, svm_score, p0, method='dogbox') 40 | print(popt) 41 | plt.plot(ali_snr, sigmoid(ali_snr, *popt)) 42 | plt.scatter(snr, svm_score) 43 | 44 | plt.xlabel('SNR (dB)') 45 | plt.ylabel('accuracy') 46 | plt.legend(['lm', 'lm', 'nn old', 'nn old', 'nn new', 'nn new', 'svm', 'svm']) 47 | plt.show() 48 | -------------------------------------------------------------------------------- /tools/fma_full.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | 4 | dummys = set() 5 | querys = [] 6 | with open('lists/fma_full.csv', 'r') as fin: 7 | reader = csv.reader(fin) 8 | next(reader) 9 | for row in reader: 10 | du = float(row[1]) 11 | if du > 3600 or du < 30: 12 | continue 13 | dummys.add(row[0]) 14 | 15 | with open('lists/fma_medium_test.csv', 'r') as fin: 16 | reader = csv.reader(fin) 17 | next(reader) 18 | for row in reader: 19 | du = float(row[1]) 20 | dummys.discard(row[0]) 21 | querys.append(row[0]) 22 | 23 | dummys = list(dummys) 24 | random.seed(3) 25 | random.shuffle(dummys) 26 | dummys = dummys[0:100000] 27 | dummys.sort() 28 | querys.sort() 29 | 30 | with open('lists/fma_dummy_large.txt', 'w') as fout: 31 | for x in dummys: 32 | fout.write('../pfann_dataset/fma_full/' + x + '\n') 33 | for x in querys: 34 | fout.write('../pfann_dataset/fma_medium/' + x + '\n') 35 | -------------------------------------------------------------------------------- /tools/fma_large.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | 4 | dummys = set() 5 | with open('lists/fma_large.csv', 'r') as fin: 6 | reader = csv.reader(fin) 7 | next(reader) 8 | for row in reader: 9 | du = float(row[1]) 10 | if du < 29.9: 11 | continue 12 | dummys.add(row[0]) 13 | 14 | with open('lists/fma_medium_train.csv', 'r') as fin: 15 | reader = csv.reader(fin) 16 | next(reader) 17 | for row in reader: 18 | du = float(row[1]) 19 | dummys.discard(row[0]) 20 | 21 | vals = [] 22 | with open('lists/fma_medium_val.csv', 'r') as fin: 23 | reader = csv.reader(fin) 24 | next(reader) 25 | for row in reader: 26 | dummys.discard(row[0]) 27 | vals.append(row[0]) 28 | 29 | tests = [] 30 | with open('lists/fma_medium_test.csv', 'r') as fin: 31 | reader = csv.reader(fin) 32 | next(reader) 33 | for row in reader: 34 | dummys.discard(row[0]) 35 | tests.append(row[0]) 36 | 37 | dummys = list(dummys) 38 | random.seed(3) 39 | random.shuffle(dummys) 40 | dummys = dummys[0:10000] 41 | dummys.sort() 42 | vals.sort() 43 | tests.sort() 44 | 45 | with open('lists/fma_out1.txt', 'w') as fout: 46 | for x in dummys: 47 | fout.write('../pfann_dataset/fma_large/' + x + '\n') 48 | for x in vals: 49 | fout.write('../pfann_dataset/fma_medium/' + x + '\n') 50 | 51 | with open('lists/fma_out2.txt', 'w') as fout: 52 | for x in dummys: 53 | fout.write('../pfann_dataset/fma_large/' + x + '\n') 54 | for x in tests: 55 | fout.write('../pfann_dataset/fma_medium/' + x + '\n') 56 | -------------------------------------------------------------------------------- /tools/listaudio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import multiprocessing 4 | from multiprocessing import Pool 5 | import os 6 | import random 7 | import subprocess 8 | import wave 9 | 10 | import tqdm 11 | 12 | argp = argparse.ArgumentParser() 13 | argp.add_argument('--folder', required=True) 14 | argp.add_argument('--sample', type=int) 15 | argp.add_argument('--threads', type=int) 16 | argp.add_argument('--out', default='out.csv') 17 | args = argp.parse_args() 18 | 19 | class HackExtensibleWave: 20 | def __init__(self, stream): 21 | self.stream = stream 22 | self.pos = 0 23 | def read(self, n): 24 | r = self.stream.read(n) 25 | new_pos = self.pos + len(r) 26 | if self.pos < 20 and self.pos + n >= 20: 27 | r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:] 28 | elif 20 <= self.pos < 22: 29 | r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:] 30 | self.pos = new_pos 31 | return r 32 | 33 | def ffmpeg_get_audio_length(filename): 34 | proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'], 35 | stderr=open(os.devnull), 36 | stdin=open(os.devnull), 37 | stdout=subprocess.PIPE, 38 | bufsize=1000000) 39 | try: 40 | wav = wave.open(HackExtensibleWave(proc.stdout)) 41 | smprate = wav.getframerate() 42 | has = 1 43 | n = 0 44 | while has: 45 | has = len(wav.readframes(1000000)) 46 | n += has 47 | n //= wav.getsampwidth() * wav.getnchannels() 48 | #smprate, wav = scipy.io.wavfile.read(proc.stdout) 49 | return n / smprate, smprate, wav.getnchannels() 50 | return wav.shape[0] / smprate, smprate, wav.shape[1] 51 | except (wave.Error, EOFError) as x: 52 | try: 53 | n = os.stat(filename).st_size 54 | if n == 0: 55 | print('file %s is empty!' % filename) 56 | else: 57 | print('failed to decode %s. maybe the file is broken!' % filename) 58 | except: 59 | print('failed to stat %s. maybe it is not a file anymore!' % filename) 60 | return None 61 | 62 | def get_audio_length(filename): 63 | ext = os.path.splitext(filename)[1] 64 | return ffmpeg_get_audio_length(filename) 65 | 66 | formats = {'.wav', '.mp3', '.m4a', '.aac', '.ogg', '.flac', '.webm'} 67 | def find_all_audio(folder, relative, all_files): 68 | for name in os.listdir(folder): 69 | full_name = os.path.join(folder, name) 70 | nxt_relative = os.path.join(relative, name) 71 | if os.path.isdir(full_name): 72 | find_all_audio(full_name, nxt_relative, all_files) 73 | else: 74 | ext = os.path.splitext(name)[1] 75 | if ext in formats: 76 | all_files.append(nxt_relative) 77 | return all_files 78 | 79 | def worker(filename): 80 | folder, relative = filename 81 | return relative, get_audio_length(os.path.join(folder, relative)) 82 | 83 | if __name__ == '__main__': 84 | all_files = [] 85 | print('searching audio files...') 86 | find_all_audio(args.folder, '', all_files) 87 | all_files = [(args.folder, x) for x in all_files] 88 | multiprocessing.set_start_method('spawn') 89 | with Pool(args.threads) as p: 90 | sound_files = [] 91 | with tqdm.tqdm(total=len(all_files)) as pbar: 92 | for i, (filename, du) in enumerate(p.imap_unordered(worker, all_files)): 93 | if du is not None: 94 | sound_files.append([filename, *du]) 95 | pbar.update() 96 | sound_files.sort() 97 | if args.sample: 98 | sound_files = random.sample(sound_files, args.sample) 99 | with open(args.out, 'w', encoding='utf8', newline='\n') as fout: 100 | if args.out.endswith('.csv'): 101 | # csv format with duration info 102 | writer = csv.writer(fout, lineterminator="\r\n") 103 | writer.writerow(['file', 'duration', 'sample_rate', 'channels']) 104 | writer.writerows(sound_files) 105 | else: 106 | # plain text list 107 | for sound_name, duration, smprate, nchannels in sound_files: 108 | fout.write(sound_name + '\n') 109 | -------------------------------------------------------------------------------- /tools/mirexacc.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import argparse 3 | import os 4 | 5 | args = argparse.ArgumentParser() 6 | args.add_argument('groundtruth') 7 | args.add_argument('predict') 8 | args = args.parse_args() 9 | 10 | def extract_ans_txt(file): 11 | with open(file, 'r') as fin: 12 | out = {} 13 | for line in fin: 14 | if line.endswith('\n'): line = line[:-1] 15 | query, ans = line.split('\t') 16 | my_query = os.path.splitext(os.path.split(query)[1])[0] 17 | my_ans = os.path.splitext(os.path.split(ans)[1])[0] 18 | if my_query in out: 19 | print('Warning! query %s occured twice' % query) 20 | out[my_query] = my_ans, 0 21 | return out 22 | 23 | def extract_ans_csv(file): 24 | with open(file, 'r') as fin: 25 | out = {} 26 | reader = csv.reader(fin) 27 | next(reader) 28 | for line in reader: 29 | query, ans = line[:2] 30 | my_query = os.path.splitext(os.path.split(query)[1])[0] 31 | my_ans = os.path.splitext(os.path.split(ans)[1])[0] 32 | if my_query in out: 33 | print('Warning! query %s occured twice' % query) 34 | out[my_query] = my_ans, float(line[2]) 35 | return out 36 | 37 | def extract_ans(file): 38 | if file.endswith('.csv'): 39 | return extract_ans_csv(file) 40 | return extract_ans_txt(file) 41 | 42 | GT = extract_ans(args.groundtruth) 43 | PR = extract_ans(args.predict) 44 | 45 | correct = 0 46 | total = 0 47 | scores = [] 48 | for query in PR: 49 | ans, sco = PR[query] 50 | if query in GT: 51 | real_ans, _ = GT[query] 52 | total += 1 53 | if ans == real_ans: 54 | correct += 1 55 | scores.append((sco, ans==real_ans)) 56 | else: 57 | print('query %s in prediction file not found!!' % query) 58 | print('ARE YOU KIDDING ME?') 59 | exit(1) 60 | print('song correct %d acc %.2f' % (correct, correct/total * 100)) 61 | scores.sort() 62 | if correct == 0: 63 | print('totally wrong') 64 | elif correct == total: 65 | print('all correct') 66 | else: 67 | thres = (scores[total-correct-1][0] + scores[total-correct][0]) / 2 68 | FN = 0 69 | for sco, ok in scores: 70 | if sco > thres: break 71 | FN += ok 72 | print('threshold %f TP %d FN %d' % (thres, correct - FN, FN)) 73 | -------------------------------------------------------------------------------- /tools/stat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | from collections import Counter 4 | 5 | args = argparse.ArgumentParser() 6 | args.add_argument('log') 7 | args = args.parse_args() 8 | 9 | total_times = Counter() 10 | with open(args.log, encoding='utf8') as fin: 11 | for line in fin: 12 | split = line.rfind('] ') 13 | if split == -1: 14 | body = line 15 | else: 16 | body = line[split+2:] 17 | for task in ['load', 'resample', 'stereo to mono', 'compute embedding', 'search', 'rerank', 'output answer', 'total query time']: 18 | s = re.search(task + r' (\d+\.\d+)s', body) 19 | if s: 20 | secs = float(s[1]) 21 | total_times[task] += secs 22 | for task in total_times: 23 | print('%s %.3f s' % (task, total_times[task])) 24 | -------------------------------------------------------------------------------- /tools/traintestsplit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import random 4 | 5 | argp = argparse.ArgumentParser() 6 | argp.add_argument('--csv', default='out.csv') 7 | argp.add_argument('--train-size', type=int) 8 | argp.add_argument('--train', default='train.csv') 9 | argp.add_argument('--test-size', type=int) 10 | argp.add_argument('--test', default='test.csv') 11 | argp.add_argument('-p', '--portion', action='store_true') 12 | args = argp.parse_args() 13 | 14 | random.seed(1) 15 | with open(args.csv, 'r', encoding='utf8') as fin: 16 | reader = csv.reader(fin) 17 | data = [] 18 | firstrow = next(reader) 19 | for row in reader: 20 | data.append(row) 21 | 22 | n = len(data) 23 | if args.portion: 24 | ab = args.train_size + args.test_size 25 | train_size = n * args.train_size // ab 26 | test_size = n - train_size 27 | else: 28 | if args.train_size is None: 29 | if args.test_size is None: 30 | train_size = n//2 31 | else: 32 | train_size = n - args.test_size 33 | else: 34 | train_size = args.train_size 35 | if args.test_size is None: 36 | test_size = n - train_size 37 | else: 38 | test_size = args.test_size 39 | print('There are %d data' % n) 40 | assert train_size + test_size <= n, 'Not enough data for train/test split' 41 | 42 | train_index = random.sample(list(range(n)), train_size) 43 | train_index.sort() 44 | less_index = list(set(range(n)) - set(train_index)) 45 | test_index = random.sample(less_index, test_size) 46 | test_index.sort() 47 | train_data = map(lambda x: data[x], train_index) 48 | 49 | with open(args.train, 'w', encoding='utf8', newline='\n') as fout: 50 | writer = csv.writer(fout) 51 | if firstrow: 52 | writer.writerow(firstrow) 53 | writer.writerows(train_data) 54 | print('train data: %d' % train_size) 55 | 56 | test_data = map(lambda x: data[x], test_index) 57 | with open(args.test, 'w', encoding='utf8', newline='\n') as fout: 58 | writer = csv.writer(fout) 59 | if firstrow: 60 | writer.writerow(firstrow) 61 | writer.writerows(test_data) 62 | print('test data: %d' % test_size) 63 | -------------------------------------------------------------------------------- /tools/wham.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import julius 8 | import miniaudio 9 | import torch 10 | import tqdm 11 | 12 | TOTAL_HOURS = 2.3 13 | TOTAL_SECS = TOTAL_HOURS * 3600 14 | NEW_SAMPLE_RATE = 8000 15 | 16 | def gen_clips(noise_dir, noises, out_dir, out_type, total_secs): 17 | longs = 0 18 | wham_list = [] 19 | out_dir = os.path.join(out_dir, out_type) 20 | os.makedirs(out_dir, exist_ok=True) 21 | with tqdm.tqdm(total=total_secs) as t: 22 | for name in noises: 23 | info = miniaudio.wav_read_file_f32(os.path.join(noise_dir, name)) 24 | du = info.duration 25 | wham_list.append([os.path.join(out_type, name), du]) 26 | longs += du 27 | with open(os.path.join(noise_dir, name), 'rb') as fin: 28 | code = fin.read() 29 | with open(os.path.join(out_dir, name), 'wb') as fout: 30 | fout.write(code) 31 | if longs >= total_secs: 32 | break 33 | t.update(du) 34 | with open(os.path.join(out_dir, 'list.csv'), 'w', encoding='utf8', newline='\n') as fout: 35 | writer = csv.writer(fout) 36 | writer.writerows(wham_list) 37 | return wham_list 38 | 39 | if __name__ == '__main__': 40 | args = argparse.ArgumentParser() 41 | args.add_argument('--wham', required=True) 42 | args.add_argument('--audioset', required=True) 43 | args = args.parse_args() 44 | 45 | wham_dir = os.path.join(args.wham, 'tr') 46 | noises = os.listdir(wham_dir) 47 | random.shuffle(noises) 48 | lst = gen_clips(wham_dir, noises, args.audioset, 'tr', TOTAL_SECS * 0.8) 49 | 50 | wham_dir = os.path.join(args.wham, 'cv') 51 | noises = os.listdir(wham_dir) 52 | random.shuffle(noises) 53 | lst = gen_clips(wham_dir, noises, args.audioset, 'cv', TOTAL_SECS * 0.2) 54 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import shutil 5 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.multiprocessing as mp 12 | import tensorboardX 13 | import torch_optimizer as optim 14 | 15 | from model import FpNetwork 16 | from datautil.dataset_v2 import SegmentedDataLoader 17 | from datautil.mock_data import MockedDataLoader 18 | import simpleutils 19 | from datautil.specaug import SpecAugment 20 | 21 | from torch.cuda.amp import autocast, GradScaler 22 | 23 | # fix PyTorch bug #49630 24 | # apply pull request #49631 25 | CosineAnnealingWarmRestarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts 26 | def new_cosinedecay_init(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False): 27 | if T_0 <= 0 or not isinstance(T_0, int): 28 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 29 | if T_mult < 1 or not isinstance(T_mult, int): 30 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 31 | self.T_0 = T_0 32 | self.T_i = T_0 33 | self.T_mult = T_mult 34 | self.eta_min = eta_min 35 | 36 | self.T_cur = 0 if last_epoch < 0 else last_epoch 37 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose) 38 | 39 | torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.__init__ = new_cosinedecay_init 40 | 41 | def similarity_loss(y, tau): 42 | a = torch.matmul(y, y.T) 43 | a /= tau 44 | Ls = [] 45 | for i in range(y.shape[0]): 46 | nn_self = torch.cat([a[i,:i], a[i,i+1:]]) 47 | softmax = torch.nn.functional.log_softmax(nn_self, dim=0) 48 | Ls.append(softmax[i if i%2 == 0 else i-1]) 49 | Ls = torch.stack(Ls) 50 | 51 | loss = torch.sum(Ls) / -y.shape[0] 52 | return loss 53 | 54 | def train(model, optimizer, train_data, val_data, batch_size, device, params, writer, start_epoch, scaler): 55 | logger = mp.get_logger() 56 | minibatch = 40 57 | if torch.cuda.get_device_properties(0).total_memory > 11e9: 58 | minibatch = 640 59 | total_epoch = params.get('epoch', 100) 60 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 61 | T_0=total_epoch, eta_min=1e-7, last_epoch=start_epoch) 62 | os.makedirs(params['model_dir'], exist_ok=True) 63 | specaug = SpecAugment(params) 64 | for epoch in range(start_epoch+1, total_epoch): 65 | logger.info('epoch %d', epoch+1) 66 | model.train() 67 | tau = params.get('tau', 0.05) 68 | print('epoch %d' % (epoch+1)) 69 | losses = [] 70 | # set dataloadet to train mode 71 | train_data.shuffle = True 72 | train_data.eval_time_shift = False 73 | train_data.augmented = True 74 | train_data.set_epoch(epoch) 75 | 76 | pbar = tqdm(train_data, ncols=80) 77 | for x in pbar: 78 | optimizer.zero_grad() 79 | 80 | x = torch.flatten(x, 0, 1) 81 | x = specaug.augment(x) 82 | if minibatch < batch_size: 83 | with torch.no_grad(): 84 | xs = torch.split(x, minibatch) 85 | ys = [] 86 | for xx in xs: 87 | ys.append(model(xx.to(device))) 88 | # compute gradient of model output 89 | y = torch.cat(ys) 90 | y.requires_grad = True 91 | loss = similarity_loss(y, tau) 92 | loss.backward() 93 | # manual backward 94 | ys = torch.split(y.grad, minibatch) 95 | for xx, yg in zip(xs, ys): 96 | yy = model(xx.to(device)) 97 | yy.backward(yg.to(device)) 98 | else: 99 | with autocast(): 100 | y = model(x.to(device)) 101 | loss = similarity_loss(y, tau) 102 | scaler.scale(loss).backward() 103 | scaler.step(optimizer) 104 | scaler.update() 105 | lossnum = float(loss.item()) 106 | pbar.set_description('loss=%f'%lossnum) 107 | losses.append(lossnum) 108 | writer.add_scalar('train/loss', np.mean(losses), epoch) 109 | print('loss: %f' % np.mean(losses)) 110 | 111 | model.eval() 112 | with torch.no_grad(): 113 | print('validating') 114 | x_embed = [] 115 | # set dataloader to eval mode 116 | train_data.shuffle = False 117 | train_data.eval_time_shift = True 118 | train_data.augmented = False 119 | 120 | for x in tqdm(train_data, desc='train data', ncols=80): 121 | x = x[:, 0] 122 | for xx in torch.split(x, minibatch): 123 | y = model(xx.to(device)).cpu() 124 | x_embed.append(y) 125 | x_embed = torch.cat(x_embed) 126 | train_N = x_embed.shape[0] 127 | acc = 0 128 | validate_N = 0 129 | y_embed = [] 130 | for x in tqdm(val_data, desc='val data', ncols=80): 131 | x = torch.flatten(x, 0, 1) 132 | for xx in torch.split(x, minibatch): 133 | y = model(xx.to(device)).cpu() 134 | y_embed.append(y) 135 | y_embed = torch.cat(y_embed) 136 | y_embed_org = y_embed[0::2] 137 | y_embed_aug = y_embed[1::2].to(device) 138 | 139 | # compute validation score on GPU 140 | self_score = [] 141 | for embeds in torch.split(y_embed_org, 320): 142 | A = torch.matmul(y_embed_aug, embeds.T.to(device)) 143 | self_score.append(A.diagonal(-validate_N).cpu()) 144 | validate_N += embeds.shape[0] 145 | self_score = torch.cat(self_score).to(device) 146 | 147 | ranks = torch.zeros(validate_N, dtype=torch.long).to(device) 148 | for embeds in torch.split(x_embed, 320): 149 | A = torch.matmul(y_embed_aug, embeds.T.to(device)) 150 | ranks += (A.T >= self_score).sum(dim=0) 151 | for embeds in torch.split(y_embed_org, 320): 152 | A = torch.matmul(y_embed_aug, embeds.T.to(device)) 153 | ranks += (A.T >= self_score).sum(dim=0) 154 | acc = int((ranks == 1).sum()) 155 | acc10 = int((ranks <= 10).sum()) 156 | acc20 = int((ranks <= 20).sum()) 157 | acc100 = int((ranks <= 100).sum()) 158 | print('validate score: %f' % (acc / validate_N,)) 159 | writer.add_scalar('validation/accuracy', acc / validate_N, epoch) 160 | writer.add_scalar('validation/top10', acc10 / validate_N, epoch) 161 | writer.add_scalar('validation/top20', acc20 / validate_N, epoch) 162 | writer.add_scalar('validation/top100', acc100 / validate_N, epoch) 163 | #writer.add_scalar('validation/MRR', (1/ranks).mean(), epoch) 164 | scheduler.step() 165 | del A, ranks, self_score, y_embed_aug, y_embed_org, y_embed 166 | writer.flush() 167 | 168 | # save checkpoint 169 | check = { 170 | 'epoch': epoch, 171 | 'model': model.state_dict(), 172 | 'optimizer': optimizer.state_dict(), 173 | 'scaler': scaler.state_dict(), 174 | } 175 | torch.save(check, os.path.join(params['model_dir'], 'checkpoint%d.ckpt' % epoch)) 176 | # cleanup old checkpoints 177 | if epoch % 10 != 0: 178 | try: 179 | os.unlink(os.path.join(params['model_dir'], 'checkpoint%d.ckpt' % (epoch-10))) 180 | except: 181 | pass 182 | with open(os.path.join(params['model_dir'], 'epochs.txt'), 'w') as fout: 183 | fout.write('%d\n' % epoch) 184 | os.makedirs(params['model_dir'], exist_ok=True) 185 | torch.save(model.state_dict(), os.path.join(params['model_dir'], 'model.pt')) 186 | 187 | def test_train(args): 188 | logger = mp.get_logger() 189 | params = simpleutils.read_config(args.params) 190 | torch.manual_seed(123) 191 | torch.cuda.manual_seed(123) 192 | torch.backends.cudnn.benchmark = False 193 | torch.backends.cudnn.deterministic = True 194 | d = params['model']['d'] 195 | h = params['model']['h'] 196 | u = params['model']['u'] 197 | F_bin = params['n_mels'] 198 | segn = int(params['segment_size'] * params['sample_rate']) 199 | T = (segn + params['stft_hop'] - 1) // params['stft_hop'] 200 | batch_size = params['batch_size'] 201 | device = torch.device('cuda') 202 | model = FpNetwork(d, h, u, F_bin, T, params['model']).to(device) 203 | 204 | optimizer = params.get('optimizer', 'adam') 205 | if optimizer == 'lamb': 206 | optimizer = optim.Lamb(model.parameters(), lr=params.get('lr', 1e-4), 207 | weight_decay=1e-6, clamp_value=1e3, debias=True) 208 | else: 209 | optimizer = torch.optim.Adam(model.parameters(), lr=params.get('lr', 1e-4)) 210 | scaler = GradScaler() 211 | 212 | # load checkpoint 213 | os.makedirs(params['model_dir'], exist_ok=True) 214 | epoch = -1 215 | if os.path.exists(os.path.join(params['model_dir'], 'date.txt')): 216 | with open(os.path.join(params['model_dir'], 'date.txt')) as fin: 217 | date_str = next(fin).strip() 218 | else: 219 | date_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 220 | with open(os.path.join(params['model_dir'], 'date.txt'), 'w') as fout: 221 | fout.write(date_str + '\n') 222 | 223 | if os.path.exists(os.path.join(params['model_dir'], 'epochs.txt')): 224 | with open(os.path.join(params['model_dir'], 'epochs.txt')) as fin: 225 | epoch = int(fin.read().strip()) 226 | if epoch+1 >= params.get('epoch', 100): 227 | print('This model has finished training!') 228 | exit(1) 229 | print('Load from epoch %d' % (epoch+1)) 230 | check = torch.load(os.path.join(params['model_dir'], 'checkpoint%d.ckpt' % epoch), map_location='cpu') 231 | model.load_state_dict(check['model']) 232 | optimizer.load_state_dict(check['optimizer']) 233 | if 'scaler' in check: 234 | scaler.load_state_dict(check['scaler']) 235 | torch.cuda.empty_cache() 236 | else: 237 | shutil.copyfile(args.params, os.path.join(params['model_dir'], 'configs.json')) 238 | 239 | # tensorboard visualize 240 | safe_name = os.path.split(params['model_dir'])[1] 241 | if safe_name == '': 242 | safe_name = os.path.split(os.path.split(params['model_dir'])[0])[1] 243 | log_dir = "runs/" + safe_name + '-' + date_str 244 | writer = tensorboardX.SummaryWriter(log_dir) 245 | 246 | if torch.cuda.is_available(): 247 | print('GPU mem usage: %dMB' % (torch.cuda.memory_allocated()/1024**2)) 248 | 249 | logger.info('load augmentation data') 250 | ADataLoader = SegmentedDataLoader 251 | if args.mock: 252 | ADataLoader = MockedDataLoader 253 | 254 | train_data = ADataLoader('train', params, num_workers=args.workers) 255 | print('training data contains %d samples' % len(train_data.dataset)) 256 | 257 | val_data = ADataLoader('validate', params, num_workers=args.workers) 258 | val_data.shuffle = False 259 | val_data.eval_time_shift = True 260 | print('validation data contains %d samples' % len(val_data.dataset)) 261 | 262 | train(model, optimizer, train_data, val_data, batch_size, device, params, writer, epoch, scaler) 263 | 264 | if __name__ == "__main__": 265 | logger_init = simpleutils.MultiProcessInitLogger('train') 266 | logger_init() 267 | logger = mp.get_logger() 268 | logger.info('logger init') 269 | torch.use_deterministic_algorithms(True) 270 | torch.set_num_threads(2) 271 | mp.set_start_method('spawn') 272 | args = argparse.ArgumentParser() 273 | args.add_argument('-p', '--params', default='configs/default.json') 274 | args.add_argument('-w', '--workers', type=int, default=4) 275 | args.add_argument('--mock', action='store_true') 276 | args = args.parse_args() 277 | logger.info(args) 278 | test_train(args) 279 | --------------------------------------------------------------------------------