├── .gitignore ├── LICENSE ├── README.md ├── data └── icbhi_dataset │ ├── filename_differences.txt │ ├── filename_format.txt │ ├── metadata.txt │ ├── official_split.txt │ ├── patient_diagnosis.txt │ ├── patient_list_foldwise.txt │ └── script.sh ├── main.py ├── method ├── __init__.py ├── patchmix.py └── patchmix_cl.py ├── models ├── __init__.py ├── ast.py ├── cnn6.py ├── efficientnet.py ├── projector.py ├── resnet.py └── ssast.py ├── requirements.txt ├── scripts ├── icbhi_ce.sh ├── icbhi_patchmix_ce.sh └── icbhi_patchmix_cl.sh └── util ├── __init__.py ├── augmentation.py ├── icbhi_dataset.py ├── icbhi_util.py ├── misc.py └── time_warping.py /.gitignore: -------------------------------------------------------------------------------- 1 | save/ 2 | data/ 3 | *.swp 4 | __pycache__/ 5 | *.pyc 6 | *.jpg 7 | *.ipynb 8 | *.pth 9 | *.json 10 | *.sh 11 | .ipynb_checkpoints/ 12 | jupyter_notebook/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Sangmin Bae 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Patch-Mix Contrastive Learning (INTERSPEECH 2023) 2 | [arXiv](https://arxiv.org/abs/2305.14032) | [Conference](https://www.isca-archive.org/interspeech_2023/bae23b_interspeech.html#) | [BibTeX](#bibtex) 3 | 4 | 5 |

6 | 7 |

8 | 9 | **Patch-Mix Contrastive Learning with Audio Spectrogram Transformer on Respiratory Sound Classification**
10 | [Sangmin Bae](https://www.raymin0223.com)\*, 11 | [June-Woo Kim](https://github.com/kaen2891)\*, 12 | [Won-Yang Cho](https://github.com/wonyangcho), 13 | [Hyerim Baek](https://github.com/rimiyeyo), 14 | [Soyoun Son](https://github.com/soyounson), 15 | [Byungjo Lee](https://github.com/bzlee-bio), 16 | [Changwan Ha](https://github.com/cwh1981), 17 | [Kyongpil Tae](https://github.com/kyongpiltae), 18 | [Sungnyun Kim](https://github.com/sungnyun)$^\dagger$, 19 | [Se-Young Yun](https://fbsqkd.github.io)$^\dagger$
20 | \* equal contribution    $^\dagger$ corresponding authors 21 | 22 | - We demonstrate that the **pretrained model on large-scale visual and audio datasets** can be generalized to the respiratory sound classification task. 23 | - We introduce a straightforward **Patch-Mix augmentation**, which randomly mixes patches between different samples, with Audio Spectrogram Transformer (AST). 24 | - To overcome the label hierarchy in lung sound datasets, we propose an effective **Patch-Mix Contrastive Learning** to distinguish the mixed representations in the latent space. 25 | 26 | 27 | ## Requirements 28 | Install the necessary packages with: 29 | ``` 30 | $ pip install torch torchvision torchaudio 31 | $ pip install -r requirements.txt 32 | ``` 33 | 34 | 35 | ## Data Preparation 36 | Download the ICBHI dataset files from [official_page](https://bhichallenge.med.auth.gr/ICBHI_2017_Challenge). 37 | ```bash 38 | $ wget https://bhichallenge.med.auth.gr/sites/default/files/ICBHI_final_database/ICBHI_final_database.zip 39 | ``` 40 | All `*.wav` and `*.txt` should be saved in `data/icbhi_dataset/audio_test_data`. 41 | 42 | Note that ICBHI dataset consists of a total of 6,898 respiratory cycles, 43 | of which 1,864 contain crackles, 886 contain wheezes, and 506 contain both crackles and wheezes, in 920 annotated audio samples from 126 subjects. 44 | 45 | 46 | ## Training 47 | To simply train the model, run the shell files in `scripts/`. 48 | 1. **`scripts/icbhi_ce.sh`**: Cross-Entropy loss with AST model. 49 | 2. **`scripts/icbhi_patchmix_ce.sh`**: Patch-Mix loss with AST model, where the label depends on the interpolation ratio. 50 | 3. **`scripts/icbhi_patchmix_cl.sh`**: Patch-Mix contrastive loss with AST model. 51 | 52 | Important arguments for different data settings. 53 | - `--dataset`: other lungsound datasets or heart sound can be implemented 54 | - `--class_split`: "lungsound" or "diagnosis" classification 55 | - `--n_cls`: set number of classes as 4 or 2 (normal / abnormal) for lungsound classification 56 | - `--test_fold`: "official" denotes 60/40% train/test split, and "0"~"4" denote 80/20% split 57 | 58 | Important arguments for models. 59 | - `--model`: network architecture, see [models](models/) 60 | - `--from_sl_official`: load ImageNet pretrained checkpoint 61 | - `--audioset_pretrained`: load AudioSet pretrained checkpoint and only support AST and SSAST 62 | 63 | Important arugment for evaluation. 64 | - `--eval`: switch mode to evaluation without any training 65 | - `--pretrained`: load pretrained checkpoint and require `pretrained_ckpt` argument. 66 | - `--pretrained_ckpt`: path for the pretrained checkpoint 67 | 68 | The pretrained model checkpoints will be saved at `save/[EXP_NAME]/best.pth`. 69 | 70 | ## Result 71 | 72 | Patch-Mix Contrastive Learning achieves the state-of-the-art performance of 62.37%, which is higher than previous Score by +4.08%. 73 |

74 | 75 |

76 | 77 | 78 | ## BibTeX 79 | If you find this repo useful for your research, please consider citing our paper: 80 | 81 | ``` 82 | @inproceedings{bae23b_interspeech, 83 | title = {Patch-Mix Contrastive Learning with Audio Spectrogram Transformer on Respiratory Sound Classification}, 84 | author = {Sangmin Bae and June-Woo Kim and Won-Yang Cho and Hyerim Baek and Soyoun Son and Byungjo Lee and Changwan Ha and Kyongpil Tae and Sungnyun Kim and Se-Young Yun}, 85 | year = {2023}, 86 | booktitle = {INTERSPEECH 2023}, 87 | pages = {5436--5440}, 88 | doi = {10.21437/Interspeech.2023-1426}, 89 | issn = {2958-1796}, 90 | } 91 | ``` 92 | 93 | ## Contact 94 | - Sangmin Bae: bsmn0223@kaist.ac.kr 95 | - June-Woo Kim: kaen2891@gmail.com 96 | -------------------------------------------------------------------------------- /data/icbhi_dataset/filename_differences.txt: -------------------------------------------------------------------------------- 1 | '101_1b1_Al_sc_AKGC417L' 2 | '101_1b1_Pr_sc_AKGC417L' 3 | '102_1b1_Ar_sc_AKGC417L' 4 | '105_1b1_Tc_sc_LittC2SE' 5 | '108_1b1_Al_sc_LittC2SE' 6 | '111_1b2_Tc_sc_LittC2SE' 7 | '111_1b3_Tc_sc_LittC2SE' 8 | '115_1b1_Ar_sc_LittC2SE' 9 | '116_1b2_Pl_sc_LittC2SE' 10 | '116_1b2_Tc_sc_LittC2SE' 11 | '119_1b1_Ar_sc_AKGC417L' 12 | '121_1b1_Tc_sc_LittC2SE' 13 | '121_1p1_Tc_sc_LittC2SE' 14 | '123_1b1_Al_sc_AKGC417L' 15 | '125_1b1_Tc_sc_LittC2SE' 16 | '126_1b1_Al_sc_AKGC417L' 17 | '127_1b1_Ar_sc_LittC2SE' 18 | '129_1b1_Ar_sc_LittC2SE' 19 | '131_1b1_Al_sc_LittC2SE' 20 | '136_1b1_Ar_sc_AKGC417L' 21 | '137_1b1_Ar_sc_LittC2SE' 22 | '137_1b1_Ll_sc_LittC2SE' 23 | '143_1b1_Al_sc_AKGC417L' 24 | '144_1b1_Al_sc_AKGC417L' 25 | '144_1b1_Tc_sc_AKGC417L' 26 | '148_1b1_Al_sc_LittC2SE' 27 | '149_1b1_Al_sc_LittC2SE' 28 | '149_1b1_Lr_sc_LittC2SE' 29 | '149_1b1_Pl_sc_LittC2SE' 30 | '150_1b2_Al_sc_AKGC417L' 31 | '152_1b1_Al_sc_LittC2SE' 32 | '153_1b1_Al_sc_LittC2SE' 33 | '159_1b1_Al_sc_AKGC417L' 34 | '159_1b1_Ar_sc_AKGC417L' 35 | '159_1b1_Ll_sc_AKGC417L' 36 | '159_1b1_Pr_sc_AKGC417L' 37 | '161_1b1_Al_sc_LittC2SE' 38 | '161_1b1_Pl_sc_LittC2SE' 39 | '164_1b1_Ll_sc_LittC2SE' 40 | '165_1b1_Ar_sc_AKGC417L' 41 | '165_1b1_Pl_sc_AKGC417L' 42 | '165_1b1_Pr_sc_AKGC417L' 43 | '167_1b1_Al_sc_LittC2SE' 44 | '167_1b1_Pr_sc_LittC2SE' 45 | '168_1b1_Al_sc_LittC2SE' 46 | '169_1b1_Lr_sc_AKGC417L' 47 | '169_1b2_Ll_sc_AKGC417L' 48 | '171_1b1_Al_sc_AKGC417L' 49 | '173_1b1_Al_sc_AKGC417L' 50 | '179_1b1_Al_sc_LittC2SE' 51 | '179_1b1_Tc_sc_LittC2SE' 52 | '182_1b1_Tc_sc_LittC2SE' 53 | '183_1b1_Pl_sc_AKGC417L' 54 | '183_1b1_Tc_sc_AKGC417L' 55 | '184_1b1_Ar_sc_LittC2SE' 56 | '187_1b1_Ll_sc_AKGC417L' 57 | '188_1b1_Al_sc_LittC2SE' 58 | '188_1b1_Ar_sc_LittC2SE' 59 | '188_1b1_Pl_sc_LittC2SE' 60 | '188_1b1_Tc_sc_LittC2SE' 61 | '190_1b1_Tc_sc_AKGC417L' 62 | '194_1b1_Lr_sc_AKGC417L' 63 | '194_1b1_Pr_sc_AKGC417L' 64 | '196_1b1_Pr_sc_LittC2SE' 65 | '197_1b1_Al_sc_AKGC417L' 66 | '197_1b1_Tc_sc_AKGC417L' 67 | '201_1b1_Al_sc_LittC2SE' 68 | '201_1b1_Ar_sc_LittC2SE' 69 | '201_1b2_Al_sc_LittC2SE' 70 | '201_1b2_Ar_sc_LittC2SE' 71 | '201_1b3_Al_sc_LittC2SE' 72 | '201_1b3_Ar_sc_LittC2SE' 73 | '202_1b1_Ar_sc_AKGC417L' 74 | '206_1b1_Ar_sc_LittC2SE' 75 | '206_1b1_Lr_sc_LittC2SE' 76 | '206_1b1_Pl_sc_LittC2SE' 77 | '208_1b1_Ll_sc_LittC2SE' 78 | '209_1b1_Tc_sc_LittC2SE' 79 | '210_1b1_Al_sc_LittC2SE' 80 | '210_1b1_Ar_sc_LittC2SE' 81 | '214_1b1_Ar_sc_AKGC417L' 82 | '215_1b2_Ar_sc_LittC2SE' 83 | '215_1b3_Tc_sc_LittC2SE' 84 | '216_1b1_Al_sc_AKGC417L' 85 | '216_1b1_Pl_sc_AKGC417L' 86 | '217_1b1_Tc_sc_LittC2SE' 87 | '224_1b1_Tc_sc_AKGC417L' 88 | '224_1b2_Al_sc_AKGC417L' 89 | '225_1b1_Pl_sc_AKGC417L' 90 | '226_1b1_Al_sc_LittC2SE' 91 | '226_1b1_Ll_sc_LittC2SE' 92 | -------------------------------------------------------------------------------- /data/icbhi_dataset/filename_format.txt: -------------------------------------------------------------------------------- 1 | Elements contained in the filenames: 2 | 3 | Patient number (101,102,...,226) 4 | Recording index 5 | Chest location (Trachea (Tc), {Anterior (A), Posterior (P), Lateral (L)}{left (l), right (r)}) 6 | Acquisition mode (sequential/single channel (sc), simultaneous/multichannel (mc)) 7 | Recording equipment (AKG C417L Microphone, 3M Littmann Classic II SE Stethoscope, 3M Litmmann 3200 Electronic Stethoscope, WelchAllyn Meditron Master Elite Electronic Stethoscope) -------------------------------------------------------------------------------- /data/icbhi_dataset/metadata.txt: -------------------------------------------------------------------------------- 1 | 101 3.0 F NA 19.0 99.0 Pr 2 | 102 0.75 F NA 9.8 73.0 Ar 3 | 103 70.0 F 33.0 NA NA Ar 4 | 104 70.0 F 28.47 NA NA Al 5 | 105 7.0 F NA 32.0 135.0 Tc 6 | 106 73.0 F 21.0 NA NA Pl 7 | 107 75.0 F 33.7 NA NA Ll 8 | 108 3.0 M NA NA NA Al 9 | 109 84.0 F 33.53 NA NA Lr 10 | 110 75.0 M 25.21 NA NA Pr 11 | 111 63.0 M 28.4 NA NA Tc 12 | 112 60.0 M 22.86 NA NA Pr 13 | 113 58.0 M 28.41 NA NA Lr 14 | 114 77.0 M 23.12 NA NA Pr 15 | 115 0.58 M NA 7.14 64.0 Ar 16 | 116 56.0 M 28.58 NA NA Pl 17 | 117 68.0 M 24.4 NA NA Tc 18 | 118 81.0 M 36.76 NA NA Pr 19 | 119 2.0 F NA 15.2 94.0 Ar 20 | 120 78.0 M 35.14 NA NA Ar 21 | 121 13.0 F NA 65.0 170.0 Tc 22 | 122 66.0 M 33.0 NA NA Ar 23 | 123 5.0 M NA 25.0 125.0 Al 24 | 124 65.0 M 29.07 NA NA Al 25 | 125 14.0 M NA 62.0 170.0 Tc 26 | 126 1.0 F NA 10.18 80.0 Al 27 | 127 2.0 M NA 12.6 98.0 Ar 28 | 128 65.0 F 24.3 NA NA Tc 29 | 129 6.0 M NA 23.0 119.0 Ar 30 | 130 85.0 F 17.1 NA NA Pl 31 | 131 3.0 M NA 14.0 97.0 Al 32 | 132 71.0 M 34.0 NA NA Lr 33 | 133 68.0 M 27.4 NA NA Pl 34 | 134 61.0 M 32.0 NA NA Al 35 | 135 70.0 M 21.0 NA NA Ar 36 | 136 5.0 M NA 16.2 110.0 Ar 37 | 137 4.0 M NA 18.0 104.0 Ar 38 | 138 56.0 F 21.6 NA NA Ll 39 | 139 61.0 M 28.68 NA NA Lr 40 | 140 79.0 F 23.0 NA NA Tc 41 | 141 66.0 M 22.4 NA NA Pr 42 | 142 78.0 M 26.1 NA NA Pl 43 | 143 0.25 F NA 8.24 68.0 Al 44 | 144 3.0 M NA 16.7 100.0 Al 45 | 145 69.0 M 23.4 NA NA Lr 46 | 146 67.0 M 28.0 NA NA Pl 47 | 147 77.0 M 25.7 NA NA Ar 48 | 148 4.0 M NA 33.0 110.0 Al 49 | 149 0.67 M NA 9.5 70.0 Lr 50 | 150 0.67 F NA 8.12 74.0 Al 51 | 151 75.0 M 28.4 NA NA Ll 52 | 152 16.0 M NA 70.0 183.0 Al 53 | 153 3.0 M NA 16.7 103.0 Al 54 | 154 65.0 M 28.1 NA NA Ll 55 | 155 69.0 M 26.0 NA NA Al 56 | 156 80.0 M 22.9 NA NA Al 57 | 157 62.0 M 53.5 NA NA Pr 58 | 158 63.0 M 16.5 NA NA Lr 59 | 159 0.83 F NA 11.0 80.0 Al 60 | 160 74.0 M 27.4 NA NA Tc 61 | 161 2.0 F NA 12.0 85.0 Pl 62 | 162 67.0 F 24.9 NA NA Ll 63 | 163 62.0 M 28.3 NA NA Pl 64 | 164 1.0 M NA 13.0 NA Ll 65 | 165 2.0 F NA 12.7 97.0 Pr 66 | 166 71.0 M 25.06 NA NA Ar 67 | 167 1.0 F NA 11.5 86.4 Pr 68 | 168 19.0 F 17.35 NA NA Al 69 | 169 50.0 F 28.81 NA NA Ll 70 | 170 79.0 M 22.6 NA NA Pr 71 | 171 9.0 M NA 32.0 133.0 Al 72 | 172 73.0 M 29.3 NA NA Lr 73 | 173 3.0 M NA 17.3 NA Al 74 | 174 68.0 M 26.4 NA NA Al 75 | 175 63.0 M 28.34 NA NA Pl 76 | 176 65.0 M 30.1 NA NA Pl 77 | 177 56.0 M 22.1 NA NA Pr 78 | 178 58.0 M 30.1 NA NA Tc 79 | 179 10.0 F NA 15.0 104.0 Al 80 | 180 93.0 M 29.03 NA NA Lr 81 | 181 65.0 M 26.4 NA NA Ar 82 | 182 11.0 M NA 33.0 136.0 Tc 83 | 183 14.0 F NA NA NA Tc 84 | 184 2.0 F NA 15.0 100.0 Ar 85 | 185 75.0 M 27.7 NA NA Pl 86 | 186 71.0 M 30.0 NA NA Al 87 | 187 0.5 F NA 8.26 71.0 Ll 88 | 188 3.0 M NA 16.0 100.0 Al 89 | 189 75.0 F 26.2 NA NA Lr 90 | 190 3.0 F NA NA NA Tc 91 | 191 74.0 F 36.0 NA NA Pr 92 | 192 69.0 M 28.0 NA NA Ar 93 | 193 77.0 M 26.3 NA NA Al 94 | 194 2.0 M NA 12.8 86.0 Pr 95 | 195 67.0 M 29.41 NA NA Al 96 | 196 21.0 F 25.5 NA NA Pr 97 | 197 16.0 F NA NA NA Al 98 | 198 71.0 M 18.6 NA NA Pr 99 | 199 71.0 M 20.0 NA NA Ll 100 | 200 72.0 F 27.8 NA NA Ar 101 | 201 73.0 F 28.52 NA NA Al 102 | 202 2.0 M NA 11.84 87.0 Ar 103 | 203 57.0 F 24.0 NA NA Pr 104 | 204 66.0 M 29.76 NA NA Ll 105 | 205 45.0 M 20.1 NA NA Al 106 | 206 3.0 M NA 13.0 92.0 Pl 107 | 207 63.0 F 29.6 NA NA Pl 108 | 208 5.0 F NA 24.1 117.0 Ll 109 | 209 14.0 F NA 80.0 183.0 Tc 110 | 210 1.0 F NA 12.96 76.0 Al 111 | 211 70.0 F 31.1 NA NA Ar 112 | 212 83.0 M 23.0 NA NA Tc 113 | 213 58.0 F 24.7 NA NA Al 114 | 214 5.0 M NA 30.0 118.0 Ar 115 | 215 56.0 F 25.35 NA NA Ar 116 | 216 1.0 M NA 10.25 78.0 Pl 117 | 217 12.0 F NA NA NA Tc 118 | 218 75.0 M 26.29 NA NA Al 119 | 219 81.0 M 26.0 NA NA Ar 120 | 220 66.0 M 35.4 NA NA Al 121 | 221 74.0 F 29.0 NA NA Pl 122 | 222 60.0 M NA NA NA Ar 123 | 223 NA NA NA NA NA Ll 124 | 224 10.0 F NA 32.3 143.0 Tc 125 | 225 0.83 M NA 7.8 74.0 Pl 126 | 226 4.0 M NA 16.7 103.0 Ll 127 | -------------------------------------------------------------------------------- /data/icbhi_dataset/official_split.txt: -------------------------------------------------------------------------------- 1 | 101_1b1_Al_sc_Meditron test 2 | 101_1b1_Pr_sc_Meditron test 3 | 102_1b1_Ar_sc_Meditron test 4 | 103_2b2_Ar_mc_LittC2SE train 5 | 104_1b1_Al_sc_Litt3200 test 6 | 104_1b1_Ar_sc_Litt3200 test 7 | 104_1b1_Ll_sc_Litt3200 test 8 | 104_1b1_Lr_sc_Litt3200 test 9 | 104_1b1_Pl_sc_Litt3200 test 10 | 104_1b1_Pr_sc_Litt3200 test 11 | 105_1b1_Tc_sc_Meditron train 12 | 106_2b1_Pl_mc_LittC2SE train 13 | 106_2b1_Pr_mc_LittC2SE train 14 | 107_2b3_Al_mc_AKGC417L train 15 | 107_2b3_Ar_mc_AKGC417L train 16 | 107_2b3_Ll_mc_AKGC417L train 17 | 107_2b3_Lr_mc_AKGC417L train 18 | 107_2b3_Pl_mc_AKGC417L train 19 | 107_2b3_Pr_mc_AKGC417L train 20 | 107_2b3_Tc_mc_AKGC417L train 21 | 107_2b4_Al_mc_AKGC417L train 22 | 107_2b4_Ar_mc_AKGC417L train 23 | 107_2b4_Ll_mc_AKGC417L train 24 | 107_2b4_Lr_mc_AKGC417L train 25 | 107_2b4_Pl_mc_AKGC417L train 26 | 107_2b4_Pr_mc_AKGC417L train 27 | 107_2b4_Tc_mc_AKGC417L train 28 | 107_2b5_Al_mc_AKGC417L train 29 | 107_2b5_Ar_mc_AKGC417L train 30 | 107_2b5_Ll_mc_AKGC417L train 31 | 107_2b5_Lr_mc_AKGC417L train 32 | 107_2b5_Pl_mc_AKGC417L train 33 | 107_2b5_Pr_mc_AKGC417L train 34 | 107_2b5_Tc_mc_AKGC417L train 35 | 107_3p2_Al_mc_AKGC417L train 36 | 107_3p2_Ar_mc_AKGC417L train 37 | 107_3p2_Ll_mc_AKGC417L train 38 | 107_3p2_Lr_mc_AKGC417L train 39 | 107_3p2_Pl_mc_AKGC417L train 40 | 107_3p2_Pr_mc_AKGC417L train 41 | 107_3p2_Tc_mc_AKGC417L train 42 | 108_1b1_Al_sc_Meditron train 43 | 109_1b1_Al_sc_Litt3200 test 44 | 109_1b1_Ar_sc_Litt3200 test 45 | 109_1b1_Ll_sc_Litt3200 test 46 | 109_1b1_Lr_sc_Litt3200 test 47 | 109_1b1_Pl_sc_Litt3200 test 48 | 109_1b1_Pr_sc_Litt3200 test 49 | 110_1b1_Pr_sc_Meditron train 50 | 110_1p1_Al_sc_Meditron train 51 | 110_1p1_Ll_sc_Meditron train 52 | 110_1p1_Lr_sc_Meditron train 53 | 110_1p1_Pr_sc_Meditron train 54 | 111_1b2_Tc_sc_Meditron train 55 | 111_1b3_Tc_sc_Meditron train 56 | 112_1b1_Ar_sc_Meditron train 57 | 112_1b1_Lr_sc_Meditron train 58 | 112_1p1_Ll_sc_Litt3200 train 59 | 112_1p1_Pl_sc_Litt3200 train 60 | 112_1p1_Pr_sc_Litt3200 train 61 | 113_1b1_Al_sc_Litt3200 test 62 | 113_1b1_Ar_sc_Litt3200 test 63 | 113_1b1_Ll_sc_Litt3200 test 64 | 113_1b1_Lr_sc_Litt3200 test 65 | 113_1b1_Pl_sc_Litt3200 test 66 | 113_1b1_Pr_sc_Litt3200 test 67 | 114_1b4_Al_mc_AKGC417L train 68 | 114_1b4_Ar_mc_AKGC417L train 69 | 114_1b4_Lr_mc_AKGC417L train 70 | 114_1b4_Pl_mc_AKGC417L train 71 | 114_1b4_Pr_mc_AKGC417L train 72 | 115_1b1_Ar_sc_Meditron train 73 | 116_1b2_Pl_sc_Meditron train 74 | 116_1b2_Tc_sc_Meditron train 75 | 117_1b2_Tc_mc_LittC2SE train 76 | 117_1b3_Tc_mc_LittC2SE train 77 | 118_1b1_Al_sc_Litt3200 test 78 | 118_1b1_Ar_sc_Litt3200 test 79 | 118_1b1_Ll_sc_Litt3200 test 80 | 118_1b1_Lr_sc_Litt3200 test 81 | 118_1b1_Pl_sc_Litt3200 test 82 | 118_1b1_Pr_sc_Litt3200 test 83 | 119_1b1_Ar_sc_Meditron test 84 | 120_1b1_Al_sc_Meditron train 85 | 120_1b1_Ar_sc_Meditron train 86 | 120_1b1_Lr_sc_Meditron train 87 | 120_1b1_Pl_sc_Meditron train 88 | 120_1b1_Pr_sc_Meditron train 89 | 121_1b1_Tc_sc_Meditron train 90 | 121_1p1_Tc_sc_Meditron train 91 | 122_2b1_Al_mc_LittC2SE train 92 | 122_2b1_Ar_mc_LittC2SE train 93 | 122_2b1_Tc_mc_LittC2SE train 94 | 122_2b2_Al_mc_LittC2SE train 95 | 122_2b2_Ar_mc_LittC2SE train 96 | 122_2b2_Tc_mc_LittC2SE train 97 | 122_2b3_Al_mc_LittC2SE train 98 | 122_2b3_Ar_mc_LittC2SE train 99 | 122_2b3_Tc_mc_LittC2SE train 100 | 123_1b1_Al_sc_Meditron test 101 | 124_1b1_Al_sc_Litt3200 test 102 | 124_1b1_Ar_sc_Litt3200 test 103 | 124_1b1_Ll_sc_Litt3200 test 104 | 124_1b1_Lr_sc_Litt3200 test 105 | 124_1b1_Pl_sc_Litt3200 test 106 | 124_1b1_Pr_sc_Litt3200 test 107 | 125_1b1_Tc_sc_Meditron train 108 | 126_1b1_Al_sc_Meditron test 109 | 127_1b1_Ar_sc_Meditron train 110 | 128_1b3_Tc_mc_LittC2SE train 111 | 129_1b1_Ar_sc_Meditron train 112 | 130_1p2_Al_mc_AKGC417L train 113 | 130_1p2_Ar_mc_AKGC417L train 114 | 130_1p2_Ll_mc_AKGC417L train 115 | 130_1p2_Lr_mc_AKGC417L train 116 | 130_1p2_Pl_mc_AKGC417L train 117 | 130_1p2_Pr_mc_AKGC417L train 118 | 130_1p2_Tc_mc_AKGC417L train 119 | 130_1p3_Al_mc_AKGC417L train 120 | 130_1p3_Ar_mc_AKGC417L train 121 | 130_1p3_Ll_mc_AKGC417L train 122 | 130_1p3_Lr_mc_AKGC417L train 123 | 130_1p3_Pl_mc_AKGC417L train 124 | 130_1p3_Pr_mc_AKGC417L train 125 | 130_1p3_Tc_mc_AKGC417L train 126 | 130_1p4_Al_mc_AKGC417L train 127 | 130_1p4_Ar_mc_AKGC417L train 128 | 130_1p4_Ll_mc_AKGC417L train 129 | 130_1p4_Lr_mc_AKGC417L train 130 | 130_1p4_Pl_mc_AKGC417L train 131 | 130_1p4_Pr_mc_AKGC417L train 132 | 130_1p4_Tc_mc_AKGC417L train 133 | 130_2b2_Al_mc_AKGC417L train 134 | 130_2b2_Ar_mc_AKGC417L train 135 | 130_2b2_Ll_mc_AKGC417L train 136 | 130_2b2_Lr_mc_AKGC417L train 137 | 130_2b2_Pl_mc_AKGC417L train 138 | 130_2b2_Pr_mc_AKGC417L train 139 | 130_2b2_Tc_mc_AKGC417L train 140 | 130_2b3_Al_mc_AKGC417L train 141 | 130_2b3_Ar_mc_AKGC417L train 142 | 130_2b3_Ll_mc_AKGC417L train 143 | 130_2b3_Lr_mc_AKGC417L train 144 | 130_2b3_Pl_mc_AKGC417L train 145 | 130_2b3_Pr_mc_AKGC417L train 146 | 130_2b3_Tc_mc_AKGC417L train 147 | 130_2b4_Al_mc_AKGC417L train 148 | 130_2b4_Ar_mc_AKGC417L train 149 | 130_2b4_Ll_mc_AKGC417L train 150 | 130_2b4_Lr_mc_AKGC417L train 151 | 130_2b4_Pl_mc_AKGC417L train 152 | 130_2p3_Pl_mc_AKGC417L train 153 | 130_2p5_Al_mc_AKGC417L train 154 | 130_2p5_Ar_mc_AKGC417L train 155 | 130_2p5_Lr_mc_AKGC417L train 156 | 130_2p5_Pl_mc_AKGC417L train 157 | 130_2p5_Pr_mc_AKGC417L train 158 | 130_2p5_Tc_mc_AKGC417L train 159 | 130_3b3_Ll_mc_AKGC417L train 160 | 130_3b4_Al_mc_AKGC417L train 161 | 130_3b4_Ar_mc_AKGC417L train 162 | 130_3b4_Lr_mc_AKGC417L train 163 | 130_3b4_Pl_mc_AKGC417L train 164 | 130_3b4_Pr_mc_AKGC417L train 165 | 130_3p2_Al_mc_AKGC417L train 166 | 130_3p2_Ar_mc_AKGC417L train 167 | 130_3p2_Pl_mc_AKGC417L train 168 | 130_3p2_Pr_mc_AKGC417L train 169 | 130_3p2_Tc_mc_AKGC417L train 170 | 130_3p3_Al_mc_AKGC417L train 171 | 130_3p3_Pl_mc_AKGC417L train 172 | 130_3p3_Pr_mc_AKGC417L train 173 | 130_3p3_Tc_mc_AKGC417L train 174 | 130_3p4_Al_mc_AKGC417L train 175 | 130_3p4_Pl_mc_AKGC417L train 176 | 130_3p4_Pr_mc_AKGC417L train 177 | 130_3p4_Tc_mc_AKGC417L train 178 | 131_1b1_Al_sc_Meditron train 179 | 132_2b1_Lr_mc_LittC2SE train 180 | 132_2b2_Lr_mc_LittC2SE train 181 | 133_2p2_Al_mc_AKGC417L test 182 | 133_2p2_Ar_mc_AKGC417L test 183 | 133_2p2_Pl_mc_AKGC417L test 184 | 133_2p2_Tc_mc_AKGC417L test 185 | 133_2p3_Al_mc_AKGC417L test 186 | 133_2p3_Ar_mc_AKGC417L test 187 | 133_2p3_Pl_mc_AKGC417L test 188 | 133_2p3_Pr_mc_AKGC417L test 189 | 133_2p3_Tc_mc_AKGC417L test 190 | 133_2p4_Al_mc_AKGC417L test 191 | 133_2p4_Ar_mc_AKGC417L test 192 | 133_2p4_Pl_mc_AKGC417L test 193 | 133_2p4_Pr_mc_AKGC417L test 194 | 133_2p4_Tc_mc_AKGC417L test 195 | 133_3p2_Al_mc_AKGC417L test 196 | 133_3p2_Ar_mc_AKGC417L test 197 | 133_3p2_Pl_mc_AKGC417L test 198 | 133_3p2_Pr_mc_AKGC417L test 199 | 133_3p4_Tc_mc_AKGC417L test 200 | 134_2b1_Al_mc_LittC2SE train 201 | 134_2b1_Ar_mc_LittC2SE train 202 | 134_2b2_Al_mc_LittC2SE train 203 | 134_2b2_Ar_mc_LittC2SE train 204 | 134_2b3_Ar_mc_LittC2SE train 205 | 135_2b1_Al_mc_LittC2SE train 206 | 135_2b1_Ar_mc_LittC2SE train 207 | 135_2b1_Pl_mc_LittC2SE train 208 | 135_2b1_Tc_mc_LittC2SE train 209 | 135_2b2_Al_mc_LittC2SE train 210 | 135_2b2_Ar_mc_LittC2SE train 211 | 135_2b2_Pl_mc_LittC2SE train 212 | 135_2b2_Tc_mc_LittC2SE train 213 | 135_2b3_Al_mc_LittC2SE train 214 | 135_2b3_Ar_mc_LittC2SE train 215 | 135_2b3_Pl_mc_LittC2SE train 216 | 135_2b3_Pr_mc_LittC2SE train 217 | 135_2b3_Tc_mc_LittC2SE train 218 | 136_1b1_Ar_sc_Meditron test 219 | 137_1b1_Ar_sc_Meditron train 220 | 137_1b1_Ll_sc_Meditron train 221 | 138_1p2_Al_mc_AKGC417L train 222 | 138_1p2_Ar_mc_AKGC417L train 223 | 138_1p2_Ll_mc_AKGC417L train 224 | 138_1p2_Lr_mc_AKGC417L train 225 | 138_1p2_Pl_mc_AKGC417L train 226 | 138_1p2_Pr_mc_AKGC417L train 227 | 138_1p2_Tc_mc_AKGC417L train 228 | 138_1p3_Al_mc_AKGC417L train 229 | 138_1p3_Ar_mc_AKGC417L train 230 | 138_1p3_Ll_mc_AKGC417L train 231 | 138_1p3_Lr_mc_AKGC417L train 232 | 138_1p3_Pl_mc_AKGC417L train 233 | 138_1p3_Pr_mc_AKGC417L train 234 | 138_1p3_Tc_mc_AKGC417L train 235 | 138_1p4_Ar_mc_AKGC417L train 236 | 138_1p4_Ll_mc_AKGC417L train 237 | 138_1p4_Lr_mc_AKGC417L train 238 | 138_1p4_Pl_mc_AKGC417L train 239 | 138_1p4_Pr_mc_AKGC417L train 240 | 138_1p4_Tc_mc_AKGC417L train 241 | 138_2p2_Al_mc_AKGC417L train 242 | 138_2p2_Ar_mc_AKGC417L train 243 | 138_2p2_Ll_mc_AKGC417L train 244 | 138_2p2_Lr_mc_AKGC417L train 245 | 138_2p2_Pl_mc_AKGC417L train 246 | 138_2p2_Pr_mc_AKGC417L train 247 | 138_2p2_Tc_mc_AKGC417L train 248 | 139_1b1_Al_sc_Litt3200 test 249 | 139_1b1_Ar_sc_Litt3200 test 250 | 139_1b1_Ll_sc_Litt3200 test 251 | 139_1b1_Lr_sc_Litt3200 test 252 | 139_1b1_Pl_sc_Litt3200 test 253 | 139_1b1_Pr_sc_Litt3200 test 254 | 140_2b2_Ll_mc_LittC2SE train 255 | 140_2b2_Tc_mc_LittC2SE train 256 | 140_2b3_Ll_mc_LittC2SE train 257 | 140_2b3_Tc_mc_LittC2SE train 258 | 141_1b1_Pr_mc_LittC2SE train 259 | 141_1b2_Ar_mc_LittC2SE train 260 | 141_1b2_Lr_mc_LittC2SE train 261 | 141_1b2_Pr_mc_LittC2SE train 262 | 141_1b2_Tc_mc_LittC2SE train 263 | 141_1b3_Al_mc_LittC2SE train 264 | 141_1b3_Ar_mc_LittC2SE train 265 | 141_1b3_Pr_mc_LittC2SE train 266 | 142_1b1_Pl_mc_LittC2SE train 267 | 143_1b1_Al_sc_Meditron test 268 | 144_1b1_Al_sc_Meditron test 269 | 144_1b1_Tc_sc_Meditron test 270 | 145_2b2_Al_mc_AKGC417L test 271 | 145_2b2_Ar_mc_AKGC417L test 272 | 145_2b2_Lr_mc_AKGC417L test 273 | 145_2b2_Pr_mc_AKGC417L test 274 | 145_3b2_Ar_mc_AKGC417L test 275 | 145_3b2_Lr_mc_AKGC417L test 276 | 145_3b4_Pl_mc_AKGC417L test 277 | 146_2b2_Pl_mc_AKGC417L test 278 | 146_2b4_Al_mc_AKGC417L test 279 | 146_2b4_Ar_mc_AKGC417L test 280 | 146_2b4_Ll_mc_AKGC417L test 281 | 146_2b4_Lr_mc_AKGC417L test 282 | 146_2b4_Pr_mc_AKGC417L test 283 | 146_8p3_Al_mc_AKGC417L test 284 | 146_8p3_Ar_mc_AKGC417L test 285 | 146_8p3_Lr_mc_AKGC417L test 286 | 146_8p3_Pl_mc_AKGC417L test 287 | 146_8p3_Pr_mc_AKGC417L test 288 | 147_1b2_Tc_mc_AKGC417L test 289 | 147_1b3_Tc_mc_AKGC417L test 290 | 147_1b4_Tc_mc_AKGC417L test 291 | 147_2b2_Al_mc_AKGC417L test 292 | 147_2b2_Ar_mc_AKGC417L test 293 | 147_2b2_Pl_mc_AKGC417L test 294 | 147_2b3_Al_mc_AKGC417L test 295 | 147_2b3_Ar_mc_AKGC417L test 296 | 147_2b3_Ll_mc_AKGC417L test 297 | 147_2b3_Lr_mc_AKGC417L test 298 | 147_2b3_Pl_mc_AKGC417L test 299 | 147_2b4_Al_mc_AKGC417L test 300 | 147_2b4_Ar_mc_AKGC417L test 301 | 147_2b4_Ll_mc_AKGC417L test 302 | 147_2b4_Lr_mc_AKGC417L test 303 | 147_2b4_Pl_mc_AKGC417L test 304 | 148_1b1_Al_sc_Meditron train 305 | 149_1b1_Al_sc_Meditron test 306 | 149_1b1_Lr_sc_Meditron test 307 | 149_1b1_Pl_sc_Meditron test 308 | 150_1b2_Al_sc_Meditron test 309 | 151_2p2_Al_mc_AKGC417L test 310 | 151_2p2_Ar_mc_AKGC417L test 311 | 151_2p2_Ll_mc_AKGC417L test 312 | 151_2p2_Lr_mc_AKGC417L test 313 | 151_2p2_Pl_mc_AKGC417L test 314 | 151_2p2_Pr_mc_AKGC417L test 315 | 151_2p2_Tc_mc_AKGC417L test 316 | 151_2p3_Al_mc_AKGC417L test 317 | 151_2p3_Ar_mc_AKGC417L test 318 | 151_2p3_Ll_mc_AKGC417L test 319 | 151_2p3_Lr_mc_AKGC417L test 320 | 151_2p3_Pl_mc_AKGC417L test 321 | 151_2p3_Pr_mc_AKGC417L test 322 | 151_2p3_Tc_mc_AKGC417L test 323 | 151_2p4_Al_mc_AKGC417L test 324 | 151_2p4_Ar_mc_AKGC417L test 325 | 151_2p4_Ll_mc_AKGC417L test 326 | 151_2p4_Lr_mc_AKGC417L test 327 | 151_2p4_Pl_mc_AKGC417L test 328 | 151_2p4_Pr_mc_AKGC417L test 329 | 151_2p4_Tc_mc_AKGC417L test 330 | 151_3p2_Al_mc_AKGC417L test 331 | 151_3p2_Ar_mc_AKGC417L test 332 | 151_3p2_Lr_mc_AKGC417L test 333 | 151_3p2_Pl_mc_AKGC417L test 334 | 151_3p2_Pr_mc_AKGC417L test 335 | 151_3p2_Tc_mc_AKGC417L test 336 | 151_3p3_Ll_mc_AKGC417L test 337 | 152_1b1_Al_sc_Meditron train 338 | 153_1b1_Al_sc_Meditron train 339 | 154_1b3_Al_mc_AKGC417L train 340 | 154_1b3_Ar_mc_AKGC417L train 341 | 154_1b3_Ll_mc_AKGC417L train 342 | 154_1b3_Lr_mc_AKGC417L train 343 | 154_1b3_Pl_mc_AKGC417L train 344 | 154_1b3_Pr_mc_AKGC417L train 345 | 154_1b3_Tc_mc_AKGC417L train 346 | 154_2b4_Al_mc_AKGC417L train 347 | 154_2b4_Ar_mc_AKGC417L train 348 | 154_2b4_Ll_mc_AKGC417L train 349 | 154_2b4_Lr_mc_AKGC417L train 350 | 154_2b4_Pl_mc_AKGC417L train 351 | 154_2b4_Pr_mc_AKGC417L train 352 | 154_2b4_Tc_mc_AKGC417L train 353 | 154_3b3_Al_mc_AKGC417L train 354 | 154_3b3_Ar_mc_AKGC417L train 355 | 154_3b3_Ll_mc_AKGC417L train 356 | 154_4b4_Al_mc_AKGC417L train 357 | 154_4b4_Ar_mc_AKGC417L train 358 | 154_4b4_Ll_mc_AKGC417L train 359 | 154_4b4_Lr_mc_AKGC417L train 360 | 154_4b4_Pl_mc_AKGC417L train 361 | 154_4b4_Pr_mc_AKGC417L train 362 | 155_2b1_Al_mc_LittC2SE train 363 | 156_2b3_Al_mc_AKGC417L test 364 | 156_2b3_Ar_mc_AKGC417L train 365 | 156_2b3_Ll_mc_AKGC417L train 366 | 156_2b3_Lr_mc_AKGC417L test 367 | 156_2b3_Pl_mc_AKGC417L test 368 | 156_2b3_Pr_mc_AKGC417L train 369 | 156_5b3_Al_mc_AKGC417L train 370 | 156_5b3_Ar_mc_AKGC417L test 371 | 156_5b3_Ll_mc_AKGC417L test 372 | 156_5b3_Lr_mc_AKGC417L train 373 | 156_5b3_Pl_mc_AKGC417L test 374 | 156_5b3_Pr_mc_AKGC417L train 375 | 156_8b3_Al_mc_AKGC417L test 376 | 156_8b3_Ar_mc_AKGC417L train 377 | 156_8b3_Ll_mc_AKGC417L train 378 | 156_8b3_Lr_mc_AKGC417L test 379 | 156_8b3_Pl_mc_AKGC417L train 380 | 157_1b1_Al_sc_Meditron train 381 | 157_1b1_Ar_sc_Meditron train 382 | 157_1b1_Lr_sc_Meditron train 383 | 157_1b1_Pl_sc_Meditron train 384 | 157_1b1_Pr_sc_Meditron train 385 | 158_1b3_Ar_mc_LittC2SE train 386 | 158_1p2_Al_mc_AKGC417L train 387 | 158_1p2_Ar_mc_AKGC417L train 388 | 158_1p2_Ll_mc_AKGC417L train 389 | 158_1p2_Lr_mc_AKGC417L train 390 | 158_1p2_Pl_mc_AKGC417L train 391 | 158_1p2_Pr_mc_AKGC417L train 392 | 158_1p2_Tc_mc_AKGC417L train 393 | 158_1p3_Al_mc_AKGC417L train 394 | 158_1p3_Ar_mc_AKGC417L train 395 | 158_1p3_Ll_mc_AKGC417L train 396 | 158_1p3_Lr_mc_AKGC417L train 397 | 158_1p3_Pl_mc_AKGC417L train 398 | 158_1p3_Pr_mc_AKGC417L train 399 | 158_1p3_Tc_mc_AKGC417L train 400 | 158_1p4_Al_mc_AKGC417L train 401 | 158_1p4_Ar_mc_AKGC417L train 402 | 158_1p4_Lr_mc_AKGC417L train 403 | 158_1p4_Pl_mc_AKGC417L train 404 | 158_1p4_Pr_mc_AKGC417L train 405 | 158_1p4_Tc_mc_AKGC417L train 406 | 158_2p2_Ar_mc_AKGC417L train 407 | 158_2p3_Lr_mc_AKGC417L train 408 | 158_2p3_Tc_mc_AKGC417L train 409 | 159_1b1_Al_sc_Meditron train 410 | 159_1b1_Ar_sc_Meditron train 411 | 159_1b1_Ll_sc_Meditron train 412 | 159_1b1_Pr_sc_Meditron train 413 | 160_1b2_Al_mc_AKGC417L test 414 | 160_1b2_Ar_mc_AKGC417L test 415 | 160_1b2_Lr_mc_AKGC417L test 416 | 160_1b2_Pl_mc_AKGC417L test 417 | 160_1b2_Pr_mc_AKGC417L test 418 | 160_1b2_Tc_mc_AKGC417L test 419 | 160_1b3_Al_mc_AKGC417L test 420 | 160_1b3_Ar_mc_AKGC417L test 421 | 160_1b3_Lr_mc_AKGC417L test 422 | 160_1b3_Pl_mc_AKGC417L test 423 | 160_1b3_Pr_mc_AKGC417L test 424 | 160_1b3_Tc_mc_AKGC417L test 425 | 160_1b4_Al_mc_AKGC417L test 426 | 160_1b4_Ar_mc_AKGC417L test 427 | 160_1b4_Lr_mc_AKGC417L test 428 | 160_1b4_Pl_mc_AKGC417L test 429 | 160_1b4_Pr_mc_AKGC417L test 430 | 160_1b4_Tc_mc_AKGC417L test 431 | 160_2b3_Lr_mc_AKGC417L test 432 | 160_2b4_Ar_mc_AKGC417L test 433 | 160_2b4_Pl_mc_AKGC417L test 434 | 160_2b4_Pr_mc_AKGC417L test 435 | 160_2b4_Tc_mc_AKGC417L test 436 | 161_1b1_Al_sc_Meditron train 437 | 161_1b1_Pl_sc_Meditron train 438 | 162_1b2_Al_mc_AKGC417L train 439 | 162_1b2_Ar_mc_AKGC417L train 440 | 162_1b2_Ll_mc_AKGC417L train 441 | 162_1b2_Lr_mc_AKGC417L train 442 | 162_1b2_Pl_mc_AKGC417L train 443 | 162_1b2_Pr_mc_AKGC417L train 444 | 162_1b2_Tc_mc_AKGC417L train 445 | 162_2b2_Al_mc_AKGC417L train 446 | 162_2b2_Ar_mc_AKGC417L train 447 | 162_2b2_Pl_mc_AKGC417L train 448 | 162_2b2_Pr_mc_AKGC417L train 449 | 162_2b2_Tc_mc_AKGC417L train 450 | 162_2b3_Al_mc_AKGC417L train 451 | 162_2b3_Ar_mc_AKGC417L train 452 | 162_2b3_Lr_mc_AKGC417L train 453 | 162_2b3_Pl_mc_AKGC417L train 454 | 162_2b3_Pr_mc_AKGC417L train 455 | 162_2b3_Tc_mc_AKGC417L train 456 | 162_2b4_Al_mc_AKGC417L train 457 | 162_2b4_Ar_mc_AKGC417L train 458 | 162_2b4_Lr_mc_AKGC417L train 459 | 162_2b4_Pl_mc_AKGC417L train 460 | 162_2b4_Pr_mc_AKGC417L train 461 | 162_2b4_Tc_mc_AKGC417L train 462 | 163_2b2_Al_mc_AKGC417L train 463 | 163_2b2_Ar_mc_AKGC417L train 464 | 163_2b2_Ll_mc_AKGC417L train 465 | 163_2b2_Lr_mc_AKGC417L train 466 | 163_2b2_Pl_mc_AKGC417L train 467 | 163_2b2_Pr_mc_AKGC417L train 468 | 163_2b2_Tc_mc_AKGC417L train 469 | 163_8b3_Al_mc_AKGC417L train 470 | 163_8b3_Ar_mc_AKGC417L train 471 | 163_8b3_Ll_mc_AKGC417L train 472 | 163_8b3_Lr_mc_AKGC417L train 473 | 163_8b3_Pl_mc_AKGC417L train 474 | 163_8b3_Pr_mc_AKGC417L train 475 | 164_1b1_Ll_sc_Meditron train 476 | 165_1b1_Ar_sc_Meditron test 477 | 165_1b1_Pl_sc_Meditron test 478 | 165_1b1_Pr_sc_Meditron test 479 | 166_1p1_Al_sc_Meditron train 480 | 166_1p1_Ar_sc_Meditron train 481 | 166_1p1_Ll_sc_Meditron train 482 | 166_1p1_Pl_sc_Meditron train 483 | 166_1p1_Pr_sc_Meditron train 484 | 167_1b1_Al_sc_Meditron train 485 | 167_1b1_Pr_sc_Meditron train 486 | 168_1b1_Al_sc_Meditron train 487 | 169_1b1_Lr_sc_Meditron test 488 | 169_1b2_Ll_sc_Meditron test 489 | 170_1b2_Al_mc_AKGC417L test 490 | 170_1b2_Ar_mc_AKGC417L test 491 | 170_1b2_Lr_mc_AKGC417L test 492 | 170_1b2_Pl_mc_AKGC417L test 493 | 170_1b2_Pr_mc_AKGC417L test 494 | 170_1b2_Tc_mc_AKGC417L test 495 | 170_1b3_Al_mc_AKGC417L test 496 | 170_1b3_Ar_mc_AKGC417L test 497 | 170_1b3_Ll_mc_AKGC417L test 498 | 170_1b3_Lr_mc_AKGC417L test 499 | 170_1b3_Pl_mc_AKGC417L test 500 | 170_1b3_Pr_mc_AKGC417L test 501 | 170_1b3_Tc_mc_AKGC417L test 502 | 170_1b4_Al_mc_AKGC417L test 503 | 170_1b4_Ar_mc_AKGC417L test 504 | 170_1b4_Lr_mc_AKGC417L test 505 | 170_1b4_Pl_mc_AKGC417L test 506 | 170_1b4_Pr_mc_AKGC417L test 507 | 170_1b4_Tc_mc_AKGC417L test 508 | 170_2b2_Al_mc_AKGC417L test 509 | 170_2b2_Ar_mc_AKGC417L test 510 | 170_2b2_Lr_mc_AKGC417L test 511 | 170_2b2_Pl_mc_AKGC417L test 512 | 170_2b2_Pr_mc_AKGC417L test 513 | 170_2b2_Tc_mc_AKGC417L test 514 | 171_1b1_Al_sc_Meditron test 515 | 172_1b3_Al_mc_AKGC417L train 516 | 172_1b3_Ar_mc_AKGC417L train 517 | 172_1b3_Ll_mc_AKGC417L train 518 | 172_1b3_Lr_mc_AKGC417L train 519 | 172_1b3_Pl_mc_AKGC417L train 520 | 172_1b3_Pr_mc_AKGC417L train 521 | 172_1b3_Tc_mc_AKGC417L train 522 | 172_1b4_Al_mc_AKGC417L train 523 | 172_1b4_Ar_mc_AKGC417L train 524 | 172_1b4_Ll_mc_AKGC417L train 525 | 172_1b4_Lr_mc_AKGC417L train 526 | 172_1b4_Pl_mc_AKGC417L train 527 | 172_1b4_Pr_mc_AKGC417L train 528 | 172_1b4_Tc_mc_AKGC417L train 529 | 172_1b5_Al_mc_AKGC417L train 530 | 172_1b5_Ar_mc_AKGC417L train 531 | 172_1b5_Ll_mc_AKGC417L train 532 | 172_1b5_Lr_mc_AKGC417L train 533 | 172_1b5_Pl_mc_AKGC417L train 534 | 172_1b5_Pr_mc_AKGC417L train 535 | 172_1b5_Tc_mc_AKGC417L train 536 | 172_2b5_Al_mc_AKGC417L train 537 | 172_2b5_Ar_mc_AKGC417L train 538 | 172_2b5_Lr_mc_AKGC417L train 539 | 172_2b5_Pl_mc_AKGC417L train 540 | 172_2b5_Pr_mc_AKGC417L train 541 | 172_2b5_Tc_mc_AKGC417L train 542 | 173_1b1_Al_sc_Meditron test 543 | 174_1p2_Ar_mc_AKGC417L test 544 | 174_1p2_Ll_mc_AKGC417L test 545 | 174_1p2_Lr_mc_AKGC417L test 546 | 174_1p2_Pl_mc_AKGC417L test 547 | 174_1p2_Pr_mc_AKGC417L test 548 | 174_1p2_Tc_mc_AKGC417L test 549 | 174_1p3_Ar_mc_AKGC417L test 550 | 174_1p3_Ll_mc_AKGC417L test 551 | 174_1p3_Lr_mc_AKGC417L test 552 | 174_1p3_Pl_mc_AKGC417L test 553 | 174_1p3_Pr_mc_AKGC417L test 554 | 174_1p3_Tc_mc_AKGC417L test 555 | 174_1p4_Ar_mc_AKGC417L test 556 | 174_1p4_Ll_mc_AKGC417L test 557 | 174_1p4_Lr_mc_AKGC417L test 558 | 174_1p4_Pl_mc_AKGC417L test 559 | 174_1p4_Pr_mc_AKGC417L test 560 | 174_1p4_Tc_mc_AKGC417L test 561 | 174_2p3_Al_mc_AKGC417L test 562 | 174_2p3_Ar_mc_AKGC417L test 563 | 174_2p3_Pl_mc_AKGC417L test 564 | 174_2p3_Pr_mc_AKGC417L test 565 | 174_2p3_Tc_mc_AKGC417L test 566 | 175_1b1_Al_sc_Litt3200 test 567 | 175_1b1_Ar_sc_Litt3200 test 568 | 175_1b1_Ll_sc_Litt3200 test 569 | 175_1b1_Lr_sc_Litt3200 test 570 | 175_1b1_Pl_sc_Litt3200 test 571 | 175_1b1_Pr_sc_Litt3200 test 572 | 176_1b3_Al_mc_AKGC417L test 573 | 176_1b3_Ar_mc_AKGC417L test 574 | 176_1b3_Ll_mc_AKGC417L test 575 | 176_1b3_Lr_mc_AKGC417L test 576 | 176_1b3_Pl_mc_AKGC417L test 577 | 176_1b3_Pr_mc_AKGC417L test 578 | 176_1b3_Tc_mc_AKGC417L test 579 | 176_1b4_Al_mc_AKGC417L test 580 | 176_1b4_Ar_mc_AKGC417L test 581 | 176_1b4_Ll_mc_AKGC417L test 582 | 176_1b4_Lr_mc_AKGC417L test 583 | 176_1b4_Pl_mc_AKGC417L test 584 | 176_1b4_Pr_mc_AKGC417L test 585 | 176_1b4_Tc_mc_AKGC417L test 586 | 176_2b3_Al_mc_AKGC417L test 587 | 176_2b3_Ar_mc_AKGC417L test 588 | 176_2b3_Ll_mc_AKGC417L test 589 | 176_2b3_Lr_mc_AKGC417L test 590 | 176_2b3_Pl_mc_AKGC417L test 591 | 176_2b3_Pr_mc_AKGC417L test 592 | 176_2b3_Tc_mc_AKGC417L test 593 | 177_1b2_Al_mc_AKGC417L test 594 | 177_1b2_Ar_mc_AKGC417L test 595 | 177_1b2_Lr_mc_AKGC417L test 596 | 177_1b2_Pl_mc_AKGC417L test 597 | 177_1b2_Pr_mc_AKGC417L test 598 | 177_1b2_Tc_mc_AKGC417L test 599 | 177_1b4_Al_mc_AKGC417L test 600 | 177_1b4_Ar_mc_AKGC417L test 601 | 177_1b4_Lr_mc_AKGC417L test 602 | 177_1b4_Pl_mc_AKGC417L test 603 | 177_1b4_Pr_mc_AKGC417L test 604 | 177_1b4_Tc_mc_AKGC417L test 605 | 177_2b4_Al_mc_AKGC417L test 606 | 177_2b4_Lr_mc_AKGC417L test 607 | 177_2b4_Pl_mc_AKGC417L test 608 | 177_2b4_Pr_mc_AKGC417L test 609 | 177_2b4_Tc_mc_AKGC417L test 610 | 178_1b2_Al_mc_AKGC417L test 611 | 178_1b2_Ar_mc_AKGC417L test 612 | 178_1b2_Lr_mc_AKGC417L test 613 | 178_1b2_Pl_mc_AKGC417L test 614 | 178_1b2_Pr_mc_AKGC417L test 615 | 178_1b2_Tc_mc_AKGC417L test 616 | 178_1b3_Al_mc_AKGC417L test 617 | 178_1b3_Ar_mc_AKGC417L test 618 | 178_1b3_Lr_mc_AKGC417L test 619 | 178_1b3_Pl_mc_AKGC417L test 620 | 178_1b3_Pr_mc_AKGC417L test 621 | 178_1b3_Tc_mc_AKGC417L test 622 | 178_1b6_Al_mc_AKGC417L test 623 | 178_1b6_Ar_mc_AKGC417L test 624 | 178_1b6_Ll_mc_AKGC417L test 625 | 178_1b6_Lr_mc_AKGC417L test 626 | 178_1b6_Pl_mc_AKGC417L test 627 | 178_1b6_Pr_mc_AKGC417L test 628 | 178_1b6_Tc_mc_AKGC417L test 629 | 178_2b2_Al_mc_AKGC417L test 630 | 178_2b2_Ar_mc_AKGC417L test 631 | 178_2b2_Lr_mc_AKGC417L test 632 | 178_2b2_Pr_mc_AKGC417L test 633 | 178_2b2_Tc_mc_AKGC417L test 634 | 179_1b1_Al_sc_Meditron train 635 | 179_1b1_Tc_sc_Meditron train 636 | 180_1b4_Al_mc_AKGC417L train 637 | 180_1b4_Ar_mc_AKGC417L train 638 | 180_1b4_Lr_mc_AKGC417L train 639 | 180_1b4_Pl_mc_AKGC417L train 640 | 180_1b4_Pr_mc_AKGC417L train 641 | 181_1b1_Ar_mc_LittC2SE train 642 | 181_1b1_Tc_mc_LittC2SE train 643 | 181_1b2_Ar_mc_LittC2SE train 644 | 181_1b3_Tc_mc_LittC2SE train 645 | 182_1b1_Tc_sc_Meditron test 646 | 183_1b1_Pl_sc_Meditron train 647 | 183_1b1_Tc_sc_Meditron train 648 | 184_1b1_Ar_sc_Meditron train 649 | 185_1b1_Al_sc_Litt3200 test 650 | 185_1b1_Ar_sc_Litt3200 test 651 | 185_1b1_Ll_sc_Litt3200 test 652 | 185_1b1_Lr_sc_Litt3200 test 653 | 185_1b1_Pl_sc_Litt3200 test 654 | 185_1b1_Pr_sc_Litt3200 test 655 | 186_2b2_Al_mc_AKGC417L train 656 | 186_2b2_Ar_mc_AKGC417L train 657 | 186_2b2_Lr_mc_AKGC417L train 658 | 186_2b2_Pl_mc_AKGC417L train 659 | 186_2b2_Pr_mc_AKGC417L train 660 | 186_2b2_Tc_mc_AKGC417L train 661 | 186_2b3_Al_mc_AKGC417L train 662 | 186_2b3_Ar_mc_AKGC417L train 663 | 186_2b3_Lr_mc_AKGC417L train 664 | 186_2b3_Pl_mc_AKGC417L train 665 | 186_2b3_Pr_mc_AKGC417L train 666 | 186_2b3_Tc_mc_AKGC417L train 667 | 186_2b4_Al_mc_AKGC417L train 668 | 186_2b4_Ar_mc_AKGC417L train 669 | 186_2b4_Lr_mc_AKGC417L train 670 | 186_2b4_Pl_mc_AKGC417L train 671 | 186_2b4_Pr_mc_AKGC417L train 672 | 186_2b4_Tc_mc_AKGC417L train 673 | 186_3b3_Al_mc_AKGC417L train 674 | 186_3b3_Ar_mc_AKGC417L train 675 | 186_3b3_Lr_mc_AKGC417L train 676 | 186_3b3_Pl_mc_AKGC417L train 677 | 186_3b3_Pr_mc_AKGC417L train 678 | 186_3b3_Tc_mc_AKGC417L train 679 | 187_1b1_Ll_sc_Meditron test 680 | 188_1b1_Al_sc_Meditron train 681 | 188_1b1_Ar_sc_Meditron train 682 | 188_1b1_Pl_sc_Meditron train 683 | 188_1b1_Tc_sc_Meditron train 684 | 189_1b2_Lr_mc_LittC2SE train 685 | 190_1b1_Tc_sc_Meditron train 686 | 191_2b1_Pl_mc_LittC2SE train 687 | 191_2b1_Pr_mc_LittC2SE train 688 | 191_2b2_Tc_mc_LittC2SE train 689 | 192_2b1_Al_mc_LittC2SE train 690 | 192_2b1_Ar_mc_LittC2SE train 691 | 192_2b2_Al_mc_LittC2SE train 692 | 192_2b2_Ar_mc_LittC2SE train 693 | 192_2b3_Al_mc_LittC2SE train 694 | 192_2b3_Ar_mc_LittC2SE train 695 | 193_1b2_Al_mc_AKGC417L train 696 | 193_1b2_Ar_mc_AKGC417L train 697 | 193_1b2_Ll_mc_AKGC417L train 698 | 193_1b2_Pl_mc_AKGC417L train 699 | 193_1b2_Pr_mc_AKGC417L train 700 | 193_1b2_Tc_mc_AKGC417L train 701 | 193_1b4_Lr_mc_AKGC417L train 702 | 193_7b3_Al_mc_AKGC417L train 703 | 193_7b3_Ar_mc_AKGC417L train 704 | 193_7b3_Ll_mc_AKGC417L train 705 | 193_7b3_Lr_mc_AKGC417L train 706 | 193_7b3_Pl_mc_AKGC417L train 707 | 193_7b3_Pr_mc_AKGC417L train 708 | 193_7b3_Tc_mc_AKGC417L train 709 | 194_1b1_Lr_sc_Meditron test 710 | 194_1b1_Pr_sc_Meditron test 711 | 195_1b1_Al_sc_Litt3200 test 712 | 195_1b1_Ar_sc_Litt3200 test 713 | 195_1b1_Ll_sc_Litt3200 test 714 | 195_1b1_Lr_sc_Litt3200 test 715 | 195_1b1_Pl_sc_Litt3200 test 716 | 195_1b1_Pr_sc_Litt3200 test 717 | 196_1b1_Pr_sc_Meditron train 718 | 197_1b1_Al_sc_Meditron train 719 | 197_1b1_Tc_sc_Meditron train 720 | 198_1b5_Al_mc_AKGC417L test 721 | 198_1b5_Ar_mc_AKGC417L test 722 | 198_1b5_Ll_mc_AKGC417L test 723 | 198_1b5_Lr_mc_AKGC417L test 724 | 198_1b5_Pl_mc_AKGC417L test 725 | 198_1b5_Pr_mc_AKGC417L test 726 | 198_1b5_Tc_mc_AKGC417L test 727 | 198_6p1_Al_mc_AKGC417L test 728 | 198_6p1_Ar_mc_AKGC417L test 729 | 198_6p1_Ll_mc_AKGC417L test 730 | 198_6p1_Lr_mc_AKGC417L test 731 | 198_6p1_Pl_mc_AKGC417L test 732 | 198_6p1_Pr_mc_AKGC417L test 733 | 198_6p1_Tc_mc_AKGC417L test 734 | 199_2b1_Ll_mc_LittC2SE train 735 | 199_2b3_Ll_mc_LittC2SE train 736 | 200_2p2_Al_mc_AKGC417L train 737 | 200_2p2_Ar_mc_AKGC417L train 738 | 200_2p2_Lr_mc_AKGC417L train 739 | 200_2p2_Pl_mc_AKGC417L train 740 | 200_2p2_Pr_mc_AKGC417L train 741 | 200_2p2_Tc_mc_AKGC417L train 742 | 200_2p3_Al_mc_AKGC417L train 743 | 200_2p3_Ar_mc_AKGC417L train 744 | 200_2p3_Lr_mc_AKGC417L train 745 | 200_2p3_Pl_mc_AKGC417L train 746 | 200_2p3_Pr_mc_AKGC417L train 747 | 200_2p3_Tc_mc_AKGC417L train 748 | 200_2p4_Al_mc_AKGC417L train 749 | 200_2p4_Ar_mc_AKGC417L train 750 | 200_2p4_Lr_mc_AKGC417L train 751 | 200_2p4_Pl_mc_AKGC417L train 752 | 200_2p4_Pr_mc_AKGC417L train 753 | 200_2p4_Tc_mc_AKGC417L train 754 | 200_3p4_Al_mc_AKGC417L train 755 | 200_3p4_Ar_mc_AKGC417L train 756 | 200_3p4_Pl_mc_AKGC417L train 757 | 200_3p4_Pr_mc_AKGC417L train 758 | 200_3p4_Tc_mc_AKGC417L train 759 | 201_1b1_Al_sc_Meditron train 760 | 201_1b1_Ar_sc_Meditron train 761 | 201_1b2_Al_sc_Meditron train 762 | 201_1b2_Ar_sc_Meditron train 763 | 201_1b3_Al_sc_Meditron train 764 | 201_1b3_Ar_sc_Meditron train 765 | 202_1b1_Ar_sc_Meditron test 766 | 203_1p2_Al_mc_AKGC417L train 767 | 203_1p2_Ar_mc_AKGC417L train 768 | 203_1p2_Lr_mc_AKGC417L train 769 | 203_1p2_Pl_mc_AKGC417L train 770 | 203_1p2_Pr_mc_AKGC417L train 771 | 203_1p2_Tc_mc_AKGC417L train 772 | 203_1p3_Al_mc_AKGC417L train 773 | 203_1p3_Ar_mc_AKGC417L train 774 | 203_1p3_Pl_mc_AKGC417L train 775 | 203_1p3_Pr_mc_AKGC417L train 776 | 203_1p3_Tc_mc_AKGC417L train 777 | 203_1p4_Al_mc_AKGC417L train 778 | 203_1p4_Ar_mc_AKGC417L train 779 | 203_1p4_Pl_mc_AKGC417L train 780 | 203_1p4_Pr_mc_AKGC417L train 781 | 203_1p4_Tc_mc_AKGC417L train 782 | 203_2p3_Al_mc_AKGC417L train 783 | 203_2p3_Ar_mc_AKGC417L train 784 | 203_2p3_Pl_mc_AKGC417L train 785 | 203_2p3_Pr_mc_AKGC417L train 786 | 203_2p3_Tc_mc_AKGC417L train 787 | 204_2b5_Al_mc_AKGC417L train 788 | 204_2b5_Ar_mc_AKGC417L train 789 | 204_2b5_Ll_mc_AKGC417L train 790 | 204_7p5_Al_mc_AKGC417L train 791 | 204_7p5_Ar_mc_AKGC417L train 792 | 204_7p5_Ll_mc_AKGC417L train 793 | 204_7p5_Lr_mc_AKGC417L train 794 | 204_7p5_Pr_mc_AKGC417L train 795 | 204_7p5_Tc_mc_AKGC417L train 796 | 205_1b3_Al_mc_AKGC417L test 797 | 205_1b3_Ar_mc_AKGC417L test 798 | 205_1b3_Ll_mc_AKGC417L test 799 | 205_1b3_Lr_mc_AKGC417L test 800 | 205_1b3_Pl_mc_AKGC417L test 801 | 205_1b3_Pr_mc_AKGC417L test 802 | 205_2b2_Pr_mc_AKGC417L test 803 | 205_2b3_Al_mc_AKGC417L test 804 | 205_2b3_Ar_mc_AKGC417L test 805 | 205_2b3_Ll_mc_AKGC417L test 806 | 205_2b4_Pl_mc_AKGC417L test 807 | 205_3b4_Al_mc_AKGC417L test 808 | 205_3b4_Ar_mc_AKGC417L test 809 | 205_3b4_Pl_mc_AKGC417L test 810 | 205_3b4_Pr_mc_AKGC417L test 811 | 205_4b2_Al_mc_AKGC417L test 812 | 205_4b2_Ar_mc_AKGC417L test 813 | 205_4b2_Lr_mc_AKGC417L test 814 | 205_4b2_Pl_mc_AKGC417L test 815 | 205_4b2_Pr_mc_AKGC417L test 816 | 206_1b1_Ar_sc_Meditron train 817 | 206_1b1_Lr_sc_Meditron train 818 | 206_1b1_Pl_sc_Meditron train 819 | 207_2b2_Al_mc_AKGC417L test 820 | 207_2b2_Ar_mc_AKGC417L test 821 | 207_2b2_Pl_mc_AKGC417L test 822 | 207_2b2_Pr_mc_AKGC417L test 823 | 207_2b2_Tc_mc_AKGC417L test 824 | 207_2b3_Al_mc_AKGC417L test 825 | 207_2b3_Ar_mc_AKGC417L test 826 | 207_2b3_Pl_mc_AKGC417L test 827 | 207_2b3_Pr_mc_AKGC417L test 828 | 207_2b3_Tc_mc_AKGC417L test 829 | 207_2b4_Al_mc_AKGC417L test 830 | 207_2b4_Ar_mc_AKGC417L test 831 | 207_2b4_Pl_mc_AKGC417L test 832 | 207_2b4_Pr_mc_AKGC417L test 833 | 207_2b4_Tc_mc_AKGC417L test 834 | 207_3b2_Al_mc_AKGC417L test 835 | 207_3b2_Ar_mc_AKGC417L test 836 | 207_3b2_Lr_mc_AKGC417L test 837 | 207_3b2_Pl_mc_AKGC417L test 838 | 207_3b2_Pr_mc_AKGC417L test 839 | 207_3b2_Tc_mc_AKGC417L test 840 | 208_1b1_Ll_sc_Meditron train 841 | 209_1b1_Tc_sc_Meditron train 842 | 210_1b1_Al_sc_Meditron train 843 | 210_1b1_Ar_sc_Meditron train 844 | 211_1p2_Ar_mc_AKGC417L test 845 | 211_1p2_Pl_mc_AKGC417L test 846 | 211_1p2_Pr_mc_AKGC417L test 847 | 211_1p3_Ar_mc_AKGC417L test 848 | 211_1p5_Ar_mc_AKGC417L test 849 | 211_2p2_Tc_mc_AKGC417L test 850 | 211_2p3_Tc_mc_AKGC417L test 851 | 211_2p4_Tc_mc_AKGC417L test 852 | 212_2b2_Tc_mc_LittC2SE train 853 | 213_1p2_Al_mc_AKGC417L train 854 | 213_1p2_Ar_mc_AKGC417L train 855 | 213_1p2_Lr_mc_AKGC417L train 856 | 213_1p2_Pl_mc_AKGC417L train 857 | 213_1p2_Pr_mc_AKGC417L train 858 | 213_1p2_Tc_mc_AKGC417L train 859 | 213_1p3_Al_mc_AKGC417L train 860 | 213_1p3_Ar_mc_AKGC417L train 861 | 213_1p3_Pl_mc_AKGC417L train 862 | 213_1p3_Pr_mc_AKGC417L train 863 | 213_1p5_Al_mc_AKGC417L train 864 | 213_1p5_Ar_mc_AKGC417L train 865 | 213_1p5_Pl_mc_AKGC417L train 866 | 213_1p5_Pr_mc_AKGC417L train 867 | 213_1p5_Tc_mc_AKGC417L train 868 | 213_2p2_Al_mc_AKGC417L train 869 | 213_2p2_Ar_mc_AKGC417L train 870 | 213_2p2_Pl_mc_AKGC417L train 871 | 213_2p2_Pr_mc_AKGC417L train 872 | 213_2p2_Tc_mc_AKGC417L train 873 | 214_1b1_Ar_sc_Meditron test 874 | 215_1b2_Ar_sc_Meditron train 875 | 215_1b3_Tc_sc_Meditron train 876 | 216_1b1_Al_sc_Meditron test 877 | 216_1b1_Pl_sc_Meditron test 878 | 217_1b1_Tc_sc_Meditron train 879 | 218_1b1_Al_sc_Meditron test 880 | 218_1b1_Ar_sc_Meditron train 881 | 218_1b1_Lr_sc_Meditron train 882 | 218_1b1_Pl_sc_Meditron test 883 | 218_1b1_Pr_sc_Meditron test 884 | 218_1p1_Ar_sc_Litt3200 train 885 | 218_1p1_Pl_sc_Litt3200 test 886 | 218_1p1_Pr_sc_Litt3200 train 887 | 219_2b1_Ar_mc_LittC2SE train 888 | 219_2b1_Tc_mc_LittC2SE train 889 | 219_2b2_Ar_mc_LittC2SE train 890 | 219_2b2_Tc_mc_LittC2SE train 891 | 219_2b3_Tc_mc_LittC2SE train 892 | 220_1b1_Tc_mc_LittC2SE train 893 | 220_1b2_Al_mc_LittC2SE train 894 | 221_2b1_Al_mc_LittC2SE train 895 | 221_2b1_Ar_mc_LittC2SE train 896 | 221_2b1_Lr_mc_LittC2SE train 897 | 221_2b1_Pl_mc_LittC2SE train 898 | 221_2b2_Al_mc_LittC2SE train 899 | 221_2b2_Ar_mc_LittC2SE train 900 | 221_2b2_Lr_mc_LittC2SE train 901 | 221_2b2_Pl_mc_LittC2SE train 902 | 221_2b3_Al_mc_LittC2SE train 903 | 221_2b3_Ar_mc_LittC2SE train 904 | 221_2b3_Lr_mc_LittC2SE train 905 | 221_2b3_Pr_mc_LittC2SE train 906 | 222_1b1_Ar_sc_Meditron train 907 | 222_1b1_Lr_sc_Meditron train 908 | 222_1b1_Pr_sc_Meditron train 909 | 223_1b1_Pr_sc_Meditron test 910 | 223_1b1_Pl_sc_Meditron test 911 | 223_1b1_Ar_sc_Meditron test 912 | 223_1b1_Al_sc_Meditron test 913 | 223_1b1_Lr_sc_Meditron test 914 | 223_1b1_Ll_sc_Meditron test 915 | 224_1b1_Tc_sc_Meditron test 916 | 224_1b2_Al_sc_Meditron test 917 | 225_1b1_Pl_sc_Meditron test 918 | 226_1b1_Al_sc_Meditron train 919 | 226_1b1_Ll_sc_Meditron train 920 | 226_1b1_Pl_sc_LittC2SE train -------------------------------------------------------------------------------- /data/icbhi_dataset/patient_diagnosis.txt: -------------------------------------------------------------------------------- 1 | 101 URTI 2 | 102 Healthy 3 | 103 Asthma 4 | 104 COPD 5 | 105 URTI 6 | 106 COPD 7 | 107 COPD 8 | 108 LRTI 9 | 109 COPD 10 | 110 COPD 11 | 111 Bronchiectasis 12 | 112 COPD 13 | 113 COPD 14 | 114 COPD 15 | 115 LRTI 16 | 116 Bronchiectasis 17 | 117 COPD 18 | 118 COPD 19 | 119 URTI 20 | 120 COPD 21 | 121 Healthy 22 | 122 Pneumonia 23 | 123 Healthy 24 | 124 COPD 25 | 125 Healthy 26 | 126 Healthy 27 | 127 Healthy 28 | 128 COPD 29 | 129 URTI 30 | 130 COPD 31 | 131 URTI 32 | 132 COPD 33 | 133 COPD 34 | 134 COPD 35 | 135 Pneumonia 36 | 136 Healthy 37 | 137 URTI 38 | 138 COPD 39 | 139 COPD 40 | 140 Pneumonia 41 | 141 COPD 42 | 142 COPD 43 | 143 Healthy 44 | 144 Healthy 45 | 145 COPD 46 | 146 COPD 47 | 147 COPD 48 | 148 URTI 49 | 149 Bronchiolitis 50 | 150 URTI 51 | 151 COPD 52 | 152 Healthy 53 | 153 Healthy 54 | 154 COPD 55 | 155 COPD 56 | 156 COPD 57 | 157 COPD 58 | 158 COPD 59 | 159 Healthy 60 | 160 COPD 61 | 161 Bronchiolitis 62 | 162 COPD 63 | 163 COPD 64 | 164 URTI 65 | 165 URTI 66 | 166 COPD 67 | 167 Bronchiolitis 68 | 168 Bronchiectasis 69 | 169 Bronchiectasis 70 | 170 COPD 71 | 171 Healthy 72 | 172 COPD 73 | 173 Bronchiolitis 74 | 174 COPD 75 | 175 COPD 76 | 176 COPD 77 | 177 COPD 78 | 178 COPD 79 | 179 Healthy 80 | 180 COPD 81 | 181 COPD 82 | 182 Healthy 83 | 183 Healthy 84 | 184 Healthy 85 | 185 COPD 86 | 186 COPD 87 | 187 Healthy 88 | 188 URTI 89 | 189 COPD 90 | 190 URTI 91 | 191 Pneumonia 92 | 192 COPD 93 | 193 COPD 94 | 194 Healthy 95 | 195 COPD 96 | 196 Bronchiectasis 97 | 197 URTI 98 | 198 COPD 99 | 199 COPD 100 | 200 COPD 101 | 201 Bronchiectasis 102 | 202 Healthy 103 | 203 COPD 104 | 204 COPD 105 | 205 COPD 106 | 206 Bronchiolitis 107 | 207 COPD 108 | 208 Healthy 109 | 209 Healthy 110 | 210 URTI 111 | 211 COPD 112 | 212 COPD 113 | 213 COPD 114 | 214 Healthy 115 | 215 Bronchiectasis 116 | 216 Bronchiolitis 117 | 217 Healthy 118 | 218 COPD 119 | 219 Pneumonia 120 | 220 COPD 121 | 221 COPD 122 | 222 COPD 123 | 223 COPD 124 | 224 Healthy 125 | 225 Healthy 126 | 226 Pneumonia -------------------------------------------------------------------------------- /data/icbhi_dataset/patient_list_foldwise.txt: -------------------------------------------------------------------------------- 1 | 158 0 2 | 193 0 3 | 177 0 4 | 170 0 5 | 180 0 6 | 211 0 7 | 147 0 8 | 107 0 9 | 162 0 10 | 156 0 11 | 146 0 12 | 200 0 13 | 138 0 14 | 160 0 15 | 203 0 16 | 204 0 17 | 172 0 18 | 207 0 19 | 163 0 20 | 205 0 21 | 213 0 22 | 114 0 23 | 130 0 24 | 154 0 25 | 186 0 26 | 184 0 27 | 153 1 28 | 115 1 29 | 224 1 30 | 223 1 31 | 201 1 32 | 218 1 33 | 127 1 34 | 137 1 35 | 215 1 36 | 161 1 37 | 206 1 38 | 101 1 39 | 168 1 40 | 131 1 41 | 216 1 42 | 120 1 43 | 188 1 44 | 167 1 45 | 210 1 46 | 197 1 47 | 183 1 48 | 152 1 49 | 173 1 50 | 108 1 51 | 208 1 52 | 105 2 53 | 110 2 54 | 116 2 55 | 196 2 56 | 182 2 57 | 222 2 58 | 166 2 59 | 209 2 60 | 144 2 61 | 111 2 62 | 165 2 63 | 148 2 64 | 164 2 65 | 159 2 66 | 121 2 67 | 157 2 68 | 217 2 69 | 123 2 70 | 169 2 71 | 179 2 72 | 190 2 73 | 125 2 74 | 129 2 75 | 225 2 76 | 136 2 77 | 118 3 78 | 185 3 79 | 112 3 80 | 124 3 81 | 104 3 82 | 195 3 83 | 175 3 84 | 212 3 85 | 140 3 86 | 219 3 87 | 132 3 88 | 142 3 89 | 220 3 90 | 122 3 91 | 191 3 92 | 128 3 93 | 226 3 94 | 141 3 95 | 103 3 96 | 134 3 97 | 117 3 98 | 192 3 99 | 106 3 100 | 155 3 101 | 199 3 102 | 174 4 103 | 145 4 104 | 151 4 105 | 176 4 106 | 178 4 107 | 133 4 108 | 198 4 109 | 214 4 110 | 149 4 111 | 143 4 112 | 187 4 113 | 202 4 114 | 119 4 115 | 194 4 116 | 126 4 117 | 150 4 118 | 171 4 119 | 102 4 120 | 109 4 121 | 113 4 122 | 139 4 123 | 189 4 124 | 181 4 125 | 221 4 126 | 135 4 127 | -------------------------------------------------------------------------------- /data/icbhi_dataset/script.sh: -------------------------------------------------------------------------------- 1 | n=1 2 | appendix="Meditron" 3 | newname="Blank new name" 4 | while IFS= read -r oldname; do 5 | # echo $oldname $newname $appendix 6 | truncatedoldname=${oldname%?????????} 7 | 8 | newname=$truncatedoldname$appendix 9 | 10 | newname=${newname#"'"} 11 | oldname=${oldname#"'"} 12 | oldname=${oldname%"'"} 13 | 14 | mv -vn "$oldname".txt "$newname".txt 15 | mv -vn "$oldname".wav "$newname".wav 16 | done < "filename_differences.txt" -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import sys 4 | import json 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | 8 | import math 9 | import time 10 | import random 11 | import pickle 12 | import argparse 13 | import numpy as np 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | import torch.backends.cudnn as cudnn 19 | from torchvision import transforms 20 | 21 | from util.icbhi_dataset import ICBHIDataset 22 | from util.icbhi_util import get_score 23 | from util.augmentation import SpecAugment 24 | from util.misc import adjust_learning_rate, warmup_learning_rate, set_optimizer, update_moving_average 25 | from util.misc import AverageMeter, accuracy, save_model, update_json 26 | from models import get_backbone_class, Projector 27 | from method import PatchMixLoss, PatchMixConLoss 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser('argument for supervised training') 32 | 33 | parser.add_argument('--seed', type=int, default=0) 34 | parser.add_argument('--print_freq', type=int, default=10) 35 | parser.add_argument('--save_freq', type=int, default=100) 36 | parser.add_argument('--save_dir', type=str, default='./save') 37 | parser.add_argument('--tag', type=str, default='', 38 | help='tag for experiment name') 39 | parser.add_argument('--resume', type=str, default=None, 40 | help='path of model checkpoint to resume') 41 | parser.add_argument('--eval', action='store_true', 42 | help='only evaluation with pretrained encoder and classifier') 43 | parser.add_argument('--two_cls_eval', action='store_true', 44 | help='evaluate with two classes') 45 | 46 | # optimization 47 | parser.add_argument('--optimizer', type=str, default='adam') 48 | parser.add_argument('--epochs', type=int, default=400) 49 | parser.add_argument('--learning_rate', type=float, default=1e-3) 50 | parser.add_argument('--lr_decay_epochs', type=str, default='120,160') 51 | parser.add_argument('--lr_decay_rate', type=float, default=0.1) 52 | parser.add_argument('--weight_decay', type=float, default=1e-4) 53 | parser.add_argument('--momentum', type=float, default=0.9) 54 | parser.add_argument('--cosine', action='store_true', 55 | help='using cosine annealing') 56 | parser.add_argument('--warm', action='store_true', 57 | help='warm-up for large batch training') 58 | parser.add_argument('--warm_epochs', type=int, default=0, 59 | help='warmup epochs') 60 | parser.add_argument('--weighted_loss', action='store_true', 61 | help='weighted cross entropy loss (higher weights on abnormal class)') 62 | parser.add_argument('--mix_beta', default=1.0, type=float, 63 | help='patch-mix interpolation coefficient') 64 | parser.add_argument('--time_domain', action='store_true', 65 | help='patchmix for the specific time domain') 66 | 67 | # dataset 68 | parser.add_argument('--dataset', type=str, default='icbhi') 69 | parser.add_argument('--data_folder', type=str, default='./data/') 70 | parser.add_argument('--batch_size', type=int, default=128) 71 | parser.add_argument('--num_workers', type=int, default=8) 72 | # icbhi dataset 73 | parser.add_argument('--class_split', type=str, default='lungsound', 74 | help='lungsound: (normal, crackles, wheezes, both), diagnosis: (healthy, chronic diseases, non-chronic diseases)') 75 | parser.add_argument('--n_cls', type=int, default=4, 76 | help='set k-way classification problem') 77 | parser.add_argument('--test_fold', type=str, default='official', choices=['official', '0', '1', '2', '3', '4'], 78 | help='test fold to use official 60-40 split or 80-20 split from RespireNet') 79 | parser.add_argument('--weighted_sampler', action='store_true', 80 | help='weighted sampler inversly proportional to class ratio') 81 | parser.add_argument('--stetho_id', type=int, default=-1, 82 | help='stethoscope device id, use only when finetuning on each stethoscope data') 83 | parser.add_argument('--sample_rate', type=int, default=16000, 84 | help='sampling rate when load audio data, and it denotes the number of samples per one second') 85 | parser.add_argument('--butterworth_filter', type=int, default=None, 86 | help='apply specific order butterworth band-pass filter') 87 | parser.add_argument('--desired_length', type=int, default=8, 88 | help='fixed length size of individual cycle') 89 | parser.add_argument('--nfft', type=int, default=1024, 90 | help='the frequency size of fast fourier transform') 91 | parser.add_argument('--n_mels', type=int, default=128, 92 | help='the number of mel filter banks') 93 | parser.add_argument('--concat_aug_scale', type=float, default=0, 94 | help='to control the number (scale) of concatenation-based augmented samples') 95 | parser.add_argument('--pad_types', type=str, default='repeat', 96 | help='zero: zero-padding, repeat: padding with duplicated samples, aug: padding with augmented samples') 97 | parser.add_argument('--resz', type=float, default=1, 98 | help='resize the scale of mel-spectrogram') 99 | parser.add_argument('--raw_augment', type=int, default=0, 100 | help='control how many number of augmented raw audio samples') 101 | parser.add_argument('--blank_region_clip', action='store_true', 102 | help='remove the blank region, high frequency region') 103 | parser.add_argument('--specaug_policy', type=str, default='icbhi_ast_sup', 104 | help='policy (argument values) for SpecAugment') 105 | parser.add_argument('--specaug_mask', type=str, default='mean', 106 | help='specaug mask value', choices=['mean', 'zero']) 107 | 108 | # model 109 | parser.add_argument('--model', type=str, default='ast') 110 | parser.add_argument('--pretrained', action='store_true') 111 | parser.add_argument('--pretrained_ckpt', type=str, default=None, 112 | help='path to pre-trained encoder model') 113 | parser.add_argument('--from_sl_official', action='store_true', 114 | help='load from supervised imagenet-pretrained model (official PyTorch)') 115 | parser.add_argument('--ma_update', action='store_true', 116 | help='whether to use moving average update for model') 117 | parser.add_argument('--ma_beta', type=float, default=0, 118 | help='moving average value') 119 | # for AST 120 | parser.add_argument('--audioset_pretrained', action='store_true', 121 | help='load from imagenet- and audioset-pretrained model') 122 | # for SSAST 123 | parser.add_argument('--ssast_task', type=str, default='ft_avgtok', 124 | help='pretraining or fine-tuning task', choices=['ft_avgtok', 'ft_cls']) 125 | parser.add_argument('--fshape', type=int, default=16, 126 | help='fshape of SSAST') 127 | parser.add_argument('--tshape', type=int, default=16, 128 | help='tshape of SSAST') 129 | parser.add_argument('--ssast_pretrained_type', type=str, default='Patch', 130 | help='pretrained ckpt version of SSAST model') 131 | 132 | parser.add_argument('--method', type=str, default='ce') 133 | # Patch-Mix CL loss 134 | parser.add_argument('--proj_dim', type=int, default=768) 135 | parser.add_argument('--temperature', type=float, default=0.06) 136 | parser.add_argument('--alpha', type=float, default=1.0) 137 | parser.add_argument('--negative_pair', type=str, default='all', 138 | help='the method for selecting negative pair', choices=['all', 'diff_label']) 139 | parser.add_argument('--target_type', type=str, default='grad_block', 140 | help='how to make target representation', choices=['grad_block', 'grad_flow', 'project_block', 'project_flow']) 141 | 142 | args = parser.parse_args() 143 | 144 | iterations = args.lr_decay_epochs.split(',') 145 | args.lr_decay_epochs = list([]) 146 | for it in iterations: 147 | args.lr_decay_epochs.append(int(it)) 148 | 149 | args.model_name = '{}_{}_{}'.format(args.dataset, args.model, args.method) 150 | if args.tag: 151 | args.model_name += '_{}'.format(args.tag) 152 | 153 | if args.method in ['patchmix', 'patchmix_cl']: 154 | assert args.model in ['ast', 'ssast'] 155 | 156 | args.save_folder = os.path.join(args.save_dir, args.model_name) 157 | if not os.path.isdir(args.save_folder): 158 | os.makedirs(args.save_folder) 159 | 160 | if args.warm: 161 | args.warmup_from = args.learning_rate * 0.1 162 | args.warm_epochs = 10 163 | if args.cosine: 164 | eta_min = args.learning_rate * (args.lr_decay_rate ** 3) 165 | args.warmup_to = eta_min + (args.learning_rate - eta_min) * ( 166 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2 167 | else: 168 | args.warmup_to = args.learning_rate 169 | 170 | if args.dataset == 'icbhi': 171 | if args.class_split == 'lungsound': 172 | if args.n_cls == 4: 173 | args.cls_list = ['normal', 'crackle', 'wheeze', 'both'] 174 | elif args.n_cls == 2: 175 | args.cls_list = ['normal', 'abnormal'] 176 | elif args.class_split == 'diagnosis': 177 | if args.n_cls == 3: 178 | args.cls_list = ['healthy', 'chronic_diseases', 'non-chronic_diseases'] 179 | elif args.n_cls == 2: 180 | args.cls_list = ['healthy', 'unhealthy'] 181 | else: 182 | raise NotImplementedError 183 | 184 | return args 185 | 186 | 187 | def set_loader(args): 188 | if args.dataset == 'icbhi': 189 | # get rawo information and calculate mean and std for normalization 190 | # dataset = ICBHIDataset(train_flag=True, transform=transforms.Compose([transforms.ToTensor()]), args=args, print_flag=False, mean_std=True) 191 | # mean, std = get_mean_and_std(dataset) 192 | # args.h, args.w = dataset.h, dataset.w 193 | 194 | # print('*' * 20) 195 | # print('[Raw dataset information]') 196 | # print('Stethoscope device number: {}, and patience number without overlap: {}'.format(len(dataset.device_to_id), len(set(sum(dataset.device_id_to_patient.values(), []))) )) 197 | # for device, id in dataset.device_to_id.items(): 198 | # print('Device {} ({}): {} number of patience'.format(id, device, len(dataset.device_id_to_patient[id]))) 199 | # print('Spectrogram shpae on ICBHI dataset: {} (height) and {} (width)'.format(args.h, args.w)) 200 | # print('Mean and std of ICBHI dataset: {} (mean) and {} (std)'.format(round(mean.item(), 2), round(std.item(), 2))) 201 | 202 | args.h, args.w = 798, 128 203 | train_transform = [transforms.ToTensor(), 204 | SpecAugment(args), 205 | transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz)))] 206 | val_transform = [transforms.ToTensor(), 207 | transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz)))] 208 | # train_transform.append(transforms.Normalize(mean=mean, std=std)) 209 | # val_transform.append(transforms.Normalize(mean=mean, std=std)) 210 | 211 | train_transform = transforms.Compose(train_transform) 212 | val_transform = transforms.Compose(val_transform) 213 | 214 | train_dataset = ICBHIDataset(train_flag=True, transform=train_transform, args=args, print_flag=True) 215 | val_dataset = ICBHIDataset(train_flag=False, transform=val_transform, args=args, print_flag=True) 216 | 217 | # for weighted_loss 218 | args.class_nums = train_dataset.class_nums 219 | else: 220 | raise NotImplemented 221 | 222 | if args.weighted_sampler: 223 | reciprocal_weights = [] 224 | for idx in range(len(train_dataset)): 225 | reciprocal_weights.append(train_dataset.class_ratio[train_dataset.labels[idx]]) 226 | weights = (1 / torch.Tensor(reciprocal_weights)) 227 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(train_dataset)) 228 | else: 229 | sampler = None 230 | 231 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=sampler is None, 232 | num_workers=args.num_workers, pin_memory=True, sampler=sampler, drop_last=True) 233 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 234 | num_workers=args.num_workers, pin_memory=True, sampler=None) 235 | 236 | return train_loader, val_loader, args 237 | 238 | 239 | def set_model(args): 240 | kwargs = {} 241 | if args.model == 'ast': 242 | kwargs['input_fdim'] = int(args.h * args.resz) 243 | kwargs['input_tdim'] = int(args.w * args.resz) 244 | kwargs['label_dim'] = args.n_cls 245 | kwargs['imagenet_pretrain'] = args.from_sl_official 246 | kwargs['audioset_pretrain'] = args.audioset_pretrained 247 | kwargs['mix_beta'] = args.mix_beta # for Patch-MixCL 248 | elif args.model == 'ssast': 249 | kwargs['label_dim'] = args.n_cls 250 | kwargs['fshape'], kwargs['tshape'] = args.fshape, args.tshape 251 | kwargs['fstride'], kwargs['tstride'] = 10, 10 252 | kwargs['input_tdim'] = 798 253 | kwargs['task'] = args.ssast_task 254 | kwargs['pretrain_stage'] = not args.audioset_pretrained 255 | kwargs['load_pretrained_mdl_path'] = args.ssast_pretrained_type 256 | kwargs['mix_beta'] = args.mix_beta # for Patch-MixCL 257 | 258 | model = get_backbone_class(args.model)(**kwargs) 259 | classifier = nn.Linear(model.final_feat_dim, args.n_cls) if args.model not in ['ast', 'ssast'] else deepcopy(model.mlp_head) 260 | 261 | if not args.weighted_loss: 262 | weights = None 263 | criterion = nn.CrossEntropyLoss() 264 | else: 265 | weights = torch.tensor(args.class_nums, dtype=torch.float32) 266 | weights = 1.0 / (weights / weights.sum()) 267 | weights /= weights.sum() 268 | 269 | criterion = nn.CrossEntropyLoss(weight=weights) 270 | 271 | if args.model not in ['ast', 'ssast'] and args.from_sl_official: 272 | model.load_sl_official_weights() 273 | print('pretrained model loaded from PyTorch ImageNet-pretrained') 274 | 275 | # load SSL pretrained checkpoint for linear evaluation 276 | if args.pretrained and args.pretrained_ckpt is not None: 277 | ckpt = torch.load(args.pretrained_ckpt, map_location='cpu') 278 | state_dict = ckpt['model'] 279 | 280 | # HOTFIX: always use dataparallel during SSL pretraining 281 | new_state_dict = {} 282 | for k, v in state_dict.items(): 283 | if "module." in k: 284 | k = k.replace("module.", "") 285 | if "backbone." in k: 286 | k = k.replace("backbone.", "") 287 | 288 | new_state_dict[k] = v 289 | state_dict = new_state_dict 290 | model.load_state_dict(state_dict, strict=False) 291 | 292 | if ckpt.get('classifier', None) is not None: 293 | classifier.load_state_dict(ckpt['classifier'], strict=True) 294 | 295 | print('pretrained model loaded from: {}'.format(args.pretrained_ckpt)) 296 | 297 | projector = Projector(model.final_feat_dim, args.proj_dim) if args.method == 'patchmix_cl' else nn.Identity() 298 | 299 | if args.method == 'ce': 300 | criterion = [criterion.cuda()] 301 | elif args.method == 'patchmix': 302 | criterion = [criterion.cuda(), PatchMixLoss(criterion=criterion).cuda()] 303 | elif args.method == 'patchmix_cl': 304 | criterion = [criterion.cuda(), PatchMixConLoss(temperature=args.temperature).cuda()] 305 | 306 | if torch.cuda.device_count() > 1: 307 | model = torch.nn.DataParallel(model) 308 | 309 | model.cuda() 310 | classifier.cuda() 311 | projector.cuda() 312 | 313 | optim_params = list(model.parameters()) + list(classifier.parameters()) + list(projector.parameters()) 314 | optimizer = set_optimizer(args, optim_params) 315 | 316 | return model, classifier, projector, criterion, optimizer 317 | 318 | 319 | def train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args, scaler=None): 320 | model.train() 321 | classifier.train() 322 | projector.train() 323 | 324 | batch_time = AverageMeter() 325 | data_time = AverageMeter() 326 | losses = AverageMeter() 327 | top1 = AverageMeter() 328 | 329 | end = time.time() 330 | for idx, (images, labels, metadata) in enumerate(train_loader): 331 | if args.ma_update: 332 | # store the previous iter checkpoint 333 | with torch.no_grad(): 334 | ma_ckpt = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict()), deepcopy(projector.state_dict())] 335 | 336 | data_time.update(time.time() - end) 337 | 338 | images = images.cuda(non_blocking=True) 339 | labels = labels.cuda(non_blocking=True) 340 | bsz = labels.shape[0] 341 | 342 | warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer) 343 | 344 | with torch.cuda.amp.autocast(): 345 | if args.method == 'ce': 346 | features = model(images) 347 | output = classifier(features) 348 | loss = criterion[0](output, labels) 349 | 350 | elif args.method == 'patchmix': 351 | mix_images, labels_a, labels_b, lam, index = model(images, y=labels, patch_mix=True, time_domain=args.time_domain) 352 | output = classifier(mix_images) 353 | loss = criterion[1](output, labels_a, labels_b, lam) 354 | 355 | elif args.method == 'patchmix_cl': 356 | features = model(images) 357 | output = classifier(features) 358 | loss = criterion[0](output, labels) 359 | 360 | if args.target_type == 'grad_block': 361 | proj1 = deepcopy(features.detach()) 362 | elif args.target_type == 'grad_flow': 363 | proj1 = features 364 | elif args.target_type == 'project_block': 365 | proj1 = deepcopy(projector(features).detach()) 366 | elif args.target_type == 'project_flow': 367 | proj1 = projector(features) 368 | 369 | # use 'patchmix_cl' for augmentation 370 | mix_images, labels_a, labels_b, lam, index = model(images, y=labels, patch_mix=True, time_domain=args.time_domain) 371 | proj2 = projector(mix_images) 372 | loss += args.alpha * criterion[1](proj1, proj2, labels, labels_b, lam, index, args) 373 | 374 | losses.update(loss.item(), bsz) 375 | [acc1], _ = accuracy(output[:bsz], labels, topk=(1,)) 376 | top1.update(acc1[0], bsz) 377 | 378 | optimizer.zero_grad() 379 | scaler.scale(loss).backward() 380 | scaler.step(optimizer) 381 | scaler.update() 382 | 383 | # measure elapsed time 384 | batch_time.update(time.time() - end) 385 | end = time.time() 386 | 387 | if args.ma_update: 388 | with torch.no_grad(): 389 | # exponential moving average update 390 | model = update_moving_average(args.ma_beta, model, ma_ckpt[0]) 391 | classifier = update_moving_average(args.ma_beta, classifier, ma_ckpt[1]) 392 | projector = update_moving_average(args.ma_beta, projector, ma_ckpt[2]) 393 | 394 | # print info 395 | if (idx + 1) % args.print_freq == 0: 396 | print('Train: [{0}][{1}/{2}]\t' 397 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 398 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 399 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 400 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 401 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 402 | data_time=data_time, loss=losses, top1=top1)) 403 | sys.stdout.flush() 404 | 405 | return losses.avg, top1.avg 406 | 407 | 408 | def validate(val_loader, model, classifier, criterion, args, best_acc, best_model=None): 409 | save_bool = False 410 | model.eval() 411 | classifier.eval() 412 | 413 | batch_time = AverageMeter() 414 | losses = AverageMeter() 415 | top1 = AverageMeter() 416 | hits, counts = [0.0] * args.n_cls, [0.0] * args.n_cls 417 | 418 | with torch.no_grad(): 419 | end = time.time() 420 | for idx, (images, labels, metadata) in enumerate(val_loader): 421 | images = images.cuda(non_blocking=True) 422 | labels = labels.cuda(non_blocking=True) 423 | bsz = labels.shape[0] 424 | 425 | with torch.cuda.amp.autocast(): 426 | features = model(images) 427 | output = classifier(features) 428 | loss = criterion[0](output, labels) 429 | 430 | losses.update(loss.item(), bsz) 431 | [acc1], _ = accuracy(output, labels, topk=(1,)) 432 | top1.update(acc1[0], bsz) 433 | 434 | _, preds = torch.max(output, 1) 435 | for idx in range(preds.shape[0]): 436 | counts[labels[idx].item()] += 1.0 437 | if not args.two_cls_eval: 438 | if preds[idx].item() == labels[idx].item(): 439 | hits[labels[idx].item()] += 1.0 440 | else: # only when args.n_cls == 4 441 | if labels[idx].item() == 0 and preds[idx].item() == labels[idx].item(): 442 | hits[labels[idx].item()] += 1.0 443 | elif labels[idx].item() != 0 and preds[idx].item() > 0: # abnormal 444 | hits[labels[idx].item()] += 1.0 445 | 446 | sp, se, sc = get_score(hits, counts) 447 | 448 | batch_time.update(time.time() - end) 449 | end = time.time() 450 | 451 | if (idx + 1) % args.print_freq == 0: 452 | print('Test: [{0}/{1}]\t' 453 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 454 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 455 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 456 | idx + 1, len(val_loader), batch_time=batch_time, 457 | loss=losses, top1=top1)) 458 | 459 | if sc > best_acc[-1] and se > 5: 460 | save_bool = True 461 | best_acc = [sp, se, sc] 462 | best_model = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict())] 463 | 464 | print(' * S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f} (Best S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f})'.format(sp, se, sc, best_acc[0], best_acc[1], best_acc[-1])) 465 | print(' * Acc@1 {top1.avg:.2f}'.format(top1=top1)) 466 | 467 | return best_acc, best_model, save_bool 468 | 469 | 470 | def main(): 471 | args = parse_args() 472 | with open(os.path.join(args.save_folder, 'train_args.json'), 'w') as f: 473 | json.dump(vars(args), f, indent=4) 474 | 475 | # fix seed 476 | random.seed(args.seed) 477 | np.random.seed(args.seed) 478 | torch.manual_seed(args.seed) 479 | torch.cuda.manual_seed(args.seed) 480 | cudnn.deterministic = True 481 | cudnn.benchmark = True 482 | 483 | best_model = None 484 | if args.dataset == 'icbhi': 485 | best_acc = [0, 0, 0] # Specificity, Sensitivity, Score 486 | 487 | train_loader, val_loader, args = set_loader(args) 488 | model, classifier, projector, criterion, optimizer = set_model(args) 489 | 490 | if args.resume: 491 | if os.path.isfile(args.resume): 492 | print("=> loading checkpoint '{}'".format(args.resume)) 493 | checkpoint = torch.load(args.resume) 494 | args.start_epoch = checkpoint['epoch'] 495 | model.load_state_dict(checkpoint['model']) 496 | optimizer.load_state_dict(checkpoint['optimizer']) 497 | args.start_epoch += 1 498 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 499 | else: 500 | print("=> no checkpoint found at '{}'".format(args.resume)) 501 | else: 502 | args.start_epoch = 1 503 | 504 | # use mix_precision: 505 | scaler = torch.cuda.amp.GradScaler() 506 | 507 | print('*' * 20) 508 | if not args.eval: 509 | print('Training for {} epochs on {} dataset'.format(args.epochs, args.dataset)) 510 | for epoch in range(args.start_epoch, args.epochs+1): 511 | adjust_learning_rate(args, optimizer, epoch) 512 | 513 | # train for one epoch 514 | time1 = time.time() 515 | loss, acc = train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args, scaler) 516 | time2 = time.time() 517 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format( 518 | epoch, time2-time1, acc)) 519 | 520 | # eval for one epoch 521 | best_acc, best_model, save_bool = validate(val_loader, model, classifier, criterion, args, best_acc, best_model) 522 | 523 | # save a checkpoint of model and classifier when the best score is updated 524 | if save_bool: 525 | save_file = os.path.join(args.save_folder, 'best_epoch_{}.pth'.format(epoch)) 526 | print('Best ckpt is modified with Score = {:.2f} when Epoch = {}'.format(best_acc[2], epoch)) 527 | save_model(model, optimizer, args, epoch, save_file, classifier) 528 | 529 | if epoch % args.save_freq == 0: 530 | save_file = os.path.join(args.save_folder, 'epoch_{}.pth'.format(epoch)) 531 | save_model(model, optimizer, args, epoch, save_file, classifier) 532 | 533 | # save a checkpoint of classifier with the best accuracy or score 534 | save_file = os.path.join(args.save_folder, 'best.pth') 535 | model.load_state_dict(best_model[0]) 536 | classifier.load_state_dict(best_model[1]) 537 | save_model(model, optimizer, args, epoch, save_file, classifier) 538 | else: 539 | print('Testing the pretrained checkpoint on {} dataset'.format(args.dataset)) 540 | best_acc, _, _ = validate(val_loader, model, classifier, criterion, args, best_acc) 541 | 542 | update_json('%s' % args.model_name, best_acc, path=os.path.join(args.save_dir, 'results.json')) 543 | 544 | if __name__ == '__main__': 545 | main() 546 | -------------------------------------------------------------------------------- /method/__init__.py: -------------------------------------------------------------------------------- 1 | from .patchmix import PatchMixLoss 2 | from .patchmix_cl import PatchMixConLoss -------------------------------------------------------------------------------- /method/patchmix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PatchMixLoss(nn.Module): 8 | """ Refer to models/ast.py for the code of patch mixing 9 | """ 10 | def __init__(self, criterion): 11 | super().__init__() 12 | self.criterion = criterion 13 | 14 | def forward(self, pred, y_a, y_b, lam): 15 | loss = lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b) 16 | return loss 17 | -------------------------------------------------------------------------------- /method/patchmix_cl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PatchMixConLoss(nn.Module): 8 | def __init__(self, temperature=0.06): 9 | super().__init__() 10 | self.temperature = temperature 11 | 12 | def forward(self, projection1, projection2, labels_a, labels_b, lam, index, args): 13 | batch_size = projection1.shape[0] 14 | projection1, projection2 = F.normalize(projection1), F.normalize(projection2) 15 | anchor_dot_contrast = torch.div(torch.matmul(projection2, projection1.T), self.temperature) 16 | 17 | mask_a = torch.eye(batch_size).cuda() 18 | mask_b = torch.zeros(batch_size, batch_size).cuda() 19 | mask_b[torch.arange(batch_size).unsqueeze(1), index.view(-1, 1)] = 1 20 | 21 | mask = lam * mask_a + (1 - lam) * mask_b 22 | 23 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 24 | logits = anchor_dot_contrast - logits_max.detach() # for numerical stability 25 | 26 | exp_logits = torch.exp(logits) 27 | if args.negative_pair == 'diff_label': 28 | labels_a = labels_a.contiguous().view(-1, 1) 29 | logits_mask = torch.ne(labels_a, labels_a.T).cuda() + (mask_a.bool() + mask_b.bool()) 30 | exp_logits *= logits_mask.float() 31 | 32 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 33 | 34 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 35 | 36 | loss = -mean_log_prob_pos 37 | loss = loss.view(1, batch_size) 38 | 39 | loss = loss.mean() 40 | return loss 41 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cnn6 import CNN6 2 | from .resnet import ResNet10, ResNet18, ResNet34, ResNet50, ResNet101 3 | from .efficientnet import EfficientNet_B0, EfficientNet_B1, EfficientNet_B2 4 | from .ast import ASTModel 5 | from .ssast import SSASTModel 6 | from .projector import Projector 7 | 8 | _backbone_class_map = { 9 | 'cnn6': CNN6, 10 | 'resnet10': ResNet10, 11 | 'resnet18': ResNet18, 12 | 'resnet34': ResNet34, 13 | 'resnet50': ResNet50, 14 | 'resnet101': ResNet101, 15 | 'efficientnet_b0': EfficientNet_B0, 16 | 'efficientnet_b1': EfficientNet_B1, 17 | 'efficientnet_b2': EfficientNet_B2, 18 | 'ast': ASTModel, 19 | 'ssast': SSASTModel 20 | } 21 | 22 | 23 | def get_backbone_class(key): 24 | if key in _backbone_class_map: 25 | return _backbone_class_map[key] 26 | else: 27 | raise ValueError('Invalid backbone: {}'.format(key)) -------------------------------------------------------------------------------- /models/ast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.cuda.amp import autocast 5 | import os 6 | import wget 7 | import timm 8 | from copy import deepcopy 9 | from timm.models.layers import to_2tuple,trunc_normal_ 10 | 11 | 12 | # override the timm package to relax the input shape constraint. 13 | class PatchEmbed(nn.Module): 14 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 15 | super().__init__() 16 | 17 | img_size = to_2tuple(img_size) 18 | patch_size = to_2tuple(patch_size) 19 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 20 | self.img_size = img_size 21 | self.patch_size = patch_size 22 | self.num_patches = num_patches 23 | 24 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 25 | 26 | def forward(self, x): 27 | x = self.proj(x).flatten(2).transpose(1, 2) 28 | return x 29 | 30 | 31 | class ASTModel(nn.Module): 32 | """ 33 | The AST model. 34 | :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35 35 | :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6 36 | :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6 37 | :param input_fdim: the number of frequency bins of the input spectrogram 38 | :param input_tdim: the number of time frames of the input spectrogram 39 | :param imagenet_pretrain: if use ImageNet pretrained model 40 | :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model 41 | :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining. 42 | """ 43 | def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True, mix_beta=None): 44 | super(ASTModel, self).__init__() 45 | assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.' 46 | 47 | if verbose == True: 48 | print('---------------AST Model Summary---------------') 49 | print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain))) 50 | # override timm input shape restriction 51 | timm.models.vision_transformer.PatchEmbed = PatchEmbed 52 | self.final_feat_dim = 768 53 | self.mix_beta = mix_beta 54 | 55 | # if AudioSet pretraining is not used (but ImageNet pretraining may still apply) 56 | if audioset_pretrain == False: 57 | if model_size == 'tiny224': 58 | self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain) 59 | elif model_size == 'small224': 60 | self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain) 61 | elif model_size == 'base224': 62 | self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain) 63 | elif model_size == 'base384': 64 | self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain) 65 | else: 66 | raise Exception('Model size must be one of tiny224, small224, base224, base384.') 67 | self.original_num_patches = self.v.patch_embed.num_patches 68 | self.oringal_hw = int(self.original_num_patches ** 0.5) 69 | self.original_embedding_dim = self.v.pos_embed.shape[2] 70 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) 71 | 72 | # automatcially get the intermediate shape 73 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) 74 | num_patches = f_dim * t_dim 75 | self.v.patch_embed.num_patches = num_patches 76 | if verbose == True: 77 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) 78 | print('number of patches={:d}'.format(num_patches)) 79 | 80 | # the linear projection layer 81 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) 82 | if imagenet_pretrain == True: 83 | new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1)) 84 | new_proj.bias = self.v.patch_embed.proj.bias 85 | self.v.patch_embed.proj = new_proj 86 | 87 | # the positional embedding 88 | if imagenet_pretrain == True: 89 | # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24). 90 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw) 91 | # cut (from middle) or interpolate the second dimension of the positional embedding 92 | if t_dim <= self.oringal_hw: 93 | new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim] 94 | else: 95 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear') 96 | # cut (from middle) or interpolate the first dimension of the positional embedding 97 | if f_dim <= self.oringal_hw: 98 | new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :] 99 | else: 100 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') 101 | # flatten the positional embedding 102 | new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2) 103 | # concatenate the above positional embedding with the cls token and distillation token of the deit model. 104 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) 105 | else: 106 | # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding 107 | # TODO can use sinusoidal positional embedding instead 108 | new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim)) 109 | self.v.pos_embed = new_pos_embed 110 | trunc_normal_(self.v.pos_embed, std=.02) 111 | 112 | # now load a model that is pretrained on both ImageNet and AudioSet 113 | elif audioset_pretrain == True: 114 | if audioset_pretrain == True and imagenet_pretrain == False: 115 | raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.') 116 | if model_size != 'base384': 117 | raise ValueError('currently only has base384 AudioSet pretrained model.') 118 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 119 | 120 | out_dir = './pretrained_models/' 121 | if not os.path.exists(out_dir): 122 | os.makedirs(out_dir, exist_ok=True) 123 | 124 | if os.path.exists(os.path.join(out_dir, 'audioset_10_10_0.4593.pth')) == False: 125 | # this model performs 0.4593 mAP on the audioset eval set 126 | audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1' 127 | wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'audioset_10_10_0.4593.pth')) 128 | 129 | sd = torch.load(os.path.join(out_dir, 'audioset_10_10_0.4593.pth'), map_location=device) 130 | audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False) 131 | audio_model = torch.nn.DataParallel(audio_model) 132 | audio_model.load_state_dict(sd, strict=False) 133 | self.v = audio_model.module.v 134 | self.original_embedding_dim = self.v.pos_embed.shape[2] 135 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) 136 | 137 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) 138 | num_patches = f_dim * t_dim 139 | self.v.patch_embed.num_patches = num_patches 140 | if verbose == True: 141 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) 142 | print('number of patches={:d}'.format(num_patches)) 143 | 144 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101) 145 | # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding 146 | if t_dim < 101: 147 | new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim] 148 | # otherwise interpolate 149 | else: 150 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear') 151 | if f_dim < 12: 152 | new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :] 153 | # otherwise interpolate 154 | elif f_dim > 12: 155 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') 156 | new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2) 157 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) 158 | 159 | def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024): 160 | test_input = torch.randn(1, 1, input_fdim, input_tdim) 161 | test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) 162 | test_out = test_proj(test_input) 163 | f_dim = test_out.shape[2] 164 | t_dim = test_out.shape[3] 165 | return f_dim, t_dim 166 | 167 | def square_patch(self, patch, hw_num_patch): 168 | h, w = hw_num_patch 169 | B, _, dim = patch.size() 170 | square = patch.reshape(B, h, w, dim) 171 | return square 172 | 173 | def flatten_patch(self, square): 174 | B, h, w, dim = square.shape 175 | patch = square.reshape(B, h * w, dim) 176 | return patch 177 | 178 | def patch_mix(self, image, target, time_domain=False, hw_num_patch=None): 179 | if self.mix_beta > 0: 180 | lam = np.random.beta(self.mix_beta, self.mix_beta) 181 | else: 182 | lam = 1 183 | 184 | batch_size, num_patch, dim = image.size() 185 | device = image.device 186 | 187 | index = torch.randperm(batch_size).to(device) 188 | 189 | if not time_domain: 190 | num_mask = int(num_patch * (1. - lam)) 191 | mask = torch.randperm(num_patch)[:num_mask].to(device) 192 | 193 | image[:, mask, :] = image[index][:, mask, :] 194 | lam = 1 - (num_mask / num_patch) 195 | else: 196 | squared_1 = self.square_patch(image, hw_num_patch) 197 | squared_2 = self.square_patch(image[index], hw_num_patch) 198 | 199 | w_size = squared_1.size()[2] 200 | num_mask = int(w_size * (1. - lam)) 201 | mask = torch.randperm(w_size)[:num_mask].to(device) 202 | 203 | squared_1[:, :, mask, :] = squared_2[:, :, mask, :] 204 | image = self.flatten_patch(squared_1) 205 | lam = 1 - (num_mask / w_size) 206 | 207 | y_a, y_b = target, target[index] 208 | return image, y_a, y_b, lam, index 209 | 210 | @autocast() 211 | def forward(self, x, y=None, patch_mix=False, time_domain=False): 212 | """ 213 | :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 214 | :return: prediction 215 | """ 216 | # x = x.unsqueeze(1) 217 | x = x.transpose(2, 3) 218 | h_patch, w_patch = int((x.size()[2] - 16) / 10) + 1, int((x.size()[3] - 16) / 10) + 1 219 | 220 | B = x.shape[0] 221 | x = self.v.patch_embed(x) 222 | 223 | if patch_mix: 224 | x, y_a, y_b, lam, index = self.patch_mix(x, y, time_domain=time_domain, hw_num_patch=[h_patch, w_patch]) 225 | 226 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 227 | dist_token = self.v.dist_token.expand(B, -1, -1) 228 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 229 | x = x + self.v.pos_embed 230 | x = self.v.pos_drop(x) 231 | for i, blk in enumerate(self.v.blocks): 232 | x = blk(x) 233 | x = self.v.norm(x) 234 | x = (x[:, 0] + x[:, 1]) / 2 235 | # x = self.mlp_head(x) 236 | 237 | if not patch_mix: 238 | return x 239 | else: 240 | return x, y_a, y_b, lam, index -------------------------------------------------------------------------------- /models/cnn6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def init_layer(layer): 7 | """Initialize a Linear or Convolutional layer. """ 8 | nn.init.xavier_uniform_(layer.weight) 9 | if hasattr(layer, 'bias'): 10 | if layer.bias is not None: 11 | layer.bias.data.fill_(0.) 12 | 13 | 14 | def init_bn(bn): 15 | """Initialize a Batchnorm layer. """ 16 | bn.bias.data.fill_(0.) 17 | bn.weight.data.fill_(1.) 18 | 19 | 20 | class ConvBlock5x5(nn.Module): #for CNN6 21 | def __init__(self, in_channels, out_channels, stride=(1,1)): 22 | 23 | super(ConvBlock5x5, self).__init__() 24 | 25 | self.conv1 = nn.Conv2d(in_channels=in_channels, 26 | out_channels=out_channels, 27 | kernel_size=(5, 5), stride=stride, 28 | padding=(2, 2), bias=False) 29 | 30 | self.bn1 = nn.BatchNorm2d(out_channels) 31 | 32 | self.init_weight() 33 | 34 | def init_weight(self): 35 | init_layer(self.conv1) 36 | init_bn(self.bn1) 37 | 38 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 39 | x = input 40 | x = F.relu_(self.bn1(self.conv1(x))) 41 | if pool_type == 'max': 42 | x = F.max_pool2d(x, kernel_size=pool_size) 43 | elif pool_type == 'avg': 44 | x = F.avg_pool2d(x, kernel_size=pool_size) 45 | elif pool_type == 'avg+max': 46 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 47 | x2 = F.max_pool2d(x, kernel_size=pool_size) 48 | x = x1 + x2 49 | else: 50 | raise Exception('Incorrect argument!') 51 | 52 | return x 53 | 54 | 55 | class CNN6(nn.Module): 56 | def __init__(self): 57 | super(CNN6, self).__init__() 58 | self.final_feat_dim = 512 59 | 60 | self.do_dropout = False 61 | self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64, stride=(1,1)) 62 | self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128, stride=(1,1)) 63 | self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256, stride=(1,1)) 64 | self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512, stride=(1,1)) 65 | self.dropout = nn.Dropout(0.2) 66 | # self.linear = nn.Linear(512, num_classes, bias=True) 67 | 68 | def load_sl_official_weights(self): 69 | """ download AudioSet pretrained CNN6 in https://zenodo.org/record/3960586#.Y8dz8y_kEiY 70 | """ 71 | weights = torch.load('pretrained_models/Cnn6_mAP=0.343.pth')['model'] 72 | state_dict = {k: v for k, v in weights.items() if k in self.state_dict().keys()} 73 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 74 | 75 | def forward(self, x): 76 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 77 | if self.do_dropout: 78 | x = self.dropout(x) 79 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 80 | if self.do_dropout: 81 | x = self.dropout(x) 82 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 83 | if self.do_dropout: 84 | x = self.dropout(x) 85 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 86 | if self.do_dropout: 87 | x = self.dropout(x) 88 | 89 | x = torch.mean(x, dim=3) #mean over time dim 90 | (x1, _) = torch.max(x, dim=2) #max over freq dim 91 | x2 = torch.mean(x, dim=2) #mean over freq dim (after mean over time) 92 | x = x1 + x2 93 | 94 | # if self.embed_only: 95 | # return x 96 | # return self.linear(x) 97 | return x -------------------------------------------------------------------------------- /models/efficientnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import Tensor 4 | import torch.nn as nn 5 | import math 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.hub import load_state_dict_from_url 9 | from torchvision.models.efficientnet import _efficientnet_conf 10 | 11 | 12 | # torchvision.__version__ == '0.11.0+cu113' (arguments for _efficientnet_conf are different in latest version) 13 | class EfficientNet_B0(torchvision.models.efficientnet.EfficientNet): 14 | def __init__(self): 15 | inverted_residual_setting, last_channel = _efficientnet_conf('efficientnet_b0', width_mult=1.0, depth_mult=1.0) 16 | super().__init__(inverted_residual_setting, 0.2) 17 | 18 | del self.classifier 19 | self.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False) 20 | self.final_feat_dim = 1280 21 | 22 | def load_sl_official_weights(self, progress=True): 23 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", 24 | progress=progress) 25 | del state_dict['features.0.0.weight'] 26 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 27 | #if len(missing) > 0: 28 | #raise AssertionError('Model code may be incorrect') 29 | 30 | def load_ssl_official_weights(self, progress=True): 31 | raise NotImplemented 32 | 33 | def _forward_impl(self, x: Tensor) -> Tensor: 34 | # See note [TorchScript super()] 35 | x = self.features(x) 36 | 37 | x = self.avgpool(x) 38 | x = torch.flatten(x, 1) 39 | 40 | # x = self.classifier(x) 41 | 42 | return x 43 | 44 | 45 | class EfficientNet_B1(torchvision.models.efficientnet.EfficientNet): 46 | def __init__(self): 47 | inverted_residual_setting, last_channel = _efficientnet_conf('efficientnet_b1', width_mult=1.0, depth_mult=1.1) 48 | super().__init__(inverted_residual_setting, 0.2) 49 | 50 | del self.classifier 51 | self.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False) 52 | self.final_feat_dim = 1280 53 | 54 | def load_sl_official_weights(self, progress=True): 55 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", 56 | progress=progress) 57 | del state_dict['features.0.0.weight'] 58 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 59 | #if len(missing) > 0: 60 | #raise AssertionError('Model code may be incorrect') 61 | 62 | def load_ssl_official_weights(self, progress=True): 63 | raise NotImplemented 64 | 65 | def _forward_impl(self, x: Tensor) -> Tensor: 66 | # See note [TorchScript super()] 67 | x = self.features(x) 68 | 69 | x = self.avgpool(x) 70 | x = torch.flatten(x, 1) 71 | 72 | # x = self.classifier(x) 73 | 74 | return x 75 | 76 | 77 | class EfficientNet_B2(torchvision.models.efficientnet.EfficientNet): 78 | def __init__(self): 79 | inverted_residual_setting, last_channel = _efficientnet_conf('efficientnet_b2', width_mult=1.1, depth_mult=1.2) 80 | super().__init__(inverted_residual_setting, 0.3) 81 | 82 | del self.classifier 83 | self.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False) 84 | self.final_feat_dim = 1408 85 | 86 | def load_sl_official_weights(self, progress=True): 87 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", 88 | progress=progress) 89 | del state_dict['features.0.0.weight'] 90 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 91 | #if len(missing) > 0: 92 | #raise AssertionError('Model code may be incorrect') 93 | 94 | def load_ssl_official_weights(self, progress=True): 95 | raise NotImplemented 96 | 97 | def _forward_impl(self, x: Tensor) -> Tensor: 98 | # See note [TorchScript super()] 99 | x = self.features(x) 100 | 101 | x = self.avgpool(x) 102 | x = torch.flatten(x, 1) 103 | 104 | # x = self.classifier(x) 105 | 106 | return x 107 | -------------------------------------------------------------------------------- /models/projector.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Projector(nn.Module): 5 | def __init__(self, in_dim, out_dim=768, apply_bn=True): 6 | super(Projector, self).__init__() 7 | self.linear1 = nn.Linear(in_dim, in_dim) 8 | self.linear2 = nn.Linear(in_dim, out_dim) 9 | self.bn = nn.BatchNorm1d(in_dim) 10 | self.relu = nn.ReLU() 11 | if apply_bn: 12 | self.projector = nn.Sequential(self.linear1, self.bn, self.relu, self.linear2) 13 | else: 14 | self.projector = nn.Sequential(self.linear1, self.relu, self.linear2) 15 | 16 | def forward(self, x): 17 | return self.projector(x) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import Tensor 4 | import torch.nn as nn 5 | import math 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.hub import load_state_dict_from_url 9 | 10 | 11 | class ResNet10(torchvision.models.resnet.ResNet): 12 | def __init__(self, track_bn=True): 13 | def norm_layer(*args, **kwargs): 14 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn) 15 | super().__init__(torchvision.models.resnet.BasicBlock, [1, 1, 1, 1], norm_layer=norm_layer) 16 | del self.fc 17 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 18 | self.final_feat_dim = 512 19 | 20 | def load_sl_official_weights(self, progress=True): 21 | raise NotImplemented 22 | 23 | def load_ssl_official_weights(self, progress=True): 24 | raise NotImplemented 25 | 26 | def _forward_impl(self, x: Tensor) -> Tensor: 27 | # See note [TorchScript super()] 28 | x = self.conv1(x) 29 | x = self.bn1(x) 30 | x = self.relu(x) 31 | x = self.maxpool(x) 32 | 33 | x = self.layer1(x) 34 | x = self.layer2(x) 35 | x = self.layer3(x) 36 | x = self.layer4(x) 37 | 38 | x = self.avgpool(x) 39 | x = torch.flatten(x, 1) 40 | # x = self.fc(x) 41 | 42 | return x 43 | 44 | 45 | class ResNet18(torchvision.models.resnet.ResNet): 46 | def __init__(self, track_bn=True): 47 | def norm_layer(*args, **kwargs): 48 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn) 49 | super().__init__(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], norm_layer=norm_layer) 50 | del self.fc 51 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 52 | self.final_feat_dim = 512 53 | 54 | def load_sl_official_weights(self, progress=True): 55 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet18'], 56 | progress=progress) 57 | 58 | del state_dict['conv1.weight'] 59 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 60 | # if len(missing) > 0: 61 | # raise AssertionError('Model code may be incorrect') 62 | 63 | def load_ssl_official_weights(self, progress=True): 64 | raise NotImplemented 65 | 66 | def _forward_impl(self, x: Tensor) -> Tensor: 67 | # See note [TorchScript super()] 68 | x = self.conv1(x) 69 | x = self.bn1(x) 70 | x = self.relu(x) 71 | x = self.maxpool(x) 72 | 73 | x = self.layer1(x) 74 | x = self.layer2(x) 75 | x = self.layer3(x) 76 | x = self.layer4(x) 77 | 78 | x = self.avgpool(x) 79 | x = torch.flatten(x, 1) 80 | # x = self.fc(x) 81 | 82 | return x 83 | 84 | 85 | class ResNet34(torchvision.models.resnet.ResNet): 86 | def __init__(self, track_bn=True): 87 | def norm_layer(*args, **kwargs): 88 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn) 89 | super().__init__(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], norm_layer=norm_layer) 90 | del self.fc 91 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 92 | self.final_feat_dim = 512 93 | 94 | def load_sl_official_weights(self, progress=True): 95 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet34'], 96 | progress=progress) 97 | 98 | del state_dict['conv1.weight'] 99 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 100 | # if len(missing) > 0: 101 | # raise AssertionError('Model code may be incorrect') 102 | 103 | def load_ssl_official_weights(self, progress=True): 104 | raise NotImplemented 105 | 106 | def _forward_impl(self, x: Tensor) -> Tensor: 107 | # See note [TorchScript super()] 108 | x = self.conv1(x) 109 | x = self.bn1(x) 110 | x = self.relu(x) 111 | x = self.maxpool(x) 112 | 113 | x = self.layer1(x) 114 | x = self.layer2(x) 115 | x = self.layer3(x) 116 | x = self.layer4(x) 117 | 118 | x = self.avgpool(x) 119 | x = torch.flatten(x, 1) 120 | # x = self.fc(x) 121 | 122 | return x 123 | 124 | # finetune according to the stethoscope types (code from RespireNet) 125 | # def finetune(self, block_layer=5): 126 | # for idx, child in enumerate(self.children()): 127 | # if idx > block_layer: 128 | # break 129 | # for param in child.parameters(): 130 | # param.requires_grad = False 131 | 132 | 133 | class ResNet50(torchvision.models.resnet.ResNet): 134 | def __init__(self, track_bn=True): 135 | def norm_layer(*args, **kwargs): 136 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn) 137 | super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], norm_layer=norm_layer) 138 | del self.fc 139 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 140 | self.final_feat_dim = 2048 141 | 142 | def load_sl_official_weights(self, progress=True): 143 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet50'], 144 | progress=progress) 145 | 146 | del state_dict['conv1.weight'] 147 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 148 | # if len(missing) > 0: 149 | # raise AssertionError('Model code may be incorrect') 150 | 151 | def load_ssl_official_weights(self, progress=True): 152 | # only SimCLR is available 153 | from pl_bolts.models.self_supervised import SimCLR 154 | 155 | weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' 156 | simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) 157 | 158 | state_dict = {} 159 | for k, v in simclr.state_dict().items(): 160 | if 'encoder.' in k: 161 | k = k.replace('encoder.', '') 162 | if 'fc' not in k or 'project' not in k: 163 | state_dict[k] = v 164 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 165 | # non_linear_evaluator.block_forward is a pretrained MLP classifier for SimCLR 166 | # refer to https://github.com/Lightning-AI/lightning-bolts/blob/bcbbf6ab6c36430946dd8a416ddc7e697e8507fc/pl_bolts/models/self_supervised/evaluator.py#L7 167 | if len(missing) > 0: 168 | raise AssertionError('Model code may be incorrect') 169 | 170 | def _forward_impl(self, x: Tensor) -> Tensor: 171 | # See note [TorchScript super()] 172 | x = self.conv1(x) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | x = self.maxpool(x) 176 | 177 | x = self.layer1(x) 178 | x = self.layer2(x) 179 | x = self.layer3(x) 180 | x = self.layer4(x) 181 | 182 | x = self.avgpool(x) 183 | x = torch.flatten(x, 1) 184 | # x = self.fc(x) 185 | 186 | return x 187 | 188 | 189 | class ResNet101(torchvision.models.resnet.ResNet): 190 | def __init__(self, track_bn=True): 191 | def norm_layer(*args, **kwargs): 192 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn) 193 | super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], norm_layer=norm_layer) 194 | del self.fc 195 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 196 | self.final_feat_dim = 2048 197 | 198 | def load_sl_official_weights(self, progress=True): 199 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet101'], 200 | progress=progress) 201 | 202 | del state_dict['conv1.weight'] 203 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 204 | # if len(missing) > 0: 205 | # raise AssertionError('Model code may be incorrect') 206 | 207 | def load_ssl_official_weights(self, progress=True): 208 | raise NotImplemented 209 | 210 | def _forward_impl(self, x: Tensor) -> Tensor: 211 | # See note [TorchScript super()] 212 | x = self.conv1(x) 213 | x = self.bn1(x) 214 | x = self.relu(x) 215 | x = self.maxpool(x) 216 | 217 | x = self.layer1(x) 218 | x = self.layer2(x) 219 | x = self.layer3(x) 220 | x = self.layer4(x) 221 | 222 | x = self.avgpool(x) 223 | x = torch.flatten(x, 1) 224 | # x = self.fc(x) 225 | 226 | return x 227 | -------------------------------------------------------------------------------- /models/ssast.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import sys 4 | from timm.models.layers import trunc_normal_ 5 | import timm 6 | import numpy as np 7 | import os 8 | import wget 9 | from timm.models.layers import to_2tuple 10 | from random import randrange 11 | from matplotlib import pyplot as plt 12 | import random 13 | 14 | 15 | # override the timm package to relax the input shape constraint. 16 | class PatchEmbed(nn.Module): 17 | """ Image to Patch Embedding 18 | """ 19 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 20 | super().__init__() 21 | img_size = to_2tuple(img_size) 22 | patch_size = to_2tuple(patch_size) 23 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 24 | self.img_size = img_size 25 | self.patch_size = patch_size 26 | self.num_patches = num_patches 27 | 28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 29 | 30 | def forward(self, x): 31 | x = self.proj(x).flatten(2).transpose(1, 2) 32 | return x 33 | 34 | 35 | def get_sinusoid_encoding(n_position, d_hid): 36 | ''' Sinusoid position encoding table ''' 37 | 38 | def get_position_angle_vec(position): 39 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 40 | 41 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 42 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 43 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 44 | 45 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 46 | 47 | 48 | class SSASTModel(nn.Module): 49 | def __init__(self, label_dim=527, fshape=128, tshape=2, fstride=128, tstride=2, 50 | input_fdim=128, input_tdim=1024, model_size='base', task='ft_avgtok', 51 | pretrain_stage=True, load_pretrained_mdl_path=None, mix_beta=None): 52 | 53 | super(SSASTModel, self).__init__() 54 | assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.' 55 | 56 | # override timm input shape restriction 57 | timm.models.vision_transformer.PatchEmbed = PatchEmbed 58 | 59 | self.task = task 60 | 61 | # pretrain the AST models 62 | if pretrain_stage == True: 63 | if load_pretrained_mdl_path != None: 64 | raise ValueError('Setting load_pretrained_mdl_path at pretraining stage is useless, pretraining is always from scratch, please change it to None.') 65 | if fstride != fshape or tstride != tshape: 66 | raise ValueError('fstride != fshape or tstride != tshape, they must be same at the pretraining stage, patch split overlapping is not supported.') 67 | 68 | # if AudioSet pretraining is not used (but ImageNet pretraining may still apply) 69 | if model_size == 'tiny': 70 | self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=False) 71 | self.heads, self.depth = 3, 12 72 | self.cls_token_num = 2 73 | elif model_size == 'small': 74 | self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=False) 75 | self.heads, self.depth = 6, 12 76 | self.cls_token_num = 2 77 | elif model_size == 'base': 78 | self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=False) 79 | self.heads, self.depth = 12, 12 80 | self.cls_token_num = 2 81 | elif model_size == 'base_nokd': 82 | self.v = timm.create_model('vit_deit_base_patch16_384', pretrained=False) 83 | self.heads, self.depth = 12, 12 84 | self.cls_token_num = 1 85 | else: 86 | raise Exception('Model size must be one of tiny, small, base, base_nokd') 87 | 88 | self.original_num_patches = self.v.patch_embed.num_patches 89 | self.oringal_hw = int(self.original_num_patches ** 0.5) 90 | self.original_embedding_dim = self.v.pos_embed.shape[2] 91 | 92 | # SSL Pretraining Code 93 | self.softmax = nn.Softmax(dim=-1) 94 | self.lsoftmax = nn.LogSoftmax(dim=-1) 95 | self.fshape, self.tshape = fshape, tshape 96 | self.fstride, self.tstride = fstride, tstride 97 | self.input_fdim, self.input_tdim = input_fdim, input_tdim 98 | # this is a trick to make state_dict to track pretraining input_fdim and input_tdim and save them by using torch.save 99 | self.p_input_fdim, self.p_input_tdim = nn.Parameter(torch.tensor(input_fdim), requires_grad=False), nn.Parameter(torch.tensor(input_tdim), requires_grad=False) 100 | 101 | # masked patch classification (discriminative objective) layer 102 | # we use two layers for pretext task, but using a single layer has similar performance. 103 | # we map the output of transformer (768-dim for base models) to 256-dim patch input space, and then dot product with flattened patch input (also 256-dim) to calculate loss. 104 | # alternatively, you can map the output of transformer to 768-dim patch embedding space, and dot product with patch embedding. Performance-wise they are similar, but map to 256 space is more efficient. 105 | self.cpredlayer = nn.Sequential(nn.Linear(self.original_embedding_dim, self.original_embedding_dim), nn.ReLU(), nn.Linear(self.original_embedding_dim, 256)) 106 | # masked patch reconstruction (generative objective) layer 107 | self.gpredlayer = nn.Sequential(nn.Linear(self.original_embedding_dim, self.original_embedding_dim), nn.ReLU(), nn.Linear(self.original_embedding_dim, 256)) 108 | self.unfold = torch.nn.Unfold(kernel_size=(fshape, tshape), stride=(fstride, tstride)) 109 | 110 | # we use learnable mask embedding (follow the BEIT paper), but using a fixed mask embedding (e.g., 0) leads to same performance. 111 | self.mask_embed = nn.Parameter(torch.zeros([1, 1, self.original_embedding_dim])) 112 | self.mask_embed = torch.nn.init.xavier_normal_(self.mask_embed) 113 | 114 | # get the intermediate shape 115 | self.p_f_dim, self.p_t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape) 116 | num_patches = self.p_f_dim * self.p_t_dim 117 | self.num_patches = num_patches 118 | self.v.patch_embed.num_patches = num_patches 119 | print('pretraining patch split stride: frequency={:d}, time={:d}'.format(fstride, tstride)) 120 | print('pretraining patch shape: frequency={:d}, time={:d}'.format(fshape, tshape)) 121 | print('pretraining patch array dimension: frequency={:d}, time={:d}'.format(self.p_f_dim, self.p_t_dim)) 122 | print('pretraining number of patches={:d}'.format(num_patches)) 123 | 124 | # the linear patch projection layer, use 1 channel for spectrogram rather than the original 3 channels for RGB images. 125 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride)) 126 | self.v.patch_embed.proj = new_proj 127 | 128 | # use trainable positional embedding 129 | new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + self.cls_token_num, self.original_embedding_dim)) 130 | self.v.pos_embed = new_pos_embed 131 | trunc_normal_(self.v.pos_embed, std=.02) 132 | 133 | # use a pretrained models for finetuning 134 | elif pretrain_stage == False: 135 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 136 | if load_pretrained_mdl_path == None: 137 | raise ValueError('Please set load_pretrained_mdl_path to load a pretrained models.') 138 | 139 | # added kaen2891 140 | out_dir = './pretrained_models/' 141 | if not os.path.exists(out_dir): 142 | os.makedirs(out_dir, exist_ok=True) 143 | 144 | if load_pretrained_mdl_path == 'Patch': 145 | if os.path.exists(os.path.join(out_dir, 'SSAST-Base-Patch-400.pth')) == False: 146 | # this model performs 59.9 on the Avg Audio Performance 147 | # more details are below: https://github.com/YuanGongND/ssast 148 | audioset_mdl_url = 'https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Base-Patch-400.pth?dl=1' 149 | print('Download SSAST-Base-Patch-400.pth \n') 150 | wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'SSAST-Base-Patch-400.pth')) 151 | sd = torch.load(os.path.join(out_dir, 'SSAST-Base-Patch-400.pth'), map_location=device) 152 | print('Loaded SSAST-Base-Patch-400.pth successfully.') 153 | 154 | elif load_pretrained_mdl_path == 'Frame': 155 | if os.path.exists(os.path.join(out_dir, 'SSAST-Base-Frame-400.pth')) == False: 156 | # this model performs 57.6 on the Avg Audio Performance 157 | # more details are below: https://github.com/YuanGongND/ssast 158 | audioset_mdl_url = 'https://www.dropbox.com/s/nx6nl4d4bl71sm8/SSAST-Base-Frame-400.pth?dl=1' 159 | print('Download SSAST-Base-Frame-400.pth \n') 160 | wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'SSAST-Base-Frame-400.pth')) 161 | sd = torch.load(os.path.join(out_dir, 'SSAST-Base-Frame-400.pth'), map_location=device) 162 | print('\n Loaded SSAST-Base-Frame-400.pth successfully.') 163 | 164 | # get the fshape and tshape, input_fdim and input_tdim in the pretraining stage 165 | try: 166 | p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3] 167 | p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item() 168 | except: 169 | raise ValueError('The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.') 170 | 171 | print('now load a SSL pretrained models from ' + load_pretrained_mdl_path) 172 | # during pretraining, fstride=fshape and tstride=tshape because no patch overlapping is used 173 | # here, input_fdim and input_tdim should be that used in pretraining, not that in the fine-tuning. 174 | # we need to know input_fdim and input_tdim to do positional embedding cut/interpolation. 175 | # generally it should be better to use same input_fdim during pretraining and finetuning, but input_tdim can be safely different 176 | audio_model = SSASTModel(fstride=p_fshape, tstride=p_tshape, fshape=p_fshape, tshape=p_tshape, 177 | input_fdim=p_input_fdim, input_tdim=p_input_tdim, pretrain_stage=True, model_size=model_size) 178 | audio_model = torch.nn.DataParallel(audio_model) 179 | audio_model.load_state_dict(sd, strict=False) 180 | 181 | self.v = audio_model.module.v 182 | self.final_feat_dim = self.original_embedding_dim = self.v.pos_embed.shape[2] 183 | self.mix_beta = mix_beta 184 | self.cls_token_num = audio_model.module.cls_token_num 185 | 186 | # mlp head for fine-tuning 187 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), 188 | nn.Linear(self.original_embedding_dim, label_dim)) 189 | 190 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape) 191 | # patch array dimension during pretraining 192 | p_f_dim, p_t_dim = audio_model.module.p_f_dim, audio_model.module.p_t_dim 193 | num_patches = f_dim * t_dim 194 | p_num_patches = p_f_dim * p_t_dim 195 | self.v.patch_embed.num_patches = num_patches 196 | print('fine-tuning patch split stride: frequncey={:d}, time={:d}'.format(fstride, tstride)) 197 | print('fine-tuning number of patches={:d}'.format(num_patches)) 198 | 199 | # patch shape should be same for pretraining and fine-tuning 200 | if fshape != p_fshape or tshape != p_tshape: 201 | raise ValueError('The patch shape of pretraining and fine-tuning is not consistant, pretraining: f={:d}, t={:d}, finetuning: f={:d}, t={:d}'.format(p_fshape, p_tshape, fshape, tshape)) 202 | 203 | # patch split stride generally should be different for pretraining and fine-tuning, as patch split overlapping is only used in finetuning 204 | # during pretraining, p_fshape = p_fstride and p_tshape = p_tstride 205 | if fstride != p_fshape or tstride != p_tshape: 206 | # initialize a new patch embedding layer with desired new stride. 207 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride)) 208 | # but the weights of patch embedding layer is still got from the pretrained models 209 | new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1)) 210 | new_proj.bias = self.v.patch_embed.proj.bias 211 | self.v.patch_embed.proj = new_proj 212 | 213 | new_pos_embed = self.v.pos_embed[:, self.cls_token_num:, :].detach().reshape(1, p_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, p_f_dim, p_t_dim) 214 | # cut or interpolate the positional embedding 215 | if t_dim < p_t_dim: 216 | new_pos_embed = new_pos_embed[:, :, :, int(p_t_dim/2) - int(t_dim / 2): int(p_t_dim/2) - int(t_dim / 2) + t_dim] 217 | else: 218 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(8, t_dim), mode='bilinear') 219 | if f_dim < p_f_dim: 220 | new_pos_embed = new_pos_embed[:, :, int(p_f_dim/2) - int(f_dim / 2): int(p_f_dim/2) - int(f_dim / 2) + t_dim, :] 221 | else: 222 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') 223 | 224 | new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1, 2) 225 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :self.cls_token_num, :].detach(), new_pos_embed], dim=1)) 226 | 227 | # get the shape of intermediate representation. 228 | def get_shape(self, fstride, tstride, input_fdim, input_tdim, fshape, tshape): 229 | test_input = torch.randn(1, 1, input_fdim, input_tdim) 230 | test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride)) 231 | test_out = test_proj(test_input) 232 | f_dim = test_out.shape[2] 233 | t_dim = test_out.shape[3] 234 | return f_dim, t_dim 235 | 236 | # generate mask for 16*16 patch 237 | def gen_maskid_patch(self, sequence_len=512, mask_size=100, cluster=3): 238 | mask_id = [] 239 | 240 | # randomize clutering factor in [3,6) 241 | cur_clus = randrange(cluster) + 3 242 | 243 | while len(list(set(mask_id))) <= mask_size: 244 | start_id = randrange(sequence_len) 245 | 246 | # this improves the efficiency, but might change the pretrained model 247 | # while start_id in mask_id: 248 | # start_id = randrange(sequence_len) 249 | 250 | cur_mask = [] 251 | for i in range(0, cur_clus): 252 | for j in range(0, cur_clus): 253 | mask_cand = start_id + self.p_t_dim * i + j 254 | if mask_cand > 0 and mask_cand < sequence_len: 255 | cur_mask.append(mask_cand) 256 | mask_id = mask_id + cur_mask 257 | mask_id = list(set(mask_id))[:mask_size] 258 | return torch.tensor(mask_id) 259 | 260 | # using cluster for frame masking hurts the performance, so just use the naive random sampling 261 | def gen_maskid_frame(self, sequence_len=512, mask_size=100): 262 | mask_id = random.sample(range(0, sequence_len), mask_size) 263 | return torch.tensor(mask_id) 264 | 265 | def finetuningavgtok(self, x, y=None, patch_mix=False): 266 | B = x.shape[0] 267 | x = self.v.patch_embed(x) 268 | 269 | if patch_mix: 270 | x, y_a, y_b, lam, index = self.patch_mix(x, y) 271 | 272 | if self.cls_token_num == 2: 273 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 274 | dist_token = self.v.dist_token.expand(B, -1, -1) 275 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 276 | else: 277 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 278 | x = torch.cat((cls_tokens, x), dim=1) 279 | 280 | x = x + self.v.pos_embed 281 | x = self.v.pos_drop(x) 282 | 283 | for blk_id, blk in enumerate(self.v.blocks): 284 | x = blk(x) 285 | x = self.v.norm(x) 286 | 287 | # average output of all tokens except cls token(s) 288 | x = torch.mean(x[:, self.cls_token_num:, :], dim=1) 289 | # x = self.mlp_head(x) 290 | 291 | if not patch_mix: 292 | return x 293 | else: 294 | return x, y_a, y_b, lam, index 295 | 296 | def finetuningcls(self, x): 297 | B = x.shape[0] 298 | x = self.v.patch_embed(x) 299 | if self.cls_token_num == 2: 300 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 301 | dist_token = self.v.dist_token.expand(B, -1, -1) 302 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 303 | else: 304 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 305 | x = torch.cat((cls_tokens, x), dim=1) 306 | x = x + self.v.pos_embed 307 | x = self.v.pos_drop(x) 308 | 309 | for blk_id, blk in enumerate(self.v.blocks): 310 | x = blk(x) 311 | x = self.v.norm(x) 312 | 313 | # if models has two cls tokens (DEIT), average as the clip-level representation 314 | if self.cls_token_num == 2: 315 | x = (x[:, 0] + x[:, 1]) / 2 316 | else: 317 | x = x[:, 0] 318 | # x = self.mlp_head(x) 319 | return x 320 | 321 | # masked patch pretraining with discriminative objective 322 | def mpc(self, x, mask_patch, cluster, show_mask=False): 323 | input = self.unfold(x).transpose(1, 2) 324 | B = x.shape[0] 325 | # x in shape (batch_size, sequence_len, embedding dim) 326 | x = self.v.patch_embed(x) 327 | 328 | # encode the patch 329 | # size 12(batch_size) * 100(#mask_patch) * 768(hidden_dim), prepare to save the true values of masked samples 330 | encode_samples = torch.empty((B, mask_patch, 256), device=x.device, requires_grad=False).float() 331 | # size 12(batch_size) * 100(#mask_patch), index of masked patches 332 | mask_index = torch.empty((B, mask_patch), device=x.device, requires_grad=False).long() 333 | # size 12(batch_size) * 512(sequence_len) * 768(hidden_dim) 334 | mask_dense = torch.ones([x.shape[0], x.shape[1], x.shape[2]], device=x.device) 335 | 336 | # for each audio clip in the batch 337 | for i in range(B): 338 | # randomly generate #mask_patch mask indexes without duplicate 339 | if cluster == True: 340 | # use this if you are masking e.g. 16*16 patches 341 | mask_index[i] = self.gen_maskid_patch(self.num_patches, mask_patch) 342 | else: 343 | # use this if you are masking frame, i.e., 128*2 patches 344 | mask_index[i] = self.gen_maskid_frame(self.num_patches, mask_patch) 345 | # copy the masked embeddings, note gradients are stopped in this path 346 | encode_samples[i] = input[i, mask_index[i], :].clone().detach() 347 | # mask the encode samples with 0 348 | mask_dense[i, mask_index[i], :] = 0 349 | 350 | # follow BEIT paper, mask with learnable masking embedding, but no performance diff observed compared with masking with 0s. 351 | mask_tokens = self.mask_embed.expand(B, x.shape[1], -1) 352 | 353 | # mask the patch 354 | x = x * mask_dense + (1-mask_dense) * mask_tokens 355 | 356 | # pass through the Transformer layers 357 | cls_tokens = self.v.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 358 | dist_token = self.v.dist_token.expand(B, -1, -1) 359 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 360 | x = x + self.v.pos_embed 361 | x = self.v.pos_drop(x) 362 | for blk in self.v.blocks: 363 | x = blk(x) 364 | x = self.v.norm(x) 365 | 366 | # prediction of the masked patch 367 | pred = torch.empty((B, mask_patch, 256), device=x.device).float() # e.g. size 12*100*768 368 | for i in range(B): 369 | # +2 for indexes because skipping the cls and dis token 370 | # we map the output of transformer (768-dim for base models) to 256-dim patch input space, and then dot product with flattened patch input (also 256-dim) to calculate loss. 371 | # alternatively, you can map the output of transformer to 768-dim patch embedding space, and dot product with patch embedding. Performance-wise they are similar, but map to 256 space is more efficient. 372 | pred[i] = self.cpredlayer(x[i, mask_index[i] + self.cls_token_num, :]) 373 | 374 | # calculate the NCE loss 375 | nce = torch.tensor(0.0).to(x.device) 376 | correct = torch.tensor(0.0).to(x.device) 377 | for i in np.arange(0, B): 378 | # negative samples are from the same batch 379 | # 8/12/2022: has a difference with equation (1) in the ssast paper but (likely) performance-wise similar, see https://github.com/YuanGongND/ssast/issues/13 380 | total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 100*100 381 | correct += torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=0), torch.arange(0, mask_patch, device=x.device))) # correct is a tensor 382 | nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor 383 | acc = 1. * correct / (B * mask_patch) 384 | nce = nce / (-1. * B * mask_patch) 385 | 386 | # visualize the masked area, for probing test only, set show_mask = False for any training/inference. 387 | if show_mask == False: 388 | return acc, nce 389 | else: 390 | if B > 1: 391 | raise Exception('Currently only support single spectrogram probing test.') 392 | 393 | self.mask_correct = torch.nn.Parameter(torch.arange(0, mask_patch), requires_grad=False) 394 | 395 | pred = input.clone() # [B, 512, 256] 396 | masked = input.clone() 397 | 398 | for i in range(B): 399 | result = [float(t) * 99 for t in torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct)] 400 | pred[i, mask_index[i], :] = torch.tensor(result).reshape(mask_patch, 1).expand(mask_patch, 256) 401 | masked[i, mask_index[i], :] = 99.0 402 | 403 | # print(total) 404 | # print(self.softmax(total)) 405 | # print(torch.argmax(self.softmax(total), dim=0)) 406 | # print(self.mask_correct) 407 | # print(torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct)) 408 | # print([float(t)*99 for t in torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct)]) 409 | 410 | fold = torch.nn.Fold(output_size=([self.input_fdim, self.input_tdim]), kernel_size=(self.fshape, self.tshape), stride=(self.fstride, self.tstride)) 411 | pred = fold(pred.transpose(1, 2)) 412 | masked = fold(masked.transpose(1, 2)) 413 | 414 | return pred, masked 415 | 416 | # # masked patch pretraining with generative objective 417 | def mpg(self, input, mask_patch, cluster): 418 | B = input.shape[0] 419 | x = self.v.patch_embed(input) 420 | input = self.unfold(input).transpose(1, 2) 421 | 422 | # size 12(batch_size) * 100(#mask_patch), index of masked patches 423 | mask_index = torch.empty((B, mask_patch), device=x.device, requires_grad=False).long() 424 | # size 12(batch_size) * 512(sequence_len) * 768(hidden_dim) 425 | mask_dense = torch.ones([x.shape[0], x.shape[1], x.shape[2]], device=x.device) 426 | for i in range(B): 427 | # randomly generate #mask_patch mask indexes without duplicate 428 | if cluster == True: 429 | # use this if you are masking e.g. 16*16 patches 430 | mask_index[i] = self.gen_maskid_patch(self.num_patches, mask_patch) 431 | else: 432 | # use this if you are masking frame, i.e., 128*2 patches 433 | mask_index[i] = self.gen_maskid_frame(self.num_patches, mask_patch) 434 | mask_dense[i, mask_index[i], :] = 0 435 | 436 | mask_tokens = self.mask_embed.expand(B, x.shape[1], -1) 437 | 438 | # follow BEIT paper, mask with learnable masking embedding, but no performance diff observed compared with masking with 0s. 439 | x = x * mask_dense + (1-mask_dense) * mask_tokens 440 | 441 | # go through the Transformer layers 442 | cls_tokens = self.v.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 443 | dist_token = self.v.dist_token.expand(B, -1, -1) 444 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 445 | x = x + self.v.pos_embed 446 | x = self.v.pos_drop(x) 447 | for blk in self.v.blocks: 448 | x = blk(x) 449 | x = self.v.norm(x) 450 | 451 | pred = torch.empty((B, mask_patch, self.fshape * self.tshape), device=x.device).float() # e.g. size 12*100*256 452 | target = torch.empty((B, mask_patch, self.fshape * self.tshape), device=x.device).float() # e.g. size 12*100*256 453 | 454 | for i in range(B): 455 | # +2 for indexes because cls and dis token 456 | pred[i] = self.gpredlayer(x[i, mask_index[i] + self.cls_token_num, :]) 457 | target[i] = input[i, mask_index[i], :] 458 | 459 | # calculate the MSE loss 460 | mse = torch.mean((pred - target) ** 2) 461 | 462 | return mse 463 | 464 | def patch_mix(self, image, target): 465 | if self.mix_beta > 0: 466 | lam = np.random.beta(self.mix_beta, self.mix_beta) 467 | else: 468 | lam = 1 469 | 470 | batch_size, num_patch, dim = image.size() 471 | device = image.device 472 | 473 | index = torch.randperm(batch_size).to(device) 474 | 475 | num_mask = int(num_patch * (1. - lam)) 476 | mask = torch.randperm(num_patch)[:num_mask].to(device) 477 | 478 | image[:, mask, :] = image[index][:, mask, :] 479 | lam = 1 - (num_mask / num_patch) 480 | 481 | y_a, y_b = target, target[index] 482 | return image, y_a, y_b, lam, index 483 | 484 | def forward(self, x, y=None, patch_mix=False): 485 | cluster = True 486 | mask_patch = 400 487 | 488 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 489 | # x = x.unsqueeze(1) # (12, 1, 1024, 128) 490 | x = x.transpose(2, 3) # (12, 1, 128, 1024) 491 | 492 | # finetuning (ft), use the mean of all token (patch) output as clip-level representation. 493 | # this is default for SSAST fine-tuning as during pretraining, supervision signal is given to each token, not the [cls] token 494 | if self.task == 'ft_avgtok': 495 | return self.finetuningavgtok(x, y, patch_mix) 496 | # alternatively, use the [cls] token output as clip-level representation. 497 | elif self.task == 'ft_cls': 498 | return self.finetuningcls(x) 499 | # pretraining, masked patch classification (discriminative objective) 500 | elif self.task == 'pretrain_mpc': 501 | return self.mpc(x, mask_patch=mask_patch, cluster=cluster) 502 | # pretraining, masked patch reconstruction (generative objective) 503 | elif self.task == 'pretrain_mpg': 504 | return self.mpg(x, mask_patch=mask_patch, cluster=cluster) 505 | elif self.task == 'visualize_mask': 506 | return self.mpc(x, mask_patch=mask_patch, cluster=cluster, show_mask=True) 507 | else: 508 | raise Exception('Task unrecognized.') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | tqdm 3 | librosa 4 | pandas 5 | nlpaug 6 | cmapy 7 | audiomentations 8 | timm==0.4.5 9 | wget 10 | cmake==3.18.4 11 | MulticoreTSNE 12 | seaborn 13 | numpy -------------------------------------------------------------------------------- /scripts/icbhi_ce.sh: -------------------------------------------------------------------------------- 1 | 2 | MODEL="ast" 3 | SEED="1" 4 | 5 | for s in $SEED 6 | do 7 | for m in $MODEL 8 | do 9 | TAG="bs8_lr5e-5_ep50_seed${s}" 10 | CUDA_VISIBLE_DEVICES=0 python main.py --tag $TAG \ 11 | --dataset icbhi \ 12 | --seed $s \ 13 | --class_split lungsound \ 14 | --n_cls 4 \ 15 | --epochs 50 \ 16 | --batch_size 8 \ 17 | --optimizer adam \ 18 | --learning_rate 5e-5 \ 19 | --weight_decay 1e-6 \ 20 | --weighted_loss \ 21 | --cosine \ 22 | --model $m \ 23 | --test_fold official \ 24 | --pad_types repeat \ 25 | --resz 1 \ 26 | --n_mels 128 \ 27 | --ma_update \ 28 | --ma_beta 0.5 \ 29 | --from_sl_official \ 30 | --audioset_pretrained \ 31 | --method ce 32 | 33 | # only for evaluation, add the following arguments 34 | # --eval \ 35 | # --pretrained \ 36 | # --pretrained_ckpt ./save/icbhi_ast_ce_bs8_lr5e-5_ep50_seed1/best.pth 37 | 38 | done 39 | done 40 | -------------------------------------------------------------------------------- /scripts/icbhi_patchmix_ce.sh: -------------------------------------------------------------------------------- 1 | 2 | MODEL="ast" 3 | SEED="1" 4 | 5 | for s in $SEED 6 | do 7 | for m in $MODEL 8 | do 9 | TAG="bs8_lr5e-5_ep50_seed${s}" 10 | CUDA_VISIBLE_DEVICES=0 python main.py --tag $TAG \ 11 | --dataset icbhi \ 12 | --seed $s \ 13 | --class_split lungsound \ 14 | --n_cls 4 \ 15 | --epochs 50 \ 16 | --batch_size 8 \ 17 | --optimizer adam \ 18 | --learning_rate 5e-5 \ 19 | --weight_decay 1e-6 \ 20 | --cosine \ 21 | --model $m \ 22 | --test_fold official \ 23 | --pad_types repeat \ 24 | --resz 1 \ 25 | --n_mels 128 \ 26 | --ma_update \ 27 | --ma_beta 0.5 \ 28 | --from_sl_official \ 29 | --audioset_pretrained \ 30 | --method patchmix \ 31 | --mix_beta 1.0 32 | 33 | # only for evaluation, add the following arguments 34 | # --eval \ 35 | # --pretrained \ 36 | # --pretrained_ckpt ./save/icbhi_ast_patchmix_bs8_lr5e-5_ep50_seed1/best.pth 37 | 38 | done 39 | done 40 | 41 | -------------------------------------------------------------------------------- /scripts/icbhi_patchmix_cl.sh: -------------------------------------------------------------------------------- 1 | 2 | MODEL="ast" 3 | SEED="1" 4 | 5 | for s in $SEED 6 | do 7 | for m in $MODEL 8 | do 9 | TAG="bs8_lr5e-5_ep50_seed${s}_best_param" 10 | CUDA_VISIBLE_DEVICES=0 python main.py --tag $TAG \ 11 | --dataset icbhi \ 12 | --seed $s \ 13 | --class_split lungsound \ 14 | --n_cls 4 \ 15 | --epochs 50 \ 16 | --batch_size 8 \ 17 | --optimizer adam \ 18 | --learning_rate 5e-5 \ 19 | --weight_decay 1e-6 \ 20 | --cosine \ 21 | --model $m \ 22 | --test_fold official \ 23 | --pad_types repeat \ 24 | --resz 1 \ 25 | --n_mels 128 \ 26 | --ma_update \ 27 | --ma_beta 0.5 \ 28 | --from_sl_official \ 29 | --audioset_pretrained \ 30 | --method patchmix_cl \ 31 | --temperature 0.06 \ 32 | --proj_dim 768 \ 33 | --alpha 1.0 \ 34 | --mix_beta 1.0 35 | 36 | # only for evaluation, add the following arguments 37 | # --eval \ 38 | # --pretrained \ 39 | # --pretrained_ckpt ./save/icbhi_ast_patchmix_cl_bs8_lr5e-5_ep50_seed1_best_param/best.pth 40 | 41 | done 42 | done 43 | 44 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raymin0223/patch-mix_contrastive_learning/836b09fea1b70eb29fe0b25afa481286b56f5104/util/__init__.py -------------------------------------------------------------------------------- /util/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import nlpaug.augmenter.audio as naa 4 | 5 | import torch 6 | from torchvision.utils import _log_api_usage_once 7 | from torchvision.transforms import transforms 8 | from torchaudio import transforms as T 9 | from .time_warping import sparse_image_warp 10 | 11 | __all__ = ['augment_raw_audio', 'SpecAugment'] 12 | 13 | 14 | def augment_raw_audio(sample, sample_rate, args): 15 | """ 16 | Raw audio data augmentation technique 17 | you can utilize any library code 18 | 1) nlpaug 19 | 2) audiomentations 20 | 3) librosa 21 | """ 22 | 23 | """ 1) nlpaug """ 24 | augment_list = [ 25 | # naa.CropAug(sampling_rate=sample_rate) 26 | naa.NoiseAug(), # apply noise injection operation 27 | naa.SpeedAug(), # apply speed adjustment operation 28 | naa.LoudnessAug(factor=(0.5, 2)), # apply adjusting loudness operation 29 | naa.VtlpAug(sampling_rate=sample_rate, zone=(0.0, 1.0)), # apply vocal tract length perturbation (VTLP) operation 30 | naa.PitchAug(sampling_rate=sample_rate, factor=(-1,3)) # apply pitch adjustment operation 31 | ] 32 | 33 | # randomly sample augmentation 34 | aug_idx = random.randint(0, len(augment_list)-1) 35 | sample = augment_list[aug_idx].augment(sample) 36 | 37 | # apply all augmentations 38 | # for aug_idx in range(len(augment_list)): 39 | # sample = augment_list[aug_idx].augment(sample) 40 | 41 | """ 2) audiomentations """ 42 | # import audiomentations 43 | # from audiomentations import AddGaussianSNR, TimeStretch, PitchShift, Shift 44 | 45 | # # when using audiomentations library (not DEBUG yet) 46 | # audio_transforms = audiomentations.Compose([ 47 | # AddGaussianSNR(min_snr_in_db=5, max_snr_in_db=40.0, p=0.5), 48 | # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5), 49 | # PitchShift(min_semitones=-4, max_semitones=4, p=0.5), 50 | # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5), 51 | # ]) 52 | 53 | # sample = audio_transforms(samples=sample, sample_rate=sample_rate) 54 | 55 | """ 3) librosa """ 56 | # import librosa 57 | 58 | # def _noise(data): 59 | # noise_amp = 0.035 * np.random.uniform() * np.amax(data) 60 | # data = data + noise_amp * np.random.normal(size=data.shape[0]) 61 | # return data 62 | 63 | # def _stretch(data, rate=0.8): 64 | # return librosa.effects.time_stretch(data, rate) 65 | 66 | # def _shift(data): 67 | # shift_range = int(np.random.uniform(low=-5, high=5) * 1000) 68 | # return np.roll(data, shift_range) 69 | 70 | # def _pitch(data, sampling_rate, pitch_factor=0.7): 71 | # return librosa.effects.pitch_shift(data, sampling_rate, pitch_factor) 72 | 73 | # sample = _noise(sample) 74 | # sample = _stretch(sample) 75 | # sample = _shift(sample) 76 | # sample = _pitch(sample, sample_rate) 77 | 78 | if type(sample) == list: 79 | return sample[0] 80 | else: 81 | return sample 82 | 83 | 84 | # Use this Class when you load dataset with librosa 85 | class SpecAugment(torch.nn.Module): 86 | ''' 87 | Unofficial Implementation of SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition 88 | Paper: https://arxiv.org/pdf/1904.08779.pdf 89 | Ref. github: https://github.com/pyyush/SpecAugment/blob/219fc6e9ed4838fe9700295700040b1da283c536/augment.py#L10 90 | 91 | Augmentation Parameters for policies 92 | ----------------------------------------- 93 | Policy | W | F | m_F | T | p | m_T 94 | ----------------------------------------- 95 | None | 0 | 0 | - | 0 | - | - 96 | ----------------------------------------- 97 | LB | 80 | 27 | 1 | 100 | 1.0 | 1 98 | ----------------------------------------- 99 | LD | 80 | 27 | 2 | 100 | 1.0 | 2 100 | ----------------------------------------- 101 | SM | 40 | 15 | 2 | 70 | 0.2 | 2 102 | ----------------------------------------- 103 | SS | 40 | 27 | 2 | 70 | 0.2 | 2 104 | ----------------------------------------- 105 | 106 | LB : LibriSpeech basic 107 | LD : LibriSpeech double 108 | SM : Switchboard mild 109 | SS : Switchboard strong 110 | W : Time Warp parameter 111 | F : Frequency Mask parameter 112 | m_F : Number of Frequency masks 113 | T : Time Mask parameter 114 | p : Parameter for calculating upper bound for time mask 115 | m_T : Number of time masks 116 | ''' 117 | #def __init__(self, policy, zero_mean_normalized=False): 118 | def __init__(self, args): 119 | super().__init__() 120 | _log_api_usage_once(self) 121 | 122 | self.policy = args.specaug_policy 123 | self.mask = args.specaug_mask 124 | 125 | # Policy Specific Parameters 126 | if self.policy == 'LB': 127 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 80, 27, 1, 100, 1.0, 1 128 | elif self.policy == 'LD': 129 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 80, 27, 2, 100, 1.0, 2 130 | elif self.policy == 'SM': 131 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 40, 15, 2, 70, 0.2, 2 132 | elif self.policy == 'SS': 133 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 40, 27, 2, 70, 0.2, 2 134 | elif self.policy == 'icbhi_sup': 135 | # following https://github.com/ilyassmoummad/scl_icbhi2017 136 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 0, 20, 2, 50, 1.0, 2 137 | elif self.policy == 'icbhi_ast_sup': 138 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 0, 48, 2, 160, 1.0, 2 139 | 140 | # mimic SpecAugment by using torchaudio.transforms 141 | # self.spec_aug = torch.nn.Sequential( 142 | # T.TimeStretch(0.8, fixed_rate=True), 143 | # T.FrequencyMasking(freq_mask_param=self.F), 144 | # T.TimeMasking(time_mask_param=self.T) 145 | # ) 146 | 147 | def time_warp(self): 148 | """ Tensorflow version """ 149 | # v, tau = self.mel_spectrogram.shape[1], self.mel_spectrogram.shape[2] 150 | 151 | # horiz_line_thru_ctr = self.mel_spectrogram[0][v//2] 152 | 153 | # random_pt = horiz_line_thru_ctr[random.randrange(self.W, tau - self.W)] # random point along the horizontal/time axis 154 | # w = np.random.uniform((-self.W), self.W) # distance 155 | 156 | # src_points = [[[v//2, random_pt[0]]]] # Source Points 157 | # dest_points = [[[v//2, random_pt[0] + w]]] # Destination Points 158 | # self.mel_spectrogram, _ = sparse_image_warp(self.mel_spectrogram, src_points, dest_points, num_boundary_points=2) 159 | # self.mel_spectrogram = self.mel_spectrogram.numpy() 160 | 161 | """ Pytorch version """ 162 | # refer to https://github.com/zcaceres/spec_augment/blob/master/SpecAugment.ipynb 163 | num_rows = self.mel_spectrogram.shape[2] 164 | spec_len = self.mel_spectrogram.shape[1] 165 | device = self.mel_spectrogram.device 166 | 167 | # adapted from https://github.com/DemisEom/SpecAugment/ 168 | pt = (num_rows - 2 * self.W) * torch.rand([1], dtype=torch.float) + self.W # random point along the time axis 169 | src_ctr_pt_freq = torch.arange(0, spec_len // 2) # control points on freq-axis 170 | src_ctr_pt_time = torch.ones_like(src_ctr_pt_freq) * pt # control points on time-axis 171 | src_ctr_pts = torch.stack((src_ctr_pt_freq, src_ctr_pt_time), dim=-1) 172 | src_ctr_pts = src_ctr_pts.float().to(device) 173 | 174 | # Destination 175 | w = 2 * self.W * torch.rand([1], dtype=torch.float) - self.W # distance 176 | dest_ctr_pt_freq = src_ctr_pt_freq 177 | dest_ctr_pt_time = src_ctr_pt_time + w 178 | dest_ctr_pts = torch.stack((dest_ctr_pt_freq, dest_ctr_pt_time), dim=-1) 179 | dest_ctr_pts = dest_ctr_pts.float().to(device) 180 | 181 | # warp 182 | source_control_point_locations = torch.unsqueeze(src_ctr_pts, 0) # (1, v//2, 2) 183 | dest_control_point_locations = torch.unsqueeze(dest_ctr_pts, 0) # (1, v//2, 2) 184 | warped_spectro, dense_flows = sparse_image_warp(self.mel_spectrogram, source_control_point_locations, dest_control_point_locations) 185 | 186 | return warped_spectro.squeeze(3) 187 | 188 | def freq_mask(self): 189 | if self.mask == 'mean': 190 | # maksing to mean value 191 | mask_value = self.mel_spectrogram.mean() 192 | elif self.mask == 'zero': 193 | # maksing to zero value 194 | mask_value = 0. 195 | 196 | v = self.mel_spectrogram.shape[1] # no. of mel bins 197 | 198 | # apply m_F frequency masks to the mel spectrogram 199 | for i in range(self.m_F): 200 | f = int(np.random.uniform(0, self.F)) # [0, F) 201 | f0 = random.randint(0, v - f) # [0, v - f) 202 | self.mel_spectrogram[:, f0:f0 + f, :] = mask_value 203 | 204 | return self.mel_spectrogram 205 | 206 | def time_mask(self): 207 | if self.mask == 'mean': 208 | # maksing to mean value 209 | mask_value = self.mel_spectrogram.mean() 210 | elif self.mask == 'zero': 211 | # maksing to zero value 212 | mask_value = 0. 213 | 214 | tau = self.mel_spectrogram.shape[2] # time frames 215 | 216 | # apply m_T time masks to the mel spectrogram 217 | for i in range(self.m_T): 218 | t = int(np.random.uniform(0, self.T)) # [0, T) 219 | t0 = random.randint(0, tau - t) # [0, tau - t) 220 | self.mel_spectrogram[:, :, t0:t0 + t] = mask_value 221 | 222 | return self.mel_spectrogram 223 | 224 | def forward(self, img): 225 | """ 226 | Args: 227 | img (Tensor): Mel-spectrogram to be specaugmented. 228 | Returns: 229 | Tensor: Time-warped, time masked and freq masked image. 230 | """ 231 | # self.mel_spectrogram = img # np.array [time, freq, channel] 232 | self.mel_spectrogram = img # torch.tensor [channel, time, freq] 233 | self.mel_spectrogram = self.mel_spectrogram.transpose(2, 1) # torch.tensor [channel, freq, time] 234 | 235 | if self.p >= torch.randn(1): 236 | if self.W: 237 | try: 238 | # input shape of time_warp should be [sample_size, time, freq] 239 | # assume that channel == 1 and augment each "one" sample 240 | self.mel_spectrogram= self.time_warp() 241 | except Exception as e: 242 | # torch.linalg.solve: (Batch element 0): The solver failed because the input matrix is singular. 243 | # print(e) 244 | pass 245 | 246 | self.mel_spectrogram = self.freq_mask() 247 | self.mel_spectrogram = self.time_mask() 248 | 249 | return self.mel_spectrogram.transpose(2, 1) 250 | 251 | def __repr__(self) -> str: 252 | return f"{self.__class__.__name__}()" -------------------------------------------------------------------------------- /util/icbhi_dataset.py: -------------------------------------------------------------------------------- 1 | from curses import meta 2 | import os 3 | import cv2 4 | import pickle 5 | import random 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | import librosa 11 | import torch 12 | from torch.utils.data import Dataset 13 | from copy import deepcopy 14 | from PIL import Image 15 | 16 | from .icbhi_util import get_annotations, save_image, generate_fbank, get_individual_cycles_librosa, split_pad_sample, generate_mel_spectrogram, concat_augmentation 17 | from .icbhi_util import get_individual_cycles_torchaudio, cut_pad_sample_torchaudio 18 | from .augmentation import augment_raw_audio 19 | 20 | 21 | class ICBHIDataset(Dataset): 22 | def __init__(self, train_flag, transform, args, print_flag=True, mean_std=False): 23 | data_folder = os.path.join(args.data_folder, 'icbhi_dataset/audio_test_data') 24 | folds_file = os.path.join(args.data_folder, 'icbhi_dataset/patient_list_foldwise.txt') 25 | official_folds_file = os.path.join(args.data_folder, 'icbhi_dataset/official_split.txt') 26 | test_fold = args.test_fold 27 | 28 | self.data_folder = data_folder 29 | self.train_flag = train_flag 30 | self.split = 'train' if train_flag else 'test' 31 | self.transform = transform 32 | self.args = args 33 | self.mean_std = mean_std 34 | 35 | # parameters for spectrograms 36 | self.sample_rate = args.sample_rate 37 | self.desired_length = args.desired_length 38 | self.pad_types = args.pad_types 39 | self.nfft = args.nfft 40 | self.hop = self.nfft // 2 41 | self.n_mels = args.n_mels 42 | self.f_min = 50 43 | self.f_max = 2000 44 | self.dump_images = False 45 | 46 | # ========================================================================== 47 | """ get ICBHI dataset meta information """ 48 | # store stethoscope device information for each file or patient 49 | self.file_to_device = {} 50 | self.device_to_id = {'Meditron': 0, 'LittC2SE': 1, 'Litt3200': 2, 'AKGC417L': 3} 51 | self.device_id_to_patient = {0: [], 1: [], 2: [], 3: []} 52 | 53 | filenames = os.listdir(data_folder) 54 | filenames =set([f.strip().split('.')[0] for f in filenames if '.wav' in f or '.txt' in f]) 55 | for f in filenames: 56 | f += '.wav' 57 | # get the total number of devices from original dataset (icbhi dataset has 4 stethoscope devices) 58 | device = f.strip().split('_')[-1].split('.')[0] 59 | # if device not in self.device_to_id: 60 | # self.device_to_id[device] = device_id 61 | # self.device_id_to_patient[device_id] = [] 62 | # device_id += 1 63 | 64 | # get the device information for each wav file 65 | self.file_to_device[f.strip().split('.')[0]] = self.device_to_id[device] 66 | 67 | pat_id = f.strip().split('_')[0] 68 | if pat_id not in self.device_id_to_patient[self.device_to_id[device]]: 69 | self.device_id_to_patient[self.device_to_id[device]].append(pat_id) 70 | 71 | # store all metadata (age, sex, adult_BMI, child_weight, child_height, device_index) 72 | self.file_to_metadata = {} 73 | meta_file = pd.read_csv(os.path.join(args.data_folder, 'icbhi_dataset/metadata.txt'), names=['age', 'sex', 'adult_BMI', 'child_weight', 'child_height', 'chest_location'], delimiter= '\t') 74 | meta_file['chest_location'].replace({'Tc':0, 'Al':1, 'Ar':2, 'Pl':3, 'Pr':4, 'Ll':5, 'Lr':6}, inplace=True) 75 | for f in filenames: 76 | pat_idx = int(f.strip().split('_')[0]) 77 | info = list(meta_file.loc[pat_idx]) 78 | info[1] = 0 if info[1] == 'M' else 1 79 | 80 | info = np.array(info) 81 | for idx in np.argwhere(np.isnan(info)): 82 | info[idx] = -1 83 | 84 | self.file_to_metadata[f] = torch.tensor(np.append(info, self.file_to_device[f.strip()])) 85 | # ========================================================================== 86 | 87 | # ========================================================================== 88 | """ train-test split based on train_flag and test_fold """ 89 | if test_fold in ['0', '1', '2', '3', '4']: # from RespireNet, 80-20% split 90 | patient_dict = {} 91 | all_patients = open(folds_file).read().splitlines() 92 | for line in all_patients: 93 | idx, fold = line.strip().split(' ') 94 | if train_flag and int(fold) != int(test_fold): 95 | patient_dict[idx] = fold 96 | elif train_flag == False and int(fold) == int(test_fold): 97 | patient_dict[idx] = fold 98 | 99 | if print_flag: 100 | print('*' * 20) 101 | print('Train and test 80-20% split with test_fold {}'.format(test_fold)) 102 | print('Patience number in {} dataset: {}'.format(self.split, len(patient_dict))) 103 | else: 104 | """ 105 | args.test_fold == 'official', 60-40% split 106 | two patient dataset contain both train and test samples 107 | """ 108 | patient_dict = {} 109 | all_fpath = open(official_folds_file).read().splitlines() 110 | for line in all_fpath: 111 | fpath, fold = line.strip().split('\t') 112 | if train_flag and fold == 'train': 113 | # idx = fpath.strip().split('_')[0] 114 | patient_dict[fpath] = fold 115 | elif not train_flag and fold == 'test': 116 | # idx = fpath.strip().split('_')[0] 117 | patient_dict[fpath] = fold 118 | 119 | if print_flag: 120 | print('*' * 20) 121 | print('Train and test 60-40% split with test_fold {}'.format(test_fold)) 122 | print('File number in {} dataset: {}'.format(self.split, len(patient_dict))) 123 | # ========================================================================== 124 | 125 | # dict {filename: annotations}, annotation is for breathing cycle 126 | annotation_dict = get_annotations(args, data_folder) 127 | 128 | self.filenames = [] 129 | for f in filenames: 130 | # for 'official' test_fold, two patient dataset contain both train and test samples 131 | idx = f.split('_')[0] if test_fold in ['0', '1', '2', '3', '4'] else f 132 | if args.stetho_id >= 0: # extract specific device dataset 133 | if idx in patient_dict and self.file_to_device[f] == args.stetho_id: 134 | self.filenames.append(f) 135 | else: # use all dataset 136 | if idx in patient_dict: 137 | self.filenames.append(f) 138 | 139 | self.audio_data = [] # each sample is a tuple with (audio_data, label, filename) 140 | self.metadata = [] # (age, sex, adult_BMI, child_weight, child_height, device_idx) 141 | self.labels = [] 142 | 143 | if print_flag: 144 | print('*' * 20) 145 | print("Extracting individual breathing cycles..") 146 | 147 | self.cycle_list = [] 148 | self.filename_to_label = {} 149 | self.classwise_cycle_list = [[] for _ in range(args.n_cls)] 150 | 151 | # ========================================================================== 152 | """ extract individual cycles by librosa or torchaudio """ 153 | for idx, filename in enumerate(self.filenames): 154 | # you can use self.filename_to_label to get statistics of original sample labels (will not be used on other function) 155 | self.filename_to_label[filename] = [] 156 | 157 | # "RespireNet" version: get original cycles 6,898 by librosa 158 | # sample_data = get_individual_cycles_librosa(args, annotation_dict[filename], data_folder, filename, args.sample_rate, args.n_cls, args.butterworth_filter) 159 | 160 | # "SCL" version: get original cycles 6,898 by torchaudio and cut_pad samples 161 | sample_data = get_individual_cycles_torchaudio(args, annotation_dict[filename], self.file_to_metadata[filename], data_folder, filename, args.sample_rate, args.n_cls) 162 | 163 | # cycles_with_labels: [(audio_chunk, label, metadata), (...)] 164 | cycles_with_labels = [(data[0], data[1], self.file_to_metadata[filename]) for data in sample_data] 165 | 166 | self.cycle_list.extend(cycles_with_labels) 167 | for d in cycles_with_labels: 168 | # {filename: [label for cycle 1, ...]} 169 | self.filename_to_label[filename].append(d[1]) 170 | self.classwise_cycle_list[d[1]].append(d) 171 | 172 | # concatenation based augmentation scheme from "RespireNet" paper.. 173 | # TODO: how to decide the meta information of generated cycles 174 | # if train_flag and args.concat_aug_scale and args.class_split == 'lungsound' and args.n_cls == 4: 175 | # self.classwise_cycle_list, self.cycle_list = concat_augmentation(self.classwise_cycle_list, self.cycle_list, scale=args.concat_aug_scale) 176 | 177 | for sample in self.cycle_list: 178 | self.metadata.append(sample[2]) 179 | 180 | # "RespireNet" version: split and pad each cycle to the desired length (cycle numbers can be more than 6,898) 181 | # output = split_pad_sample(sample, args.desired_length, args.sample_rate, types=args.pad_types) 182 | # self.audio_data.extend(output) 183 | 184 | # "SCL" version 185 | self.audio_data.append(sample) 186 | # ========================================================================== 187 | 188 | self.class_nums = np.zeros(args.n_cls) 189 | for sample in self.audio_data: 190 | self.class_nums[sample[1]] += 1 191 | self.labels.append(sample[1]) 192 | self.class_ratio = self.class_nums / sum(self.class_nums) * 100 193 | 194 | if print_flag: 195 | print('[Preprocessed {} dataset information]'.format(self.split)) 196 | print('total number of audio data: {}'.format(len(self.audio_data))) 197 | for i, (n, p) in enumerate(zip(self.class_nums, self.class_ratio)): 198 | print('Class {} {:<9}: {:<4} ({:.1f}%)'.format(i, '('+args.cls_list[i]+')', int(n), p)) 199 | 200 | # ========================================================================== 201 | """ convert mel-spectrogram """ 202 | self.audio_images = [] 203 | for index in range(len(self.audio_data)): 204 | audio, label = self.audio_data[index][0], self.audio_data[index][1] 205 | 206 | audio_image = [] 207 | # self.aug_times = 1 + 5 * self.args.augment_times # original + five naa augmentations * augment_times (optional) 208 | for aug_idx in range(self.args.raw_augment+1): 209 | if aug_idx > 0: 210 | if self.train_flag and not mean_std: 211 | audio = augment_raw_audio(audio, self.sample_rate, self.args) 212 | 213 | # "RespireNet" version: pad incase smaller than desired length 214 | # audio = split_pad_sample([audio, 0,0], self.desired_length, self.sample_rate, types=self.pad_types)[0][0] 215 | 216 | # "SCL" version: cut longer sample or pad sample 217 | audio = cut_pad_sample_torchaudio(torch.tensor(audio), args) 218 | else: 219 | audio_image.append(None) 220 | continue 221 | 222 | image = generate_fbank(audio, self.sample_rate, n_mels=self.n_mels) 223 | # image = generate_mel_spectrogram(audio.squeeze(0).numpy(), self.sample_rate, n_mels=self.n_mels, f_max=self.f_max, nfft=self.nfft, hop=self.hop, args=self.args) # image [n_mels, 251, 1] 224 | 225 | # blank region clipping from "RespireNet" paper.. 226 | if self.args.blank_region_clip: 227 | image_copy = deepcopy(generate_fbank(audio, self.sample_rate, n_mels=self.n_mels)) 228 | # image_copy = deepcopy(generate_mel_spectrogram(audio.squeeze(0).numpy(), self.sample_rate, n_mels=self.n_mels, f_max=self.f_max, nfft=self.nfft, hop=self.hop, args=self.args)) # image [n_mels, 251, 1] 229 | 230 | image_copy[image_copy < 10] = 0 231 | for row in range(image_copy.shape[0]): 232 | black_percent = len(np.where(image_copy[row,:] == 0)[0]) / len(image_copy[row,:]) 233 | # if there is row that is filled by more than 20% regions, stop and remember that `row` 234 | if black_percent < 0.80: 235 | break 236 | 237 | # delete black percent 238 | if row + 1 < image.shape[0]: 239 | image = image[row+1:,:,:] 240 | image = cv2.resize(image, (image.shape[1], self.n_mels), interpolation=cv2.INTER_LINEAR) 241 | image = image[..., np.newaxis] 242 | 243 | audio_image.append(image) 244 | self.audio_images.append((audio_image, label)) 245 | 246 | if self.dump_images: 247 | save_image(audio_image, './') 248 | self.dump_images = False 249 | 250 | self.h, self.w, _ = self.audio_images[0][0][0].shape 251 | # ========================================================================== 252 | 253 | def __getitem__(self, index): 254 | audio_images, label, metadata = self.audio_images[index][0], self.audio_images[index][1], self.metadata[index] 255 | 256 | if self.args.raw_augment and self.train_flag and not self.mean_std: 257 | aug_idx = random.randint(0, self.args.raw_augment) 258 | audio_image = audio_images[aug_idx] 259 | else: 260 | audio_image = audio_images[0] 261 | 262 | if self.transform is not None: 263 | audio_image = self.transform(audio_image) 264 | 265 | return audio_image, label, metadata 266 | 267 | def __len__(self): 268 | return len(self.audio_data) -------------------------------------------------------------------------------- /util/icbhi_util.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import os 3 | import math 4 | import random 5 | from tkinter import W 6 | import pandas as pd 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import cv2 11 | import cmapy 12 | import librosa 13 | import torch 14 | import torchaudio 15 | from torchaudio import transforms as T 16 | from scipy.signal import butter, lfilter 17 | 18 | from .augmentation import augment_raw_audio 19 | 20 | __all__ = ['get_annotations', 'save_image', 'get_mean_and_std', 'get_individual_cycles_librosa', 'get_individual_cycles_torchaudio', 'split_pad_sample', 'generate_mel_spectrogram', 'generate_fbank', 'concat_augmentation', 'get_score'] 21 | 22 | 23 | # ========================================================================== 24 | """ ICBHI dataset information """ 25 | def _extract_lungsound_annotation(file_name, data_folder): 26 | tokens = file_name.strip().split('_') 27 | recording_info = pd.DataFrame(data = [tokens], columns = ['Patient Number', 'Recording index', 'Chest location','Acquisition mode','Recording equipment']) 28 | recording_annotations = pd.read_csv(os.path.join(data_folder, file_name + '.txt'), names = ['Start', 'End', 'Crackles', 'Wheezes'], delimiter= '\t') 29 | 30 | return recording_info, recording_annotations 31 | 32 | 33 | def get_annotations(args, data_folder): 34 | if args.class_split == 'lungsound' or args.class_split in ['lungsound_meta', 'meta']: 35 | filenames = [f.strip().split('.')[0] for f in os.listdir(data_folder) if '.txt' in f] 36 | 37 | annotation_dict = {} 38 | for f in filenames: 39 | info, ann = _extract_lungsound_annotation(f, data_folder) 40 | annotation_dict[f] = ann 41 | 42 | elif args.class_split == 'diagnosis': 43 | filenames = [f.strip().split('.')[0] for f in os.listdir(data_folder) if '.txt' in f] 44 | tmp = pd.read_csv(os.path.join(args.data_folder, 'icbhi_dataset/patient_diagnosis.txt'), names=['Disease'], delimiter='\t') 45 | 46 | annotation_dict = {} 47 | for f in filenames: 48 | info, ann = _extract_lungsound_annotation(f, data_folder) 49 | ann.drop(['Crackles', 'Wheezes'], axis=1, inplace=True) 50 | 51 | disease = tmp.loc[int(f.strip().split('_')[0]), 'Disease'] 52 | ann['Disease'] = disease 53 | 54 | annotation_dict[f] = ann 55 | 56 | return annotation_dict 57 | 58 | 59 | def save_image(image, fpath): 60 | save_dir = os.path.join(fpath, 'image.jpg') 61 | cv2.imwrite(save_dir, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 62 | 63 | 64 | def get_mean_and_std(dataset): 65 | """ Compute the mean and std value of mel-spectrogram """ 66 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8) 67 | 68 | cnt = 0 69 | fst_moment = torch.zeros(1) 70 | snd_moment = torch.zeros(1) 71 | for inputs, _, _ in dataloader: 72 | b, c, h, w = inputs.shape 73 | nb_pixels = b * h * w 74 | 75 | fst_moment += torch.sum(inputs, dim=[0,2,3]) 76 | snd_moment += torch.sum(inputs**2, dim=[0,2,3]) 77 | cnt += nb_pixels 78 | 79 | mean = fst_moment / cnt 80 | std = torch.sqrt(snd_moment/cnt - mean**2) 81 | 82 | return mean, std 83 | # ========================================================================== 84 | 85 | 86 | # ========================================================================== 87 | """ data preprocessing """ 88 | def _butter_bandpass(lowcut, highcut, fs, order=5): 89 | nyq = 0.5 * fs 90 | low = lowcut / nyq 91 | high = highcut / nyq 92 | b, a = butter(order, [low, high], btype='band') 93 | 94 | return b, a 95 | 96 | 97 | def _butter_bandpass_filter(data, lowcut, highcut, fs, order=5): 98 | b, a = _butter_bandpass(lowcut, highcut, fs, order=order) 99 | y = lfilter(b, a, data) 100 | 101 | return y 102 | 103 | 104 | def _slice_data_librosa(start, end, data, sample_rate): 105 | """ 106 | RespireNet paper.. 107 | sample_rate denotes how many sample points for one second 108 | """ 109 | max_ind = len(data) 110 | start_ind = min(int(start * sample_rate), max_ind) 111 | end_ind = min(int(end * sample_rate), max_ind) 112 | 113 | return data[start_ind: end_ind] 114 | 115 | 116 | def _get_lungsound_label(crackle, wheeze, n_cls): 117 | if n_cls == 4: 118 | if crackle == 0 and wheeze == 0: 119 | return 0 120 | elif crackle == 1 and wheeze == 0: 121 | return 1 122 | elif crackle == 0 and wheeze == 1: 123 | return 2 124 | elif crackle == 1 and wheeze == 1: 125 | return 3 126 | 127 | elif n_cls == 2: 128 | if crackle == 0 and wheeze == 0: 129 | return 0 130 | else: 131 | return 1 132 | 133 | 134 | def _get_diagnosis_label(disease, n_cls): 135 | if n_cls == 3: 136 | if disease in ['COPD', 'Bronchiectasis', 'Asthma']: 137 | return 1 138 | elif disease in ['URTI', 'LRTI', 'Pneumonia', 'Bronchiolitis']: 139 | return 2 140 | else: 141 | return 0 142 | 143 | elif n_cls == 2: 144 | if disease == 'Healthy': 145 | return 0 146 | else: 147 | return 1 148 | 149 | 150 | def get_individual_cycles_librosa(args, recording_annotations, data_folder, filename, sample_rate, n_cls, butterworth_filter=None): 151 | """ 152 | RespireNet paper.. 153 | Used to split each individual sound file into separate sound clips containing one respiratory cycle each 154 | output: [(audio_chunk:np.array, label:int), (...)] 155 | """ 156 | sample_data = [] 157 | 158 | # load file with specified sample rate (also converts to mono) 159 | data, rate = librosa.load(os.path.join(data_folder, filename+'.wav'), sr=sample_rate) 160 | 161 | if butterworth_filter: 162 | # butter bandpass filter 163 | data = _butter_bandpass_filter(lowcut=200, highcut=1800, fs=sample_rate, order=butterworth_filter) 164 | 165 | for idx in recording_annotations.index: 166 | row = recording_annotations.loc[idx] 167 | 168 | start = row['Start'] # time (second) 169 | end = row['End'] # time (second) 170 | audio_chunk = _slice_data_librosa(start, end, data, rate) 171 | 172 | if args.class_split == 'lungsound': 173 | crackles = row['Crackles'] 174 | wheezes = row['Wheezes'] 175 | sample_data.append((audio_chunk, _get_lungsound_label(crackles, wheezes, n_cls))) 176 | elif args.class_split == 'diagnosis': 177 | disease = row['Disease'] 178 | sample_data.append((audio_chunk, _get_diagnosis_label(disease, n_cls))) 179 | 180 | return sample_data 181 | 182 | 183 | def _slice_data_torchaudio(start, end, data, sample_rate): 184 | """ 185 | SCL paper.. 186 | sample_rate denotes how many sample points for one second 187 | """ 188 | max_ind = data.shape[1] 189 | start_ind = min(int(start * sample_rate), max_ind) 190 | end_ind = min(int(end * sample_rate), max_ind) 191 | 192 | return data[:, start_ind: end_ind] 193 | 194 | 195 | def cut_pad_sample_torchaudio(data, args): 196 | fade_samples_ratio = 16 197 | fade_samples = int(args.sample_rate / fade_samples_ratio) 198 | fade_out = T.Fade(fade_in_len=0, fade_out_len=fade_samples, fade_shape='linear') 199 | target_duration = args.desired_length * args.sample_rate 200 | 201 | if data.shape[-1] > target_duration: 202 | data = data[..., :target_duration] 203 | else: 204 | if args.pad_types == 'zero': 205 | tmp = torch.zeros(1, target_duration, dtype=torch.float32) 206 | diff = target_duration - data.shape[-1] 207 | tmp[..., diff//2:data.shape[-1]+diff//2] = data 208 | data = tmp 209 | elif args.pad_types == 'repeat': 210 | ratio = math.ceil(target_duration / data.shape[-1]) 211 | data = data.repeat(1, ratio) 212 | data = data[..., :target_duration] 213 | data = fade_out(data) 214 | 215 | return data 216 | 217 | def get_individual_cycles_torchaudio(args, recording_annotations, metadata, data_folder, filename, sample_rate, n_cls): 218 | """ 219 | SCL paper.. 220 | used to split each individual sound file into separate sound clips containing one respiratory cycle each 221 | output: [(audio_chunk:np.array, label:int), (...)] 222 | """ 223 | sample_data = [] 224 | fpath = os.path.join(data_folder, filename+'.wav') 225 | 226 | sr = librosa.get_samplerate(fpath) 227 | data, _ = torchaudio.load(fpath) 228 | 229 | if sr != sample_rate: 230 | resample = T.Resample(sr, sample_rate) 231 | data = resample(data) 232 | 233 | fade_samples_ratio = 16 234 | fade_samples = int(sample_rate / fade_samples_ratio) 235 | 236 | fade = T.Fade(fade_in_len=fade_samples, fade_out_len=fade_samples, fade_shape='linear') 237 | 238 | data = fade(data) 239 | for idx in recording_annotations.index: 240 | row = recording_annotations.loc[idx] 241 | 242 | start = row['Start'] # time (second) 243 | end = row['End'] # time (second) 244 | audio_chunk = _slice_data_torchaudio(start, end, data, sample_rate) 245 | 246 | if args.class_split == 'lungsound': 247 | crackles = row['Crackles'] 248 | wheezes = row['Wheezes'] 249 | sample_data.append((audio_chunk, _get_lungsound_label(crackles, wheezes, n_cls))) 250 | elif args.class_split == 'diagnosis': 251 | disease = row['Disease'] 252 | sample_data.append((audio_chunk, _get_diagnosis_label(disease, n_cls))) 253 | 254 | padded_sample_data = [] 255 | for data, label in sample_data: 256 | data = cut_pad_sample_torchaudio(data, args) 257 | padded_sample_data.append((data, label)) 258 | 259 | return padded_sample_data 260 | 261 | 262 | def _zero_padding(source, output_length): 263 | copy = np.zeros(output_length, dtype=np.float32) 264 | src_length = len(source) 265 | 266 | frac = src_length / output_length 267 | if frac < 0.5: 268 | # tile forward sounds to fill empty space 269 | cursor = 0 270 | while(cursor + src_length) < output_length: 271 | copy[cursor:(cursor + src_length)] = source[:] 272 | cursor += src_length 273 | else: 274 | # [src_length:] part will be zeros 275 | copy[:src_length] = source[:] 276 | 277 | return copy 278 | 279 | 280 | def _equally_slice_pad_sample(sample, desired_length, sample_rate): 281 | """ 282 | pad_type == 0: zero-padding 283 | if sample length > desired_length, 284 | all equally sliced samples with samples_per_slice number are zero-padded or recursively duplicated 285 | """ 286 | output_length = int(desired_length * sample_rate) # desired_length is second 287 | soundclip = sample[0].copy() 288 | n_samples = len(soundclip) 289 | 290 | total_length = n_samples / sample_rate # length of cycle in seconds 291 | n_slices = int(math.ceil(total_length / desired_length)) # get the minimum number of slices needed 292 | samples_per_slice = n_samples // n_slices 293 | 294 | output = [] # holds the resultant slices 295 | src_start = 0 # staring index of the samples to copy from the sample buffer 296 | for i in range(n_slices): 297 | src_end = min(src_start + samples_per_slice, n_samples) 298 | length = src_end - src_start 299 | 300 | copy = _zero_padding(soundclip[src_start:src_end], output_length) 301 | output.append((copy, sample[1], sample[2])) 302 | src_start += length 303 | 304 | return output 305 | 306 | 307 | def _duplicate_padding(sample, source, output_length, sample_rate, types): 308 | # pad_type == 1 or 2 309 | copy = np.zeros(output_length, dtype=np.float32) 310 | src_length = len(source) 311 | left = output_length - src_length # amount to be padded 312 | 313 | if types == 'repeat': 314 | aug = sample 315 | else: 316 | aug = augment_raw_audio(sample, sample_rate) 317 | 318 | while len(aug) < left: 319 | aug = np.concatenate([aug, aug]) 320 | 321 | prob = random.random() 322 | if prob < 0.5: 323 | # pad the back part of original sample 324 | copy[left:] = source 325 | copy[:left] = aug[len(aug)-left:] 326 | else: 327 | # pad the front part of original sample 328 | copy[:src_length] = source[:] 329 | copy[src_length:] = aug[:left] 330 | 331 | return copy 332 | 333 | 334 | def split_pad_sample(sample, desired_length, sample_rate, types='repeat'): 335 | """ 336 | if the audio sample length > desired_length, then split and pad samples 337 | else simply pad samples according to pad_types 338 | * types 'zero' : simply pad by zeros (zero-padding) 339 | * types 'repeat' : pad with duplicate on both sides (half-n-half) 340 | * types 'aug' : pad with augmented sample on both sides (half-n-half) 341 | """ 342 | if types == 'zero': 343 | return _equally_slice_pad_sample(sample, desired_length, sample_rate) 344 | 345 | output_length = int(desired_length * sample_rate) 346 | soundclip = sample[0].copy() 347 | n_samples = len(soundclip) 348 | 349 | output = [] 350 | if n_samples > output_length: 351 | """ 352 | if sample length > desired_length, slice samples with desired_length then just use them, 353 | and the last sample is padded according to the padding types 354 | """ 355 | # frames[j] = x[j * hop_length : j * hop_length + frame_length] 356 | frames = librosa.util.frame(soundclip, frame_length=output_length, hop_length=output_length//2, axis=0) 357 | for i in range(frames.shape[0]): 358 | output.append((frames[i], sample[1], sample[2])) 359 | 360 | # get the last sample 361 | last_id = frames.shape[0] * (output_length//2) 362 | last_sample = soundclip[last_id:] 363 | 364 | padded = _duplicate_padding(soundclip, last_sample, output_length, sample_rate, types) 365 | output.append((padded, sample[1], sample[2])) 366 | else: # only pad 367 | padded = _duplicate_padding(soundclip, soundclip, output_length, sample_rate, types) 368 | output.append((padded, sample[1], sample[2])) 369 | 370 | return output 371 | 372 | 373 | def generate_mel_spectrogram(audio, sample_rate, n_mels=64, f_min=50, f_max=2000, nfft=1024, hop=512, args=None): 374 | """ 375 | use librosa library and convert mel-spectrogram to have 3 channels 376 | """ 377 | S = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop) 378 | # convert scale to dB from magnitude 379 | S = librosa.power_to_db(S, ref=np.max) 380 | S = (S - S.min()) / (S.max() - S.min()) 381 | # S *= 255 382 | 383 | if args.blank_region_clip: 384 | S = cv2.flip(S, 0) # up-down flip 385 | img = S.reshape(S.shape[0], S.shape[1], 1) 386 | 387 | return img 388 | 389 | 390 | def generate_fbank(audio, sample_rate, n_mels=128): 391 | """ 392 | use torchaudio library to convert mel fbank for AST model 393 | """ 394 | assert sample_rate == 16000, 'input audio sampling rate must be 16kHz' 395 | fbank = torchaudio.compliance.kaldi.fbank(audio, htk_compat=True, sample_frequency=sample_rate, use_energy=False, window_type='hanning', num_mel_bins=n_mels, dither=0.0, frame_shift=10) 396 | 397 | mean, std = -4.2677393, 4.5689974 398 | fbank = (fbank - mean) / (std * 2) # mean / std 399 | fbank = fbank.unsqueeze(-1).numpy() 400 | return fbank 401 | 402 | 403 | def concat_augmentation(classwise_cycle_list, cycle_list, scale=1.): 404 | """ From "RespireNet" paper.. 405 | """ 406 | 407 | def _get_random_cycles(classwise_cycle_list, idx1, idx2): 408 | i = random.randint(0, len(classwise_cycle_list[idx1])-1) 409 | j = random.randint(0, len(classwise_cycle_list[idx2])-1) 410 | 411 | sample_i = classwise_cycle_list[idx1][i] 412 | sample_j = classwise_cycle_list[idx2][j] 413 | 414 | return sample_i, sample_j 415 | 416 | print('*' * 20) 417 | # augment normal 418 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[0])) 419 | print('# of concatenation-based augmentation for normal class is {}'.format(aug_nums)) 420 | 421 | for _ in range(aug_nums): 422 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 0) 423 | new_sample = np.concatenate([sample_i[0], sample_j[0]]) 424 | # cycle_list: [(audio_chunk, label, filename, pad_times), (...)] 425 | cycle_list.append((new_sample, 0, sample_i[2]+'-'+sample_j[2])) # sample_i[2] denotes filename 426 | 427 | # augment crackle 428 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[1])) 429 | print('# of concatenation-based augmentation for crackle class is {}'.format(aug_nums)) 430 | 431 | for _ in range(aug_nums): 432 | aug_prob = random.random() 433 | if aug_prob < 0.6: 434 | # crackle_i + crackle_j 435 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 1, 1) 436 | elif aug_prob >= 0.6 and aug_prob < 0.8: 437 | # crackle_i + normal_j 438 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 1, 0) 439 | else: 440 | # normal_i + crackle_j 441 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 1) 442 | 443 | new_sample = np.concatenate([sample_i[0], sample_j[0]]) 444 | cycle_list.append((new_sample, 1, sample_i[2]+'-'+sample_j[2])) 445 | 446 | # augment wheeze 447 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[2])) 448 | print('# of concatenation-based augmentation for wheeze class is {}'.format(aug_nums)) 449 | 450 | for _ in range(aug_nums): 451 | aug_prob = random.random() 452 | if aug_prob < 0.6: 453 | # wheeze_i + wheeze_j 454 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 2, 2) 455 | elif aug_prob >= 0.6 and aug_prob < 0.8: 456 | # wheeze_i + normal_j 457 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 2, 0) 458 | else: 459 | # normal_i + wheeze_j 460 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 2) 461 | 462 | new_sample = np.concatenate([sample_i[0], sample_j[0]]) 463 | cycle_list.append((new_sample, 2, sample_i[2]+'-'+sample_j[2])) 464 | 465 | # augment both 466 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[3])) 467 | print('# of concatenation-based augmentation for both class is {}'.format(aug_nums)) 468 | 469 | for _ in range(aug_nums): 470 | aug_prob = random.random() 471 | if aug_prob < 0.5: 472 | # both_i + both_j 473 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 3, 3) 474 | elif aug_prob >= 0.5 and aug_prob < 0.7: 475 | # crackle_i + wheeze_j 476 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 1, 2) 477 | elif aug_prob >=0.7 and aug_prob < 0.8: 478 | # wheeze_i + crackle_j 479 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 2, 1) 480 | elif aug_prob >=0.8 and aug_prob < 0.9: 481 | # both_i + normal_j 482 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 3, 0) 483 | else: 484 | # normal_i + both_j 485 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 3) 486 | 487 | new_sample = np.concatenate([sample_i[0], sample_j[0]]) 488 | cycle_list.append((new_sample, 3, sample_i[2]+'-'+sample_j[2])) 489 | 490 | return classwise_cycle_list, cycle_list 491 | # ========================================================================== 492 | 493 | 494 | # ========================================================================== 495 | """ evaluation metric """ 496 | def get_score(hits, counts, pflag=False): 497 | # normal accuracy 498 | sp = hits[0] / (counts[0] + 1e-10) * 100 499 | # abnormal accuracy 500 | se = sum(hits[1:]) / (sum(counts[1:]) + 1e-10) * 100 501 | sc = (sp + se) / 2.0 502 | 503 | if pflag: 504 | # print("************* Metrics ******************") 505 | print("S_p: {}, S_e: {}, Score: {}".format(sp, se, sc)) 506 | 507 | return sp, se, sc 508 | # ========================================================================== 509 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | 11 | """ train util """ 12 | def adjust_learning_rate(args, optimizer, epoch): 13 | lr = args.learning_rate 14 | if args.cosine: 15 | eta_min = lr * (args.lr_decay_rate ** 3) 16 | lr = eta_min + (lr - eta_min) * ( 17 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 18 | else: 19 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 20 | if steps > 0: 21 | lr = lr * (args.lr_decay_rate ** steps) 22 | 23 | for i, param_group in enumerate(optimizer.param_groups): 24 | param_group['lr'] = lr 25 | 26 | 27 | def adjust_lr_wd(args, optimizer, epoch): 28 | lr = args.learning_rate 29 | if args.cosine: 30 | eta_min = lr * (args.lr_decay_rate ** 3) 31 | lr = eta_min + (lr - eta_min) * ( 32 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 33 | else: 34 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 35 | if steps > 0: 36 | lr = lr * (args.lr_decay_rate ** steps) 37 | 38 | wd = args.weight_decay 39 | if args.wd_scheduler: 40 | wd_min = args.weight_decay_end 41 | wd = wd_min + (wd - wd_min) * ( 42 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 43 | 44 | for i, param_group in enumerate(optimizer.param_groups): 45 | param_group['lr'] = lr 46 | if i == 0: # in case of DINO and ViT, only wd for regularized params 47 | param_group['weight_decay'] = wd 48 | 49 | 50 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 51 | if args.warm and epoch <= args.warm_epochs: 52 | p = (batch_id + (epoch - 1) * total_batches) / \ 53 | (args.warm_epochs * total_batches) 54 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 55 | 56 | for param_group in optimizer.param_groups: 57 | param_group['lr'] = lr 58 | 59 | 60 | def set_optimizer(args, optim_params=None, model=None): 61 | if model is not None: 62 | optim_params = model.parameters() 63 | 64 | if args.optimizer == 'sgd': 65 | optimizer = optim.SGD(optim_params, 66 | lr=args.learning_rate, 67 | momentum=args.momentum, 68 | weight_decay=args.weight_decay) 69 | elif args.optimizer == 'adam': 70 | optimizer = optim.Adam(optim_params, 71 | lr=args.learning_rate, 72 | weight_decay=args.weight_decay) 73 | else: 74 | raise NotImplemented 75 | 76 | return optimizer 77 | 78 | 79 | class MA: 80 | def __init__(self, beta): 81 | super().__init__() 82 | self.beta = beta 83 | 84 | def update_average(self, old, new): 85 | if old is None: 86 | return new 87 | return old * self.beta + (1 - self.beta) * new 88 | 89 | 90 | def update_moving_average(ma_beta, current_model, ma_ckpt): 91 | ma_updater = MA(ma_beta) 92 | new_state_dict = {} 93 | for (k1, current_params), (k2, ma_params) in zip(current_model.state_dict().items(), ma_ckpt.items()): 94 | assert k1 == k2 95 | old_weight, up_weight = ma_params.data, current_params.data 96 | new_state_dict[k1] = ma_updater.update_average(old_weight, up_weight) 97 | 98 | current_model.load_state_dict(new_state_dict) 99 | return current_model 100 | 101 | 102 | """ eval util """ 103 | class AverageMeter(object): 104 | """ Computes and stores the average and current value """ 105 | def __init__(self): 106 | self.reset() 107 | 108 | def reset(self): 109 | self.val = 0 110 | self.avg = 0 111 | self.sum = 0 112 | self.count = 0 113 | 114 | def update(self, val, n=1): 115 | self.val = val 116 | self.sum += val * n 117 | self.count += n 118 | self.avg = self.sum / self.count 119 | 120 | 121 | def accuracy(output, target, topk=(1,)): 122 | """Computes the accuracy over the k top predictions for the specified values of k""" 123 | with torch.no_grad(): 124 | n_cls = output.shape[1] 125 | valid_topk = [k for k in topk if k <= n_cls] 126 | 127 | maxk = max(valid_topk) 128 | bsz = target.size(0) 129 | 130 | _, pred = output.topk(maxk, 1, True, True) 131 | pred = pred.t() 132 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 133 | 134 | res = [] 135 | for k in topk: 136 | if k in valid_topk: 137 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 138 | res.append(correct_k.mul_(100.0 / bsz)) 139 | else: res.append(torch.tensor([0.])) 140 | 141 | return res, bsz 142 | 143 | 144 | def save_model(model, optimizer, args, epoch, save_file, classifier): 145 | print('==> Saving...') 146 | state = { 147 | 'args': args, 148 | 'model': model.state_dict(), 149 | 'optimizer': optimizer.state_dict(), 150 | 'epoch': epoch, 151 | 'classifier': classifier.state_dict() 152 | } 153 | 154 | torch.save(state, save_file) 155 | del state 156 | 157 | 158 | def update_json(exp_name, acc, path='./save/results.json'): 159 | acc = [round(a, 2) for a in acc] 160 | if not os.path.exists(path): 161 | with open(path, 'w') as f: 162 | json.dump({}, f) 163 | 164 | with open(path, 'r', encoding="UTF-8") as f: 165 | result_dict = json.load(f) 166 | result_dict[exp_name] = acc 167 | 168 | with open(path, 'w') as f: 169 | json.dump(result_dict, f) 170 | 171 | print('best Score: {} (sp, se, sc)'.format(acc)) 172 | print('results updated to %s' % path) -------------------------------------------------------------------------------- /util/time_warping.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is from https://github.com/zcaceres/spec_augment 3 | """ 4 | import torch 5 | 6 | __all__ = ['sparse_image_warp'] 7 | 8 | 9 | def sparse_image_warp(img_tensor, 10 | source_control_point_locations, 11 | dest_control_point_locations, 12 | interpolation_order=2, 13 | regularization_weight=0.0, 14 | num_boundaries_points=0): 15 | device = img_tensor.device 16 | control_point_flows = (dest_control_point_locations - source_control_point_locations) 17 | 18 | # clamp_boundaries = num_boundary_points > 0 19 | # boundary_points_per_edge = num_boundary_points - 1 20 | batch_size, image_height, image_width = img_tensor.shape 21 | flattened_grid_locations = get_flat_grid_locations(image_height, image_width, device) 22 | 23 | # IGNORED FOR OUR BASIC VERSION... 24 | # flattened_grid_locations = constant_op.constant( 25 | # _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) 26 | 27 | # if clamp_boundaries: 28 | # (dest_control_point_locations, 29 | # control_point_flows) = _add_zero_flow_controls_at_boundary( 30 | # dest_control_point_locations, control_point_flows, image_height, 31 | # image_width, boundary_points_per_edge) 32 | 33 | flattened_flows = interpolate_spline( 34 | dest_control_point_locations, 35 | control_point_flows, 36 | flattened_grid_locations, 37 | interpolation_order, 38 | regularization_weight) 39 | 40 | dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width) 41 | 42 | warped_image = dense_image_warp(img_tensor, dense_flows) 43 | 44 | return warped_image, dense_flows 45 | 46 | 47 | def get_grid_locations(image_height, image_width, device): 48 | y_range = torch.linspace(0, image_height - 1, image_height, device=device) 49 | x_range = torch.linspace(0, image_width - 1, image_width, device=device) 50 | y_grid, x_grid = torch.meshgrid(y_range, x_range) 51 | return torch.stack((y_grid, x_grid), -1) 52 | 53 | 54 | def flatten_grid_locations(grid_locations, image_height, image_width): 55 | return torch.reshape(grid_locations, [image_height * image_width, 2]) 56 | 57 | 58 | def get_flat_grid_locations(image_height, image_width, device): 59 | y_range = torch.linspace(0, image_height - 1, image_height, device=device) 60 | x_range = torch.linspace(0, image_width - 1, image_width, device=device) 61 | y_grid, x_grid = torch.meshgrid(y_range, x_range) 62 | return torch.stack((y_grid, x_grid), -1).reshape([image_height * image_width, 2]) 63 | 64 | 65 | def create_dense_flows(flattened_flows, batch_size, image_height, image_width): 66 | # possibly .view 67 | return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) 68 | 69 | 70 | def interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0,): 71 | # First, fit the spline to the observed data. 72 | w, v = solve_interpolation(train_points, train_values, order, regularization_weight) 73 | # Then, evaluate the spline at the query locations. 74 | query_values = apply_interpolation(query_points, train_points, w, v, order) 75 | 76 | return query_values 77 | 78 | 79 | def solve_interpolation(train_points, train_values, order, regularization_weight, eps=1e-7): 80 | device = train_points.device 81 | b, n, d = train_points.shape 82 | k = train_values.shape[-1] 83 | 84 | # First, rename variables so that the notation (c, f, w, v, A, B, etc.) 85 | # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. 86 | # To account for python style guidelines we use 87 | # matrix_a for A and matrix_b for B. 88 | 89 | c = train_points 90 | f = train_values.float() 91 | 92 | matrix_a = phi(cross_squared_distance_matrix(c,c), order).unsqueeze(0) # [b, n, n] 93 | # if regularization_weight > 0: 94 | # batch_identity_matrix = array_ops.expand_dims( 95 | # linalg_ops.eye(n, dtype=c.dtype), 0) 96 | # matrix_a += regularization_weight * batch_identity_matrix 97 | 98 | # Append ones to the feature values for the bias term in the linear model. 99 | ones = torch.ones(n, dtype=train_points.dtype, device=device).view([-1, n, 1]) 100 | matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1] 101 | 102 | # [b, n + d + 1, n] 103 | left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1) 104 | 105 | num_b_cols = matrix_b.shape[2] # d + 1 106 | 107 | # In Tensorflow, zeros are used here. Pytorch solve fails with zeros for some reason we don't understand. 108 | # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication. 109 | # lhs_zeros = torch.zeros((b, num_b_cols, num_b_cols), device=device) 110 | lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) * eps 111 | right_block = torch.cat((matrix_b, lhs_zeros), 112 | 1) # [b, n + d + 1, d + 1] 113 | lhs = torch.cat((left_block, right_block), 114 | 2) # [b, n + d + 1, n + d + 1] 115 | 116 | rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype, device=device).float() 117 | rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k] 118 | 119 | # Then, solve the linear system and unpack the results. 120 | X = torch.linalg.solve(lhs, rhs) 121 | w = X[:, :n, :] 122 | v = X[:, n:, :] 123 | return w, v 124 | 125 | 126 | def cross_squared_distance_matrix(x, y): 127 | """Pairwise squared distance between two (batch) matrices' rows (2nd dim). 128 | Computes the pairwise distances between rows of x and rows of y 129 | Args: 130 | x: [batch_size, n, d] float `Tensor` 131 | y: [batch_size, m, d] float `Tensor` 132 | Returns: 133 | squared_dists: [batch_size, n, m] float `Tensor`, where 134 | squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 135 | """ 136 | x_norm_squared = torch.sum(torch.mul(x, x)) 137 | y_norm_squared = torch.sum(torch.mul(y, y)) 138 | 139 | x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0,1)) 140 | 141 | # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 142 | squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared 143 | 144 | return squared_dists.float() 145 | 146 | 147 | def phi(r, order): 148 | """Coordinate-wise nonlinearity used to define the order of the interpolation. 149 | See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. 150 | Args: 151 | r: input op 152 | order: interpolation order 153 | Returns: 154 | phi_k evaluated coordinate-wise on r, for k = r 155 | """ 156 | EPSILON=torch.tensor(1e-10, device=r.device) 157 | # using EPSILON prevents log(0), sqrt0), etc. 158 | # sqrt(0) is well-defined, but its gradient is not 159 | if order == 1: 160 | r = torch.max(r, EPSILON) 161 | r = torch.sqrt(r) 162 | return r 163 | elif order == 2: 164 | return 0.5 * r * torch.log(torch.max(r, EPSILON)) 165 | elif order == 4: 166 | return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON)) 167 | elif order % 2 == 0: 168 | r = torch.max(r, EPSILON) 169 | return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r) 170 | else: 171 | r = torch.max(r, EPSILON) 172 | return torch.pow(r, 0.5 * order) 173 | 174 | 175 | def apply_interpolation(query_points, train_points, w, v, order): 176 | """Apply polyharmonic interpolation model to data. 177 | Given coefficients w and v for the interpolation model, we evaluate 178 | interpolated function values at query_points. 179 | Args: 180 | query_points: `[b, m, d]` x values to evaluate the interpolation at 181 | train_points: `[b, n, d]` x values that act as the interpolation centers 182 | ( the c variables in the wikipedia article) 183 | w: `[b, n, k]` weights on each interpolation center 184 | v: `[b, d, k]` weights on each input dimension 185 | order: order of the interpolation 186 | Returns: 187 | Polyharmonic interpolation evaluated at points defined in query_points. 188 | """ 189 | query_points = query_points.unsqueeze(0) 190 | # First, compute the contribution from the rbf term. 191 | pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float()) 192 | phi_pairwise_dists = phi(pairwise_dists, order) 193 | 194 | rbf_term = torch.matmul(phi_pairwise_dists, w) 195 | 196 | # Then, compute the contribution from the linear term. 197 | # Pad query_points with ones, for the bias term in the linear model. 198 | ones = torch.ones_like(query_points[..., :1]) 199 | query_points_pad = torch.cat(( 200 | query_points, 201 | ones 202 | ), 2).float() 203 | linear_term = torch.matmul(query_points_pad, v) 204 | 205 | return rbf_term + linear_term 206 | 207 | 208 | def dense_image_warp(image, flow): 209 | """Image warping using per-pixel flow vectors. 210 | Apply a non-linear warp to the image, where the warp is specified by a dense 211 | flow field of offset vectors that define the correspondences of pixel values 212 | in the output image back to locations in the source image. Specifically, the 213 | pixel value at output[b, j, i, c] is 214 | images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. 215 | The locations specified by this formula do not necessarily map to an int 216 | index. Therefore, the pixel value is obtained by bilinear 217 | interpolation of the 4 nearest pixels around 218 | (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside 219 | of the image, we use the nearest pixel values at the image boundary. 220 | Args: 221 | image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. 222 | flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. 223 | name: A name for the operation (optional). 224 | Note that image and flow can be of type tf.half, tf.float32, or tf.float64, 225 | and do not necessarily have to be the same type. 226 | Returns: 227 | A 4-D float `Tensor` with shape`[batch, height, width, channels]` 228 | and same type as input image. 229 | Raises: 230 | ValueError: if height < 2 or width < 2 or the inputs have the wrong number 231 | of dimensions. 232 | """ 233 | image = image.unsqueeze(3) # add a single channel dimension to image tensor 234 | batch_size, height, width, channels = image.shape 235 | device = image.device 236 | 237 | # The flow is defined on the image grid. Turn the flow into a list of query 238 | # points in the grid space. 239 | grid_x, grid_y = torch.meshgrid( 240 | torch.arange(width, device=device), torch.arange(height, device=device)) 241 | 242 | stacked_grid = torch.stack((grid_y, grid_x), dim=2).float() 243 | 244 | batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2) 245 | 246 | query_points_on_grid = batched_grid - flow 247 | query_points_flattened = torch.reshape(query_points_on_grid, 248 | [batch_size, height * width, 2]) 249 | # Compute values at the query points, then reshape the result back to the 250 | # image grid. 251 | interpolated = interpolate_bilinear(image, query_points_flattened) 252 | interpolated = torch.reshape(interpolated, 253 | [batch_size, height, width, channels]) 254 | return interpolated 255 | 256 | 257 | def interpolate_bilinear(grid, 258 | query_points, 259 | name='interpolate_bilinear', 260 | indexing='ij'): 261 | """Similar to Matlab's interp2 function. 262 | Finds values for query points on a grid using bilinear interpolation. 263 | Args: 264 | grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. 265 | query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. 266 | name: a name for the operation (optional). 267 | indexing: whether the query points are specified as row and column (ij), 268 | or Cartesian coordinates (xy). 269 | Returns: 270 | values: a 3-D `Tensor` with shape `[batch, N, channels]` 271 | Raises: 272 | ValueError: if the indexing mode is invalid, or if the shape of the inputs 273 | invalid. 274 | """ 275 | if indexing != 'ij' and indexing != 'xy': 276 | raise ValueError('Indexing mode must be \'ij\' or \'xy\'') 277 | 278 | 279 | shape = grid.shape 280 | if len(shape) != 4: 281 | msg = 'Grid must be 4 dimensional. Received size: ' 282 | raise ValueError(msg + str(grid.shape)) 283 | 284 | batch_size, height, width, channels = grid.shape 285 | 286 | shape = [batch_size, height, width, channels] 287 | query_type = query_points.dtype 288 | grid_type = grid.dtype 289 | grid_device = grid.device 290 | 291 | num_queries = query_points.shape[1] 292 | 293 | alphas = [] 294 | floors = [] 295 | ceils = [] 296 | index_order = [0, 1] if indexing == 'ij' else [1, 0] 297 | unstacked_query_points = query_points.unbind(2) 298 | 299 | for dim in index_order: 300 | queries = unstacked_query_points[dim] 301 | 302 | size_in_indexing_dimension = shape[dim + 1] 303 | 304 | # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 305 | # is still a valid index into the grid. 306 | max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type, device=grid_device) 307 | min_floor = torch.tensor(0.0, dtype=query_type, device=grid_device) 308 | maxx = torch.max(min_floor, torch.floor(queries)) 309 | floor = torch.min(maxx, max_floor) 310 | int_floor = floor.long() 311 | floors.append(int_floor) 312 | ceil = int_floor + 1 313 | ceils.append(ceil) 314 | 315 | # alpha has the same type as the grid, as we will directly use alpha 316 | # when taking linear combinations of pixel values from the image. 317 | 318 | 319 | alpha = (queries - floor).clone().detach().type(grid_type) 320 | min_alpha = torch.tensor(0.0, dtype=grid_type, device=grid_device) 321 | max_alpha = torch.tensor(1.0, dtype=grid_type, device=grid_device) 322 | alpha = torch.min(torch.max(min_alpha, alpha), max_alpha) 323 | 324 | # Expand alpha to [b, n, 1] so we can use broadcasting 325 | # (since the alpha values don't depend on the channel). 326 | alpha = torch.unsqueeze(alpha, 2) 327 | alphas.append(alpha) 328 | 329 | flattened_grid = torch.reshape( 330 | grid, [batch_size * height * width, channels]) 331 | batch_offsets = torch.reshape( 332 | torch.arange(batch_size, device=grid_device) * height * width, [batch_size, 1]) 333 | 334 | # This wraps array_ops.gather. We reshape the image data such that the 335 | # batch, y, and x coordinates are pulled into the first dimension. 336 | # Then we gather. Finally, we reshape the output back. It's possible this 337 | # code would be made simpler by using array_ops.gather_nd. 338 | def gather(y_coords, x_coords, name): 339 | linear_coordinates = batch_offsets + y_coords * width + x_coords 340 | gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates) 341 | return torch.reshape(gathered_values, 342 | [batch_size, num_queries, channels]) 343 | 344 | # grab the pixel values in the 4 corners around each query point 345 | top_left = gather(floors[0], floors[1], 'top_left') 346 | top_right = gather(floors[0], ceils[1], 'top_right') 347 | bottom_left = gather(ceils[0], floors[1], 'bottom_left') 348 | bottom_right = gather(ceils[0], ceils[1], 'bottom_right') 349 | 350 | interp_top = alphas[1] * (top_right - top_left) + top_left 351 | interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left 352 | interp = alphas[0] * (interp_bottom - interp_top) + interp_top 353 | 354 | return interp --------------------------------------------------------------------------------