├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── data ├── Readme.txt └── patient_list_foldwise.txt ├── eval.py ├── image_dataloader.py ├── images ├── blank_region_clipping.PNG ├── concatenation_augmentation.PNG └── overview_large.PNG ├── models ├── 67_sota_training_progress.txt └── ckpt_best.pkl ├── nets ├── __init__.py ├── network_cnn.py └── network_hybrid.py ├── scripts ├── devicewise_script_run.sh ├── eval_script.sh └── train_script_run.sh ├── train.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RespireNet 2020 2 | This is the official implementation of the work **[RespireNet](https://arxiv.org/abs/2011.00196)**. 3 | 4 | ## Dependencies: 5 | 6 | ``` 7 | * Python3 8 | * Pytorch (torch, torchvision and other dependencies for Pytorch) 9 | * Numpy 10 | * Librosa (0.8.0) 11 | * nlpaug (0.0.14) 12 | * OpenCV (4.2.0) 13 | * Pandas (0.22.0) 14 | * scikit-learn (0.23.1) 15 | * tqdm (4.48.0) 16 | * cudnn (CUDA for training on GPU) 17 | ``` 18 | 19 | These are all easily installable via, e.g., `pip install numpy`. Any reasonably recent version of these packages shold work. 20 | It is recommended to use a python `virtual` environment to setup the dependencies and the code. 21 | 22 | ## Dataset 23 | * The dataset used is the **[ICBHI Respiratory Challenge 2017 dataset](https://bhichallenge.med.auth.gr/ICBHI_2017_Challenge)**. 24 | * To find out more details about the dataset please visit the official ICBHI Challenge [webpage](https://bhichallenge.med.auth.gr/ICBHI_2017_Challenge). 25 | * Download from above and place the dataset in the `data` folder. 26 | 27 | ## Train and Test Script 28 | * We follow both the official `60-40` train-test split as well as the `80-20` split. 29 | * For training we employ a 2-stage training protocol. 30 | * Stage 1: The model is trained end-to-end on train data from all the 4 devices. 31 | * Stage 2: The model is fine-tuned (with a lower learning rate `1e-4`) on only the data from a single device. We do this separately for each device (device specific fine-tuning). 32 | 33 | ### Train Command: 34 | 35 | Stage 1 36 | 37 | `python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-3 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4` 38 | 39 | Stage 2 (Device specific fine-tuning) 40 | 41 | `python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 50 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 0 42 | ` 43 | 44 | replace the `stetho_id` as `0 or 1 or 2 or 3` for devices `0-3` 45 | 46 | Please go through our paper for more details. 47 | 48 | ### Test Command: 49 | 50 | Evaluation script 51 | 52 | `python eval.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --batch_size 64 --num_worker 4 --test_fold 4 --checkpoint models/ckpt_best.pkl` 53 | 54 | ## RespireNet Overview 55 | 56 |

57 | 58 |

59 | 60 | ## Quantitative Results 61 | 62 | ### Performance of the proposed model 63 | 64 | | Split & Task | Method | Sp | Se | Score | 65 | |--------------|:------:|:---:|:---:|:-----:| 66 | | 60-40 split & 4-class| CNN | 71.4% | 39.0%| 55.2%| 67 | | | CNN+CBA+BRC | 71.8% | 39.6% | 55.7%| 68 | | | CNN+CBA+BRC+FT | 72.3% | 40.1% | 56.2%| 69 | | 80/20 split & 4-class | CNN | 78.8% | 53.6% | 66.2% | 70 | | |CNN+CBA+BRC | 79.7% | 54.4% | 67.1% | 71 | | |CNN+CBA+BRC+FT | 83.3% | 53.7% | 68.5%| 72 | | 80/20 split & 2-class | CNN | 83.3% | 60.5% | 71.9%| 73 | | | CNN+CBA+BRC | 76.4% | 71.0% |73.7%| 74 | | | CNN+CBA+BRC+FT | 80.9% | 73.1% | 77.0%| 75 | 76 | ``` 77 | Performance comparison of the proposed model with the state-of-the-art systems following random splits. We see significant improvements from our proposed techniques: concatenation-based augmentation (CBA), blank region clipping (BRC) and device specific fine-tuning (FT). 78 | ``` 79 | 80 | ### Effect of time window 81 | |Length | 1 sec | 2 sec | 3 sec | 4 sec | 5 sec | 6 sec | 7 sec | 8 sec | 9 sec| 82 | |----------|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:| 83 | |Scores | 56.6% | 59.0% | 60.3% | 61.1% | 62.3% | 64.4% | 66.2% | 65.1% | 65.5%| 84 | 85 | ### Some additional results 86 | 87 | | Method | Sp | Se | Score | 88 | |--------:|:---:|:---:|:-----:| 89 | | CNN + Mixup | 73.9% | 48.9% | 61.4%| 90 | | CNN + VGGish*| 76.0% | 42.2% | 59.1%| 91 | | Hybrid CNN + Transformer | 75.3% | 49.9% | 63.2%| 92 | 93 | ``` 94 | We also tried experiments with Mixup augmentations, using pre-trained VGGish features and a novel Hybrid CNN + Transformer architecture, however they did not prove to be very useful. 95 | However transformers with appropriate pretraining have found to be useful in many applications (especially NLP tasks) and may prove to be useful in the future. 96 | ``` 97 | *[VGGish link](https://github.com/tensorflow/models/tree/master/research/audioset/vggish)* 98 | 99 | ### About the code ### 100 | 101 | * `train.py`: Main trainer code. 102 | * `image_dataloader.py`: Dataloader module. 103 | * `utils.py`: Contains a lot of utility functions for audio file loading, feature extraction, augmentations, etc. 104 | * `eval.py`: Evaluation source files for trained model. 105 | * `scripts`: Directory which contains the runner scripts. 106 | * `nets`: Contains the different network modules. 107 | * `data`: Training-Testing split and should contain the ICBHI data. 108 | * `models`: Contains the trained checkpoint for our proposed framework. 109 | 110 | ### Blank Region Clipping Scheme 111 | 112 |

113 | 114 |

115 | 116 | ### Concatenation Based Augmentation 117 | 118 | 119 |
120 |

121 | 122 |

123 | 124 | ## To cite this work: 125 | ``` 126 | @misc{gairola2020respirenet, 127 | title={RespireNet: A Deep Neural Network for Accurately Detecting Abnormal Lung Sounds in Limited Data Setting}, 128 | author={Siddhartha Gairola and Francis Tom and Nipun Kwatra and Mohit Jain}, 129 | year={2020}, 130 | eprint={2011.00196}, 131 | archivePrefix={arXiv}, 132 | primaryClass={cs.SD} 133 | } 134 | ``` 135 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /data/Readme.txt: -------------------------------------------------------------------------------- 1 | Keep the data here. 2 | -------------------------------------------------------------------------------- /data/patient_list_foldwise.txt: -------------------------------------------------------------------------------- 1 | 158 0 2 | 193 0 3 | 177 0 4 | 170 0 5 | 180 0 6 | 211 0 7 | 147 0 8 | 107 0 9 | 162 0 10 | 156 0 11 | 146 0 12 | 200 0 13 | 138 0 14 | 160 0 15 | 203 0 16 | 204 0 17 | 172 0 18 | 207 0 19 | 163 0 20 | 205 0 21 | 213 0 22 | 114 0 23 | 130 0 24 | 154 0 25 | 186 0 26 | 184 0 27 | 153 1 28 | 115 1 29 | 224 1 30 | 223 1 31 | 201 1 32 | 218 1 33 | 127 1 34 | 137 1 35 | 215 1 36 | 161 1 37 | 206 1 38 | 101 1 39 | 168 1 40 | 131 1 41 | 216 1 42 | 120 1 43 | 188 1 44 | 167 1 45 | 210 1 46 | 197 1 47 | 183 1 48 | 152 1 49 | 173 1 50 | 108 1 51 | 208 1 52 | 105 2 53 | 110 2 54 | 116 2 55 | 196 2 56 | 182 2 57 | 222 2 58 | 166 2 59 | 209 2 60 | 144 2 61 | 111 2 62 | 165 2 63 | 148 2 64 | 164 2 65 | 159 2 66 | 121 2 67 | 157 2 68 | 217 2 69 | 123 2 70 | 169 2 71 | 179 2 72 | 190 2 73 | 125 2 74 | 129 2 75 | 225 2 76 | 136 2 77 | 118 3 78 | 185 3 79 | 112 3 80 | 124 3 81 | 104 3 82 | 195 3 83 | 175 3 84 | 212 3 85 | 140 3 86 | 219 3 87 | 132 3 88 | 142 3 89 | 220 3 90 | 122 3 91 | 191 3 92 | 128 3 93 | 226 3 94 | 141 3 95 | 103 3 96 | 134 3 97 | 117 3 98 | 192 3 99 | 106 3 100 | 155 3 101 | 199 3 102 | 174 4 103 | 145 4 104 | 151 4 105 | 176 4 106 | 178 4 107 | 133 4 108 | 198 4 109 | 214 4 110 | 149 4 111 | 143 4 112 | 187 4 113 | 202 4 114 | 119 4 115 | 194 4 116 | 126 4 117 | 150 4 118 | 171 4 119 | 102 4 120 | 109 4 121 | 113 4 122 | 139 4 123 | 189 4 124 | 181 4 125 | 221 4 126 | 135 4 127 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com)) 3 | 4 | import os 5 | import itertools 6 | import argparse 7 | import random 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from torch.optim import Adam, lr_scheduler 15 | 16 | 17 | import torchvision 18 | from torchvision.transforms import Compose, Normalize, ToTensor 19 | 20 | import numpy as np 21 | import pandas as pd 22 | import matplotlib 23 | matplotlib.use('Agg') 24 | import matplotlib.pyplot as plt 25 | from sklearn.metrics import confusion_matrix, accuracy_score 26 | 27 | # load external modules 28 | from utils import * 29 | from image_dataloader import * 30 | from nets.network_cnn import * 31 | 32 | print ("Train import done successfully") 33 | 34 | # input argmuments 35 | parser = argparse.ArgumentParser(description='Lung Sound Classification') 36 | parser.add_argument('--steth_id', default=-1.0, type=float, help='learning rate') 37 | parser.add_argument('--gpu_ids', default=[0,1], help='a list of gpus') 38 | parser.add_argument('--num_worker', default=4, type=int, help='numbers of worker') 39 | parser.add_argument('--batch_size', default=4, type=int, help='bacth size') 40 | parser.add_argument('--data_dir', type=str, help='data directory') 41 | parser.add_argument('--folds_file', type=str, help='folds text file') 42 | parser.add_argument('--test_fold', default=4, type=int, help='Test Fold ID') 43 | parser.add_argument('--checkpoint', default=None, type=str, help='load checkpoint') 44 | 45 | args = parser.parse_args() 46 | 47 | ############################################################################## 48 | def get_score(hits, counts, pflag=False): 49 | se = (hits[1] + hits[2] + hits[3]) / (counts[1] + counts[2] + counts[3]) 50 | sp = hits[0] / counts[0] 51 | sc = (se+sp) / 2.0 52 | 53 | if pflag: 54 | print("*************Metrics******************") 55 | print("Se: {}, Sp: {}, Score: {}".format(se, sp, sc)) 56 | print("Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(hits[0]/counts[0], hits[1]/counts[1], 57 | hits[2]/counts[2], hits[3]/counts[3])) 58 | 59 | class Trainer: 60 | def __init__(self): 61 | self.args = args 62 | 63 | mean, std = [0.5091, 0.1739, 0.4363], [0.2288, 0.1285, 0.0743] 64 | self.input_transform = Compose([ToTensor(), Normalize(mean, std)]) 65 | 66 | test_dataset = image_loader(self.args.data_dir, self.args.folds_file, self.args.test_fold, 67 | False, "params_json", self.input_transform, self.args.steth_id) 68 | self.test_ids = np.array(test_dataset.identifiers) 69 | self.test_paths = test_dataset.filenames_with_labels 70 | 71 | # loading checkpoint 72 | self.net = model(num_classes=4).cuda() 73 | if self.args.checkpoint is not None: 74 | checkpoint = torch.load(self.args.checkpoint) 75 | self.net.load_state_dict(checkpoint) 76 | self.net.fine_tune(block_layer=5) 77 | print("Pre-trained Model Loaded:", self.args.checkpoint) 78 | self.net = nn.DataParallel(self.net, device_ids=self.args.gpu_ids) 79 | 80 | self.val_data_loader = DataLoader(test_dataset, num_workers=self.args.num_worker, 81 | batch_size=self.args.batch_size, shuffle=False) 82 | print("Test Size", len(test_dataset)) 83 | print("DATA LOADED") 84 | 85 | self.loss_func = nn.CrossEntropyLoss() 86 | self.loss_nored = nn.CrossEntropyLoss(reduction='none') 87 | 88 | def evaluate(self, net, epoch, iteration): 89 | 90 | self.net.eval() 91 | test_losses = [] 92 | class_hits = [0.0, 0.0, 0.0, 0.0] # normal, crackle, wheeze, both 93 | class_counts = [0.0, 0.0, 0.0+1e-7, 0.0+1e-7] # normal, crackle, wheeze, both 94 | running_corrects = 0.0 95 | denom = 0.0 96 | 97 | classwise_test_losses = [[], [], [], []] 98 | 99 | conf_label = [] 100 | conf_pred = [] 101 | for i, (image, label) in tqdm(enumerate(self.val_data_loader)): 102 | image, label = image.cuda(), label.cuda() 103 | output = self.net(image) 104 | 105 | # calculate loss from output 106 | loss = self.loss_func(output, label) 107 | loss_nored = self.loss_nored(output, label) 108 | test_losses.append(loss.data.cpu().numpy()) 109 | 110 | _, preds = torch.max(output, 1) 111 | running_corrects += torch.sum(preds == label.data) 112 | 113 | # updating denom 114 | denom += len(label.data) 115 | 116 | #class 117 | for idx in range(preds.shape[0]): 118 | class_counts[label[idx].item()] += 1.0 119 | conf_label.append(label[idx].item()) 120 | conf_pred.append(preds[idx].item()) 121 | if preds[idx].item() == label[idx].item(): 122 | class_hits[label[idx].item()] += 1.0 123 | 124 | classwise_test_losses[label[idx].item()].append(loss_nored[idx].item()) 125 | 126 | print("Val Accuracy: {}".format(running_corrects.double() / denom)) 127 | print("epoch {}, Validation BCE loss: {}".format(epoch, np.mean(test_losses))) 128 | 129 | #aggregating same id, majority voting 130 | conf_label = np.array(conf_label) 131 | conf_pred = np.array(conf_pred) 132 | y_pred, y_true = [], [] 133 | for pt in self.test_paths: 134 | y_pred.append(np.argmax(np.bincount(conf_pred[np.where(self.test_ids == pt)]))) 135 | y_true.append(int(pt.split('_')[-1])) 136 | 137 | conf_matrix = confusion_matrix(y_true, y_pred) 138 | acc = accuracy_score(y_true, y_pred) 139 | print("Confusion Matrix", conf_matrix) 140 | print("Accuracy Score", acc) 141 | 142 | conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:,np.newaxis] 143 | print("Classwise Scores", conf_matrix.diagonal()) 144 | return acc, np.mean(test_losses) 145 | 146 | if __name__ == "__main__": 147 | trainer = Trainer() 148 | acc, test_loss = trainer.evaluate(trainer.net, 0, 0) 149 | -------------------------------------------------------------------------------- /image_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com) 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | import cv2 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | import librosa 13 | from tqdm import tqdm 14 | 15 | from utils import * 16 | 17 | class image_loader(Dataset): 18 | def __init__(self, data_dir, folds_file, test_fold, train_flag, params_json, input_transform=None, stetho_id=-1, aug_scale=None): 19 | 20 | # getting device-wise information 21 | self.file_to_device = {} 22 | device_to_id = {} 23 | device_id = 0 24 | files = os.listdir(data_dir) 25 | device_patient_list = [] 26 | pats = [] 27 | for f in files: 28 | device = f.strip().split('_')[-1].split('.')[0] 29 | if device not in device_to_id: 30 | device_to_id[device] = device_id 31 | device_id += 1 32 | device_patient_list.append([]) 33 | self.file_to_device[f.strip().split('.')[0]] = device_to_id[device] 34 | pat = f.strip().split('_')[0] 35 | if pat not in device_patient_list[device_to_id[device]]: 36 | device_patient_list[device_to_id[device]].append(pat) 37 | if pat not in pats: 38 | pats.append(pat) 39 | 40 | print("DEVICE DICT", device_to_id) 41 | for idx in range(device_id): 42 | print("Device", idx, len(device_patient_list[idx])) 43 | 44 | # get patients dict in current fold based on train flag 45 | all_patients = open(folds_file).read().splitlines() 46 | patient_dict = {} 47 | for line in all_patients: 48 | idx, fold = line.strip().split(' ') 49 | if train_flag and int(fold) != test_fold: 50 | patient_dict[idx] = fold 51 | elif train_flag == False and int(fold) == test_fold: 52 | patient_dict[idx] = fold 53 | 54 | #extracting the audiofilenames and the data for breathing cycle and it's label 55 | print("Getting filenames ...") 56 | filenames, rec_annotations_dict = get_annotations(data_dir) 57 | if stetho_id >= 0: 58 | self.filenames = [s for s in filenames if s.split('_')[0] in patient_dict and self.file_to_device[s] == stetho_id] 59 | else: 60 | self.filenames = [s for s in filenames if s.split('_')[0] in patient_dict] 61 | 62 | self.audio_data = [] # each sample is a tuple with id_0: audio_data, id_1: label, id_2: file_name, id_3: cycle id, id_4: aug id, id_5: split id 63 | self.labels = [] 64 | self.train_flag = train_flag 65 | self.data_dir = data_dir 66 | self.input_transform = input_transform 67 | 68 | # parameters for spectrograms 69 | self.sample_rate = 4000 70 | self.desired_length = 8 71 | self.n_mels = 64 72 | self.nfft = 256 73 | self.hop = self.nfft//2 74 | self.f_max = 2000 75 | 76 | self.dump_images = False 77 | self.filenames_with_labels = [] 78 | 79 | # get individual breathing cycles from each audio file 80 | print("Exracting Individual Cycles") 81 | self.cycle_list = [] 82 | self.classwise_cycle_list = [[], [], [], []] 83 | for idx, file_name in tqdm(enumerate(self.filenames)): 84 | data = get_sound_samples(rec_annotations_dict[file_name], file_name, data_dir, self.sample_rate) 85 | cycles_with_labels = [(d[0], d[3], file_name, cycle_idx, 0) for cycle_idx, d in enumerate(data[1:])] 86 | self.cycle_list.extend(cycles_with_labels) 87 | for cycle_idx, d in enumerate(cycles_with_labels): 88 | self.filenames_with_labels.append(file_name+'_'+str(d[3])+'_'+str(d[1])) 89 | self.classwise_cycle_list[d[1]].append(d) 90 | 91 | # concatenation based augmentation scheme 92 | if train_flag and aug_scale: 93 | self.new_augment(scale=aug_scale) 94 | 95 | # split and pad each cycle to the desired length 96 | for idx, sample in enumerate(self.cycle_list): 97 | output = split_and_pad(sample, self.desired_length, self.sample_rate, types=1) 98 | self.audio_data.extend(output) 99 | 100 | self.device_wise = [] 101 | for idx in range(device_id): 102 | self.device_wise.append([]) 103 | self.class_probs = np.zeros(4) 104 | self.identifiers = [] 105 | for idx, sample in enumerate(self.audio_data): 106 | self.class_probs[sample[1]] += 1.0 107 | self.labels.append(sample[1]) 108 | self.identifiers.append(sample[2]+'_'+str(sample[3])+'_'+str(sample[1])) 109 | self.device_wise[self.file_to_device[sample[2]]].append(sample) 110 | 111 | if self.train_flag: 112 | print("TRAIN DETAILS") 113 | else: 114 | print("TEST DETAILS") 115 | 116 | print("CLASSWISE SAMPLE COUNTS:", self.class_probs) 117 | print("Device to ID", device_to_id) 118 | for idx in range(device_id): 119 | print("DEVICE ID", idx, "size", len(self.device_wise[idx])) 120 | self.class_probs = self.class_probs / sum(self.class_probs) 121 | print("CLASSWISE PROBS", self.class_probs) 122 | print("LEN AUDIO DATA", len(self.audio_data)) 123 | 124 | def new_augment(self, scale=1): 125 | 126 | # augment normal 127 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[0]) 128 | for idx in range(aug_nos): 129 | # normal_i + normal_j 130 | i = random.randint(0, len(self.classwise_cycle_list[0])-1) 131 | j = random.randint(0, len(self.classwise_cycle_list[0])-1) 132 | normal_i = self.classwise_cycle_list[0][i] 133 | normal_j = self.classwise_cycle_list[0][j] 134 | new_sample = np.concatenate([normal_i[0], normal_j[0]]) 135 | self.cycle_list.append((new_sample, 0, normal_i[2]+'-'+normal_j[2], 136 | idx, 0)) 137 | self.filenames_with_labels.append(normal_i[2]+'-'+normal_j[2]+'_'+str(idx)+'_0') 138 | 139 | # augment crackle 140 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[1]) 141 | for idx in range(aug_nos): 142 | aug_prob = random.random() 143 | 144 | if aug_prob < 0.6: 145 | # crackle_i + crackle_j 146 | i = random.randint(0, len(self.classwise_cycle_list[1])-1) 147 | j = random.randint(0, len(self.classwise_cycle_list[1])-1) 148 | sample_i = self.classwise_cycle_list[1][i] 149 | sample_j = self.classwise_cycle_list[1][j] 150 | elif aug_prob >= 0.6 and aug_prob < 0.8: 151 | # crackle_i + normal_j 152 | i = random.randint(0, len(self.classwise_cycle_list[1])-1) 153 | j = random.randint(0, len(self.classwise_cycle_list[0])-1) 154 | sample_i = self.classwise_cycle_list[1][i] 155 | sample_j = self.classwise_cycle_list[0][j] 156 | else: 157 | # normal_i + crackle_j 158 | i = random.randint(0, len(self.classwise_cycle_list[0])-1) 159 | j = random.randint(0, len(self.classwise_cycle_list[1])-1) 160 | sample_i = self.classwise_cycle_list[0][i] 161 | sample_j = self.classwise_cycle_list[1][j] 162 | 163 | new_sample = np.concatenate([sample_i[0], sample_j[0]]) 164 | self.cycle_list.append((new_sample, 1, sample_i[2]+'-'+sample_j[2], 165 | idx, 0)) 166 | self.filenames_with_labels.append(sample_i[2]+'-'+sample_j[2]+'_'+str(idx)+'_1') 167 | 168 | # augment wheeze 169 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[2]) 170 | for idx in range(aug_nos): 171 | aug_prob = random.random() 172 | 173 | if aug_prob < 0.6: 174 | # wheeze_i + wheeze_j 175 | i = random.randint(0, len(self.classwise_cycle_list[2])-1) 176 | j = random.randint(0, len(self.classwise_cycle_list[2])-1) 177 | sample_i = self.classwise_cycle_list[2][i] 178 | sample_j = self.classwise_cycle_list[2][j] 179 | elif aug_prob >= 0.6 and aug_prob < 0.8: 180 | # wheeze_i + normal_j 181 | i = random.randint(0, len(self.classwise_cycle_list[2])-1) 182 | j = random.randint(0, len(self.classwise_cycle_list[0])-1) 183 | sample_i = self.classwise_cycle_list[2][i] 184 | sample_j = self.classwise_cycle_list[0][j] 185 | else: 186 | # normal_i + wheeze_j 187 | i = random.randint(0, len(self.classwise_cycle_list[0])-1) 188 | j = random.randint(0, len(self.classwise_cycle_list[2])-1) 189 | sample_i = self.classwise_cycle_list[0][i] 190 | sample_j = self.classwise_cycle_list[2][j] 191 | 192 | new_sample = np.concatenate([sample_i[0], sample_j[0]]) 193 | self.cycle_list.append((new_sample, 2, sample_i[2]+'-'+sample_j[2], 194 | idx, 0)) 195 | self.filenames_with_labels.append(sample_i[2]+'-'+sample_j[2]+'_'+str(idx)+'_2') 196 | 197 | # augment both 198 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[3]) 199 | for idx in range(aug_nos): 200 | aug_prob = random.random() 201 | 202 | if aug_prob < 0.5: 203 | # both_i + both_j 204 | i = random.randint(0, len(self.classwise_cycle_list[3])-1) 205 | j = random.randint(0, len(self.classwise_cycle_list[3])-1) 206 | sample_i = self.classwise_cycle_list[3][i] 207 | sample_j = self.classwise_cycle_list[3][j] 208 | elif aug_prob >= 0.5 and aug_prob < 0.7: 209 | # crackle_i + wheeze_j 210 | i = random.randint(0, len(self.classwise_cycle_list[1])-1) 211 | j = random.randint(0, len(self.classwise_cycle_list[2])-1) 212 | sample_i = self.classwise_cycle_list[1][i] 213 | sample_j = self.classwise_cycle_list[2][j] 214 | elif aug_prob >=0.7 and aug_prob < 0.8: 215 | # wheeze_i + crackle_j 216 | i = random.randint(0, len(self.classwise_cycle_list[2])-1) 217 | j = random.randint(0, len(self.classwise_cycle_list[1])-1) 218 | sample_i = self.classwise_cycle_list[2][i] 219 | sample_j = self.classwise_cycle_list[1][j] 220 | elif aug_prob >=0.8 and aug_prob < 0.9: 221 | # both_i + normal_j 222 | i = random.randint(0, len(self.classwise_cycle_list[3])-1) 223 | j = random.randint(0, len(self.classwise_cycle_list[0])-1) 224 | sample_i = self.classwise_cycle_list[3][i] 225 | sample_j = self.classwise_cycle_list[0][j] 226 | else: 227 | # normal_i + both_j 228 | i = random.randint(0, len(self.classwise_cycle_list[0])-1) 229 | j = random.randint(0, len(self.classwise_cycle_list[3])-1) 230 | sample_i = self.classwise_cycle_list[0][i] 231 | sample_j = self.classwise_cycle_list[3][j] 232 | 233 | new_sample = np.concatenate([sample_i[0], sample_j[0]]) 234 | self.cycle_list.append((new_sample, 3, sample_i[2]+'-'+sample_j[2], 235 | idx, 0)) 236 | self.filenames_with_labels.append(sample_i[2]+'-'+sample_j[2]+'_'+str(idx)+'_3') 237 | 238 | def __getitem__(self, index): 239 | 240 | audio = self.audio_data[index][0] 241 | 242 | aug_prob = random.random() 243 | if self.train_flag and aug_prob > 0.5: 244 | # apply augmentation to audio 245 | audio = gen_augmented(audio, self.sample_rate) 246 | 247 | # pad incase smaller than desired length 248 | audio = split_and_pad([audio, 0,0,0,0], self.desired_length, self.sample_rate, types=1)[0][0] 249 | 250 | # roll audio sample 251 | roll_prob = random.random() 252 | if self.train_flag and roll_prob > 0.5: 253 | audio = rollAudio(audio) 254 | 255 | # convert audio signal to spectrogram 256 | # spectrograms resized to 3x of original size 257 | audio_image = cv2.cvtColor(create_mel_raw(audio, self.sample_rate, f_max=self.f_max, 258 | n_mels=self.n_mels, nfft=self.nfft, hop=self.hop, resz=3), cv2.COLOR_BGR2RGB) 259 | 260 | # blank region clipping 261 | audio_raw_gray = cv2.cvtColor(create_mel_raw(audio, self.sample_rate, f_max=self.f_max, 262 | n_mels=self.n_mels, nfft=self.nfft, hop=self.hop), cv2.COLOR_BGR2GRAY) 263 | 264 | audio_raw_gray[audio_raw_gray < 10] = 0 265 | for row in range(audio_raw_gray.shape[0]): 266 | black_percent = len(np.where(audio_raw_gray[row,:]==0)[0])/len(audio_raw_gray[row,:]) 267 | if black_percent < 0.80: 268 | break 269 | 270 | if (row+1)*3 < audio_image.shape[0]: 271 | audio_image = audio_image[(row+1)*3:, :, :] 272 | audio_image = cv2.resize(audio_image, (audio_image.shape[1], self.n_mels*3), interpolation=cv2.INTER_LINEAR) 273 | 274 | if self.dump_images: 275 | save_images((audio_image, self.audio_data[index][2], self.audio_data[index][3], 276 | self.audio_data[index][5], self.audio_data[index][1]), self.train_flag) 277 | 278 | # label 279 | label = self.audio_data[index][1] 280 | 281 | # apply image transform 282 | if self.input_transform is not None: 283 | audio_image = self.input_transform(audio_image) 284 | 285 | return audio_image, label 286 | 287 | def __len__(self): 288 | return len(self.audio_data) 289 | -------------------------------------------------------------------------------- /images/blank_region_clipping.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/images/blank_region_clipping.PNG -------------------------------------------------------------------------------- /images/concatenation_augmentation.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/images/concatenation_augmentation.PNG -------------------------------------------------------------------------------- /images/overview_large.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/images/overview_large.PNG -------------------------------------------------------------------------------- /models/ckpt_best.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/models/ckpt_best.pkl -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/nets/__init__.py -------------------------------------------------------------------------------- /nets/network_cnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com) 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision.models import resnet18, resnet34, resnet50, densenet121 8 | #from torchsummary import summary 9 | 10 | class model(nn.Module): 11 | def __init__(self, num_classes=10, drop_prob=0.5): 12 | super(model, self).__init__() 13 | 14 | # encoder 15 | self.model_ft = resnet34(pretrained=True) 16 | num_ftrs = self.model_ft.fc.in_features#*4 17 | self.model_ft.fc = nn.Sequential(nn.Dropout(drop_prob), nn.Linear(num_ftrs, 128), nn.ReLU(True), 18 | nn.Dropout(drop_prob), nn.Linear(128, 128), nn.ReLU(True)) 19 | self.cls_fc = nn.Linear(128, num_classes) 20 | 21 | def forward(self, x): 22 | x = self.model_ft(x) 23 | x = self.cls_fc(x) 24 | return x 25 | 26 | def fine_tune(self, block_layer=5): 27 | 28 | for idx, child in enumerate(self.model_ft.children()): 29 | if idx>block_layer: 30 | break 31 | for param in child.parameters(): 32 | param.requires_grad = False 33 | -------------------------------------------------------------------------------- /nets/network_hybrid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com) 3 | # Transformer part has been referred from the following resource: 4 | # https://pytorch.org/tutorials/beginner/transformer_tutorial.html 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torchvision.models import resnet18, resnet34, vgg16 10 | 11 | # Transformer Model 12 | class TransformerModel(nn.Module): 13 | 14 | def __init__(self, ninp, nhead, nhid, nlayers, dropout=0.5, nu_classes=10): 15 | super(TransformerModel, self).__init__() 16 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 17 | self.model_type = 'Transformer' 18 | self.src_mask = None 19 | self.pos_encoder = PositionalEncoding(ninp, dropout) 20 | encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) 21 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 22 | 23 | self.ninp = ninp 24 | self.linear = nn.Sequential(nn.Linear(ninp, 128), nn.ReLU(True)) 25 | self.classifier = nn.Linear(128, nu_classes) 26 | 27 | def _generate_square_subsequent_mask(self, sz): 28 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 29 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 30 | return mask 31 | 32 | def forward(self, src): 33 | src = src.float() 34 | if self.src_mask is None or self.src_mask.size(0) != len(src): 35 | device = src.device 36 | mask = self._generate_square_subsequent_mask(len(src)).to(device) 37 | self.src_mask = mask 38 | 39 | #src = self.encoder(src) * math.sqrt(self.ninp) 40 | src = src*math.sqrt(self.ninp) 41 | src = self.pos_encoder(src) 42 | output = self.transformer_encoder(src, self.src_mask) 43 | output = output.mean(axis=1) 44 | output = self.linear(output) 45 | output = self.classifier(output) 46 | 47 | return output 48 | 49 | 50 | ###################################################################### 51 | # ``PositionalEncoding`` module injects some information about the 52 | # relative or absolute position of the tokens in the sequence. The 53 | # positional encodings have the same dimension as the embeddings so that 54 | # the two can be summed. Here, we use ``sine`` and ``cosine`` functions of 55 | # different frequencies. 56 | # 57 | 58 | class PositionalEncoding(nn.Module): 59 | 60 | def __init__(self, d_model, dropout=0.1, max_len=5000): 61 | super(PositionalEncoding, self).__init__() 62 | self.dropout = nn.Dropout(p=dropout) 63 | 64 | pe = torch.zeros(max_len, d_model) 65 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 66 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 67 | pe[:, 0::2] = torch.sin(position * div_term) 68 | pe[:, 1::2] = torch.cos(position * div_term) 69 | pe = pe.unsqueeze(0).transpose(0, 1) 70 | self.register_buffer('pe', pe) 71 | 72 | def forward(self, x): 73 | x = x + self.pe[:x.size(0), :] 74 | return self.dropout(x) 75 | 76 | class model(nn.Module): 77 | def __init__(self, num_classes=10, drop_prob=0.5, pretrained=True): 78 | super(model, self).__init__() 79 | 80 | resnet = resnet34(pretrained=True) 81 | self.conv_feats = nn.Sequential(*list(resnet.children())[:8]) 82 | self.avgpool = resnet.avgpool 83 | self.reduction = nn.Sequential(nn.Linear(128*24, 128), nn.ReLU(True), nn.BatchNorm1d(128)) 84 | 85 | #num_ftrs = self.embeddings.fc.in_features 86 | #self.embeddings.fc = nn.Sequential(nn.Dropout(drop_prob), nn.Linear(num_ftrs, 128), nn.ReLU(True), 87 | # nn.Dropout(drop_prob), nn.Linear(128, 128), nn.ReLU(True)) 88 | # transformer model 89 | self.trfr_classifier = TransformerModel(128, 2, 200, 4, drop_prob, num_classes) 90 | 91 | def forward(self, x): 92 | 93 | # first extract features from the CNN 94 | x = self.conv_feats(x) 95 | 96 | # reshape the conv features and preserve the width as time 97 | # initial shape of x is batch x channels x height x width 98 | b, c, h, w = x.shape 99 | x = x.view(b, -1, w) 100 | x = x.transpose(1,2) # x: batch x width x num_features; num_features = (channels x height) 101 | 102 | # reducing the higher dimension num_features to 128 103 | x_red = [] 104 | for t in range(x.size(1)): 105 | x_red.append(self.reduction(x[:,t,:])) 106 | 107 | x_red = torch.stack(x_red, dim=0).transpose_(0,1) # batch x time x reduced_num_features 108 | 109 | # pass x through the transformer 110 | return self.trfr_classifier(x_red) 111 | -------------------------------------------------------------------------------- /scripts/devicewise_script_run.sh: -------------------------------------------------------------------------------- 1 | python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 0 2 | #python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 1 3 | #python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 2 4 | #python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 3 5 | -------------------------------------------------------------------------------- /scripts/eval_script.sh: -------------------------------------------------------------------------------- 1 | python eval.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --batch_size 64 --num_worker 4 --test_fold 4 --checkpoint models/ckpt_best.pkl --steth_id -1 2 | -------------------------------------------------------------------------------- /scripts/train_script_run.sh: -------------------------------------------------------------------------------- 1 | python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-3 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com)) 3 | 4 | import os 5 | import itertools 6 | import argparse 7 | import random 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from torch.optim import Adam, lr_scheduler 15 | 16 | 17 | import torchvision 18 | from torchvision.transforms import Compose, Normalize, ToTensor 19 | 20 | import numpy as np 21 | import pandas as pd 22 | import matplotlib 23 | matplotlib.use('Agg') 24 | import matplotlib.pyplot as plt 25 | 26 | # load external modules 27 | from utils import * 28 | from image_dataloader import * 29 | from nets.network_cnn import * 30 | #from nets.network_hybrid import * 31 | from sklearn.metrics import confusion_matrix, accuracy_score 32 | print ("Train import done successfully") 33 | 34 | # input argmuments 35 | parser = argparse.ArgumentParser(description='RespireNet: Lung Sound Classification') 36 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 37 | parser.add_argument('--weight_decay', default=0.0005,help='weight decay value') 38 | parser.add_argument('--gpu_ids', default=[0,1], help='a list of gpus') 39 | parser.add_argument('--num_worker', default=4, type=int, help='numbers of worker') 40 | parser.add_argument('--batch_size', default=4, type=int, help='bacth size') 41 | parser.add_argument('--epochs', default=10, type=int, help='epochs') 42 | parser.add_argument('--start_epochs', default=0, type=int, help='start epochs') 43 | 44 | parser.add_argument('--data_dir', type=str, help='data directory') 45 | parser.add_argument('--folds_file', type=str, help='folds text file') 46 | parser.add_argument('--test_fold', default=4, type=int, help='Test Fold ID') 47 | parser.add_argument('--stetho_id', default=-1, type=int, help='Stethoscope device id') 48 | parser.add_argument('--aug_scale', default=None, type=float, help='Augmentation multiplier') 49 | parser.add_argument('--model_path',type=str, help='model saving directory') 50 | parser.add_argument('--checkpoint', default=None, type=str, help='load checkpoint') 51 | 52 | args = parser.parse_args() 53 | 54 | ################################MIXUP##################################### 55 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 56 | '''Returns mixed inputs, pairs of targets, and lambda''' 57 | if alpha > 0: 58 | lam = np.random.beta(alpha, alpha) 59 | else: 60 | lam = 1 61 | 62 | batch_size = x.size()[0] 63 | if use_cuda: 64 | index = torch.randperm(batch_size).cuda() 65 | else: 66 | index = torch.randperm(batch_size) 67 | 68 | mixed_x = lam * x + (1 - lam) * x[index, :] 69 | y_a, y_b = y, y[index] 70 | return mixed_x, y_a, y_b, lam 71 | 72 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 73 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 74 | 75 | ############################################################################## 76 | def get_score(hits, counts, pflag=False): 77 | se = (hits[1] + hits[2] + hits[3]) / (counts[1] + counts[2] + counts[3]) 78 | sp = hits[0] / counts[0] 79 | sc = (se+sp) / 2.0 80 | 81 | if pflag: 82 | print("*************Metrics******************") 83 | print("Se: {}, Sp: {}, Score: {}".format(se, sp, sc)) 84 | print("Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(hits[0]/counts[0], hits[1]/counts[1], 85 | hits[2]/counts[2], hits[3]/counts[3])) 86 | 87 | class Trainer: 88 | def __init__(self): 89 | self.args = args 90 | mean, std = get_mean_and_std(image_loader(self.args.data_dir, self.args.folds_file, 91 | self.args.test_fold, True, "Params_json", Compose([ToTensor()]), stetho_id=self.args.stetho_id)) 92 | print("MEAN", mean, "STD", std) 93 | 94 | self.input_transform = Compose([ToTensor(), Normalize(mean, std)]) 95 | train_dataset = image_loader(self.args.data_dir, self.args.folds_file, self.args.test_fold, 96 | True, "params_json", self.input_transform, stetho_id=self.args.stetho_id, aug_scale=self.args.aug_scale) 97 | test_dataset = image_loader(self.args.data_dir, self.args.folds_file, self.args.test_fold, 98 | False, "params_json", self.input_transform, stetho_id=self.args.stetho_id) 99 | self.test_ids = np.array(test_dataset.identifiers) 100 | self.test_paths = test_dataset.filenames_with_labels 101 | 102 | # loading checkpoint 103 | self.net = model(num_classes=4).cuda() 104 | if self.args.checkpoint is not None: 105 | checkpoint = torch.load(self.args.checkpoint) 106 | self.net.load_state_dict(checkpoint) 107 | # uncomment in case fine-tuning, specify block layer 108 | # before block_layer, all layers will be frozen durin training 109 | #self.net.fine_tune(block_layer=5) 110 | print("Pre-trained Model Loaded:", self.args.checkpoint) 111 | self.net = nn.DataParallel(self.net, device_ids=self.args.gpu_ids) 112 | 113 | # weighted sampler 114 | reciprocal_weights = [] 115 | for idx in range(len(train_dataset)): 116 | reciprocal_weights.append(train_dataset.class_probs[train_dataset.labels[idx]]) 117 | weights = (1 / torch.Tensor(reciprocal_weights)) 118 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(train_dataset)) 119 | 120 | self.train_data_loader = DataLoader(train_dataset, num_workers=self.args.num_worker, 121 | batch_size=self.args.batch_size, sampler=sampler) 122 | self.val_data_loader = DataLoader(test_dataset, num_workers=self.args.num_worker, 123 | batch_size=self.args.batch_size, shuffle=False) 124 | print("DATA LOADED") 125 | 126 | print("Params to learn:") 127 | params_to_update = [] 128 | for name,param in self.net.named_parameters(): 129 | if param.requires_grad == True: 130 | params_to_update.append(param) 131 | 132 | # Observe that all parameters are being optimized 133 | self.optimizer = optim.SGD(params_to_update, lr=self.args.lr, momentum=0.9, weight_decay=self.args.weight_decay) 134 | #self.optimizer = optim.Adam(params_to_update, lr=self.args.lr, weight_decay=self.args.weight_decay) 135 | 136 | # Decay LR by a factor 137 | #self.exp_lr_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=20, gamma=0.33) 138 | 139 | # weights for the loss function 140 | weights = torch.tensor([3.0, 1.0, 1.0, 1.0], dtype=torch.float32) 141 | #weights = torch.tensor(train_dataset.class_probs, dtype=torch.float32) 142 | weights = weights / weights.sum() 143 | weights = 1.0 / weights 144 | weights = weights / weights.sum() 145 | weights = weights.cuda() 146 | self.loss_func = nn.CrossEntropyLoss(weight=weights) 147 | self.loss_nored = nn.CrossEntropyLoss(reduction='none') 148 | 149 | def evaluate(self, net, epoch, iteration): 150 | 151 | self.net.eval() 152 | test_losses = [] 153 | class_hits = [0.0, 0.0, 0.0, 0.0] # normal, crackle, wheeze, both 154 | class_counts = [0.0, 0.0, 0.0+1e-7, 0.0+1e-7] # normal, crackle, wheeze, both 155 | running_corrects = 0.0 156 | denom = 0.0 157 | 158 | classwise_test_losses = [[], [], [], []] 159 | conf_label, conf_pred = [], [] 160 | for i, (image, label) in tqdm(enumerate(self.val_data_loader)): 161 | image, label = image.cuda(), label.cuda() 162 | output = self.net(image) 163 | 164 | # calculate loss from output 165 | loss = self.loss_func(output, label) 166 | loss_nored = self.loss_nored(output, label) 167 | test_losses.append(loss.data.cpu().numpy()) 168 | 169 | _, preds = torch.max(output, 1) 170 | running_corrects += torch.sum(preds == label.data) 171 | 172 | # updating denom 173 | denom += len(label.data) 174 | 175 | #class 176 | for idx in range(preds.shape[0]): 177 | class_counts[label[idx].item()] += 1.0 178 | conf_label.append(label[idx].item()) 179 | conf_pred.append(preds[idx].item()) 180 | if preds[idx].item() == label[idx].item(): 181 | class_hits[label[idx].item()] += 1.0 182 | 183 | classwise_test_losses[label[idx].item()].append(loss_nored[idx].item()) 184 | 185 | print("Val Accuracy: {}".format(running_corrects.double() / denom)) 186 | print("epoch {}, Validation BCE loss: {}".format(epoch, np.mean(test_losses))) 187 | #print("Classwise_Losses Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(np.mean(classwise_test_losses[0]), 188 | # np.mean(classwise_test_losses[1]), np.mean(classwise_test_losses[2]), np.mean(classwise_test_losses[3]))) 189 | #get_score(class_hits, class_counts, True) 190 | 191 | #aggregating same id, majority voting 192 | conf_label = np.array(conf_label) 193 | conf_pred = np.array(conf_pred) 194 | y_pred, y_true = [], [] 195 | for pt in self.test_paths: 196 | y_pred.append(np.argmax(np.bincount(conf_pred[np.where(self.test_ids == pt)]))) 197 | y_true.append(int(pt.split('_')[-1])) 198 | 199 | conf_matrix = confusion_matrix(y_true, y_pred) 200 | acc = accuracy_score(y_true, y_pred) 201 | print("Confusion Matrix", conf_matrix) 202 | print("Accuracy Score", acc) 203 | conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:,np.newaxis] 204 | print("Classwise Scores", conf_matrix.diagonal()) 205 | self.net.train() 206 | 207 | return acc, np.mean(test_losses) 208 | 209 | def train(self): 210 | train_losses = [] 211 | test_losses = [] 212 | test_acc = [] 213 | best_acc = -1 214 | 215 | for _, epoch in tqdm(enumerate(range(self.args.start_epochs, self.args.epochs))): 216 | losses = [] 217 | class_hits = [0.0, 0.0, 0.0, 0.0] 218 | class_counts = [0.0+1e-7, 0.0+1e-7, 0.0+1e-7, 0.0+1e-7] 219 | running_corrects = 0.0 220 | denom = 0.0 221 | classwise_train_losses = [[], [], [], []] 222 | 223 | for i, (image, label) in tqdm(enumerate(self.train_data_loader)): 224 | 225 | image, label = image.cuda(), label.cuda() 226 | # in case using mixup, uncomment 2 lines below 227 | #image, label_a, label_b, lam = mixup_data(image, label, alpha=0.5) 228 | #image, label_a, label_b = map(Variable, (image, label_a, label_b)) 229 | 230 | output = self.net(image) 231 | 232 | # calculate loss from output 233 | # in case using mixup, uncomment line below and comment the next line 234 | #loss = mixup_criterion(self.loss_func, output, label_a, label_b, lam) 235 | loss = self.loss_func(output, label) 236 | loss_nored = self.loss_nored(output, label) 237 | 238 | _, preds = torch.max(output, 1) 239 | running_corrects += torch.sum(preds == label.data) 240 | denom += len(label.data) 241 | 242 | #class 243 | for idx in range(preds.shape[0]): 244 | class_counts[label[idx].item()] += 1.0 245 | if preds[idx].item() == label[idx].item(): 246 | class_hits[label[idx].item()] += 1.0 247 | classwise_train_losses[label[idx].item()].append(loss_nored[idx].item()) 248 | 249 | self.optimizer.zero_grad() 250 | loss.backward() 251 | self.optimizer.step() 252 | losses.append(loss.data.cpu().numpy()) 253 | 254 | if i % 1000 == self.train_data_loader.__len__()-1: 255 | print("---------------------------------------------") 256 | print("epoch {} iter {}/{} Train Total loss: {}".format(epoch, 257 | i, len(self.train_data_loader), np.mean(losses))) 258 | print("Train Accuracy: {}".format(running_corrects.double() / denom)) 259 | print("Classwise_Losses Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(np.mean(classwise_train_losses[0]), 260 | np.mean(classwise_train_losses[1]), np.mean(classwise_train_losses[2]), np.mean(classwise_train_losses[3]))) 261 | get_score(class_hits, class_counts, True) 262 | 263 | print("testing......") 264 | acc, test_loss = self.evaluate(self.net, epoch, i) 265 | 266 | if best_acc < acc: 267 | best_acc = acc 268 | torch.save(self.net.module.state_dict(), args.model_path+'/ckpt_best_'+str(self.args.epochs)+'_'+str(self.args.stetho_id)+'.pkl') 269 | print("Best ACC achieved......", best_acc.item()) 270 | print("BEST ACCURACY TILL NOW", best_acc) 271 | 272 | train_losses.append(np.mean(losses)) 273 | test_losses.append(test_loss) 274 | test_acc.append(acc) 275 | #self.exp_lr_scheduler.step() 276 | 277 | if __name__ == "__main__": 278 | trainer = Trainer() 279 | trainer.train() 280 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com)) 3 | import numpy as np 4 | import os 5 | import io 6 | import math 7 | import random 8 | import pandas as pd 9 | 10 | import matplotlib.pyplot as plt 11 | import librosa 12 | import librosa.display 13 | import cv2 14 | import cmapy 15 | 16 | import nlpaug 17 | import nlpaug.augmenter.audio as naa 18 | 19 | import torch 20 | 21 | from scipy.signal import butter, lfilter 22 | 23 | def butter_bandpass(lowcut, highcut, fs, order=5): 24 | nyq = 0.5 * fs 25 | low = lwcut / nyq 26 | high = highcut / nyq 27 | b, a = butter(order, [low, high], btype='band') 28 | return b, a 29 | 30 | def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): 31 | b, a = butter_bandpass(lowcut, highcut, fs, order=order) 32 | y = lfilter(b,a, data) 33 | return y 34 | 35 | def Extract_Annotation_Data(file_name, data_dir): 36 | tokens = file_name.split('_') 37 | recording_info = pd.DataFrame(data = [tokens], columns = ['Patient Number', 'Recording index', 'Chest location','Acquisition mode','Recording equipment']) 38 | recording_annotations = pd.read_csv(os.path.join(data_dir, file_name + '.txt'), names = ['Start', 'End', 'Crackles', 'Wheezes'], delimiter= '\t') 39 | return recording_info, recording_annotations 40 | 41 | # get annotations data and filenames 42 | def get_annotations(data_dir): 43 | filenames = [s.split('.')[0] for s in os.listdir(data_dir) if '.txt' in s] 44 | i_list = [] 45 | rec_annotations_dict = {} 46 | for s in filenames: 47 | i,a = Extract_Annotation_Data(s, data_dir) 48 | i_list.append(i) 49 | rec_annotations_dict[s] = a 50 | 51 | recording_info = pd.concat(i_list, axis = 0) 52 | recording_info.head() 53 | 54 | return filenames, rec_annotations_dict 55 | 56 | def slice_data(start, end, raw_data, sample_rate): 57 | max_ind = len(raw_data) 58 | start_ind = min(int(start * sample_rate), max_ind) 59 | end_ind = min(int(end * sample_rate), max_ind) 60 | return raw_data[start_ind: end_ind] 61 | 62 | def get_label(crackle, wheeze): 63 | if crackle == 0 and wheeze == 0: 64 | return 0 65 | elif crackle == 1 and wheeze == 0: 66 | return 1 67 | elif crackle == 0 and wheeze == 1: 68 | return 2 69 | else: 70 | return 3 71 | #Used to split each individual sound file into separate sound clips containing one respiratory cycle each 72 | #output: [filename, (sample_data:np.array, start:float, end:float, label:int (...) ] 73 | #label: 0-normal, 1-crackle, 2-wheeze, 3-both 74 | def get_sound_samples(recording_annotations, file_name, data_dir, sample_rate): 75 | sample_data = [file_name] 76 | # load file with specified sample rate (also converts to mono) 77 | data, rate = librosa.load(os.path.join(data_dir, file_name+'.wav'), sr=sample_rate) 78 | #print("Sample Rate", rate) 79 | 80 | for i in range(len(recording_annotations.index)): 81 | row = recording_annotations.loc[i] 82 | start = row['Start'] 83 | end = row['End'] 84 | crackles = row['Crackles'] 85 | wheezes = row['Wheezes'] 86 | audio_chunk = slice_data(start, end, data, rate) 87 | sample_data.append((audio_chunk, start,end, get_label(crackles, wheezes))) 88 | return sample_data 89 | 90 | 91 | # split samples according to desired length 92 | ''' 93 | types: 94 | * 0: simply pad by zeros 95 | * 1: pad with duplicate on both sides (half-n-half) 96 | * 2: pad with augmented sample on both sides (half-n-half) 97 | ''' 98 | def split_and_pad(original, desiredLength, sample_rate, types=0): 99 | if types==0: 100 | return split_and_pad_old(original, desiredLength, sample_rate) 101 | 102 | output_buffer_length = int(desiredLength*sample_rate) 103 | soundclip = original[0].copy() 104 | n_samples = len(soundclip) 105 | 106 | output = [] 107 | # if: the audio sample length > desiredLength, then split & pad 108 | # else: simply pad according to given type 1 or 2 109 | if n_samples > output_buffer_length: 110 | frames = librosa.util.frame(soundclip, frame_length=output_buffer_length, hop_length=output_buffer_length//2, axis=0) 111 | for i in range(frames.shape[0]): 112 | output.append((frames[i], original[1], original[2], original[3], original[4], i, 0)) 113 | 114 | last_id = frames.shape[0]*(output_buffer_length//2) 115 | last_sample = soundclip[last_id:]; pad_times = (output_buffer_length-len(last_sample))/len(last_sample) 116 | padded = generate_padded_samples(soundclip, last_sample, output_buffer_length, sample_rate, types) 117 | output.append((padded, original[1], original[2], original[3], original[4], i+1, pad_times)) 118 | 119 | else: 120 | padded = generate_padded_samples(soundclip, soundclip, output_buffer_length, sample_rate, types); pad_times = (output_buffer_length-len(soundclip))/len(soundclip) 121 | output.append((padded, original[1], original[2], original[3], original[4], 0, pad_times)) 122 | 123 | return output 124 | 125 | def split_and_pad_old(original, desiredLength, sample_rate): 126 | output_buffer_length = int(desiredLength * sample_rate) 127 | soundclip = original[0].copy() 128 | n_samples = len(soundclip) 129 | total_length = n_samples / sample_rate #length of cycle in seconds 130 | n_slices = int(math.ceil(total_length / desiredLength)) #get the minimum number of slices needed 131 | samples_per_slice = n_samples // n_slices 132 | src_start = 0 #Staring index of the samples to copy from the original buffer 133 | output = [] #Holds the resultant slices 134 | for i in range(n_slices): 135 | src_end = min(src_start + samples_per_slice, n_samples) 136 | length = src_end - src_start 137 | copy = generate_padded_samples_old(soundclip[src_start:src_end], output_buffer_length) 138 | output.append((copy, original[1], original[2], original[3], original[4], i)) 139 | src_start += length 140 | return output 141 | 142 | def generate_padded_samples_old(source, output_length): 143 | copy = np.zeros(output_length, dtype = np.float32) 144 | src_length = len(source) 145 | frac = src_length / output_length 146 | if(frac < 0.5): 147 | #tile forward sounds to fill empty space 148 | cursor = 0 149 | while(cursor + src_length) < output_length: 150 | copy[cursor:(cursor + src_length)] = source[:] 151 | cursor += src_length 152 | else: 153 | copy[:src_length] = source[:] 154 | return copy 155 | 156 | def generate_padded_samples(original, source, output_length, sample_rate, types): 157 | copy = np.zeros(output_length, dtype=np.float32) 158 | src_length = len(source) 159 | left = output_length-src_length # amount to be padded 160 | # pad front or back 161 | prob = random.random() 162 | if types == 1: 163 | aug = original 164 | else: 165 | aug = gen_augmented(original, sample_rate) 166 | 167 | while len(aug) < left: 168 | aug = np.concatenate([aug, aug]) 169 | 170 | if prob < 0.5: 171 | #pad back 172 | copy[left:] = source 173 | copy[:left] = aug[len(aug)-left:] 174 | else: 175 | #pad front 176 | copy[:src_length] = source[:] 177 | copy[src_length:] = aug[:left] 178 | 179 | return copy 180 | 181 | 182 | #**********************DATA AUGMENTAION*************************** 183 | #Creates a copy of each time slice, but stretches or contracts it by a random amount 184 | def gen_augmented(original, sample_rate): 185 | # list of augmentors available from the nlpaug library 186 | augment_list = [ 187 | #naa.CropAug(sampling_rate=sample_rate) 188 | naa.NoiseAug(), 189 | naa.SpeedAug(), 190 | naa.LoudnessAug(factor=(0.5, 2)), 191 | naa.VtlpAug(sampling_rate=sample_rate, zone=(0.0, 1.0)), 192 | naa.PitchAug(sampling_rate=sample_rate, factor=(-1,3)) 193 | ] 194 | # sample augmentation randomly 195 | aug_idx = random.randint(0, len(augment_list)-1) 196 | augmented_data = augment_list[aug_idx].augment(original) 197 | return augmented_data 198 | 199 | #Same as above, but applies it to a list of samples 200 | def augment_list(audio_with_labels, sample_rate, n_repeats): 201 | augmented_samples = [] 202 | for i in range(n_repeats): 203 | addition = [(gen_augmented(t[0], sample_rate), t[1], t[2], t[3], t[4]+i+1 ) for t in audio_with_labels] 204 | augmented_samples.extend(addition) 205 | return augmented_samples 206 | 207 | def create_spectrograms(current_window, sample_rate, n_mels=128, f_min=50, f_max=4000, nfft=2048, hop=512): 208 | fig = plt.figure(figsize=[1.0, 1.0]) 209 | #fig = plt.figure(figsize=[0.72,0.72]) 210 | ax = fig.add_subplot(111) 211 | ax.axes.get_xaxis().set_visible(False) 212 | ax.axes.get_yaxis().set_visible(False) 213 | ax.set_frame_on(False) 214 | S = librosa.feature.melspectrogram(y=current_window, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop) 215 | librosa.display.specshow(librosa.power_to_db(S, ref=np.max)) 216 | 217 | # There may be a better way to do the following, skipping it for now. 218 | buf = io.BytesIO() 219 | plt.savefig(buf, dpi=800, bbox_inches='tight',pad_inches=0) 220 | buf.seek(0) 221 | img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) 222 | buf.close() 223 | img = cv2.imdecode(img_arr, 1) 224 | plt.close('all') 225 | return img 226 | 227 | def create_spectrograms_raw(current_window, sample_rate, n_mels=128, f_min=50, f_max=4000, nfft=2048, hop=512): 228 | S = librosa.feature.melspectrogram(y=current_window, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop) 229 | S = librosa.power_to_db(S, ref=np.max) 230 | S = (S-S.min()) / (S.max() - S.min()) 231 | S *= 255 232 | img = cv2.applyColorMap(S.astype(np.uint8), cmapy.cmap('viridis')) 233 | height, width, _ = img.shape 234 | img = cv2.resize(img, (width*3, height*3), interpolation=cv2.INTER_LINEAR) 235 | return img 236 | 237 | def create_mel_raw(current_window, sample_rate, n_mels=128, f_min=50, f_max=4000, nfft=2048, hop=512, resz=1): 238 | S = librosa.feature.melspectrogram(y=current_window, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop) 239 | S = librosa.power_to_db(S, ref=np.max) 240 | S = (S-S.min()) / (S.max() - S.min()) 241 | S *= 255 242 | img = cv2.applyColorMap(S.astype(np.uint8), cmapy.cmap('magma')) 243 | height, width, _ = img.shape 244 | if resz > 0: 245 | img = cv2.resize(img, (width*resz, height*resz), interpolation=cv2.INTER_LINEAR) 246 | img = cv2.flip(img, 0) 247 | return img 248 | 249 | #Transpose and wrap each array along the time axis 250 | def rollFFT(fft): 251 | n_row, n_col = fft.shape[:2] 252 | pivot = np.random.randint(n_col) 253 | return np.reshape(np.roll(fft, pivot, axis = 1), (n_row, n_col, 1)) 254 | 255 | def rollAudio(audio): 256 | # expect audio to be 1 dimensional 257 | pivot = np.random.randint(audio.shape[0]) 258 | rolled_audio = np.roll(audio, pivot, axis=0) 259 | assert audio.shape[0] == rolled_audio.shape[0], "Roll audio shape mismatch" 260 | return rolled_audio 261 | 262 | # others 263 | def get_mean_and_std(dataset): 264 | '''Compute the mean and std value of dataset.''' 265 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 266 | mean = torch.zeros(3) 267 | std = torch.zeros(3) 268 | print('==> Computing mean and std..') 269 | for inputs, targets in dataloader: 270 | #inputs = inputs[:,0,:,:,:] 271 | for i in range(3): 272 | mean[i] += inputs[:,i,:,:].mean() 273 | std[i] += inputs[:,i,:,:].std() 274 | mean.div_(len(dataset)) 275 | std.div_(len(dataset)) 276 | return mean, std 277 | 278 | def init_params(net): 279 | '''Init layer parameters.''' 280 | for m in net.modules(): 281 | if isinstance(m, nn.Conv2d): 282 | init.kaiming_normal(m.weight, mode='fan_out') 283 | if m.bias: 284 | init.constant(m.bias, 0) 285 | elif isinstance(m, nn.BatchNorm2d): 286 | init.constant(m.weight, 1) 287 | init.constant(m.bias, 0) 288 | elif isinstance(m, nn.Linear): 289 | init.normal(m.weight, std=1e-3) 290 | if m.bias: 291 | init.constant(m.bias, 0) 292 | 293 | # save images from dataloader in dump_images/train or dump_images/test based on train_flag 294 | # expect images to be in RGB format 295 | # image is a tuple: (spectrogram, filename, label, cycle, split) 296 | def save_images(image, train_flag): 297 | save_dir = 'dump_image' 298 | if not os.path.isdir(save_dir): 299 | os.makedirs(save_dir) 300 | 301 | if train_flag: 302 | save_dir = os.path.join(save_dir, 'train') 303 | if not os.path.isdir(save_dir): 304 | os.makedirs(save_dir) 305 | cv2.imwrite(os.path.join(save_dir, image[1]+'_'+str(image[2])+'_'+str(image[3])+'_'+str(image[4])+'.jpg'), cv2.cvtColor(image[0], cv2.COLOR_RGB2BGR)) 306 | else: 307 | save_dir = os.path.join(save_dir, 'test') 308 | if not os.path.isdir(save_dir): 309 | os.makedirs(save_dir) 310 | cv2.imwrite(os.path.join(save_dir, image[1]+'_'+str(image[2])+'_'+str(image[3])+'_'+str(image[4])+'.jpg'), cv2.cvtColor(image[0], cv2.COLOR_RGB2BGR)) 311 | 312 | 313 | 314 | 315 | --------------------------------------------------------------------------------