├── Original.png ├── specmixed.png ├── README.md └── specmix.py /Original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anas-rz/specmix-pytorch/HEAD/Original.png -------------------------------------------------------------------------------- /specmixed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anas-rz/specmix-pytorch/HEAD/specmixed.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of SpecMix : A Mixed Sample Data Augmentation method for Training with Time-Frequency Domain Features 2 | 3 | [Link to the paper](https://arxiv.org/abs/2108.03020) 4 | 5 | Before: 6 | 7 | ![Original Spectrogram](Original.png "Original Spectrogram Without SpecMix") 8 | 9 | After: 10 | 11 | ![SpecMixed](specmixed.png "Time-Frequency Mixed Spectrogram") 12 | 13 | Especially thankful to the author of paper [Gwantae Kim](https://sites.google.com/korea.ac.kr/gwantae-kim/) for his feedback and wonderful insight. 14 | -------------------------------------------------------------------------------- /specmix.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | def get_band(x, min_band_size, max_band_size, band_type, mask): 5 | assert band_type.lower() in ['freq', 'time'], f"band_type must be in ['freq', 'time']" 6 | if band_type.lower() == 'freq': 7 | axis = 2 8 | else: 9 | axis = 1 10 | band_size = random.randint(min_band_size, max_band_size) 11 | mask_start = random.randint(0, x.size()[axis] - band_size) 12 | mask_end = mask_start + band_size 13 | 14 | if band_type.lower() == 'freq': 15 | mask[:, mask_start:mask_end] = 1 16 | if band_type.lower() == 'time': 17 | mask[mask_start:mask_end, :] = 1 18 | return mask 19 | 20 | def specmix(x, y, prob, min_band_size, max_band_size, max_frequency_bands=3, max_time_bands=3): 21 | if prob < 0: 22 | raise ValueError('prob must be a positive value') 23 | 24 | k = random.random() 25 | if k > 1 - prob: 26 | batch_size = x.size()[0] 27 | batch_idx = torch.randperm(batch_size) 28 | print(batch_idx) 29 | mask = torch.zeros(x.size()[1:3]) 30 | num_frequency_bands = random.randint(1, max_frequency_bands) 31 | for i in range(1, num_frequency_bands): 32 | mask = get_band(x, min_band_size, max_band_size, 'freq', mask) 33 | num_time_bands = random.randint(1, max_time_bands) 34 | for i in range(1, num_time_bands): 35 | mask = get_band(x, min_band_size, max_band_size, 'time', mask) 36 | lam = torch.sum(mask) / (x.size()[1] * x.size()[2]) 37 | x = x * (1 - mask) + x[batch_idx] * mask 38 | y = y * (1 - lam) + y[batch_idx] * (lam) 39 | return x, y 40 | else: 41 | return x, y 42 | --------------------------------------------------------------------------------