├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── data
├── Readme.txt
└── patient_list_foldwise.txt
├── eval.py
├── image_dataloader.py
├── images
├── blank_region_clipping.PNG
├── concatenation_augmentation.PNG
└── overview_large.PNG
├── models
├── 67_sota_training_progress.txt
└── ckpt_best.pkl
├── nets
├── __init__.py
├── network_cnn.py
└── network_hybrid.py
├── scripts
├── devicewise_script_run.sh
├── eval_script.sh
└── train_script_run.sh
├── train.py
└── utils.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RespireNet 2020
2 | This is the official implementation of the work **[RespireNet](https://arxiv.org/abs/2011.00196)**.
3 |
4 | ## Dependencies:
5 |
6 | ```
7 | * Python3
8 | * Pytorch (torch, torchvision and other dependencies for Pytorch)
9 | * Numpy
10 | * Librosa (0.8.0)
11 | * nlpaug (0.0.14)
12 | * OpenCV (4.2.0)
13 | * Pandas (0.22.0)
14 | * scikit-learn (0.23.1)
15 | * tqdm (4.48.0)
16 | * cudnn (CUDA for training on GPU)
17 | ```
18 |
19 | These are all easily installable via, e.g., `pip install numpy`. Any reasonably recent version of these packages shold work.
20 | It is recommended to use a python `virtual` environment to setup the dependencies and the code.
21 |
22 | ## Dataset
23 | * The dataset used is the **[ICBHI Respiratory Challenge 2017 dataset](https://bhichallenge.med.auth.gr/ICBHI_2017_Challenge)**.
24 | * To find out more details about the dataset please visit the official ICBHI Challenge [webpage](https://bhichallenge.med.auth.gr/ICBHI_2017_Challenge).
25 | * Download from above and place the dataset in the `data` folder.
26 |
27 | ## Train and Test Script
28 | * We follow both the official `60-40` train-test split as well as the `80-20` split.
29 | * For training we employ a 2-stage training protocol.
30 | * Stage 1: The model is trained end-to-end on train data from all the 4 devices.
31 | * Stage 2: The model is fine-tuned (with a lower learning rate `1e-4`) on only the data from a single device. We do this separately for each device (device specific fine-tuning).
32 |
33 | ### Train Command:
34 |
35 | Stage 1
36 |
37 | `python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-3 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4`
38 |
39 | Stage 2 (Device specific fine-tuning)
40 |
41 | `python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 50 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 0
42 | `
43 |
44 | replace the `stetho_id` as `0 or 1 or 2 or 3` for devices `0-3`
45 |
46 | Please go through our paper for more details.
47 |
48 | ### Test Command:
49 |
50 | Evaluation script
51 |
52 | `python eval.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --batch_size 64 --num_worker 4 --test_fold 4 --checkpoint models/ckpt_best.pkl`
53 |
54 | ## RespireNet Overview
55 |
56 |
57 |
58 |
59 |
60 | ## Quantitative Results
61 |
62 | ### Performance of the proposed model
63 |
64 | | Split & Task | Method | Sp | Se | Score |
65 | |--------------|:------:|:---:|:---:|:-----:|
66 | | 60-40 split & 4-class| CNN | 71.4% | 39.0%| 55.2%|
67 | | | CNN+CBA+BRC | 71.8% | 39.6% | 55.7%|
68 | | | CNN+CBA+BRC+FT | 72.3% | 40.1% | 56.2%|
69 | | 80/20 split & 4-class | CNN | 78.8% | 53.6% | 66.2% |
70 | | |CNN+CBA+BRC | 79.7% | 54.4% | 67.1% |
71 | | |CNN+CBA+BRC+FT | 83.3% | 53.7% | 68.5%|
72 | | 80/20 split & 2-class | CNN | 83.3% | 60.5% | 71.9%|
73 | | | CNN+CBA+BRC | 76.4% | 71.0% |73.7%|
74 | | | CNN+CBA+BRC+FT | 80.9% | 73.1% | 77.0%|
75 |
76 | ```
77 | Performance comparison of the proposed model with the state-of-the-art systems following random splits. We see significant improvements from our proposed techniques: concatenation-based augmentation (CBA), blank region clipping (BRC) and device specific fine-tuning (FT).
78 | ```
79 |
80 | ### Effect of time window
81 | |Length | 1 sec | 2 sec | 3 sec | 4 sec | 5 sec | 6 sec | 7 sec | 8 sec | 9 sec|
82 | |----------|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|
83 | |Scores | 56.6% | 59.0% | 60.3% | 61.1% | 62.3% | 64.4% | 66.2% | 65.1% | 65.5%|
84 |
85 | ### Some additional results
86 |
87 | | Method | Sp | Se | Score |
88 | |--------:|:---:|:---:|:-----:|
89 | | CNN + Mixup | 73.9% | 48.9% | 61.4%|
90 | | CNN + VGGish*| 76.0% | 42.2% | 59.1%|
91 | | Hybrid CNN + Transformer | 75.3% | 49.9% | 63.2%|
92 |
93 | ```
94 | We also tried experiments with Mixup augmentations, using pre-trained VGGish features and a novel Hybrid CNN + Transformer architecture, however they did not prove to be very useful.
95 | However transformers with appropriate pretraining have found to be useful in many applications (especially NLP tasks) and may prove to be useful in the future.
96 | ```
97 | *[VGGish link](https://github.com/tensorflow/models/tree/master/research/audioset/vggish)*
98 |
99 | ### About the code ###
100 |
101 | * `train.py`: Main trainer code.
102 | * `image_dataloader.py`: Dataloader module.
103 | * `utils.py`: Contains a lot of utility functions for audio file loading, feature extraction, augmentations, etc.
104 | * `eval.py`: Evaluation source files for trained model.
105 | * `scripts`: Directory which contains the runner scripts.
106 | * `nets`: Contains the different network modules.
107 | * `data`: Training-Testing split and should contain the ICBHI data.
108 | * `models`: Contains the trained checkpoint for our proposed framework.
109 |
110 | ### Blank Region Clipping Scheme
111 |
112 |
113 |
114 |
115 |
116 | ### Concatenation Based Augmentation
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 | ## To cite this work:
125 | ```
126 | @misc{gairola2020respirenet,
127 | title={RespireNet: A Deep Neural Network for Accurately Detecting Abnormal Lung Sounds in Limited Data Setting},
128 | author={Siddhartha Gairola and Francis Tom and Nipun Kwatra and Mohit Jain},
129 | year={2020},
130 | eprint={2011.00196},
131 | archivePrefix={arXiv},
132 | primaryClass={cs.SD}
133 | }
134 | ```
135 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/data/Readme.txt:
--------------------------------------------------------------------------------
1 | Keep the data here.
2 |
--------------------------------------------------------------------------------
/data/patient_list_foldwise.txt:
--------------------------------------------------------------------------------
1 | 158 0
2 | 193 0
3 | 177 0
4 | 170 0
5 | 180 0
6 | 211 0
7 | 147 0
8 | 107 0
9 | 162 0
10 | 156 0
11 | 146 0
12 | 200 0
13 | 138 0
14 | 160 0
15 | 203 0
16 | 204 0
17 | 172 0
18 | 207 0
19 | 163 0
20 | 205 0
21 | 213 0
22 | 114 0
23 | 130 0
24 | 154 0
25 | 186 0
26 | 184 0
27 | 153 1
28 | 115 1
29 | 224 1
30 | 223 1
31 | 201 1
32 | 218 1
33 | 127 1
34 | 137 1
35 | 215 1
36 | 161 1
37 | 206 1
38 | 101 1
39 | 168 1
40 | 131 1
41 | 216 1
42 | 120 1
43 | 188 1
44 | 167 1
45 | 210 1
46 | 197 1
47 | 183 1
48 | 152 1
49 | 173 1
50 | 108 1
51 | 208 1
52 | 105 2
53 | 110 2
54 | 116 2
55 | 196 2
56 | 182 2
57 | 222 2
58 | 166 2
59 | 209 2
60 | 144 2
61 | 111 2
62 | 165 2
63 | 148 2
64 | 164 2
65 | 159 2
66 | 121 2
67 | 157 2
68 | 217 2
69 | 123 2
70 | 169 2
71 | 179 2
72 | 190 2
73 | 125 2
74 | 129 2
75 | 225 2
76 | 136 2
77 | 118 3
78 | 185 3
79 | 112 3
80 | 124 3
81 | 104 3
82 | 195 3
83 | 175 3
84 | 212 3
85 | 140 3
86 | 219 3
87 | 132 3
88 | 142 3
89 | 220 3
90 | 122 3
91 | 191 3
92 | 128 3
93 | 226 3
94 | 141 3
95 | 103 3
96 | 134 3
97 | 117 3
98 | 192 3
99 | 106 3
100 | 155 3
101 | 199 3
102 | 174 4
103 | 145 4
104 | 151 4
105 | 176 4
106 | 178 4
107 | 133 4
108 | 198 4
109 | 214 4
110 | 149 4
111 | 143 4
112 | 187 4
113 | 202 4
114 | 119 4
115 | 194 4
116 | 126 4
117 | 150 4
118 | 171 4
119 | 102 4
120 | 109 4
121 | 113 4
122 | 139 4
123 | 189 4
124 | 181 4
125 | 221 4
126 | 135 4
127 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com))
3 |
4 | import os
5 | import itertools
6 | import argparse
7 | import random
8 | from tqdm import tqdm
9 |
10 | import torch
11 | import torch.optim as optim
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from torch.optim import Adam, lr_scheduler
15 |
16 |
17 | import torchvision
18 | from torchvision.transforms import Compose, Normalize, ToTensor
19 |
20 | import numpy as np
21 | import pandas as pd
22 | import matplotlib
23 | matplotlib.use('Agg')
24 | import matplotlib.pyplot as plt
25 | from sklearn.metrics import confusion_matrix, accuracy_score
26 |
27 | # load external modules
28 | from utils import *
29 | from image_dataloader import *
30 | from nets.network_cnn import *
31 |
32 | print ("Train import done successfully")
33 |
34 | # input argmuments
35 | parser = argparse.ArgumentParser(description='Lung Sound Classification')
36 | parser.add_argument('--steth_id', default=-1.0, type=float, help='learning rate')
37 | parser.add_argument('--gpu_ids', default=[0,1], help='a list of gpus')
38 | parser.add_argument('--num_worker', default=4, type=int, help='numbers of worker')
39 | parser.add_argument('--batch_size', default=4, type=int, help='bacth size')
40 | parser.add_argument('--data_dir', type=str, help='data directory')
41 | parser.add_argument('--folds_file', type=str, help='folds text file')
42 | parser.add_argument('--test_fold', default=4, type=int, help='Test Fold ID')
43 | parser.add_argument('--checkpoint', default=None, type=str, help='load checkpoint')
44 |
45 | args = parser.parse_args()
46 |
47 | ##############################################################################
48 | def get_score(hits, counts, pflag=False):
49 | se = (hits[1] + hits[2] + hits[3]) / (counts[1] + counts[2] + counts[3])
50 | sp = hits[0] / counts[0]
51 | sc = (se+sp) / 2.0
52 |
53 | if pflag:
54 | print("*************Metrics******************")
55 | print("Se: {}, Sp: {}, Score: {}".format(se, sp, sc))
56 | print("Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(hits[0]/counts[0], hits[1]/counts[1],
57 | hits[2]/counts[2], hits[3]/counts[3]))
58 |
59 | class Trainer:
60 | def __init__(self):
61 | self.args = args
62 |
63 | mean, std = [0.5091, 0.1739, 0.4363], [0.2288, 0.1285, 0.0743]
64 | self.input_transform = Compose([ToTensor(), Normalize(mean, std)])
65 |
66 | test_dataset = image_loader(self.args.data_dir, self.args.folds_file, self.args.test_fold,
67 | False, "params_json", self.input_transform, self.args.steth_id)
68 | self.test_ids = np.array(test_dataset.identifiers)
69 | self.test_paths = test_dataset.filenames_with_labels
70 |
71 | # loading checkpoint
72 | self.net = model(num_classes=4).cuda()
73 | if self.args.checkpoint is not None:
74 | checkpoint = torch.load(self.args.checkpoint)
75 | self.net.load_state_dict(checkpoint)
76 | self.net.fine_tune(block_layer=5)
77 | print("Pre-trained Model Loaded:", self.args.checkpoint)
78 | self.net = nn.DataParallel(self.net, device_ids=self.args.gpu_ids)
79 |
80 | self.val_data_loader = DataLoader(test_dataset, num_workers=self.args.num_worker,
81 | batch_size=self.args.batch_size, shuffle=False)
82 | print("Test Size", len(test_dataset))
83 | print("DATA LOADED")
84 |
85 | self.loss_func = nn.CrossEntropyLoss()
86 | self.loss_nored = nn.CrossEntropyLoss(reduction='none')
87 |
88 | def evaluate(self, net, epoch, iteration):
89 |
90 | self.net.eval()
91 | test_losses = []
92 | class_hits = [0.0, 0.0, 0.0, 0.0] # normal, crackle, wheeze, both
93 | class_counts = [0.0, 0.0, 0.0+1e-7, 0.0+1e-7] # normal, crackle, wheeze, both
94 | running_corrects = 0.0
95 | denom = 0.0
96 |
97 | classwise_test_losses = [[], [], [], []]
98 |
99 | conf_label = []
100 | conf_pred = []
101 | for i, (image, label) in tqdm(enumerate(self.val_data_loader)):
102 | image, label = image.cuda(), label.cuda()
103 | output = self.net(image)
104 |
105 | # calculate loss from output
106 | loss = self.loss_func(output, label)
107 | loss_nored = self.loss_nored(output, label)
108 | test_losses.append(loss.data.cpu().numpy())
109 |
110 | _, preds = torch.max(output, 1)
111 | running_corrects += torch.sum(preds == label.data)
112 |
113 | # updating denom
114 | denom += len(label.data)
115 |
116 | #class
117 | for idx in range(preds.shape[0]):
118 | class_counts[label[idx].item()] += 1.0
119 | conf_label.append(label[idx].item())
120 | conf_pred.append(preds[idx].item())
121 | if preds[idx].item() == label[idx].item():
122 | class_hits[label[idx].item()] += 1.0
123 |
124 | classwise_test_losses[label[idx].item()].append(loss_nored[idx].item())
125 |
126 | print("Val Accuracy: {}".format(running_corrects.double() / denom))
127 | print("epoch {}, Validation BCE loss: {}".format(epoch, np.mean(test_losses)))
128 |
129 | #aggregating same id, majority voting
130 | conf_label = np.array(conf_label)
131 | conf_pred = np.array(conf_pred)
132 | y_pred, y_true = [], []
133 | for pt in self.test_paths:
134 | y_pred.append(np.argmax(np.bincount(conf_pred[np.where(self.test_ids == pt)])))
135 | y_true.append(int(pt.split('_')[-1]))
136 |
137 | conf_matrix = confusion_matrix(y_true, y_pred)
138 | acc = accuracy_score(y_true, y_pred)
139 | print("Confusion Matrix", conf_matrix)
140 | print("Accuracy Score", acc)
141 |
142 | conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:,np.newaxis]
143 | print("Classwise Scores", conf_matrix.diagonal())
144 | return acc, np.mean(test_losses)
145 |
146 | if __name__ == "__main__":
147 | trainer = Trainer()
148 | acc, test_loss = trainer.evaluate(trainer.net, 0, 0)
149 |
--------------------------------------------------------------------------------
/image_dataloader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com)
3 |
4 | import os
5 | import random
6 | import numpy as np
7 | import cv2
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | import librosa
13 | from tqdm import tqdm
14 |
15 | from utils import *
16 |
17 | class image_loader(Dataset):
18 | def __init__(self, data_dir, folds_file, test_fold, train_flag, params_json, input_transform=None, stetho_id=-1, aug_scale=None):
19 |
20 | # getting device-wise information
21 | self.file_to_device = {}
22 | device_to_id = {}
23 | device_id = 0
24 | files = os.listdir(data_dir)
25 | device_patient_list = []
26 | pats = []
27 | for f in files:
28 | device = f.strip().split('_')[-1].split('.')[0]
29 | if device not in device_to_id:
30 | device_to_id[device] = device_id
31 | device_id += 1
32 | device_patient_list.append([])
33 | self.file_to_device[f.strip().split('.')[0]] = device_to_id[device]
34 | pat = f.strip().split('_')[0]
35 | if pat not in device_patient_list[device_to_id[device]]:
36 | device_patient_list[device_to_id[device]].append(pat)
37 | if pat not in pats:
38 | pats.append(pat)
39 |
40 | print("DEVICE DICT", device_to_id)
41 | for idx in range(device_id):
42 | print("Device", idx, len(device_patient_list[idx]))
43 |
44 | # get patients dict in current fold based on train flag
45 | all_patients = open(folds_file).read().splitlines()
46 | patient_dict = {}
47 | for line in all_patients:
48 | idx, fold = line.strip().split(' ')
49 | if train_flag and int(fold) != test_fold:
50 | patient_dict[idx] = fold
51 | elif train_flag == False and int(fold) == test_fold:
52 | patient_dict[idx] = fold
53 |
54 | #extracting the audiofilenames and the data for breathing cycle and it's label
55 | print("Getting filenames ...")
56 | filenames, rec_annotations_dict = get_annotations(data_dir)
57 | if stetho_id >= 0:
58 | self.filenames = [s for s in filenames if s.split('_')[0] in patient_dict and self.file_to_device[s] == stetho_id]
59 | else:
60 | self.filenames = [s for s in filenames if s.split('_')[0] in patient_dict]
61 |
62 | self.audio_data = [] # each sample is a tuple with id_0: audio_data, id_1: label, id_2: file_name, id_3: cycle id, id_4: aug id, id_5: split id
63 | self.labels = []
64 | self.train_flag = train_flag
65 | self.data_dir = data_dir
66 | self.input_transform = input_transform
67 |
68 | # parameters for spectrograms
69 | self.sample_rate = 4000
70 | self.desired_length = 8
71 | self.n_mels = 64
72 | self.nfft = 256
73 | self.hop = self.nfft//2
74 | self.f_max = 2000
75 |
76 | self.dump_images = False
77 | self.filenames_with_labels = []
78 |
79 | # get individual breathing cycles from each audio file
80 | print("Exracting Individual Cycles")
81 | self.cycle_list = []
82 | self.classwise_cycle_list = [[], [], [], []]
83 | for idx, file_name in tqdm(enumerate(self.filenames)):
84 | data = get_sound_samples(rec_annotations_dict[file_name], file_name, data_dir, self.sample_rate)
85 | cycles_with_labels = [(d[0], d[3], file_name, cycle_idx, 0) for cycle_idx, d in enumerate(data[1:])]
86 | self.cycle_list.extend(cycles_with_labels)
87 | for cycle_idx, d in enumerate(cycles_with_labels):
88 | self.filenames_with_labels.append(file_name+'_'+str(d[3])+'_'+str(d[1]))
89 | self.classwise_cycle_list[d[1]].append(d)
90 |
91 | # concatenation based augmentation scheme
92 | if train_flag and aug_scale:
93 | self.new_augment(scale=aug_scale)
94 |
95 | # split and pad each cycle to the desired length
96 | for idx, sample in enumerate(self.cycle_list):
97 | output = split_and_pad(sample, self.desired_length, self.sample_rate, types=1)
98 | self.audio_data.extend(output)
99 |
100 | self.device_wise = []
101 | for idx in range(device_id):
102 | self.device_wise.append([])
103 | self.class_probs = np.zeros(4)
104 | self.identifiers = []
105 | for idx, sample in enumerate(self.audio_data):
106 | self.class_probs[sample[1]] += 1.0
107 | self.labels.append(sample[1])
108 | self.identifiers.append(sample[2]+'_'+str(sample[3])+'_'+str(sample[1]))
109 | self.device_wise[self.file_to_device[sample[2]]].append(sample)
110 |
111 | if self.train_flag:
112 | print("TRAIN DETAILS")
113 | else:
114 | print("TEST DETAILS")
115 |
116 | print("CLASSWISE SAMPLE COUNTS:", self.class_probs)
117 | print("Device to ID", device_to_id)
118 | for idx in range(device_id):
119 | print("DEVICE ID", idx, "size", len(self.device_wise[idx]))
120 | self.class_probs = self.class_probs / sum(self.class_probs)
121 | print("CLASSWISE PROBS", self.class_probs)
122 | print("LEN AUDIO DATA", len(self.audio_data))
123 |
124 | def new_augment(self, scale=1):
125 |
126 | # augment normal
127 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[0])
128 | for idx in range(aug_nos):
129 | # normal_i + normal_j
130 | i = random.randint(0, len(self.classwise_cycle_list[0])-1)
131 | j = random.randint(0, len(self.classwise_cycle_list[0])-1)
132 | normal_i = self.classwise_cycle_list[0][i]
133 | normal_j = self.classwise_cycle_list[0][j]
134 | new_sample = np.concatenate([normal_i[0], normal_j[0]])
135 | self.cycle_list.append((new_sample, 0, normal_i[2]+'-'+normal_j[2],
136 | idx, 0))
137 | self.filenames_with_labels.append(normal_i[2]+'-'+normal_j[2]+'_'+str(idx)+'_0')
138 |
139 | # augment crackle
140 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[1])
141 | for idx in range(aug_nos):
142 | aug_prob = random.random()
143 |
144 | if aug_prob < 0.6:
145 | # crackle_i + crackle_j
146 | i = random.randint(0, len(self.classwise_cycle_list[1])-1)
147 | j = random.randint(0, len(self.classwise_cycle_list[1])-1)
148 | sample_i = self.classwise_cycle_list[1][i]
149 | sample_j = self.classwise_cycle_list[1][j]
150 | elif aug_prob >= 0.6 and aug_prob < 0.8:
151 | # crackle_i + normal_j
152 | i = random.randint(0, len(self.classwise_cycle_list[1])-1)
153 | j = random.randint(0, len(self.classwise_cycle_list[0])-1)
154 | sample_i = self.classwise_cycle_list[1][i]
155 | sample_j = self.classwise_cycle_list[0][j]
156 | else:
157 | # normal_i + crackle_j
158 | i = random.randint(0, len(self.classwise_cycle_list[0])-1)
159 | j = random.randint(0, len(self.classwise_cycle_list[1])-1)
160 | sample_i = self.classwise_cycle_list[0][i]
161 | sample_j = self.classwise_cycle_list[1][j]
162 |
163 | new_sample = np.concatenate([sample_i[0], sample_j[0]])
164 | self.cycle_list.append((new_sample, 1, sample_i[2]+'-'+sample_j[2],
165 | idx, 0))
166 | self.filenames_with_labels.append(sample_i[2]+'-'+sample_j[2]+'_'+str(idx)+'_1')
167 |
168 | # augment wheeze
169 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[2])
170 | for idx in range(aug_nos):
171 | aug_prob = random.random()
172 |
173 | if aug_prob < 0.6:
174 | # wheeze_i + wheeze_j
175 | i = random.randint(0, len(self.classwise_cycle_list[2])-1)
176 | j = random.randint(0, len(self.classwise_cycle_list[2])-1)
177 | sample_i = self.classwise_cycle_list[2][i]
178 | sample_j = self.classwise_cycle_list[2][j]
179 | elif aug_prob >= 0.6 and aug_prob < 0.8:
180 | # wheeze_i + normal_j
181 | i = random.randint(0, len(self.classwise_cycle_list[2])-1)
182 | j = random.randint(0, len(self.classwise_cycle_list[0])-1)
183 | sample_i = self.classwise_cycle_list[2][i]
184 | sample_j = self.classwise_cycle_list[0][j]
185 | else:
186 | # normal_i + wheeze_j
187 | i = random.randint(0, len(self.classwise_cycle_list[0])-1)
188 | j = random.randint(0, len(self.classwise_cycle_list[2])-1)
189 | sample_i = self.classwise_cycle_list[0][i]
190 | sample_j = self.classwise_cycle_list[2][j]
191 |
192 | new_sample = np.concatenate([sample_i[0], sample_j[0]])
193 | self.cycle_list.append((new_sample, 2, sample_i[2]+'-'+sample_j[2],
194 | idx, 0))
195 | self.filenames_with_labels.append(sample_i[2]+'-'+sample_j[2]+'_'+str(idx)+'_2')
196 |
197 | # augment both
198 | aug_nos = scale*len(self.classwise_cycle_list[0]) - len(self.classwise_cycle_list[3])
199 | for idx in range(aug_nos):
200 | aug_prob = random.random()
201 |
202 | if aug_prob < 0.5:
203 | # both_i + both_j
204 | i = random.randint(0, len(self.classwise_cycle_list[3])-1)
205 | j = random.randint(0, len(self.classwise_cycle_list[3])-1)
206 | sample_i = self.classwise_cycle_list[3][i]
207 | sample_j = self.classwise_cycle_list[3][j]
208 | elif aug_prob >= 0.5 and aug_prob < 0.7:
209 | # crackle_i + wheeze_j
210 | i = random.randint(0, len(self.classwise_cycle_list[1])-1)
211 | j = random.randint(0, len(self.classwise_cycle_list[2])-1)
212 | sample_i = self.classwise_cycle_list[1][i]
213 | sample_j = self.classwise_cycle_list[2][j]
214 | elif aug_prob >=0.7 and aug_prob < 0.8:
215 | # wheeze_i + crackle_j
216 | i = random.randint(0, len(self.classwise_cycle_list[2])-1)
217 | j = random.randint(0, len(self.classwise_cycle_list[1])-1)
218 | sample_i = self.classwise_cycle_list[2][i]
219 | sample_j = self.classwise_cycle_list[1][j]
220 | elif aug_prob >=0.8 and aug_prob < 0.9:
221 | # both_i + normal_j
222 | i = random.randint(0, len(self.classwise_cycle_list[3])-1)
223 | j = random.randint(0, len(self.classwise_cycle_list[0])-1)
224 | sample_i = self.classwise_cycle_list[3][i]
225 | sample_j = self.classwise_cycle_list[0][j]
226 | else:
227 | # normal_i + both_j
228 | i = random.randint(0, len(self.classwise_cycle_list[0])-1)
229 | j = random.randint(0, len(self.classwise_cycle_list[3])-1)
230 | sample_i = self.classwise_cycle_list[0][i]
231 | sample_j = self.classwise_cycle_list[3][j]
232 |
233 | new_sample = np.concatenate([sample_i[0], sample_j[0]])
234 | self.cycle_list.append((new_sample, 3, sample_i[2]+'-'+sample_j[2],
235 | idx, 0))
236 | self.filenames_with_labels.append(sample_i[2]+'-'+sample_j[2]+'_'+str(idx)+'_3')
237 |
238 | def __getitem__(self, index):
239 |
240 | audio = self.audio_data[index][0]
241 |
242 | aug_prob = random.random()
243 | if self.train_flag and aug_prob > 0.5:
244 | # apply augmentation to audio
245 | audio = gen_augmented(audio, self.sample_rate)
246 |
247 | # pad incase smaller than desired length
248 | audio = split_and_pad([audio, 0,0,0,0], self.desired_length, self.sample_rate, types=1)[0][0]
249 |
250 | # roll audio sample
251 | roll_prob = random.random()
252 | if self.train_flag and roll_prob > 0.5:
253 | audio = rollAudio(audio)
254 |
255 | # convert audio signal to spectrogram
256 | # spectrograms resized to 3x of original size
257 | audio_image = cv2.cvtColor(create_mel_raw(audio, self.sample_rate, f_max=self.f_max,
258 | n_mels=self.n_mels, nfft=self.nfft, hop=self.hop, resz=3), cv2.COLOR_BGR2RGB)
259 |
260 | # blank region clipping
261 | audio_raw_gray = cv2.cvtColor(create_mel_raw(audio, self.sample_rate, f_max=self.f_max,
262 | n_mels=self.n_mels, nfft=self.nfft, hop=self.hop), cv2.COLOR_BGR2GRAY)
263 |
264 | audio_raw_gray[audio_raw_gray < 10] = 0
265 | for row in range(audio_raw_gray.shape[0]):
266 | black_percent = len(np.where(audio_raw_gray[row,:]==0)[0])/len(audio_raw_gray[row,:])
267 | if black_percent < 0.80:
268 | break
269 |
270 | if (row+1)*3 < audio_image.shape[0]:
271 | audio_image = audio_image[(row+1)*3:, :, :]
272 | audio_image = cv2.resize(audio_image, (audio_image.shape[1], self.n_mels*3), interpolation=cv2.INTER_LINEAR)
273 |
274 | if self.dump_images:
275 | save_images((audio_image, self.audio_data[index][2], self.audio_data[index][3],
276 | self.audio_data[index][5], self.audio_data[index][1]), self.train_flag)
277 |
278 | # label
279 | label = self.audio_data[index][1]
280 |
281 | # apply image transform
282 | if self.input_transform is not None:
283 | audio_image = self.input_transform(audio_image)
284 |
285 | return audio_image, label
286 |
287 | def __len__(self):
288 | return len(self.audio_data)
289 |
--------------------------------------------------------------------------------
/images/blank_region_clipping.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/images/blank_region_clipping.PNG
--------------------------------------------------------------------------------
/images/concatenation_augmentation.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/images/concatenation_augmentation.PNG
--------------------------------------------------------------------------------
/images/overview_large.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/images/overview_large.PNG
--------------------------------------------------------------------------------
/models/ckpt_best.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/models/ckpt_best.pkl
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/RespireNet/5e3f41aaec93e87eba9e06a67fe93c56b391b5a8/nets/__init__.py
--------------------------------------------------------------------------------
/nets/network_cnn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com)
3 | import torch
4 | from torch.autograd import Variable
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torchvision.models import resnet18, resnet34, resnet50, densenet121
8 | #from torchsummary import summary
9 |
10 | class model(nn.Module):
11 | def __init__(self, num_classes=10, drop_prob=0.5):
12 | super(model, self).__init__()
13 |
14 | # encoder
15 | self.model_ft = resnet34(pretrained=True)
16 | num_ftrs = self.model_ft.fc.in_features#*4
17 | self.model_ft.fc = nn.Sequential(nn.Dropout(drop_prob), nn.Linear(num_ftrs, 128), nn.ReLU(True),
18 | nn.Dropout(drop_prob), nn.Linear(128, 128), nn.ReLU(True))
19 | self.cls_fc = nn.Linear(128, num_classes)
20 |
21 | def forward(self, x):
22 | x = self.model_ft(x)
23 | x = self.cls_fc(x)
24 | return x
25 |
26 | def fine_tune(self, block_layer=5):
27 |
28 | for idx, child in enumerate(self.model_ft.children()):
29 | if idx>block_layer:
30 | break
31 | for param in child.parameters():
32 | param.requires_grad = False
33 |
--------------------------------------------------------------------------------
/nets/network_hybrid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com)
3 | # Transformer part has been referred from the following resource:
4 | # https://pytorch.org/tutorials/beginner/transformer_tutorial.html
5 | import math
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torchvision.models import resnet18, resnet34, vgg16
10 |
11 | # Transformer Model
12 | class TransformerModel(nn.Module):
13 |
14 | def __init__(self, ninp, nhead, nhid, nlayers, dropout=0.5, nu_classes=10):
15 | super(TransformerModel, self).__init__()
16 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
17 | self.model_type = 'Transformer'
18 | self.src_mask = None
19 | self.pos_encoder = PositionalEncoding(ninp, dropout)
20 | encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
21 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
22 |
23 | self.ninp = ninp
24 | self.linear = nn.Sequential(nn.Linear(ninp, 128), nn.ReLU(True))
25 | self.classifier = nn.Linear(128, nu_classes)
26 |
27 | def _generate_square_subsequent_mask(self, sz):
28 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
29 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
30 | return mask
31 |
32 | def forward(self, src):
33 | src = src.float()
34 | if self.src_mask is None or self.src_mask.size(0) != len(src):
35 | device = src.device
36 | mask = self._generate_square_subsequent_mask(len(src)).to(device)
37 | self.src_mask = mask
38 |
39 | #src = self.encoder(src) * math.sqrt(self.ninp)
40 | src = src*math.sqrt(self.ninp)
41 | src = self.pos_encoder(src)
42 | output = self.transformer_encoder(src, self.src_mask)
43 | output = output.mean(axis=1)
44 | output = self.linear(output)
45 | output = self.classifier(output)
46 |
47 | return output
48 |
49 |
50 | ######################################################################
51 | # ``PositionalEncoding`` module injects some information about the
52 | # relative or absolute position of the tokens in the sequence. The
53 | # positional encodings have the same dimension as the embeddings so that
54 | # the two can be summed. Here, we use ``sine`` and ``cosine`` functions of
55 | # different frequencies.
56 | #
57 |
58 | class PositionalEncoding(nn.Module):
59 |
60 | def __init__(self, d_model, dropout=0.1, max_len=5000):
61 | super(PositionalEncoding, self).__init__()
62 | self.dropout = nn.Dropout(p=dropout)
63 |
64 | pe = torch.zeros(max_len, d_model)
65 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
66 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
67 | pe[:, 0::2] = torch.sin(position * div_term)
68 | pe[:, 1::2] = torch.cos(position * div_term)
69 | pe = pe.unsqueeze(0).transpose(0, 1)
70 | self.register_buffer('pe', pe)
71 |
72 | def forward(self, x):
73 | x = x + self.pe[:x.size(0), :]
74 | return self.dropout(x)
75 |
76 | class model(nn.Module):
77 | def __init__(self, num_classes=10, drop_prob=0.5, pretrained=True):
78 | super(model, self).__init__()
79 |
80 | resnet = resnet34(pretrained=True)
81 | self.conv_feats = nn.Sequential(*list(resnet.children())[:8])
82 | self.avgpool = resnet.avgpool
83 | self.reduction = nn.Sequential(nn.Linear(128*24, 128), nn.ReLU(True), nn.BatchNorm1d(128))
84 |
85 | #num_ftrs = self.embeddings.fc.in_features
86 | #self.embeddings.fc = nn.Sequential(nn.Dropout(drop_prob), nn.Linear(num_ftrs, 128), nn.ReLU(True),
87 | # nn.Dropout(drop_prob), nn.Linear(128, 128), nn.ReLU(True))
88 | # transformer model
89 | self.trfr_classifier = TransformerModel(128, 2, 200, 4, drop_prob, num_classes)
90 |
91 | def forward(self, x):
92 |
93 | # first extract features from the CNN
94 | x = self.conv_feats(x)
95 |
96 | # reshape the conv features and preserve the width as time
97 | # initial shape of x is batch x channels x height x width
98 | b, c, h, w = x.shape
99 | x = x.view(b, -1, w)
100 | x = x.transpose(1,2) # x: batch x width x num_features; num_features = (channels x height)
101 |
102 | # reducing the higher dimension num_features to 128
103 | x_red = []
104 | for t in range(x.size(1)):
105 | x_red.append(self.reduction(x[:,t,:]))
106 |
107 | x_red = torch.stack(x_red, dim=0).transpose_(0,1) # batch x time x reduced_num_features
108 |
109 | # pass x through the transformer
110 | return self.trfr_classifier(x_red)
111 |
--------------------------------------------------------------------------------
/scripts/devicewise_script_run.sh:
--------------------------------------------------------------------------------
1 | python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 0
2 | #python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 1
3 | #python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 2
4 | #python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-4 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4 --checkpoint models/ckpt_best.pkl --stetho_id 3
5 |
--------------------------------------------------------------------------------
/scripts/eval_script.sh:
--------------------------------------------------------------------------------
1 | python eval.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --batch_size 64 --num_worker 4 --test_fold 4 --checkpoint models/ckpt_best.pkl --steth_id -1
2 |
--------------------------------------------------------------------------------
/scripts/train_script_run.sh:
--------------------------------------------------------------------------------
1 | python train.py --data_dir ../data/icbhi_dataset/audio_text_data/ --folds_file ../data/patient_list_foldwise.txt --model_path models_out --lr 1e-3 --batch_size 64 --num_worker 4 --start_epochs 0 --epochs 200 --test_fold 4
2 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com))
3 |
4 | import os
5 | import itertools
6 | import argparse
7 | import random
8 | from tqdm import tqdm
9 |
10 | import torch
11 | import torch.optim as optim
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from torch.optim import Adam, lr_scheduler
15 |
16 |
17 | import torchvision
18 | from torchvision.transforms import Compose, Normalize, ToTensor
19 |
20 | import numpy as np
21 | import pandas as pd
22 | import matplotlib
23 | matplotlib.use('Agg')
24 | import matplotlib.pyplot as plt
25 |
26 | # load external modules
27 | from utils import *
28 | from image_dataloader import *
29 | from nets.network_cnn import *
30 | #from nets.network_hybrid import *
31 | from sklearn.metrics import confusion_matrix, accuracy_score
32 | print ("Train import done successfully")
33 |
34 | # input argmuments
35 | parser = argparse.ArgumentParser(description='RespireNet: Lung Sound Classification')
36 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
37 | parser.add_argument('--weight_decay', default=0.0005,help='weight decay value')
38 | parser.add_argument('--gpu_ids', default=[0,1], help='a list of gpus')
39 | parser.add_argument('--num_worker', default=4, type=int, help='numbers of worker')
40 | parser.add_argument('--batch_size', default=4, type=int, help='bacth size')
41 | parser.add_argument('--epochs', default=10, type=int, help='epochs')
42 | parser.add_argument('--start_epochs', default=0, type=int, help='start epochs')
43 |
44 | parser.add_argument('--data_dir', type=str, help='data directory')
45 | parser.add_argument('--folds_file', type=str, help='folds text file')
46 | parser.add_argument('--test_fold', default=4, type=int, help='Test Fold ID')
47 | parser.add_argument('--stetho_id', default=-1, type=int, help='Stethoscope device id')
48 | parser.add_argument('--aug_scale', default=None, type=float, help='Augmentation multiplier')
49 | parser.add_argument('--model_path',type=str, help='model saving directory')
50 | parser.add_argument('--checkpoint', default=None, type=str, help='load checkpoint')
51 |
52 | args = parser.parse_args()
53 |
54 | ################################MIXUP#####################################
55 | def mixup_data(x, y, alpha=1.0, use_cuda=True):
56 | '''Returns mixed inputs, pairs of targets, and lambda'''
57 | if alpha > 0:
58 | lam = np.random.beta(alpha, alpha)
59 | else:
60 | lam = 1
61 |
62 | batch_size = x.size()[0]
63 | if use_cuda:
64 | index = torch.randperm(batch_size).cuda()
65 | else:
66 | index = torch.randperm(batch_size)
67 |
68 | mixed_x = lam * x + (1 - lam) * x[index, :]
69 | y_a, y_b = y, y[index]
70 | return mixed_x, y_a, y_b, lam
71 |
72 | def mixup_criterion(criterion, pred, y_a, y_b, lam):
73 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
74 |
75 | ##############################################################################
76 | def get_score(hits, counts, pflag=False):
77 | se = (hits[1] + hits[2] + hits[3]) / (counts[1] + counts[2] + counts[3])
78 | sp = hits[0] / counts[0]
79 | sc = (se+sp) / 2.0
80 |
81 | if pflag:
82 | print("*************Metrics******************")
83 | print("Se: {}, Sp: {}, Score: {}".format(se, sp, sc))
84 | print("Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(hits[0]/counts[0], hits[1]/counts[1],
85 | hits[2]/counts[2], hits[3]/counts[3]))
86 |
87 | class Trainer:
88 | def __init__(self):
89 | self.args = args
90 | mean, std = get_mean_and_std(image_loader(self.args.data_dir, self.args.folds_file,
91 | self.args.test_fold, True, "Params_json", Compose([ToTensor()]), stetho_id=self.args.stetho_id))
92 | print("MEAN", mean, "STD", std)
93 |
94 | self.input_transform = Compose([ToTensor(), Normalize(mean, std)])
95 | train_dataset = image_loader(self.args.data_dir, self.args.folds_file, self.args.test_fold,
96 | True, "params_json", self.input_transform, stetho_id=self.args.stetho_id, aug_scale=self.args.aug_scale)
97 | test_dataset = image_loader(self.args.data_dir, self.args.folds_file, self.args.test_fold,
98 | False, "params_json", self.input_transform, stetho_id=self.args.stetho_id)
99 | self.test_ids = np.array(test_dataset.identifiers)
100 | self.test_paths = test_dataset.filenames_with_labels
101 |
102 | # loading checkpoint
103 | self.net = model(num_classes=4).cuda()
104 | if self.args.checkpoint is not None:
105 | checkpoint = torch.load(self.args.checkpoint)
106 | self.net.load_state_dict(checkpoint)
107 | # uncomment in case fine-tuning, specify block layer
108 | # before block_layer, all layers will be frozen durin training
109 | #self.net.fine_tune(block_layer=5)
110 | print("Pre-trained Model Loaded:", self.args.checkpoint)
111 | self.net = nn.DataParallel(self.net, device_ids=self.args.gpu_ids)
112 |
113 | # weighted sampler
114 | reciprocal_weights = []
115 | for idx in range(len(train_dataset)):
116 | reciprocal_weights.append(train_dataset.class_probs[train_dataset.labels[idx]])
117 | weights = (1 / torch.Tensor(reciprocal_weights))
118 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(train_dataset))
119 |
120 | self.train_data_loader = DataLoader(train_dataset, num_workers=self.args.num_worker,
121 | batch_size=self.args.batch_size, sampler=sampler)
122 | self.val_data_loader = DataLoader(test_dataset, num_workers=self.args.num_worker,
123 | batch_size=self.args.batch_size, shuffle=False)
124 | print("DATA LOADED")
125 |
126 | print("Params to learn:")
127 | params_to_update = []
128 | for name,param in self.net.named_parameters():
129 | if param.requires_grad == True:
130 | params_to_update.append(param)
131 |
132 | # Observe that all parameters are being optimized
133 | self.optimizer = optim.SGD(params_to_update, lr=self.args.lr, momentum=0.9, weight_decay=self.args.weight_decay)
134 | #self.optimizer = optim.Adam(params_to_update, lr=self.args.lr, weight_decay=self.args.weight_decay)
135 |
136 | # Decay LR by a factor
137 | #self.exp_lr_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=20, gamma=0.33)
138 |
139 | # weights for the loss function
140 | weights = torch.tensor([3.0, 1.0, 1.0, 1.0], dtype=torch.float32)
141 | #weights = torch.tensor(train_dataset.class_probs, dtype=torch.float32)
142 | weights = weights / weights.sum()
143 | weights = 1.0 / weights
144 | weights = weights / weights.sum()
145 | weights = weights.cuda()
146 | self.loss_func = nn.CrossEntropyLoss(weight=weights)
147 | self.loss_nored = nn.CrossEntropyLoss(reduction='none')
148 |
149 | def evaluate(self, net, epoch, iteration):
150 |
151 | self.net.eval()
152 | test_losses = []
153 | class_hits = [0.0, 0.0, 0.0, 0.0] # normal, crackle, wheeze, both
154 | class_counts = [0.0, 0.0, 0.0+1e-7, 0.0+1e-7] # normal, crackle, wheeze, both
155 | running_corrects = 0.0
156 | denom = 0.0
157 |
158 | classwise_test_losses = [[], [], [], []]
159 | conf_label, conf_pred = [], []
160 | for i, (image, label) in tqdm(enumerate(self.val_data_loader)):
161 | image, label = image.cuda(), label.cuda()
162 | output = self.net(image)
163 |
164 | # calculate loss from output
165 | loss = self.loss_func(output, label)
166 | loss_nored = self.loss_nored(output, label)
167 | test_losses.append(loss.data.cpu().numpy())
168 |
169 | _, preds = torch.max(output, 1)
170 | running_corrects += torch.sum(preds == label.data)
171 |
172 | # updating denom
173 | denom += len(label.data)
174 |
175 | #class
176 | for idx in range(preds.shape[0]):
177 | class_counts[label[idx].item()] += 1.0
178 | conf_label.append(label[idx].item())
179 | conf_pred.append(preds[idx].item())
180 | if preds[idx].item() == label[idx].item():
181 | class_hits[label[idx].item()] += 1.0
182 |
183 | classwise_test_losses[label[idx].item()].append(loss_nored[idx].item())
184 |
185 | print("Val Accuracy: {}".format(running_corrects.double() / denom))
186 | print("epoch {}, Validation BCE loss: {}".format(epoch, np.mean(test_losses)))
187 | #print("Classwise_Losses Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(np.mean(classwise_test_losses[0]),
188 | # np.mean(classwise_test_losses[1]), np.mean(classwise_test_losses[2]), np.mean(classwise_test_losses[3])))
189 | #get_score(class_hits, class_counts, True)
190 |
191 | #aggregating same id, majority voting
192 | conf_label = np.array(conf_label)
193 | conf_pred = np.array(conf_pred)
194 | y_pred, y_true = [], []
195 | for pt in self.test_paths:
196 | y_pred.append(np.argmax(np.bincount(conf_pred[np.where(self.test_ids == pt)])))
197 | y_true.append(int(pt.split('_')[-1]))
198 |
199 | conf_matrix = confusion_matrix(y_true, y_pred)
200 | acc = accuracy_score(y_true, y_pred)
201 | print("Confusion Matrix", conf_matrix)
202 | print("Accuracy Score", acc)
203 | conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:,np.newaxis]
204 | print("Classwise Scores", conf_matrix.diagonal())
205 | self.net.train()
206 |
207 | return acc, np.mean(test_losses)
208 |
209 | def train(self):
210 | train_losses = []
211 | test_losses = []
212 | test_acc = []
213 | best_acc = -1
214 |
215 | for _, epoch in tqdm(enumerate(range(self.args.start_epochs, self.args.epochs))):
216 | losses = []
217 | class_hits = [0.0, 0.0, 0.0, 0.0]
218 | class_counts = [0.0+1e-7, 0.0+1e-7, 0.0+1e-7, 0.0+1e-7]
219 | running_corrects = 0.0
220 | denom = 0.0
221 | classwise_train_losses = [[], [], [], []]
222 |
223 | for i, (image, label) in tqdm(enumerate(self.train_data_loader)):
224 |
225 | image, label = image.cuda(), label.cuda()
226 | # in case using mixup, uncomment 2 lines below
227 | #image, label_a, label_b, lam = mixup_data(image, label, alpha=0.5)
228 | #image, label_a, label_b = map(Variable, (image, label_a, label_b))
229 |
230 | output = self.net(image)
231 |
232 | # calculate loss from output
233 | # in case using mixup, uncomment line below and comment the next line
234 | #loss = mixup_criterion(self.loss_func, output, label_a, label_b, lam)
235 | loss = self.loss_func(output, label)
236 | loss_nored = self.loss_nored(output, label)
237 |
238 | _, preds = torch.max(output, 1)
239 | running_corrects += torch.sum(preds == label.data)
240 | denom += len(label.data)
241 |
242 | #class
243 | for idx in range(preds.shape[0]):
244 | class_counts[label[idx].item()] += 1.0
245 | if preds[idx].item() == label[idx].item():
246 | class_hits[label[idx].item()] += 1.0
247 | classwise_train_losses[label[idx].item()].append(loss_nored[idx].item())
248 |
249 | self.optimizer.zero_grad()
250 | loss.backward()
251 | self.optimizer.step()
252 | losses.append(loss.data.cpu().numpy())
253 |
254 | if i % 1000 == self.train_data_loader.__len__()-1:
255 | print("---------------------------------------------")
256 | print("epoch {} iter {}/{} Train Total loss: {}".format(epoch,
257 | i, len(self.train_data_loader), np.mean(losses)))
258 | print("Train Accuracy: {}".format(running_corrects.double() / denom))
259 | print("Classwise_Losses Normal: {}, Crackle: {}, Wheeze: {}, Both: {}".format(np.mean(classwise_train_losses[0]),
260 | np.mean(classwise_train_losses[1]), np.mean(classwise_train_losses[2]), np.mean(classwise_train_losses[3])))
261 | get_score(class_hits, class_counts, True)
262 |
263 | print("testing......")
264 | acc, test_loss = self.evaluate(self.net, epoch, i)
265 |
266 | if best_acc < acc:
267 | best_acc = acc
268 | torch.save(self.net.module.state_dict(), args.model_path+'/ckpt_best_'+str(self.args.epochs)+'_'+str(self.args.stetho_id)+'.pkl')
269 | print("Best ACC achieved......", best_acc.item())
270 | print("BEST ACCURACY TILL NOW", best_acc)
271 |
272 | train_losses.append(np.mean(losses))
273 | test_losses.append(test_loss)
274 | test_acc.append(acc)
275 | #self.exp_lr_scheduler.step()
276 |
277 | if __name__ == "__main__":
278 | trainer = Trainer()
279 | trainer.train()
280 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Author: Siddhartha Gairola (t-sigai at microsoft dot com))
3 | import numpy as np
4 | import os
5 | import io
6 | import math
7 | import random
8 | import pandas as pd
9 |
10 | import matplotlib.pyplot as plt
11 | import librosa
12 | import librosa.display
13 | import cv2
14 | import cmapy
15 |
16 | import nlpaug
17 | import nlpaug.augmenter.audio as naa
18 |
19 | import torch
20 |
21 | from scipy.signal import butter, lfilter
22 |
23 | def butter_bandpass(lowcut, highcut, fs, order=5):
24 | nyq = 0.5 * fs
25 | low = lwcut / nyq
26 | high = highcut / nyq
27 | b, a = butter(order, [low, high], btype='band')
28 | return b, a
29 |
30 | def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
31 | b, a = butter_bandpass(lowcut, highcut, fs, order=order)
32 | y = lfilter(b,a, data)
33 | return y
34 |
35 | def Extract_Annotation_Data(file_name, data_dir):
36 | tokens = file_name.split('_')
37 | recording_info = pd.DataFrame(data = [tokens], columns = ['Patient Number', 'Recording index', 'Chest location','Acquisition mode','Recording equipment'])
38 | recording_annotations = pd.read_csv(os.path.join(data_dir, file_name + '.txt'), names = ['Start', 'End', 'Crackles', 'Wheezes'], delimiter= '\t')
39 | return recording_info, recording_annotations
40 |
41 | # get annotations data and filenames
42 | def get_annotations(data_dir):
43 | filenames = [s.split('.')[0] for s in os.listdir(data_dir) if '.txt' in s]
44 | i_list = []
45 | rec_annotations_dict = {}
46 | for s in filenames:
47 | i,a = Extract_Annotation_Data(s, data_dir)
48 | i_list.append(i)
49 | rec_annotations_dict[s] = a
50 |
51 | recording_info = pd.concat(i_list, axis = 0)
52 | recording_info.head()
53 |
54 | return filenames, rec_annotations_dict
55 |
56 | def slice_data(start, end, raw_data, sample_rate):
57 | max_ind = len(raw_data)
58 | start_ind = min(int(start * sample_rate), max_ind)
59 | end_ind = min(int(end * sample_rate), max_ind)
60 | return raw_data[start_ind: end_ind]
61 |
62 | def get_label(crackle, wheeze):
63 | if crackle == 0 and wheeze == 0:
64 | return 0
65 | elif crackle == 1 and wheeze == 0:
66 | return 1
67 | elif crackle == 0 and wheeze == 1:
68 | return 2
69 | else:
70 | return 3
71 | #Used to split each individual sound file into separate sound clips containing one respiratory cycle each
72 | #output: [filename, (sample_data:np.array, start:float, end:float, label:int (...) ]
73 | #label: 0-normal, 1-crackle, 2-wheeze, 3-both
74 | def get_sound_samples(recording_annotations, file_name, data_dir, sample_rate):
75 | sample_data = [file_name]
76 | # load file with specified sample rate (also converts to mono)
77 | data, rate = librosa.load(os.path.join(data_dir, file_name+'.wav'), sr=sample_rate)
78 | #print("Sample Rate", rate)
79 |
80 | for i in range(len(recording_annotations.index)):
81 | row = recording_annotations.loc[i]
82 | start = row['Start']
83 | end = row['End']
84 | crackles = row['Crackles']
85 | wheezes = row['Wheezes']
86 | audio_chunk = slice_data(start, end, data, rate)
87 | sample_data.append((audio_chunk, start,end, get_label(crackles, wheezes)))
88 | return sample_data
89 |
90 |
91 | # split samples according to desired length
92 | '''
93 | types:
94 | * 0: simply pad by zeros
95 | * 1: pad with duplicate on both sides (half-n-half)
96 | * 2: pad with augmented sample on both sides (half-n-half)
97 | '''
98 | def split_and_pad(original, desiredLength, sample_rate, types=0):
99 | if types==0:
100 | return split_and_pad_old(original, desiredLength, sample_rate)
101 |
102 | output_buffer_length = int(desiredLength*sample_rate)
103 | soundclip = original[0].copy()
104 | n_samples = len(soundclip)
105 |
106 | output = []
107 | # if: the audio sample length > desiredLength, then split & pad
108 | # else: simply pad according to given type 1 or 2
109 | if n_samples > output_buffer_length:
110 | frames = librosa.util.frame(soundclip, frame_length=output_buffer_length, hop_length=output_buffer_length//2, axis=0)
111 | for i in range(frames.shape[0]):
112 | output.append((frames[i], original[1], original[2], original[3], original[4], i, 0))
113 |
114 | last_id = frames.shape[0]*(output_buffer_length//2)
115 | last_sample = soundclip[last_id:]; pad_times = (output_buffer_length-len(last_sample))/len(last_sample)
116 | padded = generate_padded_samples(soundclip, last_sample, output_buffer_length, sample_rate, types)
117 | output.append((padded, original[1], original[2], original[3], original[4], i+1, pad_times))
118 |
119 | else:
120 | padded = generate_padded_samples(soundclip, soundclip, output_buffer_length, sample_rate, types); pad_times = (output_buffer_length-len(soundclip))/len(soundclip)
121 | output.append((padded, original[1], original[2], original[3], original[4], 0, pad_times))
122 |
123 | return output
124 |
125 | def split_and_pad_old(original, desiredLength, sample_rate):
126 | output_buffer_length = int(desiredLength * sample_rate)
127 | soundclip = original[0].copy()
128 | n_samples = len(soundclip)
129 | total_length = n_samples / sample_rate #length of cycle in seconds
130 | n_slices = int(math.ceil(total_length / desiredLength)) #get the minimum number of slices needed
131 | samples_per_slice = n_samples // n_slices
132 | src_start = 0 #Staring index of the samples to copy from the original buffer
133 | output = [] #Holds the resultant slices
134 | for i in range(n_slices):
135 | src_end = min(src_start + samples_per_slice, n_samples)
136 | length = src_end - src_start
137 | copy = generate_padded_samples_old(soundclip[src_start:src_end], output_buffer_length)
138 | output.append((copy, original[1], original[2], original[3], original[4], i))
139 | src_start += length
140 | return output
141 |
142 | def generate_padded_samples_old(source, output_length):
143 | copy = np.zeros(output_length, dtype = np.float32)
144 | src_length = len(source)
145 | frac = src_length / output_length
146 | if(frac < 0.5):
147 | #tile forward sounds to fill empty space
148 | cursor = 0
149 | while(cursor + src_length) < output_length:
150 | copy[cursor:(cursor + src_length)] = source[:]
151 | cursor += src_length
152 | else:
153 | copy[:src_length] = source[:]
154 | return copy
155 |
156 | def generate_padded_samples(original, source, output_length, sample_rate, types):
157 | copy = np.zeros(output_length, dtype=np.float32)
158 | src_length = len(source)
159 | left = output_length-src_length # amount to be padded
160 | # pad front or back
161 | prob = random.random()
162 | if types == 1:
163 | aug = original
164 | else:
165 | aug = gen_augmented(original, sample_rate)
166 |
167 | while len(aug) < left:
168 | aug = np.concatenate([aug, aug])
169 |
170 | if prob < 0.5:
171 | #pad back
172 | copy[left:] = source
173 | copy[:left] = aug[len(aug)-left:]
174 | else:
175 | #pad front
176 | copy[:src_length] = source[:]
177 | copy[src_length:] = aug[:left]
178 |
179 | return copy
180 |
181 |
182 | #**********************DATA AUGMENTAION***************************
183 | #Creates a copy of each time slice, but stretches or contracts it by a random amount
184 | def gen_augmented(original, sample_rate):
185 | # list of augmentors available from the nlpaug library
186 | augment_list = [
187 | #naa.CropAug(sampling_rate=sample_rate)
188 | naa.NoiseAug(),
189 | naa.SpeedAug(),
190 | naa.LoudnessAug(factor=(0.5, 2)),
191 | naa.VtlpAug(sampling_rate=sample_rate, zone=(0.0, 1.0)),
192 | naa.PitchAug(sampling_rate=sample_rate, factor=(-1,3))
193 | ]
194 | # sample augmentation randomly
195 | aug_idx = random.randint(0, len(augment_list)-1)
196 | augmented_data = augment_list[aug_idx].augment(original)
197 | return augmented_data
198 |
199 | #Same as above, but applies it to a list of samples
200 | def augment_list(audio_with_labels, sample_rate, n_repeats):
201 | augmented_samples = []
202 | for i in range(n_repeats):
203 | addition = [(gen_augmented(t[0], sample_rate), t[1], t[2], t[3], t[4]+i+1 ) for t in audio_with_labels]
204 | augmented_samples.extend(addition)
205 | return augmented_samples
206 |
207 | def create_spectrograms(current_window, sample_rate, n_mels=128, f_min=50, f_max=4000, nfft=2048, hop=512):
208 | fig = plt.figure(figsize=[1.0, 1.0])
209 | #fig = plt.figure(figsize=[0.72,0.72])
210 | ax = fig.add_subplot(111)
211 | ax.axes.get_xaxis().set_visible(False)
212 | ax.axes.get_yaxis().set_visible(False)
213 | ax.set_frame_on(False)
214 | S = librosa.feature.melspectrogram(y=current_window, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop)
215 | librosa.display.specshow(librosa.power_to_db(S, ref=np.max))
216 |
217 | # There may be a better way to do the following, skipping it for now.
218 | buf = io.BytesIO()
219 | plt.savefig(buf, dpi=800, bbox_inches='tight',pad_inches=0)
220 | buf.seek(0)
221 | img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
222 | buf.close()
223 | img = cv2.imdecode(img_arr, 1)
224 | plt.close('all')
225 | return img
226 |
227 | def create_spectrograms_raw(current_window, sample_rate, n_mels=128, f_min=50, f_max=4000, nfft=2048, hop=512):
228 | S = librosa.feature.melspectrogram(y=current_window, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop)
229 | S = librosa.power_to_db(S, ref=np.max)
230 | S = (S-S.min()) / (S.max() - S.min())
231 | S *= 255
232 | img = cv2.applyColorMap(S.astype(np.uint8), cmapy.cmap('viridis'))
233 | height, width, _ = img.shape
234 | img = cv2.resize(img, (width*3, height*3), interpolation=cv2.INTER_LINEAR)
235 | return img
236 |
237 | def create_mel_raw(current_window, sample_rate, n_mels=128, f_min=50, f_max=4000, nfft=2048, hop=512, resz=1):
238 | S = librosa.feature.melspectrogram(y=current_window, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop)
239 | S = librosa.power_to_db(S, ref=np.max)
240 | S = (S-S.min()) / (S.max() - S.min())
241 | S *= 255
242 | img = cv2.applyColorMap(S.astype(np.uint8), cmapy.cmap('magma'))
243 | height, width, _ = img.shape
244 | if resz > 0:
245 | img = cv2.resize(img, (width*resz, height*resz), interpolation=cv2.INTER_LINEAR)
246 | img = cv2.flip(img, 0)
247 | return img
248 |
249 | #Transpose and wrap each array along the time axis
250 | def rollFFT(fft):
251 | n_row, n_col = fft.shape[:2]
252 | pivot = np.random.randint(n_col)
253 | return np.reshape(np.roll(fft, pivot, axis = 1), (n_row, n_col, 1))
254 |
255 | def rollAudio(audio):
256 | # expect audio to be 1 dimensional
257 | pivot = np.random.randint(audio.shape[0])
258 | rolled_audio = np.roll(audio, pivot, axis=0)
259 | assert audio.shape[0] == rolled_audio.shape[0], "Roll audio shape mismatch"
260 | return rolled_audio
261 |
262 | # others
263 | def get_mean_and_std(dataset):
264 | '''Compute the mean and std value of dataset.'''
265 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
266 | mean = torch.zeros(3)
267 | std = torch.zeros(3)
268 | print('==> Computing mean and std..')
269 | for inputs, targets in dataloader:
270 | #inputs = inputs[:,0,:,:,:]
271 | for i in range(3):
272 | mean[i] += inputs[:,i,:,:].mean()
273 | std[i] += inputs[:,i,:,:].std()
274 | mean.div_(len(dataset))
275 | std.div_(len(dataset))
276 | return mean, std
277 |
278 | def init_params(net):
279 | '''Init layer parameters.'''
280 | for m in net.modules():
281 | if isinstance(m, nn.Conv2d):
282 | init.kaiming_normal(m.weight, mode='fan_out')
283 | if m.bias:
284 | init.constant(m.bias, 0)
285 | elif isinstance(m, nn.BatchNorm2d):
286 | init.constant(m.weight, 1)
287 | init.constant(m.bias, 0)
288 | elif isinstance(m, nn.Linear):
289 | init.normal(m.weight, std=1e-3)
290 | if m.bias:
291 | init.constant(m.bias, 0)
292 |
293 | # save images from dataloader in dump_images/train or dump_images/test based on train_flag
294 | # expect images to be in RGB format
295 | # image is a tuple: (spectrogram, filename, label, cycle, split)
296 | def save_images(image, train_flag):
297 | save_dir = 'dump_image'
298 | if not os.path.isdir(save_dir):
299 | os.makedirs(save_dir)
300 |
301 | if train_flag:
302 | save_dir = os.path.join(save_dir, 'train')
303 | if not os.path.isdir(save_dir):
304 | os.makedirs(save_dir)
305 | cv2.imwrite(os.path.join(save_dir, image[1]+'_'+str(image[2])+'_'+str(image[3])+'_'+str(image[4])+'.jpg'), cv2.cvtColor(image[0], cv2.COLOR_RGB2BGR))
306 | else:
307 | save_dir = os.path.join(save_dir, 'test')
308 | if not os.path.isdir(save_dir):
309 | os.makedirs(save_dir)
310 | cv2.imwrite(os.path.join(save_dir, image[1]+'_'+str(image[2])+'_'+str(image[3])+'_'+str(image[4])+'.jpg'), cv2.cvtColor(image[0], cv2.COLOR_RGB2BGR))
311 |
312 |
313 |
314 |
315 |
--------------------------------------------------------------------------------