├── .gitignore
├── LICENSE
├── README.md
├── data
└── icbhi_dataset
│ ├── filename_differences.txt
│ ├── filename_format.txt
│ ├── metadata.txt
│ ├── official_split.txt
│ ├── patient_diagnosis.txt
│ ├── patient_list_foldwise.txt
│ └── script.sh
├── main.py
├── method
├── __init__.py
├── patchmix.py
└── patchmix_cl.py
├── models
├── __init__.py
├── ast.py
├── cnn6.py
├── efficientnet.py
├── projector.py
├── resnet.py
└── ssast.py
├── requirements.txt
├── scripts
├── icbhi_ce.sh
├── icbhi_patchmix_ce.sh
└── icbhi_patchmix_cl.sh
└── util
├── __init__.py
├── augmentation.py
├── icbhi_dataset.py
├── icbhi_util.py
├── misc.py
└── time_warping.py
/.gitignore:
--------------------------------------------------------------------------------
1 | save/
2 | data/
3 | *.swp
4 | __pycache__/
5 | *.pyc
6 | *.jpg
7 | *.ipynb
8 | *.pth
9 | *.json
10 | *.sh
11 | .ipynb_checkpoints/
12 | jupyter_notebook/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Sangmin Bae
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Patch-Mix Contrastive Learning (INTERSPEECH 2023)
2 | [arXiv](https://arxiv.org/abs/2305.14032) | [Conference](https://www.isca-archive.org/interspeech_2023/bae23b_interspeech.html#) | [BibTeX](#bibtex)
3 |
4 |
5 |
6 |
7 |
8 |
9 | **Patch-Mix Contrastive Learning with Audio Spectrogram Transformer on Respiratory Sound Classification**
10 | [Sangmin Bae](https://www.raymin0223.com)\*,
11 | [June-Woo Kim](https://github.com/kaen2891)\*,
12 | [Won-Yang Cho](https://github.com/wonyangcho),
13 | [Hyerim Baek](https://github.com/rimiyeyo),
14 | [Soyoun Son](https://github.com/soyounson),
15 | [Byungjo Lee](https://github.com/bzlee-bio),
16 | [Changwan Ha](https://github.com/cwh1981),
17 | [Kyongpil Tae](https://github.com/kyongpiltae),
18 | [Sungnyun Kim](https://github.com/sungnyun)$^\dagger$,
19 | [Se-Young Yun](https://fbsqkd.github.io)$^\dagger$
20 | \* equal contribution $^\dagger$ corresponding authors
21 |
22 | - We demonstrate that the **pretrained model on large-scale visual and audio datasets** can be generalized to the respiratory sound classification task.
23 | - We introduce a straightforward **Patch-Mix augmentation**, which randomly mixes patches between different samples, with Audio Spectrogram Transformer (AST).
24 | - To overcome the label hierarchy in lung sound datasets, we propose an effective **Patch-Mix Contrastive Learning** to distinguish the mixed representations in the latent space.
25 |
26 |
27 | ## Requirements
28 | Install the necessary packages with:
29 | ```
30 | $ pip install torch torchvision torchaudio
31 | $ pip install -r requirements.txt
32 | ```
33 |
34 |
35 | ## Data Preparation
36 | Download the ICBHI dataset files from [official_page](https://bhichallenge.med.auth.gr/ICBHI_2017_Challenge).
37 | ```bash
38 | $ wget https://bhichallenge.med.auth.gr/sites/default/files/ICBHI_final_database/ICBHI_final_database.zip
39 | ```
40 | All `*.wav` and `*.txt` should be saved in `data/icbhi_dataset/audio_test_data`.
41 |
42 | Note that ICBHI dataset consists of a total of 6,898 respiratory cycles,
43 | of which 1,864 contain crackles, 886 contain wheezes, and 506 contain both crackles and wheezes, in 920 annotated audio samples from 126 subjects.
44 |
45 |
46 | ## Training
47 | To simply train the model, run the shell files in `scripts/`.
48 | 1. **`scripts/icbhi_ce.sh`**: Cross-Entropy loss with AST model.
49 | 2. **`scripts/icbhi_patchmix_ce.sh`**: Patch-Mix loss with AST model, where the label depends on the interpolation ratio.
50 | 3. **`scripts/icbhi_patchmix_cl.sh`**: Patch-Mix contrastive loss with AST model.
51 |
52 | Important arguments for different data settings.
53 | - `--dataset`: other lungsound datasets or heart sound can be implemented
54 | - `--class_split`: "lungsound" or "diagnosis" classification
55 | - `--n_cls`: set number of classes as 4 or 2 (normal / abnormal) for lungsound classification
56 | - `--test_fold`: "official" denotes 60/40% train/test split, and "0"~"4" denote 80/20% split
57 |
58 | Important arguments for models.
59 | - `--model`: network architecture, see [models](models/)
60 | - `--from_sl_official`: load ImageNet pretrained checkpoint
61 | - `--audioset_pretrained`: load AudioSet pretrained checkpoint and only support AST and SSAST
62 |
63 | Important arugment for evaluation.
64 | - `--eval`: switch mode to evaluation without any training
65 | - `--pretrained`: load pretrained checkpoint and require `pretrained_ckpt` argument.
66 | - `--pretrained_ckpt`: path for the pretrained checkpoint
67 |
68 | The pretrained model checkpoints will be saved at `save/[EXP_NAME]/best.pth`.
69 |
70 | ## Result
71 |
72 | Patch-Mix Contrastive Learning achieves the state-of-the-art performance of 62.37%, which is higher than previous Score by +4.08%.
73 |
74 |
75 |
76 |
77 |
78 | ## BibTeX
79 | If you find this repo useful for your research, please consider citing our paper:
80 |
81 | ```
82 | @inproceedings{bae23b_interspeech,
83 | title = {Patch-Mix Contrastive Learning with Audio Spectrogram Transformer on Respiratory Sound Classification},
84 | author = {Sangmin Bae and June-Woo Kim and Won-Yang Cho and Hyerim Baek and Soyoun Son and Byungjo Lee and Changwan Ha and Kyongpil Tae and Sungnyun Kim and Se-Young Yun},
85 | year = {2023},
86 | booktitle = {INTERSPEECH 2023},
87 | pages = {5436--5440},
88 | doi = {10.21437/Interspeech.2023-1426},
89 | issn = {2958-1796},
90 | }
91 | ```
92 |
93 | ## Contact
94 | - Sangmin Bae: bsmn0223@kaist.ac.kr
95 | - June-Woo Kim: kaen2891@gmail.com
96 |
--------------------------------------------------------------------------------
/data/icbhi_dataset/filename_differences.txt:
--------------------------------------------------------------------------------
1 | '101_1b1_Al_sc_AKGC417L'
2 | '101_1b1_Pr_sc_AKGC417L'
3 | '102_1b1_Ar_sc_AKGC417L'
4 | '105_1b1_Tc_sc_LittC2SE'
5 | '108_1b1_Al_sc_LittC2SE'
6 | '111_1b2_Tc_sc_LittC2SE'
7 | '111_1b3_Tc_sc_LittC2SE'
8 | '115_1b1_Ar_sc_LittC2SE'
9 | '116_1b2_Pl_sc_LittC2SE'
10 | '116_1b2_Tc_sc_LittC2SE'
11 | '119_1b1_Ar_sc_AKGC417L'
12 | '121_1b1_Tc_sc_LittC2SE'
13 | '121_1p1_Tc_sc_LittC2SE'
14 | '123_1b1_Al_sc_AKGC417L'
15 | '125_1b1_Tc_sc_LittC2SE'
16 | '126_1b1_Al_sc_AKGC417L'
17 | '127_1b1_Ar_sc_LittC2SE'
18 | '129_1b1_Ar_sc_LittC2SE'
19 | '131_1b1_Al_sc_LittC2SE'
20 | '136_1b1_Ar_sc_AKGC417L'
21 | '137_1b1_Ar_sc_LittC2SE'
22 | '137_1b1_Ll_sc_LittC2SE'
23 | '143_1b1_Al_sc_AKGC417L'
24 | '144_1b1_Al_sc_AKGC417L'
25 | '144_1b1_Tc_sc_AKGC417L'
26 | '148_1b1_Al_sc_LittC2SE'
27 | '149_1b1_Al_sc_LittC2SE'
28 | '149_1b1_Lr_sc_LittC2SE'
29 | '149_1b1_Pl_sc_LittC2SE'
30 | '150_1b2_Al_sc_AKGC417L'
31 | '152_1b1_Al_sc_LittC2SE'
32 | '153_1b1_Al_sc_LittC2SE'
33 | '159_1b1_Al_sc_AKGC417L'
34 | '159_1b1_Ar_sc_AKGC417L'
35 | '159_1b1_Ll_sc_AKGC417L'
36 | '159_1b1_Pr_sc_AKGC417L'
37 | '161_1b1_Al_sc_LittC2SE'
38 | '161_1b1_Pl_sc_LittC2SE'
39 | '164_1b1_Ll_sc_LittC2SE'
40 | '165_1b1_Ar_sc_AKGC417L'
41 | '165_1b1_Pl_sc_AKGC417L'
42 | '165_1b1_Pr_sc_AKGC417L'
43 | '167_1b1_Al_sc_LittC2SE'
44 | '167_1b1_Pr_sc_LittC2SE'
45 | '168_1b1_Al_sc_LittC2SE'
46 | '169_1b1_Lr_sc_AKGC417L'
47 | '169_1b2_Ll_sc_AKGC417L'
48 | '171_1b1_Al_sc_AKGC417L'
49 | '173_1b1_Al_sc_AKGC417L'
50 | '179_1b1_Al_sc_LittC2SE'
51 | '179_1b1_Tc_sc_LittC2SE'
52 | '182_1b1_Tc_sc_LittC2SE'
53 | '183_1b1_Pl_sc_AKGC417L'
54 | '183_1b1_Tc_sc_AKGC417L'
55 | '184_1b1_Ar_sc_LittC2SE'
56 | '187_1b1_Ll_sc_AKGC417L'
57 | '188_1b1_Al_sc_LittC2SE'
58 | '188_1b1_Ar_sc_LittC2SE'
59 | '188_1b1_Pl_sc_LittC2SE'
60 | '188_1b1_Tc_sc_LittC2SE'
61 | '190_1b1_Tc_sc_AKGC417L'
62 | '194_1b1_Lr_sc_AKGC417L'
63 | '194_1b1_Pr_sc_AKGC417L'
64 | '196_1b1_Pr_sc_LittC2SE'
65 | '197_1b1_Al_sc_AKGC417L'
66 | '197_1b1_Tc_sc_AKGC417L'
67 | '201_1b1_Al_sc_LittC2SE'
68 | '201_1b1_Ar_sc_LittC2SE'
69 | '201_1b2_Al_sc_LittC2SE'
70 | '201_1b2_Ar_sc_LittC2SE'
71 | '201_1b3_Al_sc_LittC2SE'
72 | '201_1b3_Ar_sc_LittC2SE'
73 | '202_1b1_Ar_sc_AKGC417L'
74 | '206_1b1_Ar_sc_LittC2SE'
75 | '206_1b1_Lr_sc_LittC2SE'
76 | '206_1b1_Pl_sc_LittC2SE'
77 | '208_1b1_Ll_sc_LittC2SE'
78 | '209_1b1_Tc_sc_LittC2SE'
79 | '210_1b1_Al_sc_LittC2SE'
80 | '210_1b1_Ar_sc_LittC2SE'
81 | '214_1b1_Ar_sc_AKGC417L'
82 | '215_1b2_Ar_sc_LittC2SE'
83 | '215_1b3_Tc_sc_LittC2SE'
84 | '216_1b1_Al_sc_AKGC417L'
85 | '216_1b1_Pl_sc_AKGC417L'
86 | '217_1b1_Tc_sc_LittC2SE'
87 | '224_1b1_Tc_sc_AKGC417L'
88 | '224_1b2_Al_sc_AKGC417L'
89 | '225_1b1_Pl_sc_AKGC417L'
90 | '226_1b1_Al_sc_LittC2SE'
91 | '226_1b1_Ll_sc_LittC2SE'
92 |
--------------------------------------------------------------------------------
/data/icbhi_dataset/filename_format.txt:
--------------------------------------------------------------------------------
1 | Elements contained in the filenames:
2 |
3 | Patient number (101,102,...,226)
4 | Recording index
5 | Chest location (Trachea (Tc), {Anterior (A), Posterior (P), Lateral (L)}{left (l), right (r)})
6 | Acquisition mode (sequential/single channel (sc), simultaneous/multichannel (mc))
7 | Recording equipment (AKG C417L Microphone, 3M Littmann Classic II SE Stethoscope, 3M Litmmann 3200 Electronic Stethoscope, WelchAllyn Meditron Master Elite Electronic Stethoscope)
--------------------------------------------------------------------------------
/data/icbhi_dataset/metadata.txt:
--------------------------------------------------------------------------------
1 | 101 3.0 F NA 19.0 99.0 Pr
2 | 102 0.75 F NA 9.8 73.0 Ar
3 | 103 70.0 F 33.0 NA NA Ar
4 | 104 70.0 F 28.47 NA NA Al
5 | 105 7.0 F NA 32.0 135.0 Tc
6 | 106 73.0 F 21.0 NA NA Pl
7 | 107 75.0 F 33.7 NA NA Ll
8 | 108 3.0 M NA NA NA Al
9 | 109 84.0 F 33.53 NA NA Lr
10 | 110 75.0 M 25.21 NA NA Pr
11 | 111 63.0 M 28.4 NA NA Tc
12 | 112 60.0 M 22.86 NA NA Pr
13 | 113 58.0 M 28.41 NA NA Lr
14 | 114 77.0 M 23.12 NA NA Pr
15 | 115 0.58 M NA 7.14 64.0 Ar
16 | 116 56.0 M 28.58 NA NA Pl
17 | 117 68.0 M 24.4 NA NA Tc
18 | 118 81.0 M 36.76 NA NA Pr
19 | 119 2.0 F NA 15.2 94.0 Ar
20 | 120 78.0 M 35.14 NA NA Ar
21 | 121 13.0 F NA 65.0 170.0 Tc
22 | 122 66.0 M 33.0 NA NA Ar
23 | 123 5.0 M NA 25.0 125.0 Al
24 | 124 65.0 M 29.07 NA NA Al
25 | 125 14.0 M NA 62.0 170.0 Tc
26 | 126 1.0 F NA 10.18 80.0 Al
27 | 127 2.0 M NA 12.6 98.0 Ar
28 | 128 65.0 F 24.3 NA NA Tc
29 | 129 6.0 M NA 23.0 119.0 Ar
30 | 130 85.0 F 17.1 NA NA Pl
31 | 131 3.0 M NA 14.0 97.0 Al
32 | 132 71.0 M 34.0 NA NA Lr
33 | 133 68.0 M 27.4 NA NA Pl
34 | 134 61.0 M 32.0 NA NA Al
35 | 135 70.0 M 21.0 NA NA Ar
36 | 136 5.0 M NA 16.2 110.0 Ar
37 | 137 4.0 M NA 18.0 104.0 Ar
38 | 138 56.0 F 21.6 NA NA Ll
39 | 139 61.0 M 28.68 NA NA Lr
40 | 140 79.0 F 23.0 NA NA Tc
41 | 141 66.0 M 22.4 NA NA Pr
42 | 142 78.0 M 26.1 NA NA Pl
43 | 143 0.25 F NA 8.24 68.0 Al
44 | 144 3.0 M NA 16.7 100.0 Al
45 | 145 69.0 M 23.4 NA NA Lr
46 | 146 67.0 M 28.0 NA NA Pl
47 | 147 77.0 M 25.7 NA NA Ar
48 | 148 4.0 M NA 33.0 110.0 Al
49 | 149 0.67 M NA 9.5 70.0 Lr
50 | 150 0.67 F NA 8.12 74.0 Al
51 | 151 75.0 M 28.4 NA NA Ll
52 | 152 16.0 M NA 70.0 183.0 Al
53 | 153 3.0 M NA 16.7 103.0 Al
54 | 154 65.0 M 28.1 NA NA Ll
55 | 155 69.0 M 26.0 NA NA Al
56 | 156 80.0 M 22.9 NA NA Al
57 | 157 62.0 M 53.5 NA NA Pr
58 | 158 63.0 M 16.5 NA NA Lr
59 | 159 0.83 F NA 11.0 80.0 Al
60 | 160 74.0 M 27.4 NA NA Tc
61 | 161 2.0 F NA 12.0 85.0 Pl
62 | 162 67.0 F 24.9 NA NA Ll
63 | 163 62.0 M 28.3 NA NA Pl
64 | 164 1.0 M NA 13.0 NA Ll
65 | 165 2.0 F NA 12.7 97.0 Pr
66 | 166 71.0 M 25.06 NA NA Ar
67 | 167 1.0 F NA 11.5 86.4 Pr
68 | 168 19.0 F 17.35 NA NA Al
69 | 169 50.0 F 28.81 NA NA Ll
70 | 170 79.0 M 22.6 NA NA Pr
71 | 171 9.0 M NA 32.0 133.0 Al
72 | 172 73.0 M 29.3 NA NA Lr
73 | 173 3.0 M NA 17.3 NA Al
74 | 174 68.0 M 26.4 NA NA Al
75 | 175 63.0 M 28.34 NA NA Pl
76 | 176 65.0 M 30.1 NA NA Pl
77 | 177 56.0 M 22.1 NA NA Pr
78 | 178 58.0 M 30.1 NA NA Tc
79 | 179 10.0 F NA 15.0 104.0 Al
80 | 180 93.0 M 29.03 NA NA Lr
81 | 181 65.0 M 26.4 NA NA Ar
82 | 182 11.0 M NA 33.0 136.0 Tc
83 | 183 14.0 F NA NA NA Tc
84 | 184 2.0 F NA 15.0 100.0 Ar
85 | 185 75.0 M 27.7 NA NA Pl
86 | 186 71.0 M 30.0 NA NA Al
87 | 187 0.5 F NA 8.26 71.0 Ll
88 | 188 3.0 M NA 16.0 100.0 Al
89 | 189 75.0 F 26.2 NA NA Lr
90 | 190 3.0 F NA NA NA Tc
91 | 191 74.0 F 36.0 NA NA Pr
92 | 192 69.0 M 28.0 NA NA Ar
93 | 193 77.0 M 26.3 NA NA Al
94 | 194 2.0 M NA 12.8 86.0 Pr
95 | 195 67.0 M 29.41 NA NA Al
96 | 196 21.0 F 25.5 NA NA Pr
97 | 197 16.0 F NA NA NA Al
98 | 198 71.0 M 18.6 NA NA Pr
99 | 199 71.0 M 20.0 NA NA Ll
100 | 200 72.0 F 27.8 NA NA Ar
101 | 201 73.0 F 28.52 NA NA Al
102 | 202 2.0 M NA 11.84 87.0 Ar
103 | 203 57.0 F 24.0 NA NA Pr
104 | 204 66.0 M 29.76 NA NA Ll
105 | 205 45.0 M 20.1 NA NA Al
106 | 206 3.0 M NA 13.0 92.0 Pl
107 | 207 63.0 F 29.6 NA NA Pl
108 | 208 5.0 F NA 24.1 117.0 Ll
109 | 209 14.0 F NA 80.0 183.0 Tc
110 | 210 1.0 F NA 12.96 76.0 Al
111 | 211 70.0 F 31.1 NA NA Ar
112 | 212 83.0 M 23.0 NA NA Tc
113 | 213 58.0 F 24.7 NA NA Al
114 | 214 5.0 M NA 30.0 118.0 Ar
115 | 215 56.0 F 25.35 NA NA Ar
116 | 216 1.0 M NA 10.25 78.0 Pl
117 | 217 12.0 F NA NA NA Tc
118 | 218 75.0 M 26.29 NA NA Al
119 | 219 81.0 M 26.0 NA NA Ar
120 | 220 66.0 M 35.4 NA NA Al
121 | 221 74.0 F 29.0 NA NA Pl
122 | 222 60.0 M NA NA NA Ar
123 | 223 NA NA NA NA NA Ll
124 | 224 10.0 F NA 32.3 143.0 Tc
125 | 225 0.83 M NA 7.8 74.0 Pl
126 | 226 4.0 M NA 16.7 103.0 Ll
127 |
--------------------------------------------------------------------------------
/data/icbhi_dataset/official_split.txt:
--------------------------------------------------------------------------------
1 | 101_1b1_Al_sc_Meditron test
2 | 101_1b1_Pr_sc_Meditron test
3 | 102_1b1_Ar_sc_Meditron test
4 | 103_2b2_Ar_mc_LittC2SE train
5 | 104_1b1_Al_sc_Litt3200 test
6 | 104_1b1_Ar_sc_Litt3200 test
7 | 104_1b1_Ll_sc_Litt3200 test
8 | 104_1b1_Lr_sc_Litt3200 test
9 | 104_1b1_Pl_sc_Litt3200 test
10 | 104_1b1_Pr_sc_Litt3200 test
11 | 105_1b1_Tc_sc_Meditron train
12 | 106_2b1_Pl_mc_LittC2SE train
13 | 106_2b1_Pr_mc_LittC2SE train
14 | 107_2b3_Al_mc_AKGC417L train
15 | 107_2b3_Ar_mc_AKGC417L train
16 | 107_2b3_Ll_mc_AKGC417L train
17 | 107_2b3_Lr_mc_AKGC417L train
18 | 107_2b3_Pl_mc_AKGC417L train
19 | 107_2b3_Pr_mc_AKGC417L train
20 | 107_2b3_Tc_mc_AKGC417L train
21 | 107_2b4_Al_mc_AKGC417L train
22 | 107_2b4_Ar_mc_AKGC417L train
23 | 107_2b4_Ll_mc_AKGC417L train
24 | 107_2b4_Lr_mc_AKGC417L train
25 | 107_2b4_Pl_mc_AKGC417L train
26 | 107_2b4_Pr_mc_AKGC417L train
27 | 107_2b4_Tc_mc_AKGC417L train
28 | 107_2b5_Al_mc_AKGC417L train
29 | 107_2b5_Ar_mc_AKGC417L train
30 | 107_2b5_Ll_mc_AKGC417L train
31 | 107_2b5_Lr_mc_AKGC417L train
32 | 107_2b5_Pl_mc_AKGC417L train
33 | 107_2b5_Pr_mc_AKGC417L train
34 | 107_2b5_Tc_mc_AKGC417L train
35 | 107_3p2_Al_mc_AKGC417L train
36 | 107_3p2_Ar_mc_AKGC417L train
37 | 107_3p2_Ll_mc_AKGC417L train
38 | 107_3p2_Lr_mc_AKGC417L train
39 | 107_3p2_Pl_mc_AKGC417L train
40 | 107_3p2_Pr_mc_AKGC417L train
41 | 107_3p2_Tc_mc_AKGC417L train
42 | 108_1b1_Al_sc_Meditron train
43 | 109_1b1_Al_sc_Litt3200 test
44 | 109_1b1_Ar_sc_Litt3200 test
45 | 109_1b1_Ll_sc_Litt3200 test
46 | 109_1b1_Lr_sc_Litt3200 test
47 | 109_1b1_Pl_sc_Litt3200 test
48 | 109_1b1_Pr_sc_Litt3200 test
49 | 110_1b1_Pr_sc_Meditron train
50 | 110_1p1_Al_sc_Meditron train
51 | 110_1p1_Ll_sc_Meditron train
52 | 110_1p1_Lr_sc_Meditron train
53 | 110_1p1_Pr_sc_Meditron train
54 | 111_1b2_Tc_sc_Meditron train
55 | 111_1b3_Tc_sc_Meditron train
56 | 112_1b1_Ar_sc_Meditron train
57 | 112_1b1_Lr_sc_Meditron train
58 | 112_1p1_Ll_sc_Litt3200 train
59 | 112_1p1_Pl_sc_Litt3200 train
60 | 112_1p1_Pr_sc_Litt3200 train
61 | 113_1b1_Al_sc_Litt3200 test
62 | 113_1b1_Ar_sc_Litt3200 test
63 | 113_1b1_Ll_sc_Litt3200 test
64 | 113_1b1_Lr_sc_Litt3200 test
65 | 113_1b1_Pl_sc_Litt3200 test
66 | 113_1b1_Pr_sc_Litt3200 test
67 | 114_1b4_Al_mc_AKGC417L train
68 | 114_1b4_Ar_mc_AKGC417L train
69 | 114_1b4_Lr_mc_AKGC417L train
70 | 114_1b4_Pl_mc_AKGC417L train
71 | 114_1b4_Pr_mc_AKGC417L train
72 | 115_1b1_Ar_sc_Meditron train
73 | 116_1b2_Pl_sc_Meditron train
74 | 116_1b2_Tc_sc_Meditron train
75 | 117_1b2_Tc_mc_LittC2SE train
76 | 117_1b3_Tc_mc_LittC2SE train
77 | 118_1b1_Al_sc_Litt3200 test
78 | 118_1b1_Ar_sc_Litt3200 test
79 | 118_1b1_Ll_sc_Litt3200 test
80 | 118_1b1_Lr_sc_Litt3200 test
81 | 118_1b1_Pl_sc_Litt3200 test
82 | 118_1b1_Pr_sc_Litt3200 test
83 | 119_1b1_Ar_sc_Meditron test
84 | 120_1b1_Al_sc_Meditron train
85 | 120_1b1_Ar_sc_Meditron train
86 | 120_1b1_Lr_sc_Meditron train
87 | 120_1b1_Pl_sc_Meditron train
88 | 120_1b1_Pr_sc_Meditron train
89 | 121_1b1_Tc_sc_Meditron train
90 | 121_1p1_Tc_sc_Meditron train
91 | 122_2b1_Al_mc_LittC2SE train
92 | 122_2b1_Ar_mc_LittC2SE train
93 | 122_2b1_Tc_mc_LittC2SE train
94 | 122_2b2_Al_mc_LittC2SE train
95 | 122_2b2_Ar_mc_LittC2SE train
96 | 122_2b2_Tc_mc_LittC2SE train
97 | 122_2b3_Al_mc_LittC2SE train
98 | 122_2b3_Ar_mc_LittC2SE train
99 | 122_2b3_Tc_mc_LittC2SE train
100 | 123_1b1_Al_sc_Meditron test
101 | 124_1b1_Al_sc_Litt3200 test
102 | 124_1b1_Ar_sc_Litt3200 test
103 | 124_1b1_Ll_sc_Litt3200 test
104 | 124_1b1_Lr_sc_Litt3200 test
105 | 124_1b1_Pl_sc_Litt3200 test
106 | 124_1b1_Pr_sc_Litt3200 test
107 | 125_1b1_Tc_sc_Meditron train
108 | 126_1b1_Al_sc_Meditron test
109 | 127_1b1_Ar_sc_Meditron train
110 | 128_1b3_Tc_mc_LittC2SE train
111 | 129_1b1_Ar_sc_Meditron train
112 | 130_1p2_Al_mc_AKGC417L train
113 | 130_1p2_Ar_mc_AKGC417L train
114 | 130_1p2_Ll_mc_AKGC417L train
115 | 130_1p2_Lr_mc_AKGC417L train
116 | 130_1p2_Pl_mc_AKGC417L train
117 | 130_1p2_Pr_mc_AKGC417L train
118 | 130_1p2_Tc_mc_AKGC417L train
119 | 130_1p3_Al_mc_AKGC417L train
120 | 130_1p3_Ar_mc_AKGC417L train
121 | 130_1p3_Ll_mc_AKGC417L train
122 | 130_1p3_Lr_mc_AKGC417L train
123 | 130_1p3_Pl_mc_AKGC417L train
124 | 130_1p3_Pr_mc_AKGC417L train
125 | 130_1p3_Tc_mc_AKGC417L train
126 | 130_1p4_Al_mc_AKGC417L train
127 | 130_1p4_Ar_mc_AKGC417L train
128 | 130_1p4_Ll_mc_AKGC417L train
129 | 130_1p4_Lr_mc_AKGC417L train
130 | 130_1p4_Pl_mc_AKGC417L train
131 | 130_1p4_Pr_mc_AKGC417L train
132 | 130_1p4_Tc_mc_AKGC417L train
133 | 130_2b2_Al_mc_AKGC417L train
134 | 130_2b2_Ar_mc_AKGC417L train
135 | 130_2b2_Ll_mc_AKGC417L train
136 | 130_2b2_Lr_mc_AKGC417L train
137 | 130_2b2_Pl_mc_AKGC417L train
138 | 130_2b2_Pr_mc_AKGC417L train
139 | 130_2b2_Tc_mc_AKGC417L train
140 | 130_2b3_Al_mc_AKGC417L train
141 | 130_2b3_Ar_mc_AKGC417L train
142 | 130_2b3_Ll_mc_AKGC417L train
143 | 130_2b3_Lr_mc_AKGC417L train
144 | 130_2b3_Pl_mc_AKGC417L train
145 | 130_2b3_Pr_mc_AKGC417L train
146 | 130_2b3_Tc_mc_AKGC417L train
147 | 130_2b4_Al_mc_AKGC417L train
148 | 130_2b4_Ar_mc_AKGC417L train
149 | 130_2b4_Ll_mc_AKGC417L train
150 | 130_2b4_Lr_mc_AKGC417L train
151 | 130_2b4_Pl_mc_AKGC417L train
152 | 130_2p3_Pl_mc_AKGC417L train
153 | 130_2p5_Al_mc_AKGC417L train
154 | 130_2p5_Ar_mc_AKGC417L train
155 | 130_2p5_Lr_mc_AKGC417L train
156 | 130_2p5_Pl_mc_AKGC417L train
157 | 130_2p5_Pr_mc_AKGC417L train
158 | 130_2p5_Tc_mc_AKGC417L train
159 | 130_3b3_Ll_mc_AKGC417L train
160 | 130_3b4_Al_mc_AKGC417L train
161 | 130_3b4_Ar_mc_AKGC417L train
162 | 130_3b4_Lr_mc_AKGC417L train
163 | 130_3b4_Pl_mc_AKGC417L train
164 | 130_3b4_Pr_mc_AKGC417L train
165 | 130_3p2_Al_mc_AKGC417L train
166 | 130_3p2_Ar_mc_AKGC417L train
167 | 130_3p2_Pl_mc_AKGC417L train
168 | 130_3p2_Pr_mc_AKGC417L train
169 | 130_3p2_Tc_mc_AKGC417L train
170 | 130_3p3_Al_mc_AKGC417L train
171 | 130_3p3_Pl_mc_AKGC417L train
172 | 130_3p3_Pr_mc_AKGC417L train
173 | 130_3p3_Tc_mc_AKGC417L train
174 | 130_3p4_Al_mc_AKGC417L train
175 | 130_3p4_Pl_mc_AKGC417L train
176 | 130_3p4_Pr_mc_AKGC417L train
177 | 130_3p4_Tc_mc_AKGC417L train
178 | 131_1b1_Al_sc_Meditron train
179 | 132_2b1_Lr_mc_LittC2SE train
180 | 132_2b2_Lr_mc_LittC2SE train
181 | 133_2p2_Al_mc_AKGC417L test
182 | 133_2p2_Ar_mc_AKGC417L test
183 | 133_2p2_Pl_mc_AKGC417L test
184 | 133_2p2_Tc_mc_AKGC417L test
185 | 133_2p3_Al_mc_AKGC417L test
186 | 133_2p3_Ar_mc_AKGC417L test
187 | 133_2p3_Pl_mc_AKGC417L test
188 | 133_2p3_Pr_mc_AKGC417L test
189 | 133_2p3_Tc_mc_AKGC417L test
190 | 133_2p4_Al_mc_AKGC417L test
191 | 133_2p4_Ar_mc_AKGC417L test
192 | 133_2p4_Pl_mc_AKGC417L test
193 | 133_2p4_Pr_mc_AKGC417L test
194 | 133_2p4_Tc_mc_AKGC417L test
195 | 133_3p2_Al_mc_AKGC417L test
196 | 133_3p2_Ar_mc_AKGC417L test
197 | 133_3p2_Pl_mc_AKGC417L test
198 | 133_3p2_Pr_mc_AKGC417L test
199 | 133_3p4_Tc_mc_AKGC417L test
200 | 134_2b1_Al_mc_LittC2SE train
201 | 134_2b1_Ar_mc_LittC2SE train
202 | 134_2b2_Al_mc_LittC2SE train
203 | 134_2b2_Ar_mc_LittC2SE train
204 | 134_2b3_Ar_mc_LittC2SE train
205 | 135_2b1_Al_mc_LittC2SE train
206 | 135_2b1_Ar_mc_LittC2SE train
207 | 135_2b1_Pl_mc_LittC2SE train
208 | 135_2b1_Tc_mc_LittC2SE train
209 | 135_2b2_Al_mc_LittC2SE train
210 | 135_2b2_Ar_mc_LittC2SE train
211 | 135_2b2_Pl_mc_LittC2SE train
212 | 135_2b2_Tc_mc_LittC2SE train
213 | 135_2b3_Al_mc_LittC2SE train
214 | 135_2b3_Ar_mc_LittC2SE train
215 | 135_2b3_Pl_mc_LittC2SE train
216 | 135_2b3_Pr_mc_LittC2SE train
217 | 135_2b3_Tc_mc_LittC2SE train
218 | 136_1b1_Ar_sc_Meditron test
219 | 137_1b1_Ar_sc_Meditron train
220 | 137_1b1_Ll_sc_Meditron train
221 | 138_1p2_Al_mc_AKGC417L train
222 | 138_1p2_Ar_mc_AKGC417L train
223 | 138_1p2_Ll_mc_AKGC417L train
224 | 138_1p2_Lr_mc_AKGC417L train
225 | 138_1p2_Pl_mc_AKGC417L train
226 | 138_1p2_Pr_mc_AKGC417L train
227 | 138_1p2_Tc_mc_AKGC417L train
228 | 138_1p3_Al_mc_AKGC417L train
229 | 138_1p3_Ar_mc_AKGC417L train
230 | 138_1p3_Ll_mc_AKGC417L train
231 | 138_1p3_Lr_mc_AKGC417L train
232 | 138_1p3_Pl_mc_AKGC417L train
233 | 138_1p3_Pr_mc_AKGC417L train
234 | 138_1p3_Tc_mc_AKGC417L train
235 | 138_1p4_Ar_mc_AKGC417L train
236 | 138_1p4_Ll_mc_AKGC417L train
237 | 138_1p4_Lr_mc_AKGC417L train
238 | 138_1p4_Pl_mc_AKGC417L train
239 | 138_1p4_Pr_mc_AKGC417L train
240 | 138_1p4_Tc_mc_AKGC417L train
241 | 138_2p2_Al_mc_AKGC417L train
242 | 138_2p2_Ar_mc_AKGC417L train
243 | 138_2p2_Ll_mc_AKGC417L train
244 | 138_2p2_Lr_mc_AKGC417L train
245 | 138_2p2_Pl_mc_AKGC417L train
246 | 138_2p2_Pr_mc_AKGC417L train
247 | 138_2p2_Tc_mc_AKGC417L train
248 | 139_1b1_Al_sc_Litt3200 test
249 | 139_1b1_Ar_sc_Litt3200 test
250 | 139_1b1_Ll_sc_Litt3200 test
251 | 139_1b1_Lr_sc_Litt3200 test
252 | 139_1b1_Pl_sc_Litt3200 test
253 | 139_1b1_Pr_sc_Litt3200 test
254 | 140_2b2_Ll_mc_LittC2SE train
255 | 140_2b2_Tc_mc_LittC2SE train
256 | 140_2b3_Ll_mc_LittC2SE train
257 | 140_2b3_Tc_mc_LittC2SE train
258 | 141_1b1_Pr_mc_LittC2SE train
259 | 141_1b2_Ar_mc_LittC2SE train
260 | 141_1b2_Lr_mc_LittC2SE train
261 | 141_1b2_Pr_mc_LittC2SE train
262 | 141_1b2_Tc_mc_LittC2SE train
263 | 141_1b3_Al_mc_LittC2SE train
264 | 141_1b3_Ar_mc_LittC2SE train
265 | 141_1b3_Pr_mc_LittC2SE train
266 | 142_1b1_Pl_mc_LittC2SE train
267 | 143_1b1_Al_sc_Meditron test
268 | 144_1b1_Al_sc_Meditron test
269 | 144_1b1_Tc_sc_Meditron test
270 | 145_2b2_Al_mc_AKGC417L test
271 | 145_2b2_Ar_mc_AKGC417L test
272 | 145_2b2_Lr_mc_AKGC417L test
273 | 145_2b2_Pr_mc_AKGC417L test
274 | 145_3b2_Ar_mc_AKGC417L test
275 | 145_3b2_Lr_mc_AKGC417L test
276 | 145_3b4_Pl_mc_AKGC417L test
277 | 146_2b2_Pl_mc_AKGC417L test
278 | 146_2b4_Al_mc_AKGC417L test
279 | 146_2b4_Ar_mc_AKGC417L test
280 | 146_2b4_Ll_mc_AKGC417L test
281 | 146_2b4_Lr_mc_AKGC417L test
282 | 146_2b4_Pr_mc_AKGC417L test
283 | 146_8p3_Al_mc_AKGC417L test
284 | 146_8p3_Ar_mc_AKGC417L test
285 | 146_8p3_Lr_mc_AKGC417L test
286 | 146_8p3_Pl_mc_AKGC417L test
287 | 146_8p3_Pr_mc_AKGC417L test
288 | 147_1b2_Tc_mc_AKGC417L test
289 | 147_1b3_Tc_mc_AKGC417L test
290 | 147_1b4_Tc_mc_AKGC417L test
291 | 147_2b2_Al_mc_AKGC417L test
292 | 147_2b2_Ar_mc_AKGC417L test
293 | 147_2b2_Pl_mc_AKGC417L test
294 | 147_2b3_Al_mc_AKGC417L test
295 | 147_2b3_Ar_mc_AKGC417L test
296 | 147_2b3_Ll_mc_AKGC417L test
297 | 147_2b3_Lr_mc_AKGC417L test
298 | 147_2b3_Pl_mc_AKGC417L test
299 | 147_2b4_Al_mc_AKGC417L test
300 | 147_2b4_Ar_mc_AKGC417L test
301 | 147_2b4_Ll_mc_AKGC417L test
302 | 147_2b4_Lr_mc_AKGC417L test
303 | 147_2b4_Pl_mc_AKGC417L test
304 | 148_1b1_Al_sc_Meditron train
305 | 149_1b1_Al_sc_Meditron test
306 | 149_1b1_Lr_sc_Meditron test
307 | 149_1b1_Pl_sc_Meditron test
308 | 150_1b2_Al_sc_Meditron test
309 | 151_2p2_Al_mc_AKGC417L test
310 | 151_2p2_Ar_mc_AKGC417L test
311 | 151_2p2_Ll_mc_AKGC417L test
312 | 151_2p2_Lr_mc_AKGC417L test
313 | 151_2p2_Pl_mc_AKGC417L test
314 | 151_2p2_Pr_mc_AKGC417L test
315 | 151_2p2_Tc_mc_AKGC417L test
316 | 151_2p3_Al_mc_AKGC417L test
317 | 151_2p3_Ar_mc_AKGC417L test
318 | 151_2p3_Ll_mc_AKGC417L test
319 | 151_2p3_Lr_mc_AKGC417L test
320 | 151_2p3_Pl_mc_AKGC417L test
321 | 151_2p3_Pr_mc_AKGC417L test
322 | 151_2p3_Tc_mc_AKGC417L test
323 | 151_2p4_Al_mc_AKGC417L test
324 | 151_2p4_Ar_mc_AKGC417L test
325 | 151_2p4_Ll_mc_AKGC417L test
326 | 151_2p4_Lr_mc_AKGC417L test
327 | 151_2p4_Pl_mc_AKGC417L test
328 | 151_2p4_Pr_mc_AKGC417L test
329 | 151_2p4_Tc_mc_AKGC417L test
330 | 151_3p2_Al_mc_AKGC417L test
331 | 151_3p2_Ar_mc_AKGC417L test
332 | 151_3p2_Lr_mc_AKGC417L test
333 | 151_3p2_Pl_mc_AKGC417L test
334 | 151_3p2_Pr_mc_AKGC417L test
335 | 151_3p2_Tc_mc_AKGC417L test
336 | 151_3p3_Ll_mc_AKGC417L test
337 | 152_1b1_Al_sc_Meditron train
338 | 153_1b1_Al_sc_Meditron train
339 | 154_1b3_Al_mc_AKGC417L train
340 | 154_1b3_Ar_mc_AKGC417L train
341 | 154_1b3_Ll_mc_AKGC417L train
342 | 154_1b3_Lr_mc_AKGC417L train
343 | 154_1b3_Pl_mc_AKGC417L train
344 | 154_1b3_Pr_mc_AKGC417L train
345 | 154_1b3_Tc_mc_AKGC417L train
346 | 154_2b4_Al_mc_AKGC417L train
347 | 154_2b4_Ar_mc_AKGC417L train
348 | 154_2b4_Ll_mc_AKGC417L train
349 | 154_2b4_Lr_mc_AKGC417L train
350 | 154_2b4_Pl_mc_AKGC417L train
351 | 154_2b4_Pr_mc_AKGC417L train
352 | 154_2b4_Tc_mc_AKGC417L train
353 | 154_3b3_Al_mc_AKGC417L train
354 | 154_3b3_Ar_mc_AKGC417L train
355 | 154_3b3_Ll_mc_AKGC417L train
356 | 154_4b4_Al_mc_AKGC417L train
357 | 154_4b4_Ar_mc_AKGC417L train
358 | 154_4b4_Ll_mc_AKGC417L train
359 | 154_4b4_Lr_mc_AKGC417L train
360 | 154_4b4_Pl_mc_AKGC417L train
361 | 154_4b4_Pr_mc_AKGC417L train
362 | 155_2b1_Al_mc_LittC2SE train
363 | 156_2b3_Al_mc_AKGC417L test
364 | 156_2b3_Ar_mc_AKGC417L train
365 | 156_2b3_Ll_mc_AKGC417L train
366 | 156_2b3_Lr_mc_AKGC417L test
367 | 156_2b3_Pl_mc_AKGC417L test
368 | 156_2b3_Pr_mc_AKGC417L train
369 | 156_5b3_Al_mc_AKGC417L train
370 | 156_5b3_Ar_mc_AKGC417L test
371 | 156_5b3_Ll_mc_AKGC417L test
372 | 156_5b3_Lr_mc_AKGC417L train
373 | 156_5b3_Pl_mc_AKGC417L test
374 | 156_5b3_Pr_mc_AKGC417L train
375 | 156_8b3_Al_mc_AKGC417L test
376 | 156_8b3_Ar_mc_AKGC417L train
377 | 156_8b3_Ll_mc_AKGC417L train
378 | 156_8b3_Lr_mc_AKGC417L test
379 | 156_8b3_Pl_mc_AKGC417L train
380 | 157_1b1_Al_sc_Meditron train
381 | 157_1b1_Ar_sc_Meditron train
382 | 157_1b1_Lr_sc_Meditron train
383 | 157_1b1_Pl_sc_Meditron train
384 | 157_1b1_Pr_sc_Meditron train
385 | 158_1b3_Ar_mc_LittC2SE train
386 | 158_1p2_Al_mc_AKGC417L train
387 | 158_1p2_Ar_mc_AKGC417L train
388 | 158_1p2_Ll_mc_AKGC417L train
389 | 158_1p2_Lr_mc_AKGC417L train
390 | 158_1p2_Pl_mc_AKGC417L train
391 | 158_1p2_Pr_mc_AKGC417L train
392 | 158_1p2_Tc_mc_AKGC417L train
393 | 158_1p3_Al_mc_AKGC417L train
394 | 158_1p3_Ar_mc_AKGC417L train
395 | 158_1p3_Ll_mc_AKGC417L train
396 | 158_1p3_Lr_mc_AKGC417L train
397 | 158_1p3_Pl_mc_AKGC417L train
398 | 158_1p3_Pr_mc_AKGC417L train
399 | 158_1p3_Tc_mc_AKGC417L train
400 | 158_1p4_Al_mc_AKGC417L train
401 | 158_1p4_Ar_mc_AKGC417L train
402 | 158_1p4_Lr_mc_AKGC417L train
403 | 158_1p4_Pl_mc_AKGC417L train
404 | 158_1p4_Pr_mc_AKGC417L train
405 | 158_1p4_Tc_mc_AKGC417L train
406 | 158_2p2_Ar_mc_AKGC417L train
407 | 158_2p3_Lr_mc_AKGC417L train
408 | 158_2p3_Tc_mc_AKGC417L train
409 | 159_1b1_Al_sc_Meditron train
410 | 159_1b1_Ar_sc_Meditron train
411 | 159_1b1_Ll_sc_Meditron train
412 | 159_1b1_Pr_sc_Meditron train
413 | 160_1b2_Al_mc_AKGC417L test
414 | 160_1b2_Ar_mc_AKGC417L test
415 | 160_1b2_Lr_mc_AKGC417L test
416 | 160_1b2_Pl_mc_AKGC417L test
417 | 160_1b2_Pr_mc_AKGC417L test
418 | 160_1b2_Tc_mc_AKGC417L test
419 | 160_1b3_Al_mc_AKGC417L test
420 | 160_1b3_Ar_mc_AKGC417L test
421 | 160_1b3_Lr_mc_AKGC417L test
422 | 160_1b3_Pl_mc_AKGC417L test
423 | 160_1b3_Pr_mc_AKGC417L test
424 | 160_1b3_Tc_mc_AKGC417L test
425 | 160_1b4_Al_mc_AKGC417L test
426 | 160_1b4_Ar_mc_AKGC417L test
427 | 160_1b4_Lr_mc_AKGC417L test
428 | 160_1b4_Pl_mc_AKGC417L test
429 | 160_1b4_Pr_mc_AKGC417L test
430 | 160_1b4_Tc_mc_AKGC417L test
431 | 160_2b3_Lr_mc_AKGC417L test
432 | 160_2b4_Ar_mc_AKGC417L test
433 | 160_2b4_Pl_mc_AKGC417L test
434 | 160_2b4_Pr_mc_AKGC417L test
435 | 160_2b4_Tc_mc_AKGC417L test
436 | 161_1b1_Al_sc_Meditron train
437 | 161_1b1_Pl_sc_Meditron train
438 | 162_1b2_Al_mc_AKGC417L train
439 | 162_1b2_Ar_mc_AKGC417L train
440 | 162_1b2_Ll_mc_AKGC417L train
441 | 162_1b2_Lr_mc_AKGC417L train
442 | 162_1b2_Pl_mc_AKGC417L train
443 | 162_1b2_Pr_mc_AKGC417L train
444 | 162_1b2_Tc_mc_AKGC417L train
445 | 162_2b2_Al_mc_AKGC417L train
446 | 162_2b2_Ar_mc_AKGC417L train
447 | 162_2b2_Pl_mc_AKGC417L train
448 | 162_2b2_Pr_mc_AKGC417L train
449 | 162_2b2_Tc_mc_AKGC417L train
450 | 162_2b3_Al_mc_AKGC417L train
451 | 162_2b3_Ar_mc_AKGC417L train
452 | 162_2b3_Lr_mc_AKGC417L train
453 | 162_2b3_Pl_mc_AKGC417L train
454 | 162_2b3_Pr_mc_AKGC417L train
455 | 162_2b3_Tc_mc_AKGC417L train
456 | 162_2b4_Al_mc_AKGC417L train
457 | 162_2b4_Ar_mc_AKGC417L train
458 | 162_2b4_Lr_mc_AKGC417L train
459 | 162_2b4_Pl_mc_AKGC417L train
460 | 162_2b4_Pr_mc_AKGC417L train
461 | 162_2b4_Tc_mc_AKGC417L train
462 | 163_2b2_Al_mc_AKGC417L train
463 | 163_2b2_Ar_mc_AKGC417L train
464 | 163_2b2_Ll_mc_AKGC417L train
465 | 163_2b2_Lr_mc_AKGC417L train
466 | 163_2b2_Pl_mc_AKGC417L train
467 | 163_2b2_Pr_mc_AKGC417L train
468 | 163_2b2_Tc_mc_AKGC417L train
469 | 163_8b3_Al_mc_AKGC417L train
470 | 163_8b3_Ar_mc_AKGC417L train
471 | 163_8b3_Ll_mc_AKGC417L train
472 | 163_8b3_Lr_mc_AKGC417L train
473 | 163_8b3_Pl_mc_AKGC417L train
474 | 163_8b3_Pr_mc_AKGC417L train
475 | 164_1b1_Ll_sc_Meditron train
476 | 165_1b1_Ar_sc_Meditron test
477 | 165_1b1_Pl_sc_Meditron test
478 | 165_1b1_Pr_sc_Meditron test
479 | 166_1p1_Al_sc_Meditron train
480 | 166_1p1_Ar_sc_Meditron train
481 | 166_1p1_Ll_sc_Meditron train
482 | 166_1p1_Pl_sc_Meditron train
483 | 166_1p1_Pr_sc_Meditron train
484 | 167_1b1_Al_sc_Meditron train
485 | 167_1b1_Pr_sc_Meditron train
486 | 168_1b1_Al_sc_Meditron train
487 | 169_1b1_Lr_sc_Meditron test
488 | 169_1b2_Ll_sc_Meditron test
489 | 170_1b2_Al_mc_AKGC417L test
490 | 170_1b2_Ar_mc_AKGC417L test
491 | 170_1b2_Lr_mc_AKGC417L test
492 | 170_1b2_Pl_mc_AKGC417L test
493 | 170_1b2_Pr_mc_AKGC417L test
494 | 170_1b2_Tc_mc_AKGC417L test
495 | 170_1b3_Al_mc_AKGC417L test
496 | 170_1b3_Ar_mc_AKGC417L test
497 | 170_1b3_Ll_mc_AKGC417L test
498 | 170_1b3_Lr_mc_AKGC417L test
499 | 170_1b3_Pl_mc_AKGC417L test
500 | 170_1b3_Pr_mc_AKGC417L test
501 | 170_1b3_Tc_mc_AKGC417L test
502 | 170_1b4_Al_mc_AKGC417L test
503 | 170_1b4_Ar_mc_AKGC417L test
504 | 170_1b4_Lr_mc_AKGC417L test
505 | 170_1b4_Pl_mc_AKGC417L test
506 | 170_1b4_Pr_mc_AKGC417L test
507 | 170_1b4_Tc_mc_AKGC417L test
508 | 170_2b2_Al_mc_AKGC417L test
509 | 170_2b2_Ar_mc_AKGC417L test
510 | 170_2b2_Lr_mc_AKGC417L test
511 | 170_2b2_Pl_mc_AKGC417L test
512 | 170_2b2_Pr_mc_AKGC417L test
513 | 170_2b2_Tc_mc_AKGC417L test
514 | 171_1b1_Al_sc_Meditron test
515 | 172_1b3_Al_mc_AKGC417L train
516 | 172_1b3_Ar_mc_AKGC417L train
517 | 172_1b3_Ll_mc_AKGC417L train
518 | 172_1b3_Lr_mc_AKGC417L train
519 | 172_1b3_Pl_mc_AKGC417L train
520 | 172_1b3_Pr_mc_AKGC417L train
521 | 172_1b3_Tc_mc_AKGC417L train
522 | 172_1b4_Al_mc_AKGC417L train
523 | 172_1b4_Ar_mc_AKGC417L train
524 | 172_1b4_Ll_mc_AKGC417L train
525 | 172_1b4_Lr_mc_AKGC417L train
526 | 172_1b4_Pl_mc_AKGC417L train
527 | 172_1b4_Pr_mc_AKGC417L train
528 | 172_1b4_Tc_mc_AKGC417L train
529 | 172_1b5_Al_mc_AKGC417L train
530 | 172_1b5_Ar_mc_AKGC417L train
531 | 172_1b5_Ll_mc_AKGC417L train
532 | 172_1b5_Lr_mc_AKGC417L train
533 | 172_1b5_Pl_mc_AKGC417L train
534 | 172_1b5_Pr_mc_AKGC417L train
535 | 172_1b5_Tc_mc_AKGC417L train
536 | 172_2b5_Al_mc_AKGC417L train
537 | 172_2b5_Ar_mc_AKGC417L train
538 | 172_2b5_Lr_mc_AKGC417L train
539 | 172_2b5_Pl_mc_AKGC417L train
540 | 172_2b5_Pr_mc_AKGC417L train
541 | 172_2b5_Tc_mc_AKGC417L train
542 | 173_1b1_Al_sc_Meditron test
543 | 174_1p2_Ar_mc_AKGC417L test
544 | 174_1p2_Ll_mc_AKGC417L test
545 | 174_1p2_Lr_mc_AKGC417L test
546 | 174_1p2_Pl_mc_AKGC417L test
547 | 174_1p2_Pr_mc_AKGC417L test
548 | 174_1p2_Tc_mc_AKGC417L test
549 | 174_1p3_Ar_mc_AKGC417L test
550 | 174_1p3_Ll_mc_AKGC417L test
551 | 174_1p3_Lr_mc_AKGC417L test
552 | 174_1p3_Pl_mc_AKGC417L test
553 | 174_1p3_Pr_mc_AKGC417L test
554 | 174_1p3_Tc_mc_AKGC417L test
555 | 174_1p4_Ar_mc_AKGC417L test
556 | 174_1p4_Ll_mc_AKGC417L test
557 | 174_1p4_Lr_mc_AKGC417L test
558 | 174_1p4_Pl_mc_AKGC417L test
559 | 174_1p4_Pr_mc_AKGC417L test
560 | 174_1p4_Tc_mc_AKGC417L test
561 | 174_2p3_Al_mc_AKGC417L test
562 | 174_2p3_Ar_mc_AKGC417L test
563 | 174_2p3_Pl_mc_AKGC417L test
564 | 174_2p3_Pr_mc_AKGC417L test
565 | 174_2p3_Tc_mc_AKGC417L test
566 | 175_1b1_Al_sc_Litt3200 test
567 | 175_1b1_Ar_sc_Litt3200 test
568 | 175_1b1_Ll_sc_Litt3200 test
569 | 175_1b1_Lr_sc_Litt3200 test
570 | 175_1b1_Pl_sc_Litt3200 test
571 | 175_1b1_Pr_sc_Litt3200 test
572 | 176_1b3_Al_mc_AKGC417L test
573 | 176_1b3_Ar_mc_AKGC417L test
574 | 176_1b3_Ll_mc_AKGC417L test
575 | 176_1b3_Lr_mc_AKGC417L test
576 | 176_1b3_Pl_mc_AKGC417L test
577 | 176_1b3_Pr_mc_AKGC417L test
578 | 176_1b3_Tc_mc_AKGC417L test
579 | 176_1b4_Al_mc_AKGC417L test
580 | 176_1b4_Ar_mc_AKGC417L test
581 | 176_1b4_Ll_mc_AKGC417L test
582 | 176_1b4_Lr_mc_AKGC417L test
583 | 176_1b4_Pl_mc_AKGC417L test
584 | 176_1b4_Pr_mc_AKGC417L test
585 | 176_1b4_Tc_mc_AKGC417L test
586 | 176_2b3_Al_mc_AKGC417L test
587 | 176_2b3_Ar_mc_AKGC417L test
588 | 176_2b3_Ll_mc_AKGC417L test
589 | 176_2b3_Lr_mc_AKGC417L test
590 | 176_2b3_Pl_mc_AKGC417L test
591 | 176_2b3_Pr_mc_AKGC417L test
592 | 176_2b3_Tc_mc_AKGC417L test
593 | 177_1b2_Al_mc_AKGC417L test
594 | 177_1b2_Ar_mc_AKGC417L test
595 | 177_1b2_Lr_mc_AKGC417L test
596 | 177_1b2_Pl_mc_AKGC417L test
597 | 177_1b2_Pr_mc_AKGC417L test
598 | 177_1b2_Tc_mc_AKGC417L test
599 | 177_1b4_Al_mc_AKGC417L test
600 | 177_1b4_Ar_mc_AKGC417L test
601 | 177_1b4_Lr_mc_AKGC417L test
602 | 177_1b4_Pl_mc_AKGC417L test
603 | 177_1b4_Pr_mc_AKGC417L test
604 | 177_1b4_Tc_mc_AKGC417L test
605 | 177_2b4_Al_mc_AKGC417L test
606 | 177_2b4_Lr_mc_AKGC417L test
607 | 177_2b4_Pl_mc_AKGC417L test
608 | 177_2b4_Pr_mc_AKGC417L test
609 | 177_2b4_Tc_mc_AKGC417L test
610 | 178_1b2_Al_mc_AKGC417L test
611 | 178_1b2_Ar_mc_AKGC417L test
612 | 178_1b2_Lr_mc_AKGC417L test
613 | 178_1b2_Pl_mc_AKGC417L test
614 | 178_1b2_Pr_mc_AKGC417L test
615 | 178_1b2_Tc_mc_AKGC417L test
616 | 178_1b3_Al_mc_AKGC417L test
617 | 178_1b3_Ar_mc_AKGC417L test
618 | 178_1b3_Lr_mc_AKGC417L test
619 | 178_1b3_Pl_mc_AKGC417L test
620 | 178_1b3_Pr_mc_AKGC417L test
621 | 178_1b3_Tc_mc_AKGC417L test
622 | 178_1b6_Al_mc_AKGC417L test
623 | 178_1b6_Ar_mc_AKGC417L test
624 | 178_1b6_Ll_mc_AKGC417L test
625 | 178_1b6_Lr_mc_AKGC417L test
626 | 178_1b6_Pl_mc_AKGC417L test
627 | 178_1b6_Pr_mc_AKGC417L test
628 | 178_1b6_Tc_mc_AKGC417L test
629 | 178_2b2_Al_mc_AKGC417L test
630 | 178_2b2_Ar_mc_AKGC417L test
631 | 178_2b2_Lr_mc_AKGC417L test
632 | 178_2b2_Pr_mc_AKGC417L test
633 | 178_2b2_Tc_mc_AKGC417L test
634 | 179_1b1_Al_sc_Meditron train
635 | 179_1b1_Tc_sc_Meditron train
636 | 180_1b4_Al_mc_AKGC417L train
637 | 180_1b4_Ar_mc_AKGC417L train
638 | 180_1b4_Lr_mc_AKGC417L train
639 | 180_1b4_Pl_mc_AKGC417L train
640 | 180_1b4_Pr_mc_AKGC417L train
641 | 181_1b1_Ar_mc_LittC2SE train
642 | 181_1b1_Tc_mc_LittC2SE train
643 | 181_1b2_Ar_mc_LittC2SE train
644 | 181_1b3_Tc_mc_LittC2SE train
645 | 182_1b1_Tc_sc_Meditron test
646 | 183_1b1_Pl_sc_Meditron train
647 | 183_1b1_Tc_sc_Meditron train
648 | 184_1b1_Ar_sc_Meditron train
649 | 185_1b1_Al_sc_Litt3200 test
650 | 185_1b1_Ar_sc_Litt3200 test
651 | 185_1b1_Ll_sc_Litt3200 test
652 | 185_1b1_Lr_sc_Litt3200 test
653 | 185_1b1_Pl_sc_Litt3200 test
654 | 185_1b1_Pr_sc_Litt3200 test
655 | 186_2b2_Al_mc_AKGC417L train
656 | 186_2b2_Ar_mc_AKGC417L train
657 | 186_2b2_Lr_mc_AKGC417L train
658 | 186_2b2_Pl_mc_AKGC417L train
659 | 186_2b2_Pr_mc_AKGC417L train
660 | 186_2b2_Tc_mc_AKGC417L train
661 | 186_2b3_Al_mc_AKGC417L train
662 | 186_2b3_Ar_mc_AKGC417L train
663 | 186_2b3_Lr_mc_AKGC417L train
664 | 186_2b3_Pl_mc_AKGC417L train
665 | 186_2b3_Pr_mc_AKGC417L train
666 | 186_2b3_Tc_mc_AKGC417L train
667 | 186_2b4_Al_mc_AKGC417L train
668 | 186_2b4_Ar_mc_AKGC417L train
669 | 186_2b4_Lr_mc_AKGC417L train
670 | 186_2b4_Pl_mc_AKGC417L train
671 | 186_2b4_Pr_mc_AKGC417L train
672 | 186_2b4_Tc_mc_AKGC417L train
673 | 186_3b3_Al_mc_AKGC417L train
674 | 186_3b3_Ar_mc_AKGC417L train
675 | 186_3b3_Lr_mc_AKGC417L train
676 | 186_3b3_Pl_mc_AKGC417L train
677 | 186_3b3_Pr_mc_AKGC417L train
678 | 186_3b3_Tc_mc_AKGC417L train
679 | 187_1b1_Ll_sc_Meditron test
680 | 188_1b1_Al_sc_Meditron train
681 | 188_1b1_Ar_sc_Meditron train
682 | 188_1b1_Pl_sc_Meditron train
683 | 188_1b1_Tc_sc_Meditron train
684 | 189_1b2_Lr_mc_LittC2SE train
685 | 190_1b1_Tc_sc_Meditron train
686 | 191_2b1_Pl_mc_LittC2SE train
687 | 191_2b1_Pr_mc_LittC2SE train
688 | 191_2b2_Tc_mc_LittC2SE train
689 | 192_2b1_Al_mc_LittC2SE train
690 | 192_2b1_Ar_mc_LittC2SE train
691 | 192_2b2_Al_mc_LittC2SE train
692 | 192_2b2_Ar_mc_LittC2SE train
693 | 192_2b3_Al_mc_LittC2SE train
694 | 192_2b3_Ar_mc_LittC2SE train
695 | 193_1b2_Al_mc_AKGC417L train
696 | 193_1b2_Ar_mc_AKGC417L train
697 | 193_1b2_Ll_mc_AKGC417L train
698 | 193_1b2_Pl_mc_AKGC417L train
699 | 193_1b2_Pr_mc_AKGC417L train
700 | 193_1b2_Tc_mc_AKGC417L train
701 | 193_1b4_Lr_mc_AKGC417L train
702 | 193_7b3_Al_mc_AKGC417L train
703 | 193_7b3_Ar_mc_AKGC417L train
704 | 193_7b3_Ll_mc_AKGC417L train
705 | 193_7b3_Lr_mc_AKGC417L train
706 | 193_7b3_Pl_mc_AKGC417L train
707 | 193_7b3_Pr_mc_AKGC417L train
708 | 193_7b3_Tc_mc_AKGC417L train
709 | 194_1b1_Lr_sc_Meditron test
710 | 194_1b1_Pr_sc_Meditron test
711 | 195_1b1_Al_sc_Litt3200 test
712 | 195_1b1_Ar_sc_Litt3200 test
713 | 195_1b1_Ll_sc_Litt3200 test
714 | 195_1b1_Lr_sc_Litt3200 test
715 | 195_1b1_Pl_sc_Litt3200 test
716 | 195_1b1_Pr_sc_Litt3200 test
717 | 196_1b1_Pr_sc_Meditron train
718 | 197_1b1_Al_sc_Meditron train
719 | 197_1b1_Tc_sc_Meditron train
720 | 198_1b5_Al_mc_AKGC417L test
721 | 198_1b5_Ar_mc_AKGC417L test
722 | 198_1b5_Ll_mc_AKGC417L test
723 | 198_1b5_Lr_mc_AKGC417L test
724 | 198_1b5_Pl_mc_AKGC417L test
725 | 198_1b5_Pr_mc_AKGC417L test
726 | 198_1b5_Tc_mc_AKGC417L test
727 | 198_6p1_Al_mc_AKGC417L test
728 | 198_6p1_Ar_mc_AKGC417L test
729 | 198_6p1_Ll_mc_AKGC417L test
730 | 198_6p1_Lr_mc_AKGC417L test
731 | 198_6p1_Pl_mc_AKGC417L test
732 | 198_6p1_Pr_mc_AKGC417L test
733 | 198_6p1_Tc_mc_AKGC417L test
734 | 199_2b1_Ll_mc_LittC2SE train
735 | 199_2b3_Ll_mc_LittC2SE train
736 | 200_2p2_Al_mc_AKGC417L train
737 | 200_2p2_Ar_mc_AKGC417L train
738 | 200_2p2_Lr_mc_AKGC417L train
739 | 200_2p2_Pl_mc_AKGC417L train
740 | 200_2p2_Pr_mc_AKGC417L train
741 | 200_2p2_Tc_mc_AKGC417L train
742 | 200_2p3_Al_mc_AKGC417L train
743 | 200_2p3_Ar_mc_AKGC417L train
744 | 200_2p3_Lr_mc_AKGC417L train
745 | 200_2p3_Pl_mc_AKGC417L train
746 | 200_2p3_Pr_mc_AKGC417L train
747 | 200_2p3_Tc_mc_AKGC417L train
748 | 200_2p4_Al_mc_AKGC417L train
749 | 200_2p4_Ar_mc_AKGC417L train
750 | 200_2p4_Lr_mc_AKGC417L train
751 | 200_2p4_Pl_mc_AKGC417L train
752 | 200_2p4_Pr_mc_AKGC417L train
753 | 200_2p4_Tc_mc_AKGC417L train
754 | 200_3p4_Al_mc_AKGC417L train
755 | 200_3p4_Ar_mc_AKGC417L train
756 | 200_3p4_Pl_mc_AKGC417L train
757 | 200_3p4_Pr_mc_AKGC417L train
758 | 200_3p4_Tc_mc_AKGC417L train
759 | 201_1b1_Al_sc_Meditron train
760 | 201_1b1_Ar_sc_Meditron train
761 | 201_1b2_Al_sc_Meditron train
762 | 201_1b2_Ar_sc_Meditron train
763 | 201_1b3_Al_sc_Meditron train
764 | 201_1b3_Ar_sc_Meditron train
765 | 202_1b1_Ar_sc_Meditron test
766 | 203_1p2_Al_mc_AKGC417L train
767 | 203_1p2_Ar_mc_AKGC417L train
768 | 203_1p2_Lr_mc_AKGC417L train
769 | 203_1p2_Pl_mc_AKGC417L train
770 | 203_1p2_Pr_mc_AKGC417L train
771 | 203_1p2_Tc_mc_AKGC417L train
772 | 203_1p3_Al_mc_AKGC417L train
773 | 203_1p3_Ar_mc_AKGC417L train
774 | 203_1p3_Pl_mc_AKGC417L train
775 | 203_1p3_Pr_mc_AKGC417L train
776 | 203_1p3_Tc_mc_AKGC417L train
777 | 203_1p4_Al_mc_AKGC417L train
778 | 203_1p4_Ar_mc_AKGC417L train
779 | 203_1p4_Pl_mc_AKGC417L train
780 | 203_1p4_Pr_mc_AKGC417L train
781 | 203_1p4_Tc_mc_AKGC417L train
782 | 203_2p3_Al_mc_AKGC417L train
783 | 203_2p3_Ar_mc_AKGC417L train
784 | 203_2p3_Pl_mc_AKGC417L train
785 | 203_2p3_Pr_mc_AKGC417L train
786 | 203_2p3_Tc_mc_AKGC417L train
787 | 204_2b5_Al_mc_AKGC417L train
788 | 204_2b5_Ar_mc_AKGC417L train
789 | 204_2b5_Ll_mc_AKGC417L train
790 | 204_7p5_Al_mc_AKGC417L train
791 | 204_7p5_Ar_mc_AKGC417L train
792 | 204_7p5_Ll_mc_AKGC417L train
793 | 204_7p5_Lr_mc_AKGC417L train
794 | 204_7p5_Pr_mc_AKGC417L train
795 | 204_7p5_Tc_mc_AKGC417L train
796 | 205_1b3_Al_mc_AKGC417L test
797 | 205_1b3_Ar_mc_AKGC417L test
798 | 205_1b3_Ll_mc_AKGC417L test
799 | 205_1b3_Lr_mc_AKGC417L test
800 | 205_1b3_Pl_mc_AKGC417L test
801 | 205_1b3_Pr_mc_AKGC417L test
802 | 205_2b2_Pr_mc_AKGC417L test
803 | 205_2b3_Al_mc_AKGC417L test
804 | 205_2b3_Ar_mc_AKGC417L test
805 | 205_2b3_Ll_mc_AKGC417L test
806 | 205_2b4_Pl_mc_AKGC417L test
807 | 205_3b4_Al_mc_AKGC417L test
808 | 205_3b4_Ar_mc_AKGC417L test
809 | 205_3b4_Pl_mc_AKGC417L test
810 | 205_3b4_Pr_mc_AKGC417L test
811 | 205_4b2_Al_mc_AKGC417L test
812 | 205_4b2_Ar_mc_AKGC417L test
813 | 205_4b2_Lr_mc_AKGC417L test
814 | 205_4b2_Pl_mc_AKGC417L test
815 | 205_4b2_Pr_mc_AKGC417L test
816 | 206_1b1_Ar_sc_Meditron train
817 | 206_1b1_Lr_sc_Meditron train
818 | 206_1b1_Pl_sc_Meditron train
819 | 207_2b2_Al_mc_AKGC417L test
820 | 207_2b2_Ar_mc_AKGC417L test
821 | 207_2b2_Pl_mc_AKGC417L test
822 | 207_2b2_Pr_mc_AKGC417L test
823 | 207_2b2_Tc_mc_AKGC417L test
824 | 207_2b3_Al_mc_AKGC417L test
825 | 207_2b3_Ar_mc_AKGC417L test
826 | 207_2b3_Pl_mc_AKGC417L test
827 | 207_2b3_Pr_mc_AKGC417L test
828 | 207_2b3_Tc_mc_AKGC417L test
829 | 207_2b4_Al_mc_AKGC417L test
830 | 207_2b4_Ar_mc_AKGC417L test
831 | 207_2b4_Pl_mc_AKGC417L test
832 | 207_2b4_Pr_mc_AKGC417L test
833 | 207_2b4_Tc_mc_AKGC417L test
834 | 207_3b2_Al_mc_AKGC417L test
835 | 207_3b2_Ar_mc_AKGC417L test
836 | 207_3b2_Lr_mc_AKGC417L test
837 | 207_3b2_Pl_mc_AKGC417L test
838 | 207_3b2_Pr_mc_AKGC417L test
839 | 207_3b2_Tc_mc_AKGC417L test
840 | 208_1b1_Ll_sc_Meditron train
841 | 209_1b1_Tc_sc_Meditron train
842 | 210_1b1_Al_sc_Meditron train
843 | 210_1b1_Ar_sc_Meditron train
844 | 211_1p2_Ar_mc_AKGC417L test
845 | 211_1p2_Pl_mc_AKGC417L test
846 | 211_1p2_Pr_mc_AKGC417L test
847 | 211_1p3_Ar_mc_AKGC417L test
848 | 211_1p5_Ar_mc_AKGC417L test
849 | 211_2p2_Tc_mc_AKGC417L test
850 | 211_2p3_Tc_mc_AKGC417L test
851 | 211_2p4_Tc_mc_AKGC417L test
852 | 212_2b2_Tc_mc_LittC2SE train
853 | 213_1p2_Al_mc_AKGC417L train
854 | 213_1p2_Ar_mc_AKGC417L train
855 | 213_1p2_Lr_mc_AKGC417L train
856 | 213_1p2_Pl_mc_AKGC417L train
857 | 213_1p2_Pr_mc_AKGC417L train
858 | 213_1p2_Tc_mc_AKGC417L train
859 | 213_1p3_Al_mc_AKGC417L train
860 | 213_1p3_Ar_mc_AKGC417L train
861 | 213_1p3_Pl_mc_AKGC417L train
862 | 213_1p3_Pr_mc_AKGC417L train
863 | 213_1p5_Al_mc_AKGC417L train
864 | 213_1p5_Ar_mc_AKGC417L train
865 | 213_1p5_Pl_mc_AKGC417L train
866 | 213_1p5_Pr_mc_AKGC417L train
867 | 213_1p5_Tc_mc_AKGC417L train
868 | 213_2p2_Al_mc_AKGC417L train
869 | 213_2p2_Ar_mc_AKGC417L train
870 | 213_2p2_Pl_mc_AKGC417L train
871 | 213_2p2_Pr_mc_AKGC417L train
872 | 213_2p2_Tc_mc_AKGC417L train
873 | 214_1b1_Ar_sc_Meditron test
874 | 215_1b2_Ar_sc_Meditron train
875 | 215_1b3_Tc_sc_Meditron train
876 | 216_1b1_Al_sc_Meditron test
877 | 216_1b1_Pl_sc_Meditron test
878 | 217_1b1_Tc_sc_Meditron train
879 | 218_1b1_Al_sc_Meditron test
880 | 218_1b1_Ar_sc_Meditron train
881 | 218_1b1_Lr_sc_Meditron train
882 | 218_1b1_Pl_sc_Meditron test
883 | 218_1b1_Pr_sc_Meditron test
884 | 218_1p1_Ar_sc_Litt3200 train
885 | 218_1p1_Pl_sc_Litt3200 test
886 | 218_1p1_Pr_sc_Litt3200 train
887 | 219_2b1_Ar_mc_LittC2SE train
888 | 219_2b1_Tc_mc_LittC2SE train
889 | 219_2b2_Ar_mc_LittC2SE train
890 | 219_2b2_Tc_mc_LittC2SE train
891 | 219_2b3_Tc_mc_LittC2SE train
892 | 220_1b1_Tc_mc_LittC2SE train
893 | 220_1b2_Al_mc_LittC2SE train
894 | 221_2b1_Al_mc_LittC2SE train
895 | 221_2b1_Ar_mc_LittC2SE train
896 | 221_2b1_Lr_mc_LittC2SE train
897 | 221_2b1_Pl_mc_LittC2SE train
898 | 221_2b2_Al_mc_LittC2SE train
899 | 221_2b2_Ar_mc_LittC2SE train
900 | 221_2b2_Lr_mc_LittC2SE train
901 | 221_2b2_Pl_mc_LittC2SE train
902 | 221_2b3_Al_mc_LittC2SE train
903 | 221_2b3_Ar_mc_LittC2SE train
904 | 221_2b3_Lr_mc_LittC2SE train
905 | 221_2b3_Pr_mc_LittC2SE train
906 | 222_1b1_Ar_sc_Meditron train
907 | 222_1b1_Lr_sc_Meditron train
908 | 222_1b1_Pr_sc_Meditron train
909 | 223_1b1_Pr_sc_Meditron test
910 | 223_1b1_Pl_sc_Meditron test
911 | 223_1b1_Ar_sc_Meditron test
912 | 223_1b1_Al_sc_Meditron test
913 | 223_1b1_Lr_sc_Meditron test
914 | 223_1b1_Ll_sc_Meditron test
915 | 224_1b1_Tc_sc_Meditron test
916 | 224_1b2_Al_sc_Meditron test
917 | 225_1b1_Pl_sc_Meditron test
918 | 226_1b1_Al_sc_Meditron train
919 | 226_1b1_Ll_sc_Meditron train
920 | 226_1b1_Pl_sc_LittC2SE train
--------------------------------------------------------------------------------
/data/icbhi_dataset/patient_diagnosis.txt:
--------------------------------------------------------------------------------
1 | 101 URTI
2 | 102 Healthy
3 | 103 Asthma
4 | 104 COPD
5 | 105 URTI
6 | 106 COPD
7 | 107 COPD
8 | 108 LRTI
9 | 109 COPD
10 | 110 COPD
11 | 111 Bronchiectasis
12 | 112 COPD
13 | 113 COPD
14 | 114 COPD
15 | 115 LRTI
16 | 116 Bronchiectasis
17 | 117 COPD
18 | 118 COPD
19 | 119 URTI
20 | 120 COPD
21 | 121 Healthy
22 | 122 Pneumonia
23 | 123 Healthy
24 | 124 COPD
25 | 125 Healthy
26 | 126 Healthy
27 | 127 Healthy
28 | 128 COPD
29 | 129 URTI
30 | 130 COPD
31 | 131 URTI
32 | 132 COPD
33 | 133 COPD
34 | 134 COPD
35 | 135 Pneumonia
36 | 136 Healthy
37 | 137 URTI
38 | 138 COPD
39 | 139 COPD
40 | 140 Pneumonia
41 | 141 COPD
42 | 142 COPD
43 | 143 Healthy
44 | 144 Healthy
45 | 145 COPD
46 | 146 COPD
47 | 147 COPD
48 | 148 URTI
49 | 149 Bronchiolitis
50 | 150 URTI
51 | 151 COPD
52 | 152 Healthy
53 | 153 Healthy
54 | 154 COPD
55 | 155 COPD
56 | 156 COPD
57 | 157 COPD
58 | 158 COPD
59 | 159 Healthy
60 | 160 COPD
61 | 161 Bronchiolitis
62 | 162 COPD
63 | 163 COPD
64 | 164 URTI
65 | 165 URTI
66 | 166 COPD
67 | 167 Bronchiolitis
68 | 168 Bronchiectasis
69 | 169 Bronchiectasis
70 | 170 COPD
71 | 171 Healthy
72 | 172 COPD
73 | 173 Bronchiolitis
74 | 174 COPD
75 | 175 COPD
76 | 176 COPD
77 | 177 COPD
78 | 178 COPD
79 | 179 Healthy
80 | 180 COPD
81 | 181 COPD
82 | 182 Healthy
83 | 183 Healthy
84 | 184 Healthy
85 | 185 COPD
86 | 186 COPD
87 | 187 Healthy
88 | 188 URTI
89 | 189 COPD
90 | 190 URTI
91 | 191 Pneumonia
92 | 192 COPD
93 | 193 COPD
94 | 194 Healthy
95 | 195 COPD
96 | 196 Bronchiectasis
97 | 197 URTI
98 | 198 COPD
99 | 199 COPD
100 | 200 COPD
101 | 201 Bronchiectasis
102 | 202 Healthy
103 | 203 COPD
104 | 204 COPD
105 | 205 COPD
106 | 206 Bronchiolitis
107 | 207 COPD
108 | 208 Healthy
109 | 209 Healthy
110 | 210 URTI
111 | 211 COPD
112 | 212 COPD
113 | 213 COPD
114 | 214 Healthy
115 | 215 Bronchiectasis
116 | 216 Bronchiolitis
117 | 217 Healthy
118 | 218 COPD
119 | 219 Pneumonia
120 | 220 COPD
121 | 221 COPD
122 | 222 COPD
123 | 223 COPD
124 | 224 Healthy
125 | 225 Healthy
126 | 226 Pneumonia
--------------------------------------------------------------------------------
/data/icbhi_dataset/patient_list_foldwise.txt:
--------------------------------------------------------------------------------
1 | 158 0
2 | 193 0
3 | 177 0
4 | 170 0
5 | 180 0
6 | 211 0
7 | 147 0
8 | 107 0
9 | 162 0
10 | 156 0
11 | 146 0
12 | 200 0
13 | 138 0
14 | 160 0
15 | 203 0
16 | 204 0
17 | 172 0
18 | 207 0
19 | 163 0
20 | 205 0
21 | 213 0
22 | 114 0
23 | 130 0
24 | 154 0
25 | 186 0
26 | 184 0
27 | 153 1
28 | 115 1
29 | 224 1
30 | 223 1
31 | 201 1
32 | 218 1
33 | 127 1
34 | 137 1
35 | 215 1
36 | 161 1
37 | 206 1
38 | 101 1
39 | 168 1
40 | 131 1
41 | 216 1
42 | 120 1
43 | 188 1
44 | 167 1
45 | 210 1
46 | 197 1
47 | 183 1
48 | 152 1
49 | 173 1
50 | 108 1
51 | 208 1
52 | 105 2
53 | 110 2
54 | 116 2
55 | 196 2
56 | 182 2
57 | 222 2
58 | 166 2
59 | 209 2
60 | 144 2
61 | 111 2
62 | 165 2
63 | 148 2
64 | 164 2
65 | 159 2
66 | 121 2
67 | 157 2
68 | 217 2
69 | 123 2
70 | 169 2
71 | 179 2
72 | 190 2
73 | 125 2
74 | 129 2
75 | 225 2
76 | 136 2
77 | 118 3
78 | 185 3
79 | 112 3
80 | 124 3
81 | 104 3
82 | 195 3
83 | 175 3
84 | 212 3
85 | 140 3
86 | 219 3
87 | 132 3
88 | 142 3
89 | 220 3
90 | 122 3
91 | 191 3
92 | 128 3
93 | 226 3
94 | 141 3
95 | 103 3
96 | 134 3
97 | 117 3
98 | 192 3
99 | 106 3
100 | 155 3
101 | 199 3
102 | 174 4
103 | 145 4
104 | 151 4
105 | 176 4
106 | 178 4
107 | 133 4
108 | 198 4
109 | 214 4
110 | 149 4
111 | 143 4
112 | 187 4
113 | 202 4
114 | 119 4
115 | 194 4
116 | 126 4
117 | 150 4
118 | 171 4
119 | 102 4
120 | 109 4
121 | 113 4
122 | 139 4
123 | 189 4
124 | 181 4
125 | 221 4
126 | 135 4
127 |
--------------------------------------------------------------------------------
/data/icbhi_dataset/script.sh:
--------------------------------------------------------------------------------
1 | n=1
2 | appendix="Meditron"
3 | newname="Blank new name"
4 | while IFS= read -r oldname; do
5 | # echo $oldname $newname $appendix
6 | truncatedoldname=${oldname%?????????}
7 |
8 | newname=$truncatedoldname$appendix
9 |
10 | newname=${newname#"'"}
11 | oldname=${oldname#"'"}
12 | oldname=${oldname%"'"}
13 |
14 | mv -vn "$oldname".txt "$newname".txt
15 | mv -vn "$oldname".wav "$newname".wav
16 | done < "filename_differences.txt"
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import os
3 | import sys
4 | import json
5 | import warnings
6 | warnings.filterwarnings("ignore")
7 |
8 | import math
9 | import time
10 | import random
11 | import pickle
12 | import argparse
13 | import numpy as np
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.optim as optim
18 | import torch.backends.cudnn as cudnn
19 | from torchvision import transforms
20 |
21 | from util.icbhi_dataset import ICBHIDataset
22 | from util.icbhi_util import get_score
23 | from util.augmentation import SpecAugment
24 | from util.misc import adjust_learning_rate, warmup_learning_rate, set_optimizer, update_moving_average
25 | from util.misc import AverageMeter, accuracy, save_model, update_json
26 | from models import get_backbone_class, Projector
27 | from method import PatchMixLoss, PatchMixConLoss
28 |
29 |
30 | def parse_args():
31 | parser = argparse.ArgumentParser('argument for supervised training')
32 |
33 | parser.add_argument('--seed', type=int, default=0)
34 | parser.add_argument('--print_freq', type=int, default=10)
35 | parser.add_argument('--save_freq', type=int, default=100)
36 | parser.add_argument('--save_dir', type=str, default='./save')
37 | parser.add_argument('--tag', type=str, default='',
38 | help='tag for experiment name')
39 | parser.add_argument('--resume', type=str, default=None,
40 | help='path of model checkpoint to resume')
41 | parser.add_argument('--eval', action='store_true',
42 | help='only evaluation with pretrained encoder and classifier')
43 | parser.add_argument('--two_cls_eval', action='store_true',
44 | help='evaluate with two classes')
45 |
46 | # optimization
47 | parser.add_argument('--optimizer', type=str, default='adam')
48 | parser.add_argument('--epochs', type=int, default=400)
49 | parser.add_argument('--learning_rate', type=float, default=1e-3)
50 | parser.add_argument('--lr_decay_epochs', type=str, default='120,160')
51 | parser.add_argument('--lr_decay_rate', type=float, default=0.1)
52 | parser.add_argument('--weight_decay', type=float, default=1e-4)
53 | parser.add_argument('--momentum', type=float, default=0.9)
54 | parser.add_argument('--cosine', action='store_true',
55 | help='using cosine annealing')
56 | parser.add_argument('--warm', action='store_true',
57 | help='warm-up for large batch training')
58 | parser.add_argument('--warm_epochs', type=int, default=0,
59 | help='warmup epochs')
60 | parser.add_argument('--weighted_loss', action='store_true',
61 | help='weighted cross entropy loss (higher weights on abnormal class)')
62 | parser.add_argument('--mix_beta', default=1.0, type=float,
63 | help='patch-mix interpolation coefficient')
64 | parser.add_argument('--time_domain', action='store_true',
65 | help='patchmix for the specific time domain')
66 |
67 | # dataset
68 | parser.add_argument('--dataset', type=str, default='icbhi')
69 | parser.add_argument('--data_folder', type=str, default='./data/')
70 | parser.add_argument('--batch_size', type=int, default=128)
71 | parser.add_argument('--num_workers', type=int, default=8)
72 | # icbhi dataset
73 | parser.add_argument('--class_split', type=str, default='lungsound',
74 | help='lungsound: (normal, crackles, wheezes, both), diagnosis: (healthy, chronic diseases, non-chronic diseases)')
75 | parser.add_argument('--n_cls', type=int, default=4,
76 | help='set k-way classification problem')
77 | parser.add_argument('--test_fold', type=str, default='official', choices=['official', '0', '1', '2', '3', '4'],
78 | help='test fold to use official 60-40 split or 80-20 split from RespireNet')
79 | parser.add_argument('--weighted_sampler', action='store_true',
80 | help='weighted sampler inversly proportional to class ratio')
81 | parser.add_argument('--stetho_id', type=int, default=-1,
82 | help='stethoscope device id, use only when finetuning on each stethoscope data')
83 | parser.add_argument('--sample_rate', type=int, default=16000,
84 | help='sampling rate when load audio data, and it denotes the number of samples per one second')
85 | parser.add_argument('--butterworth_filter', type=int, default=None,
86 | help='apply specific order butterworth band-pass filter')
87 | parser.add_argument('--desired_length', type=int, default=8,
88 | help='fixed length size of individual cycle')
89 | parser.add_argument('--nfft', type=int, default=1024,
90 | help='the frequency size of fast fourier transform')
91 | parser.add_argument('--n_mels', type=int, default=128,
92 | help='the number of mel filter banks')
93 | parser.add_argument('--concat_aug_scale', type=float, default=0,
94 | help='to control the number (scale) of concatenation-based augmented samples')
95 | parser.add_argument('--pad_types', type=str, default='repeat',
96 | help='zero: zero-padding, repeat: padding with duplicated samples, aug: padding with augmented samples')
97 | parser.add_argument('--resz', type=float, default=1,
98 | help='resize the scale of mel-spectrogram')
99 | parser.add_argument('--raw_augment', type=int, default=0,
100 | help='control how many number of augmented raw audio samples')
101 | parser.add_argument('--blank_region_clip', action='store_true',
102 | help='remove the blank region, high frequency region')
103 | parser.add_argument('--specaug_policy', type=str, default='icbhi_ast_sup',
104 | help='policy (argument values) for SpecAugment')
105 | parser.add_argument('--specaug_mask', type=str, default='mean',
106 | help='specaug mask value', choices=['mean', 'zero'])
107 |
108 | # model
109 | parser.add_argument('--model', type=str, default='ast')
110 | parser.add_argument('--pretrained', action='store_true')
111 | parser.add_argument('--pretrained_ckpt', type=str, default=None,
112 | help='path to pre-trained encoder model')
113 | parser.add_argument('--from_sl_official', action='store_true',
114 | help='load from supervised imagenet-pretrained model (official PyTorch)')
115 | parser.add_argument('--ma_update', action='store_true',
116 | help='whether to use moving average update for model')
117 | parser.add_argument('--ma_beta', type=float, default=0,
118 | help='moving average value')
119 | # for AST
120 | parser.add_argument('--audioset_pretrained', action='store_true',
121 | help='load from imagenet- and audioset-pretrained model')
122 | # for SSAST
123 | parser.add_argument('--ssast_task', type=str, default='ft_avgtok',
124 | help='pretraining or fine-tuning task', choices=['ft_avgtok', 'ft_cls'])
125 | parser.add_argument('--fshape', type=int, default=16,
126 | help='fshape of SSAST')
127 | parser.add_argument('--tshape', type=int, default=16,
128 | help='tshape of SSAST')
129 | parser.add_argument('--ssast_pretrained_type', type=str, default='Patch',
130 | help='pretrained ckpt version of SSAST model')
131 |
132 | parser.add_argument('--method', type=str, default='ce')
133 | # Patch-Mix CL loss
134 | parser.add_argument('--proj_dim', type=int, default=768)
135 | parser.add_argument('--temperature', type=float, default=0.06)
136 | parser.add_argument('--alpha', type=float, default=1.0)
137 | parser.add_argument('--negative_pair', type=str, default='all',
138 | help='the method for selecting negative pair', choices=['all', 'diff_label'])
139 | parser.add_argument('--target_type', type=str, default='grad_block',
140 | help='how to make target representation', choices=['grad_block', 'grad_flow', 'project_block', 'project_flow'])
141 |
142 | args = parser.parse_args()
143 |
144 | iterations = args.lr_decay_epochs.split(',')
145 | args.lr_decay_epochs = list([])
146 | for it in iterations:
147 | args.lr_decay_epochs.append(int(it))
148 |
149 | args.model_name = '{}_{}_{}'.format(args.dataset, args.model, args.method)
150 | if args.tag:
151 | args.model_name += '_{}'.format(args.tag)
152 |
153 | if args.method in ['patchmix', 'patchmix_cl']:
154 | assert args.model in ['ast', 'ssast']
155 |
156 | args.save_folder = os.path.join(args.save_dir, args.model_name)
157 | if not os.path.isdir(args.save_folder):
158 | os.makedirs(args.save_folder)
159 |
160 | if args.warm:
161 | args.warmup_from = args.learning_rate * 0.1
162 | args.warm_epochs = 10
163 | if args.cosine:
164 | eta_min = args.learning_rate * (args.lr_decay_rate ** 3)
165 | args.warmup_to = eta_min + (args.learning_rate - eta_min) * (
166 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
167 | else:
168 | args.warmup_to = args.learning_rate
169 |
170 | if args.dataset == 'icbhi':
171 | if args.class_split == 'lungsound':
172 | if args.n_cls == 4:
173 | args.cls_list = ['normal', 'crackle', 'wheeze', 'both']
174 | elif args.n_cls == 2:
175 | args.cls_list = ['normal', 'abnormal']
176 | elif args.class_split == 'diagnosis':
177 | if args.n_cls == 3:
178 | args.cls_list = ['healthy', 'chronic_diseases', 'non-chronic_diseases']
179 | elif args.n_cls == 2:
180 | args.cls_list = ['healthy', 'unhealthy']
181 | else:
182 | raise NotImplementedError
183 |
184 | return args
185 |
186 |
187 | def set_loader(args):
188 | if args.dataset == 'icbhi':
189 | # get rawo information and calculate mean and std for normalization
190 | # dataset = ICBHIDataset(train_flag=True, transform=transforms.Compose([transforms.ToTensor()]), args=args, print_flag=False, mean_std=True)
191 | # mean, std = get_mean_and_std(dataset)
192 | # args.h, args.w = dataset.h, dataset.w
193 |
194 | # print('*' * 20)
195 | # print('[Raw dataset information]')
196 | # print('Stethoscope device number: {}, and patience number without overlap: {}'.format(len(dataset.device_to_id), len(set(sum(dataset.device_id_to_patient.values(), []))) ))
197 | # for device, id in dataset.device_to_id.items():
198 | # print('Device {} ({}): {} number of patience'.format(id, device, len(dataset.device_id_to_patient[id])))
199 | # print('Spectrogram shpae on ICBHI dataset: {} (height) and {} (width)'.format(args.h, args.w))
200 | # print('Mean and std of ICBHI dataset: {} (mean) and {} (std)'.format(round(mean.item(), 2), round(std.item(), 2)))
201 |
202 | args.h, args.w = 798, 128
203 | train_transform = [transforms.ToTensor(),
204 | SpecAugment(args),
205 | transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz)))]
206 | val_transform = [transforms.ToTensor(),
207 | transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz)))]
208 | # train_transform.append(transforms.Normalize(mean=mean, std=std))
209 | # val_transform.append(transforms.Normalize(mean=mean, std=std))
210 |
211 | train_transform = transforms.Compose(train_transform)
212 | val_transform = transforms.Compose(val_transform)
213 |
214 | train_dataset = ICBHIDataset(train_flag=True, transform=train_transform, args=args, print_flag=True)
215 | val_dataset = ICBHIDataset(train_flag=False, transform=val_transform, args=args, print_flag=True)
216 |
217 | # for weighted_loss
218 | args.class_nums = train_dataset.class_nums
219 | else:
220 | raise NotImplemented
221 |
222 | if args.weighted_sampler:
223 | reciprocal_weights = []
224 | for idx in range(len(train_dataset)):
225 | reciprocal_weights.append(train_dataset.class_ratio[train_dataset.labels[idx]])
226 | weights = (1 / torch.Tensor(reciprocal_weights))
227 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(train_dataset))
228 | else:
229 | sampler = None
230 |
231 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=sampler is None,
232 | num_workers=args.num_workers, pin_memory=True, sampler=sampler, drop_last=True)
233 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
234 | num_workers=args.num_workers, pin_memory=True, sampler=None)
235 |
236 | return train_loader, val_loader, args
237 |
238 |
239 | def set_model(args):
240 | kwargs = {}
241 | if args.model == 'ast':
242 | kwargs['input_fdim'] = int(args.h * args.resz)
243 | kwargs['input_tdim'] = int(args.w * args.resz)
244 | kwargs['label_dim'] = args.n_cls
245 | kwargs['imagenet_pretrain'] = args.from_sl_official
246 | kwargs['audioset_pretrain'] = args.audioset_pretrained
247 | kwargs['mix_beta'] = args.mix_beta # for Patch-MixCL
248 | elif args.model == 'ssast':
249 | kwargs['label_dim'] = args.n_cls
250 | kwargs['fshape'], kwargs['tshape'] = args.fshape, args.tshape
251 | kwargs['fstride'], kwargs['tstride'] = 10, 10
252 | kwargs['input_tdim'] = 798
253 | kwargs['task'] = args.ssast_task
254 | kwargs['pretrain_stage'] = not args.audioset_pretrained
255 | kwargs['load_pretrained_mdl_path'] = args.ssast_pretrained_type
256 | kwargs['mix_beta'] = args.mix_beta # for Patch-MixCL
257 |
258 | model = get_backbone_class(args.model)(**kwargs)
259 | classifier = nn.Linear(model.final_feat_dim, args.n_cls) if args.model not in ['ast', 'ssast'] else deepcopy(model.mlp_head)
260 |
261 | if not args.weighted_loss:
262 | weights = None
263 | criterion = nn.CrossEntropyLoss()
264 | else:
265 | weights = torch.tensor(args.class_nums, dtype=torch.float32)
266 | weights = 1.0 / (weights / weights.sum())
267 | weights /= weights.sum()
268 |
269 | criterion = nn.CrossEntropyLoss(weight=weights)
270 |
271 | if args.model not in ['ast', 'ssast'] and args.from_sl_official:
272 | model.load_sl_official_weights()
273 | print('pretrained model loaded from PyTorch ImageNet-pretrained')
274 |
275 | # load SSL pretrained checkpoint for linear evaluation
276 | if args.pretrained and args.pretrained_ckpt is not None:
277 | ckpt = torch.load(args.pretrained_ckpt, map_location='cpu')
278 | state_dict = ckpt['model']
279 |
280 | # HOTFIX: always use dataparallel during SSL pretraining
281 | new_state_dict = {}
282 | for k, v in state_dict.items():
283 | if "module." in k:
284 | k = k.replace("module.", "")
285 | if "backbone." in k:
286 | k = k.replace("backbone.", "")
287 |
288 | new_state_dict[k] = v
289 | state_dict = new_state_dict
290 | model.load_state_dict(state_dict, strict=False)
291 |
292 | if ckpt.get('classifier', None) is not None:
293 | classifier.load_state_dict(ckpt['classifier'], strict=True)
294 |
295 | print('pretrained model loaded from: {}'.format(args.pretrained_ckpt))
296 |
297 | projector = Projector(model.final_feat_dim, args.proj_dim) if args.method == 'patchmix_cl' else nn.Identity()
298 |
299 | if args.method == 'ce':
300 | criterion = [criterion.cuda()]
301 | elif args.method == 'patchmix':
302 | criterion = [criterion.cuda(), PatchMixLoss(criterion=criterion).cuda()]
303 | elif args.method == 'patchmix_cl':
304 | criterion = [criterion.cuda(), PatchMixConLoss(temperature=args.temperature).cuda()]
305 |
306 | if torch.cuda.device_count() > 1:
307 | model = torch.nn.DataParallel(model)
308 |
309 | model.cuda()
310 | classifier.cuda()
311 | projector.cuda()
312 |
313 | optim_params = list(model.parameters()) + list(classifier.parameters()) + list(projector.parameters())
314 | optimizer = set_optimizer(args, optim_params)
315 |
316 | return model, classifier, projector, criterion, optimizer
317 |
318 |
319 | def train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args, scaler=None):
320 | model.train()
321 | classifier.train()
322 | projector.train()
323 |
324 | batch_time = AverageMeter()
325 | data_time = AverageMeter()
326 | losses = AverageMeter()
327 | top1 = AverageMeter()
328 |
329 | end = time.time()
330 | for idx, (images, labels, metadata) in enumerate(train_loader):
331 | if args.ma_update:
332 | # store the previous iter checkpoint
333 | with torch.no_grad():
334 | ma_ckpt = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict()), deepcopy(projector.state_dict())]
335 |
336 | data_time.update(time.time() - end)
337 |
338 | images = images.cuda(non_blocking=True)
339 | labels = labels.cuda(non_blocking=True)
340 | bsz = labels.shape[0]
341 |
342 | warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer)
343 |
344 | with torch.cuda.amp.autocast():
345 | if args.method == 'ce':
346 | features = model(images)
347 | output = classifier(features)
348 | loss = criterion[0](output, labels)
349 |
350 | elif args.method == 'patchmix':
351 | mix_images, labels_a, labels_b, lam, index = model(images, y=labels, patch_mix=True, time_domain=args.time_domain)
352 | output = classifier(mix_images)
353 | loss = criterion[1](output, labels_a, labels_b, lam)
354 |
355 | elif args.method == 'patchmix_cl':
356 | features = model(images)
357 | output = classifier(features)
358 | loss = criterion[0](output, labels)
359 |
360 | if args.target_type == 'grad_block':
361 | proj1 = deepcopy(features.detach())
362 | elif args.target_type == 'grad_flow':
363 | proj1 = features
364 | elif args.target_type == 'project_block':
365 | proj1 = deepcopy(projector(features).detach())
366 | elif args.target_type == 'project_flow':
367 | proj1 = projector(features)
368 |
369 | # use 'patchmix_cl' for augmentation
370 | mix_images, labels_a, labels_b, lam, index = model(images, y=labels, patch_mix=True, time_domain=args.time_domain)
371 | proj2 = projector(mix_images)
372 | loss += args.alpha * criterion[1](proj1, proj2, labels, labels_b, lam, index, args)
373 |
374 | losses.update(loss.item(), bsz)
375 | [acc1], _ = accuracy(output[:bsz], labels, topk=(1,))
376 | top1.update(acc1[0], bsz)
377 |
378 | optimizer.zero_grad()
379 | scaler.scale(loss).backward()
380 | scaler.step(optimizer)
381 | scaler.update()
382 |
383 | # measure elapsed time
384 | batch_time.update(time.time() - end)
385 | end = time.time()
386 |
387 | if args.ma_update:
388 | with torch.no_grad():
389 | # exponential moving average update
390 | model = update_moving_average(args.ma_beta, model, ma_ckpt[0])
391 | classifier = update_moving_average(args.ma_beta, classifier, ma_ckpt[1])
392 | projector = update_moving_average(args.ma_beta, projector, ma_ckpt[2])
393 |
394 | # print info
395 | if (idx + 1) % args.print_freq == 0:
396 | print('Train: [{0}][{1}/{2}]\t'
397 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
398 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
399 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
400 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
401 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
402 | data_time=data_time, loss=losses, top1=top1))
403 | sys.stdout.flush()
404 |
405 | return losses.avg, top1.avg
406 |
407 |
408 | def validate(val_loader, model, classifier, criterion, args, best_acc, best_model=None):
409 | save_bool = False
410 | model.eval()
411 | classifier.eval()
412 |
413 | batch_time = AverageMeter()
414 | losses = AverageMeter()
415 | top1 = AverageMeter()
416 | hits, counts = [0.0] * args.n_cls, [0.0] * args.n_cls
417 |
418 | with torch.no_grad():
419 | end = time.time()
420 | for idx, (images, labels, metadata) in enumerate(val_loader):
421 | images = images.cuda(non_blocking=True)
422 | labels = labels.cuda(non_blocking=True)
423 | bsz = labels.shape[0]
424 |
425 | with torch.cuda.amp.autocast():
426 | features = model(images)
427 | output = classifier(features)
428 | loss = criterion[0](output, labels)
429 |
430 | losses.update(loss.item(), bsz)
431 | [acc1], _ = accuracy(output, labels, topk=(1,))
432 | top1.update(acc1[0], bsz)
433 |
434 | _, preds = torch.max(output, 1)
435 | for idx in range(preds.shape[0]):
436 | counts[labels[idx].item()] += 1.0
437 | if not args.two_cls_eval:
438 | if preds[idx].item() == labels[idx].item():
439 | hits[labels[idx].item()] += 1.0
440 | else: # only when args.n_cls == 4
441 | if labels[idx].item() == 0 and preds[idx].item() == labels[idx].item():
442 | hits[labels[idx].item()] += 1.0
443 | elif labels[idx].item() != 0 and preds[idx].item() > 0: # abnormal
444 | hits[labels[idx].item()] += 1.0
445 |
446 | sp, se, sc = get_score(hits, counts)
447 |
448 | batch_time.update(time.time() - end)
449 | end = time.time()
450 |
451 | if (idx + 1) % args.print_freq == 0:
452 | print('Test: [{0}/{1}]\t'
453 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
454 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
455 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
456 | idx + 1, len(val_loader), batch_time=batch_time,
457 | loss=losses, top1=top1))
458 |
459 | if sc > best_acc[-1] and se > 5:
460 | save_bool = True
461 | best_acc = [sp, se, sc]
462 | best_model = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict())]
463 |
464 | print(' * S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f} (Best S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f})'.format(sp, se, sc, best_acc[0], best_acc[1], best_acc[-1]))
465 | print(' * Acc@1 {top1.avg:.2f}'.format(top1=top1))
466 |
467 | return best_acc, best_model, save_bool
468 |
469 |
470 | def main():
471 | args = parse_args()
472 | with open(os.path.join(args.save_folder, 'train_args.json'), 'w') as f:
473 | json.dump(vars(args), f, indent=4)
474 |
475 | # fix seed
476 | random.seed(args.seed)
477 | np.random.seed(args.seed)
478 | torch.manual_seed(args.seed)
479 | torch.cuda.manual_seed(args.seed)
480 | cudnn.deterministic = True
481 | cudnn.benchmark = True
482 |
483 | best_model = None
484 | if args.dataset == 'icbhi':
485 | best_acc = [0, 0, 0] # Specificity, Sensitivity, Score
486 |
487 | train_loader, val_loader, args = set_loader(args)
488 | model, classifier, projector, criterion, optimizer = set_model(args)
489 |
490 | if args.resume:
491 | if os.path.isfile(args.resume):
492 | print("=> loading checkpoint '{}'".format(args.resume))
493 | checkpoint = torch.load(args.resume)
494 | args.start_epoch = checkpoint['epoch']
495 | model.load_state_dict(checkpoint['model'])
496 | optimizer.load_state_dict(checkpoint['optimizer'])
497 | args.start_epoch += 1
498 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
499 | else:
500 | print("=> no checkpoint found at '{}'".format(args.resume))
501 | else:
502 | args.start_epoch = 1
503 |
504 | # use mix_precision:
505 | scaler = torch.cuda.amp.GradScaler()
506 |
507 | print('*' * 20)
508 | if not args.eval:
509 | print('Training for {} epochs on {} dataset'.format(args.epochs, args.dataset))
510 | for epoch in range(args.start_epoch, args.epochs+1):
511 | adjust_learning_rate(args, optimizer, epoch)
512 |
513 | # train for one epoch
514 | time1 = time.time()
515 | loss, acc = train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args, scaler)
516 | time2 = time.time()
517 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
518 | epoch, time2-time1, acc))
519 |
520 | # eval for one epoch
521 | best_acc, best_model, save_bool = validate(val_loader, model, classifier, criterion, args, best_acc, best_model)
522 |
523 | # save a checkpoint of model and classifier when the best score is updated
524 | if save_bool:
525 | save_file = os.path.join(args.save_folder, 'best_epoch_{}.pth'.format(epoch))
526 | print('Best ckpt is modified with Score = {:.2f} when Epoch = {}'.format(best_acc[2], epoch))
527 | save_model(model, optimizer, args, epoch, save_file, classifier)
528 |
529 | if epoch % args.save_freq == 0:
530 | save_file = os.path.join(args.save_folder, 'epoch_{}.pth'.format(epoch))
531 | save_model(model, optimizer, args, epoch, save_file, classifier)
532 |
533 | # save a checkpoint of classifier with the best accuracy or score
534 | save_file = os.path.join(args.save_folder, 'best.pth')
535 | model.load_state_dict(best_model[0])
536 | classifier.load_state_dict(best_model[1])
537 | save_model(model, optimizer, args, epoch, save_file, classifier)
538 | else:
539 | print('Testing the pretrained checkpoint on {} dataset'.format(args.dataset))
540 | best_acc, _, _ = validate(val_loader, model, classifier, criterion, args, best_acc)
541 |
542 | update_json('%s' % args.model_name, best_acc, path=os.path.join(args.save_dir, 'results.json'))
543 |
544 | if __name__ == '__main__':
545 | main()
546 |
--------------------------------------------------------------------------------
/method/__init__.py:
--------------------------------------------------------------------------------
1 | from .patchmix import PatchMixLoss
2 | from .patchmix_cl import PatchMixConLoss
--------------------------------------------------------------------------------
/method/patchmix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class PatchMixLoss(nn.Module):
8 | """ Refer to models/ast.py for the code of patch mixing
9 | """
10 | def __init__(self, criterion):
11 | super().__init__()
12 | self.criterion = criterion
13 |
14 | def forward(self, pred, y_a, y_b, lam):
15 | loss = lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b)
16 | return loss
17 |
--------------------------------------------------------------------------------
/method/patchmix_cl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class PatchMixConLoss(nn.Module):
8 | def __init__(self, temperature=0.06):
9 | super().__init__()
10 | self.temperature = temperature
11 |
12 | def forward(self, projection1, projection2, labels_a, labels_b, lam, index, args):
13 | batch_size = projection1.shape[0]
14 | projection1, projection2 = F.normalize(projection1), F.normalize(projection2)
15 | anchor_dot_contrast = torch.div(torch.matmul(projection2, projection1.T), self.temperature)
16 |
17 | mask_a = torch.eye(batch_size).cuda()
18 | mask_b = torch.zeros(batch_size, batch_size).cuda()
19 | mask_b[torch.arange(batch_size).unsqueeze(1), index.view(-1, 1)] = 1
20 |
21 | mask = lam * mask_a + (1 - lam) * mask_b
22 |
23 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
24 | logits = anchor_dot_contrast - logits_max.detach() # for numerical stability
25 |
26 | exp_logits = torch.exp(logits)
27 | if args.negative_pair == 'diff_label':
28 | labels_a = labels_a.contiguous().view(-1, 1)
29 | logits_mask = torch.ne(labels_a, labels_a.T).cuda() + (mask_a.bool() + mask_b.bool())
30 | exp_logits *= logits_mask.float()
31 |
32 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
33 |
34 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
35 |
36 | loss = -mean_log_prob_pos
37 | loss = loss.view(1, batch_size)
38 |
39 | loss = loss.mean()
40 | return loss
41 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .cnn6 import CNN6
2 | from .resnet import ResNet10, ResNet18, ResNet34, ResNet50, ResNet101
3 | from .efficientnet import EfficientNet_B0, EfficientNet_B1, EfficientNet_B2
4 | from .ast import ASTModel
5 | from .ssast import SSASTModel
6 | from .projector import Projector
7 |
8 | _backbone_class_map = {
9 | 'cnn6': CNN6,
10 | 'resnet10': ResNet10,
11 | 'resnet18': ResNet18,
12 | 'resnet34': ResNet34,
13 | 'resnet50': ResNet50,
14 | 'resnet101': ResNet101,
15 | 'efficientnet_b0': EfficientNet_B0,
16 | 'efficientnet_b1': EfficientNet_B1,
17 | 'efficientnet_b2': EfficientNet_B2,
18 | 'ast': ASTModel,
19 | 'ssast': SSASTModel
20 | }
21 |
22 |
23 | def get_backbone_class(key):
24 | if key in _backbone_class_map:
25 | return _backbone_class_map[key]
26 | else:
27 | raise ValueError('Invalid backbone: {}'.format(key))
--------------------------------------------------------------------------------
/models/ast.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.cuda.amp import autocast
5 | import os
6 | import wget
7 | import timm
8 | from copy import deepcopy
9 | from timm.models.layers import to_2tuple,trunc_normal_
10 |
11 |
12 | # override the timm package to relax the input shape constraint.
13 | class PatchEmbed(nn.Module):
14 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
15 | super().__init__()
16 |
17 | img_size = to_2tuple(img_size)
18 | patch_size = to_2tuple(patch_size)
19 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
20 | self.img_size = img_size
21 | self.patch_size = patch_size
22 | self.num_patches = num_patches
23 |
24 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
25 |
26 | def forward(self, x):
27 | x = self.proj(x).flatten(2).transpose(1, 2)
28 | return x
29 |
30 |
31 | class ASTModel(nn.Module):
32 | """
33 | The AST model.
34 | :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
35 | :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
36 | :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
37 | :param input_fdim: the number of frequency bins of the input spectrogram
38 | :param input_tdim: the number of time frames of the input spectrogram
39 | :param imagenet_pretrain: if use ImageNet pretrained model
40 | :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
41 | :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
42 | """
43 | def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True, mix_beta=None):
44 | super(ASTModel, self).__init__()
45 | assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
46 |
47 | if verbose == True:
48 | print('---------------AST Model Summary---------------')
49 | print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
50 | # override timm input shape restriction
51 | timm.models.vision_transformer.PatchEmbed = PatchEmbed
52 | self.final_feat_dim = 768
53 | self.mix_beta = mix_beta
54 |
55 | # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
56 | if audioset_pretrain == False:
57 | if model_size == 'tiny224':
58 | self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
59 | elif model_size == 'small224':
60 | self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
61 | elif model_size == 'base224':
62 | self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
63 | elif model_size == 'base384':
64 | self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
65 | else:
66 | raise Exception('Model size must be one of tiny224, small224, base224, base384.')
67 | self.original_num_patches = self.v.patch_embed.num_patches
68 | self.oringal_hw = int(self.original_num_patches ** 0.5)
69 | self.original_embedding_dim = self.v.pos_embed.shape[2]
70 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
71 |
72 | # automatcially get the intermediate shape
73 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
74 | num_patches = f_dim * t_dim
75 | self.v.patch_embed.num_patches = num_patches
76 | if verbose == True:
77 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
78 | print('number of patches={:d}'.format(num_patches))
79 |
80 | # the linear projection layer
81 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
82 | if imagenet_pretrain == True:
83 | new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
84 | new_proj.bias = self.v.patch_embed.proj.bias
85 | self.v.patch_embed.proj = new_proj
86 |
87 | # the positional embedding
88 | if imagenet_pretrain == True:
89 | # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
90 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
91 | # cut (from middle) or interpolate the second dimension of the positional embedding
92 | if t_dim <= self.oringal_hw:
93 | new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
94 | else:
95 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
96 | # cut (from middle) or interpolate the first dimension of the positional embedding
97 | if f_dim <= self.oringal_hw:
98 | new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
99 | else:
100 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
101 | # flatten the positional embedding
102 | new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
103 | # concatenate the above positional embedding with the cls token and distillation token of the deit model.
104 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
105 | else:
106 | # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
107 | # TODO can use sinusoidal positional embedding instead
108 | new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
109 | self.v.pos_embed = new_pos_embed
110 | trunc_normal_(self.v.pos_embed, std=.02)
111 |
112 | # now load a model that is pretrained on both ImageNet and AudioSet
113 | elif audioset_pretrain == True:
114 | if audioset_pretrain == True and imagenet_pretrain == False:
115 | raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
116 | if model_size != 'base384':
117 | raise ValueError('currently only has base384 AudioSet pretrained model.')
118 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119 |
120 | out_dir = './pretrained_models/'
121 | if not os.path.exists(out_dir):
122 | os.makedirs(out_dir, exist_ok=True)
123 |
124 | if os.path.exists(os.path.join(out_dir, 'audioset_10_10_0.4593.pth')) == False:
125 | # this model performs 0.4593 mAP on the audioset eval set
126 | audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
127 | wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'audioset_10_10_0.4593.pth'))
128 |
129 | sd = torch.load(os.path.join(out_dir, 'audioset_10_10_0.4593.pth'), map_location=device)
130 | audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
131 | audio_model = torch.nn.DataParallel(audio_model)
132 | audio_model.load_state_dict(sd, strict=False)
133 | self.v = audio_model.module.v
134 | self.original_embedding_dim = self.v.pos_embed.shape[2]
135 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
136 |
137 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
138 | num_patches = f_dim * t_dim
139 | self.v.patch_embed.num_patches = num_patches
140 | if verbose == True:
141 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
142 | print('number of patches={:d}'.format(num_patches))
143 |
144 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101)
145 | # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
146 | if t_dim < 101:
147 | new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
148 | # otherwise interpolate
149 | else:
150 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
151 | if f_dim < 12:
152 | new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :]
153 | # otherwise interpolate
154 | elif f_dim > 12:
155 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
156 | new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
157 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
158 |
159 | def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
160 | test_input = torch.randn(1, 1, input_fdim, input_tdim)
161 | test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
162 | test_out = test_proj(test_input)
163 | f_dim = test_out.shape[2]
164 | t_dim = test_out.shape[3]
165 | return f_dim, t_dim
166 |
167 | def square_patch(self, patch, hw_num_patch):
168 | h, w = hw_num_patch
169 | B, _, dim = patch.size()
170 | square = patch.reshape(B, h, w, dim)
171 | return square
172 |
173 | def flatten_patch(self, square):
174 | B, h, w, dim = square.shape
175 | patch = square.reshape(B, h * w, dim)
176 | return patch
177 |
178 | def patch_mix(self, image, target, time_domain=False, hw_num_patch=None):
179 | if self.mix_beta > 0:
180 | lam = np.random.beta(self.mix_beta, self.mix_beta)
181 | else:
182 | lam = 1
183 |
184 | batch_size, num_patch, dim = image.size()
185 | device = image.device
186 |
187 | index = torch.randperm(batch_size).to(device)
188 |
189 | if not time_domain:
190 | num_mask = int(num_patch * (1. - lam))
191 | mask = torch.randperm(num_patch)[:num_mask].to(device)
192 |
193 | image[:, mask, :] = image[index][:, mask, :]
194 | lam = 1 - (num_mask / num_patch)
195 | else:
196 | squared_1 = self.square_patch(image, hw_num_patch)
197 | squared_2 = self.square_patch(image[index], hw_num_patch)
198 |
199 | w_size = squared_1.size()[2]
200 | num_mask = int(w_size * (1. - lam))
201 | mask = torch.randperm(w_size)[:num_mask].to(device)
202 |
203 | squared_1[:, :, mask, :] = squared_2[:, :, mask, :]
204 | image = self.flatten_patch(squared_1)
205 | lam = 1 - (num_mask / w_size)
206 |
207 | y_a, y_b = target, target[index]
208 | return image, y_a, y_b, lam, index
209 |
210 | @autocast()
211 | def forward(self, x, y=None, patch_mix=False, time_domain=False):
212 | """
213 | :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
214 | :return: prediction
215 | """
216 | # x = x.unsqueeze(1)
217 | x = x.transpose(2, 3)
218 | h_patch, w_patch = int((x.size()[2] - 16) / 10) + 1, int((x.size()[3] - 16) / 10) + 1
219 |
220 | B = x.shape[0]
221 | x = self.v.patch_embed(x)
222 |
223 | if patch_mix:
224 | x, y_a, y_b, lam, index = self.patch_mix(x, y, time_domain=time_domain, hw_num_patch=[h_patch, w_patch])
225 |
226 | cls_tokens = self.v.cls_token.expand(B, -1, -1)
227 | dist_token = self.v.dist_token.expand(B, -1, -1)
228 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
229 | x = x + self.v.pos_embed
230 | x = self.v.pos_drop(x)
231 | for i, blk in enumerate(self.v.blocks):
232 | x = blk(x)
233 | x = self.v.norm(x)
234 | x = (x[:, 0] + x[:, 1]) / 2
235 | # x = self.mlp_head(x)
236 |
237 | if not patch_mix:
238 | return x
239 | else:
240 | return x, y_a, y_b, lam, index
--------------------------------------------------------------------------------
/models/cnn6.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def init_layer(layer):
7 | """Initialize a Linear or Convolutional layer. """
8 | nn.init.xavier_uniform_(layer.weight)
9 | if hasattr(layer, 'bias'):
10 | if layer.bias is not None:
11 | layer.bias.data.fill_(0.)
12 |
13 |
14 | def init_bn(bn):
15 | """Initialize a Batchnorm layer. """
16 | bn.bias.data.fill_(0.)
17 | bn.weight.data.fill_(1.)
18 |
19 |
20 | class ConvBlock5x5(nn.Module): #for CNN6
21 | def __init__(self, in_channels, out_channels, stride=(1,1)):
22 |
23 | super(ConvBlock5x5, self).__init__()
24 |
25 | self.conv1 = nn.Conv2d(in_channels=in_channels,
26 | out_channels=out_channels,
27 | kernel_size=(5, 5), stride=stride,
28 | padding=(2, 2), bias=False)
29 |
30 | self.bn1 = nn.BatchNorm2d(out_channels)
31 |
32 | self.init_weight()
33 |
34 | def init_weight(self):
35 | init_layer(self.conv1)
36 | init_bn(self.bn1)
37 |
38 | def forward(self, input, pool_size=(2, 2), pool_type='avg'):
39 | x = input
40 | x = F.relu_(self.bn1(self.conv1(x)))
41 | if pool_type == 'max':
42 | x = F.max_pool2d(x, kernel_size=pool_size)
43 | elif pool_type == 'avg':
44 | x = F.avg_pool2d(x, kernel_size=pool_size)
45 | elif pool_type == 'avg+max':
46 | x1 = F.avg_pool2d(x, kernel_size=pool_size)
47 | x2 = F.max_pool2d(x, kernel_size=pool_size)
48 | x = x1 + x2
49 | else:
50 | raise Exception('Incorrect argument!')
51 |
52 | return x
53 |
54 |
55 | class CNN6(nn.Module):
56 | def __init__(self):
57 | super(CNN6, self).__init__()
58 | self.final_feat_dim = 512
59 |
60 | self.do_dropout = False
61 | self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64, stride=(1,1))
62 | self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128, stride=(1,1))
63 | self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256, stride=(1,1))
64 | self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512, stride=(1,1))
65 | self.dropout = nn.Dropout(0.2)
66 | # self.linear = nn.Linear(512, num_classes, bias=True)
67 |
68 | def load_sl_official_weights(self):
69 | """ download AudioSet pretrained CNN6 in https://zenodo.org/record/3960586#.Y8dz8y_kEiY
70 | """
71 | weights = torch.load('pretrained_models/Cnn6_mAP=0.343.pth')['model']
72 | state_dict = {k: v for k, v in weights.items() if k in self.state_dict().keys()}
73 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
74 |
75 | def forward(self, x):
76 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
77 | if self.do_dropout:
78 | x = self.dropout(x)
79 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
80 | if self.do_dropout:
81 | x = self.dropout(x)
82 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
83 | if self.do_dropout:
84 | x = self.dropout(x)
85 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
86 | if self.do_dropout:
87 | x = self.dropout(x)
88 |
89 | x = torch.mean(x, dim=3) #mean over time dim
90 | (x1, _) = torch.max(x, dim=2) #max over freq dim
91 | x2 = torch.mean(x, dim=2) #mean over freq dim (after mean over time)
92 | x = x1 + x2
93 |
94 | # if self.embed_only:
95 | # return x
96 | # return self.linear(x)
97 | return x
--------------------------------------------------------------------------------
/models/efficientnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import Tensor
4 | import torch.nn as nn
5 | import math
6 | import numpy as np
7 | import torch.nn.functional as F
8 | from torch.hub import load_state_dict_from_url
9 | from torchvision.models.efficientnet import _efficientnet_conf
10 |
11 |
12 | # torchvision.__version__ == '0.11.0+cu113' (arguments for _efficientnet_conf are different in latest version)
13 | class EfficientNet_B0(torchvision.models.efficientnet.EfficientNet):
14 | def __init__(self):
15 | inverted_residual_setting, last_channel = _efficientnet_conf('efficientnet_b0', width_mult=1.0, depth_mult=1.0)
16 | super().__init__(inverted_residual_setting, 0.2)
17 |
18 | del self.classifier
19 | self.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
20 | self.final_feat_dim = 1280
21 |
22 | def load_sl_official_weights(self, progress=True):
23 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
24 | progress=progress)
25 | del state_dict['features.0.0.weight']
26 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
27 | #if len(missing) > 0:
28 | #raise AssertionError('Model code may be incorrect')
29 |
30 | def load_ssl_official_weights(self, progress=True):
31 | raise NotImplemented
32 |
33 | def _forward_impl(self, x: Tensor) -> Tensor:
34 | # See note [TorchScript super()]
35 | x = self.features(x)
36 |
37 | x = self.avgpool(x)
38 | x = torch.flatten(x, 1)
39 |
40 | # x = self.classifier(x)
41 |
42 | return x
43 |
44 |
45 | class EfficientNet_B1(torchvision.models.efficientnet.EfficientNet):
46 | def __init__(self):
47 | inverted_residual_setting, last_channel = _efficientnet_conf('efficientnet_b1', width_mult=1.0, depth_mult=1.1)
48 | super().__init__(inverted_residual_setting, 0.2)
49 |
50 | del self.classifier
51 | self.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
52 | self.final_feat_dim = 1280
53 |
54 | def load_sl_official_weights(self, progress=True):
55 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
56 | progress=progress)
57 | del state_dict['features.0.0.weight']
58 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
59 | #if len(missing) > 0:
60 | #raise AssertionError('Model code may be incorrect')
61 |
62 | def load_ssl_official_weights(self, progress=True):
63 | raise NotImplemented
64 |
65 | def _forward_impl(self, x: Tensor) -> Tensor:
66 | # See note [TorchScript super()]
67 | x = self.features(x)
68 |
69 | x = self.avgpool(x)
70 | x = torch.flatten(x, 1)
71 |
72 | # x = self.classifier(x)
73 |
74 | return x
75 |
76 |
77 | class EfficientNet_B2(torchvision.models.efficientnet.EfficientNet):
78 | def __init__(self):
79 | inverted_residual_setting, last_channel = _efficientnet_conf('efficientnet_b2', width_mult=1.1, depth_mult=1.2)
80 | super().__init__(inverted_residual_setting, 0.3)
81 |
82 | del self.classifier
83 | self.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
84 | self.final_feat_dim = 1408
85 |
86 | def load_sl_official_weights(self, progress=True):
87 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
88 | progress=progress)
89 | del state_dict['features.0.0.weight']
90 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
91 | #if len(missing) > 0:
92 | #raise AssertionError('Model code may be incorrect')
93 |
94 | def load_ssl_official_weights(self, progress=True):
95 | raise NotImplemented
96 |
97 | def _forward_impl(self, x: Tensor) -> Tensor:
98 | # See note [TorchScript super()]
99 | x = self.features(x)
100 |
101 | x = self.avgpool(x)
102 | x = torch.flatten(x, 1)
103 |
104 | # x = self.classifier(x)
105 |
106 | return x
107 |
--------------------------------------------------------------------------------
/models/projector.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Projector(nn.Module):
5 | def __init__(self, in_dim, out_dim=768, apply_bn=True):
6 | super(Projector, self).__init__()
7 | self.linear1 = nn.Linear(in_dim, in_dim)
8 | self.linear2 = nn.Linear(in_dim, out_dim)
9 | self.bn = nn.BatchNorm1d(in_dim)
10 | self.relu = nn.ReLU()
11 | if apply_bn:
12 | self.projector = nn.Sequential(self.linear1, self.bn, self.relu, self.linear2)
13 | else:
14 | self.projector = nn.Sequential(self.linear1, self.relu, self.linear2)
15 |
16 | def forward(self, x):
17 | return self.projector(x)
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import Tensor
4 | import torch.nn as nn
5 | import math
6 | import numpy as np
7 | import torch.nn.functional as F
8 | from torch.hub import load_state_dict_from_url
9 |
10 |
11 | class ResNet10(torchvision.models.resnet.ResNet):
12 | def __init__(self, track_bn=True):
13 | def norm_layer(*args, **kwargs):
14 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
15 | super().__init__(torchvision.models.resnet.BasicBlock, [1, 1, 1, 1], norm_layer=norm_layer)
16 | del self.fc
17 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
18 | self.final_feat_dim = 512
19 |
20 | def load_sl_official_weights(self, progress=True):
21 | raise NotImplemented
22 |
23 | def load_ssl_official_weights(self, progress=True):
24 | raise NotImplemented
25 |
26 | def _forward_impl(self, x: Tensor) -> Tensor:
27 | # See note [TorchScript super()]
28 | x = self.conv1(x)
29 | x = self.bn1(x)
30 | x = self.relu(x)
31 | x = self.maxpool(x)
32 |
33 | x = self.layer1(x)
34 | x = self.layer2(x)
35 | x = self.layer3(x)
36 | x = self.layer4(x)
37 |
38 | x = self.avgpool(x)
39 | x = torch.flatten(x, 1)
40 | # x = self.fc(x)
41 |
42 | return x
43 |
44 |
45 | class ResNet18(torchvision.models.resnet.ResNet):
46 | def __init__(self, track_bn=True):
47 | def norm_layer(*args, **kwargs):
48 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
49 | super().__init__(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], norm_layer=norm_layer)
50 | del self.fc
51 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
52 | self.final_feat_dim = 512
53 |
54 | def load_sl_official_weights(self, progress=True):
55 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet18'],
56 | progress=progress)
57 |
58 | del state_dict['conv1.weight']
59 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
60 | # if len(missing) > 0:
61 | # raise AssertionError('Model code may be incorrect')
62 |
63 | def load_ssl_official_weights(self, progress=True):
64 | raise NotImplemented
65 |
66 | def _forward_impl(self, x: Tensor) -> Tensor:
67 | # See note [TorchScript super()]
68 | x = self.conv1(x)
69 | x = self.bn1(x)
70 | x = self.relu(x)
71 | x = self.maxpool(x)
72 |
73 | x = self.layer1(x)
74 | x = self.layer2(x)
75 | x = self.layer3(x)
76 | x = self.layer4(x)
77 |
78 | x = self.avgpool(x)
79 | x = torch.flatten(x, 1)
80 | # x = self.fc(x)
81 |
82 | return x
83 |
84 |
85 | class ResNet34(torchvision.models.resnet.ResNet):
86 | def __init__(self, track_bn=True):
87 | def norm_layer(*args, **kwargs):
88 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
89 | super().__init__(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], norm_layer=norm_layer)
90 | del self.fc
91 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
92 | self.final_feat_dim = 512
93 |
94 | def load_sl_official_weights(self, progress=True):
95 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet34'],
96 | progress=progress)
97 |
98 | del state_dict['conv1.weight']
99 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
100 | # if len(missing) > 0:
101 | # raise AssertionError('Model code may be incorrect')
102 |
103 | def load_ssl_official_weights(self, progress=True):
104 | raise NotImplemented
105 |
106 | def _forward_impl(self, x: Tensor) -> Tensor:
107 | # See note [TorchScript super()]
108 | x = self.conv1(x)
109 | x = self.bn1(x)
110 | x = self.relu(x)
111 | x = self.maxpool(x)
112 |
113 | x = self.layer1(x)
114 | x = self.layer2(x)
115 | x = self.layer3(x)
116 | x = self.layer4(x)
117 |
118 | x = self.avgpool(x)
119 | x = torch.flatten(x, 1)
120 | # x = self.fc(x)
121 |
122 | return x
123 |
124 | # finetune according to the stethoscope types (code from RespireNet)
125 | # def finetune(self, block_layer=5):
126 | # for idx, child in enumerate(self.children()):
127 | # if idx > block_layer:
128 | # break
129 | # for param in child.parameters():
130 | # param.requires_grad = False
131 |
132 |
133 | class ResNet50(torchvision.models.resnet.ResNet):
134 | def __init__(self, track_bn=True):
135 | def norm_layer(*args, **kwargs):
136 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
137 | super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], norm_layer=norm_layer)
138 | del self.fc
139 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
140 | self.final_feat_dim = 2048
141 |
142 | def load_sl_official_weights(self, progress=True):
143 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet50'],
144 | progress=progress)
145 |
146 | del state_dict['conv1.weight']
147 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
148 | # if len(missing) > 0:
149 | # raise AssertionError('Model code may be incorrect')
150 |
151 | def load_ssl_official_weights(self, progress=True):
152 | # only SimCLR is available
153 | from pl_bolts.models.self_supervised import SimCLR
154 |
155 | weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
156 | simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
157 |
158 | state_dict = {}
159 | for k, v in simclr.state_dict().items():
160 | if 'encoder.' in k:
161 | k = k.replace('encoder.', '')
162 | if 'fc' not in k or 'project' not in k:
163 | state_dict[k] = v
164 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
165 | # non_linear_evaluator.block_forward is a pretrained MLP classifier for SimCLR
166 | # refer to https://github.com/Lightning-AI/lightning-bolts/blob/bcbbf6ab6c36430946dd8a416ddc7e697e8507fc/pl_bolts/models/self_supervised/evaluator.py#L7
167 | if len(missing) > 0:
168 | raise AssertionError('Model code may be incorrect')
169 |
170 | def _forward_impl(self, x: Tensor) -> Tensor:
171 | # See note [TorchScript super()]
172 | x = self.conv1(x)
173 | x = self.bn1(x)
174 | x = self.relu(x)
175 | x = self.maxpool(x)
176 |
177 | x = self.layer1(x)
178 | x = self.layer2(x)
179 | x = self.layer3(x)
180 | x = self.layer4(x)
181 |
182 | x = self.avgpool(x)
183 | x = torch.flatten(x, 1)
184 | # x = self.fc(x)
185 |
186 | return x
187 |
188 |
189 | class ResNet101(torchvision.models.resnet.ResNet):
190 | def __init__(self, track_bn=True):
191 | def norm_layer(*args, **kwargs):
192 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
193 | super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], norm_layer=norm_layer)
194 | del self.fc
195 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
196 | self.final_feat_dim = 2048
197 |
198 | def load_sl_official_weights(self, progress=True):
199 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet101'],
200 | progress=progress)
201 |
202 | del state_dict['conv1.weight']
203 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
204 | # if len(missing) > 0:
205 | # raise AssertionError('Model code may be incorrect')
206 |
207 | def load_ssl_official_weights(self, progress=True):
208 | raise NotImplemented
209 |
210 | def _forward_impl(self, x: Tensor) -> Tensor:
211 | # See note [TorchScript super()]
212 | x = self.conv1(x)
213 | x = self.bn1(x)
214 | x = self.relu(x)
215 | x = self.maxpool(x)
216 |
217 | x = self.layer1(x)
218 | x = self.layer2(x)
219 | x = self.layer3(x)
220 | x = self.layer4(x)
221 |
222 | x = self.avgpool(x)
223 | x = torch.flatten(x, 1)
224 | # x = self.fc(x)
225 |
226 | return x
227 |
--------------------------------------------------------------------------------
/models/ssast.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import sys
4 | from timm.models.layers import trunc_normal_
5 | import timm
6 | import numpy as np
7 | import os
8 | import wget
9 | from timm.models.layers import to_2tuple
10 | from random import randrange
11 | from matplotlib import pyplot as plt
12 | import random
13 |
14 |
15 | # override the timm package to relax the input shape constraint.
16 | class PatchEmbed(nn.Module):
17 | """ Image to Patch Embedding
18 | """
19 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
20 | super().__init__()
21 | img_size = to_2tuple(img_size)
22 | patch_size = to_2tuple(patch_size)
23 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
24 | self.img_size = img_size
25 | self.patch_size = patch_size
26 | self.num_patches = num_patches
27 |
28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
29 |
30 | def forward(self, x):
31 | x = self.proj(x).flatten(2).transpose(1, 2)
32 | return x
33 |
34 |
35 | def get_sinusoid_encoding(n_position, d_hid):
36 | ''' Sinusoid position encoding table '''
37 |
38 | def get_position_angle_vec(position):
39 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
40 |
41 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
42 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
43 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
44 |
45 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
46 |
47 |
48 | class SSASTModel(nn.Module):
49 | def __init__(self, label_dim=527, fshape=128, tshape=2, fstride=128, tstride=2,
50 | input_fdim=128, input_tdim=1024, model_size='base', task='ft_avgtok',
51 | pretrain_stage=True, load_pretrained_mdl_path=None, mix_beta=None):
52 |
53 | super(SSASTModel, self).__init__()
54 | assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
55 |
56 | # override timm input shape restriction
57 | timm.models.vision_transformer.PatchEmbed = PatchEmbed
58 |
59 | self.task = task
60 |
61 | # pretrain the AST models
62 | if pretrain_stage == True:
63 | if load_pretrained_mdl_path != None:
64 | raise ValueError('Setting load_pretrained_mdl_path at pretraining stage is useless, pretraining is always from scratch, please change it to None.')
65 | if fstride != fshape or tstride != tshape:
66 | raise ValueError('fstride != fshape or tstride != tshape, they must be same at the pretraining stage, patch split overlapping is not supported.')
67 |
68 | # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
69 | if model_size == 'tiny':
70 | self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=False)
71 | self.heads, self.depth = 3, 12
72 | self.cls_token_num = 2
73 | elif model_size == 'small':
74 | self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=False)
75 | self.heads, self.depth = 6, 12
76 | self.cls_token_num = 2
77 | elif model_size == 'base':
78 | self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=False)
79 | self.heads, self.depth = 12, 12
80 | self.cls_token_num = 2
81 | elif model_size == 'base_nokd':
82 | self.v = timm.create_model('vit_deit_base_patch16_384', pretrained=False)
83 | self.heads, self.depth = 12, 12
84 | self.cls_token_num = 1
85 | else:
86 | raise Exception('Model size must be one of tiny, small, base, base_nokd')
87 |
88 | self.original_num_patches = self.v.patch_embed.num_patches
89 | self.oringal_hw = int(self.original_num_patches ** 0.5)
90 | self.original_embedding_dim = self.v.pos_embed.shape[2]
91 |
92 | # SSL Pretraining Code
93 | self.softmax = nn.Softmax(dim=-1)
94 | self.lsoftmax = nn.LogSoftmax(dim=-1)
95 | self.fshape, self.tshape = fshape, tshape
96 | self.fstride, self.tstride = fstride, tstride
97 | self.input_fdim, self.input_tdim = input_fdim, input_tdim
98 | # this is a trick to make state_dict to track pretraining input_fdim and input_tdim and save them by using torch.save
99 | self.p_input_fdim, self.p_input_tdim = nn.Parameter(torch.tensor(input_fdim), requires_grad=False), nn.Parameter(torch.tensor(input_tdim), requires_grad=False)
100 |
101 | # masked patch classification (discriminative objective) layer
102 | # we use two layers for pretext task, but using a single layer has similar performance.
103 | # we map the output of transformer (768-dim for base models) to 256-dim patch input space, and then dot product with flattened patch input (also 256-dim) to calculate loss.
104 | # alternatively, you can map the output of transformer to 768-dim patch embedding space, and dot product with patch embedding. Performance-wise they are similar, but map to 256 space is more efficient.
105 | self.cpredlayer = nn.Sequential(nn.Linear(self.original_embedding_dim, self.original_embedding_dim), nn.ReLU(), nn.Linear(self.original_embedding_dim, 256))
106 | # masked patch reconstruction (generative objective) layer
107 | self.gpredlayer = nn.Sequential(nn.Linear(self.original_embedding_dim, self.original_embedding_dim), nn.ReLU(), nn.Linear(self.original_embedding_dim, 256))
108 | self.unfold = torch.nn.Unfold(kernel_size=(fshape, tshape), stride=(fstride, tstride))
109 |
110 | # we use learnable mask embedding (follow the BEIT paper), but using a fixed mask embedding (e.g., 0) leads to same performance.
111 | self.mask_embed = nn.Parameter(torch.zeros([1, 1, self.original_embedding_dim]))
112 | self.mask_embed = torch.nn.init.xavier_normal_(self.mask_embed)
113 |
114 | # get the intermediate shape
115 | self.p_f_dim, self.p_t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape)
116 | num_patches = self.p_f_dim * self.p_t_dim
117 | self.num_patches = num_patches
118 | self.v.patch_embed.num_patches = num_patches
119 | print('pretraining patch split stride: frequency={:d}, time={:d}'.format(fstride, tstride))
120 | print('pretraining patch shape: frequency={:d}, time={:d}'.format(fshape, tshape))
121 | print('pretraining patch array dimension: frequency={:d}, time={:d}'.format(self.p_f_dim, self.p_t_dim))
122 | print('pretraining number of patches={:d}'.format(num_patches))
123 |
124 | # the linear patch projection layer, use 1 channel for spectrogram rather than the original 3 channels for RGB images.
125 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride))
126 | self.v.patch_embed.proj = new_proj
127 |
128 | # use trainable positional embedding
129 | new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + self.cls_token_num, self.original_embedding_dim))
130 | self.v.pos_embed = new_pos_embed
131 | trunc_normal_(self.v.pos_embed, std=.02)
132 |
133 | # use a pretrained models for finetuning
134 | elif pretrain_stage == False:
135 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
136 | if load_pretrained_mdl_path == None:
137 | raise ValueError('Please set load_pretrained_mdl_path to load a pretrained models.')
138 |
139 | # added kaen2891
140 | out_dir = './pretrained_models/'
141 | if not os.path.exists(out_dir):
142 | os.makedirs(out_dir, exist_ok=True)
143 |
144 | if load_pretrained_mdl_path == 'Patch':
145 | if os.path.exists(os.path.join(out_dir, 'SSAST-Base-Patch-400.pth')) == False:
146 | # this model performs 59.9 on the Avg Audio Performance
147 | # more details are below: https://github.com/YuanGongND/ssast
148 | audioset_mdl_url = 'https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Base-Patch-400.pth?dl=1'
149 | print('Download SSAST-Base-Patch-400.pth \n')
150 | wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'SSAST-Base-Patch-400.pth'))
151 | sd = torch.load(os.path.join(out_dir, 'SSAST-Base-Patch-400.pth'), map_location=device)
152 | print('Loaded SSAST-Base-Patch-400.pth successfully.')
153 |
154 | elif load_pretrained_mdl_path == 'Frame':
155 | if os.path.exists(os.path.join(out_dir, 'SSAST-Base-Frame-400.pth')) == False:
156 | # this model performs 57.6 on the Avg Audio Performance
157 | # more details are below: https://github.com/YuanGongND/ssast
158 | audioset_mdl_url = 'https://www.dropbox.com/s/nx6nl4d4bl71sm8/SSAST-Base-Frame-400.pth?dl=1'
159 | print('Download SSAST-Base-Frame-400.pth \n')
160 | wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'SSAST-Base-Frame-400.pth'))
161 | sd = torch.load(os.path.join(out_dir, 'SSAST-Base-Frame-400.pth'), map_location=device)
162 | print('\n Loaded SSAST-Base-Frame-400.pth successfully.')
163 |
164 | # get the fshape and tshape, input_fdim and input_tdim in the pretraining stage
165 | try:
166 | p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
167 | p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
168 | except:
169 | raise ValueError('The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.')
170 |
171 | print('now load a SSL pretrained models from ' + load_pretrained_mdl_path)
172 | # during pretraining, fstride=fshape and tstride=tshape because no patch overlapping is used
173 | # here, input_fdim and input_tdim should be that used in pretraining, not that in the fine-tuning.
174 | # we need to know input_fdim and input_tdim to do positional embedding cut/interpolation.
175 | # generally it should be better to use same input_fdim during pretraining and finetuning, but input_tdim can be safely different
176 | audio_model = SSASTModel(fstride=p_fshape, tstride=p_tshape, fshape=p_fshape, tshape=p_tshape,
177 | input_fdim=p_input_fdim, input_tdim=p_input_tdim, pretrain_stage=True, model_size=model_size)
178 | audio_model = torch.nn.DataParallel(audio_model)
179 | audio_model.load_state_dict(sd, strict=False)
180 |
181 | self.v = audio_model.module.v
182 | self.final_feat_dim = self.original_embedding_dim = self.v.pos_embed.shape[2]
183 | self.mix_beta = mix_beta
184 | self.cls_token_num = audio_model.module.cls_token_num
185 |
186 | # mlp head for fine-tuning
187 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim),
188 | nn.Linear(self.original_embedding_dim, label_dim))
189 |
190 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape)
191 | # patch array dimension during pretraining
192 | p_f_dim, p_t_dim = audio_model.module.p_f_dim, audio_model.module.p_t_dim
193 | num_patches = f_dim * t_dim
194 | p_num_patches = p_f_dim * p_t_dim
195 | self.v.patch_embed.num_patches = num_patches
196 | print('fine-tuning patch split stride: frequncey={:d}, time={:d}'.format(fstride, tstride))
197 | print('fine-tuning number of patches={:d}'.format(num_patches))
198 |
199 | # patch shape should be same for pretraining and fine-tuning
200 | if fshape != p_fshape or tshape != p_tshape:
201 | raise ValueError('The patch shape of pretraining and fine-tuning is not consistant, pretraining: f={:d}, t={:d}, finetuning: f={:d}, t={:d}'.format(p_fshape, p_tshape, fshape, tshape))
202 |
203 | # patch split stride generally should be different for pretraining and fine-tuning, as patch split overlapping is only used in finetuning
204 | # during pretraining, p_fshape = p_fstride and p_tshape = p_tstride
205 | if fstride != p_fshape or tstride != p_tshape:
206 | # initialize a new patch embedding layer with desired new stride.
207 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride))
208 | # but the weights of patch embedding layer is still got from the pretrained models
209 | new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
210 | new_proj.bias = self.v.patch_embed.proj.bias
211 | self.v.patch_embed.proj = new_proj
212 |
213 | new_pos_embed = self.v.pos_embed[:, self.cls_token_num:, :].detach().reshape(1, p_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, p_f_dim, p_t_dim)
214 | # cut or interpolate the positional embedding
215 | if t_dim < p_t_dim:
216 | new_pos_embed = new_pos_embed[:, :, :, int(p_t_dim/2) - int(t_dim / 2): int(p_t_dim/2) - int(t_dim / 2) + t_dim]
217 | else:
218 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(8, t_dim), mode='bilinear')
219 | if f_dim < p_f_dim:
220 | new_pos_embed = new_pos_embed[:, :, int(p_f_dim/2) - int(f_dim / 2): int(p_f_dim/2) - int(f_dim / 2) + t_dim, :]
221 | else:
222 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
223 |
224 | new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1, 2)
225 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :self.cls_token_num, :].detach(), new_pos_embed], dim=1))
226 |
227 | # get the shape of intermediate representation.
228 | def get_shape(self, fstride, tstride, input_fdim, input_tdim, fshape, tshape):
229 | test_input = torch.randn(1, 1, input_fdim, input_tdim)
230 | test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride))
231 | test_out = test_proj(test_input)
232 | f_dim = test_out.shape[2]
233 | t_dim = test_out.shape[3]
234 | return f_dim, t_dim
235 |
236 | # generate mask for 16*16 patch
237 | def gen_maskid_patch(self, sequence_len=512, mask_size=100, cluster=3):
238 | mask_id = []
239 |
240 | # randomize clutering factor in [3,6)
241 | cur_clus = randrange(cluster) + 3
242 |
243 | while len(list(set(mask_id))) <= mask_size:
244 | start_id = randrange(sequence_len)
245 |
246 | # this improves the efficiency, but might change the pretrained model
247 | # while start_id in mask_id:
248 | # start_id = randrange(sequence_len)
249 |
250 | cur_mask = []
251 | for i in range(0, cur_clus):
252 | for j in range(0, cur_clus):
253 | mask_cand = start_id + self.p_t_dim * i + j
254 | if mask_cand > 0 and mask_cand < sequence_len:
255 | cur_mask.append(mask_cand)
256 | mask_id = mask_id + cur_mask
257 | mask_id = list(set(mask_id))[:mask_size]
258 | return torch.tensor(mask_id)
259 |
260 | # using cluster for frame masking hurts the performance, so just use the naive random sampling
261 | def gen_maskid_frame(self, sequence_len=512, mask_size=100):
262 | mask_id = random.sample(range(0, sequence_len), mask_size)
263 | return torch.tensor(mask_id)
264 |
265 | def finetuningavgtok(self, x, y=None, patch_mix=False):
266 | B = x.shape[0]
267 | x = self.v.patch_embed(x)
268 |
269 | if patch_mix:
270 | x, y_a, y_b, lam, index = self.patch_mix(x, y)
271 |
272 | if self.cls_token_num == 2:
273 | cls_tokens = self.v.cls_token.expand(B, -1, -1)
274 | dist_token = self.v.dist_token.expand(B, -1, -1)
275 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
276 | else:
277 | cls_tokens = self.v.cls_token.expand(B, -1, -1)
278 | x = torch.cat((cls_tokens, x), dim=1)
279 |
280 | x = x + self.v.pos_embed
281 | x = self.v.pos_drop(x)
282 |
283 | for blk_id, blk in enumerate(self.v.blocks):
284 | x = blk(x)
285 | x = self.v.norm(x)
286 |
287 | # average output of all tokens except cls token(s)
288 | x = torch.mean(x[:, self.cls_token_num:, :], dim=1)
289 | # x = self.mlp_head(x)
290 |
291 | if not patch_mix:
292 | return x
293 | else:
294 | return x, y_a, y_b, lam, index
295 |
296 | def finetuningcls(self, x):
297 | B = x.shape[0]
298 | x = self.v.patch_embed(x)
299 | if self.cls_token_num == 2:
300 | cls_tokens = self.v.cls_token.expand(B, -1, -1)
301 | dist_token = self.v.dist_token.expand(B, -1, -1)
302 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
303 | else:
304 | cls_tokens = self.v.cls_token.expand(B, -1, -1)
305 | x = torch.cat((cls_tokens, x), dim=1)
306 | x = x + self.v.pos_embed
307 | x = self.v.pos_drop(x)
308 |
309 | for blk_id, blk in enumerate(self.v.blocks):
310 | x = blk(x)
311 | x = self.v.norm(x)
312 |
313 | # if models has two cls tokens (DEIT), average as the clip-level representation
314 | if self.cls_token_num == 2:
315 | x = (x[:, 0] + x[:, 1]) / 2
316 | else:
317 | x = x[:, 0]
318 | # x = self.mlp_head(x)
319 | return x
320 |
321 | # masked patch pretraining with discriminative objective
322 | def mpc(self, x, mask_patch, cluster, show_mask=False):
323 | input = self.unfold(x).transpose(1, 2)
324 | B = x.shape[0]
325 | # x in shape (batch_size, sequence_len, embedding dim)
326 | x = self.v.patch_embed(x)
327 |
328 | # encode the patch
329 | # size 12(batch_size) * 100(#mask_patch) * 768(hidden_dim), prepare to save the true values of masked samples
330 | encode_samples = torch.empty((B, mask_patch, 256), device=x.device, requires_grad=False).float()
331 | # size 12(batch_size) * 100(#mask_patch), index of masked patches
332 | mask_index = torch.empty((B, mask_patch), device=x.device, requires_grad=False).long()
333 | # size 12(batch_size) * 512(sequence_len) * 768(hidden_dim)
334 | mask_dense = torch.ones([x.shape[0], x.shape[1], x.shape[2]], device=x.device)
335 |
336 | # for each audio clip in the batch
337 | for i in range(B):
338 | # randomly generate #mask_patch mask indexes without duplicate
339 | if cluster == True:
340 | # use this if you are masking e.g. 16*16 patches
341 | mask_index[i] = self.gen_maskid_patch(self.num_patches, mask_patch)
342 | else:
343 | # use this if you are masking frame, i.e., 128*2 patches
344 | mask_index[i] = self.gen_maskid_frame(self.num_patches, mask_patch)
345 | # copy the masked embeddings, note gradients are stopped in this path
346 | encode_samples[i] = input[i, mask_index[i], :].clone().detach()
347 | # mask the encode samples with 0
348 | mask_dense[i, mask_index[i], :] = 0
349 |
350 | # follow BEIT paper, mask with learnable masking embedding, but no performance diff observed compared with masking with 0s.
351 | mask_tokens = self.mask_embed.expand(B, x.shape[1], -1)
352 |
353 | # mask the patch
354 | x = x * mask_dense + (1-mask_dense) * mask_tokens
355 |
356 | # pass through the Transformer layers
357 | cls_tokens = self.v.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
358 | dist_token = self.v.dist_token.expand(B, -1, -1)
359 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
360 | x = x + self.v.pos_embed
361 | x = self.v.pos_drop(x)
362 | for blk in self.v.blocks:
363 | x = blk(x)
364 | x = self.v.norm(x)
365 |
366 | # prediction of the masked patch
367 | pred = torch.empty((B, mask_patch, 256), device=x.device).float() # e.g. size 12*100*768
368 | for i in range(B):
369 | # +2 for indexes because skipping the cls and dis token
370 | # we map the output of transformer (768-dim for base models) to 256-dim patch input space, and then dot product with flattened patch input (also 256-dim) to calculate loss.
371 | # alternatively, you can map the output of transformer to 768-dim patch embedding space, and dot product with patch embedding. Performance-wise they are similar, but map to 256 space is more efficient.
372 | pred[i] = self.cpredlayer(x[i, mask_index[i] + self.cls_token_num, :])
373 |
374 | # calculate the NCE loss
375 | nce = torch.tensor(0.0).to(x.device)
376 | correct = torch.tensor(0.0).to(x.device)
377 | for i in np.arange(0, B):
378 | # negative samples are from the same batch
379 | # 8/12/2022: has a difference with equation (1) in the ssast paper but (likely) performance-wise similar, see https://github.com/YuanGongND/ssast/issues/13
380 | total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 100*100
381 | correct += torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=0), torch.arange(0, mask_patch, device=x.device))) # correct is a tensor
382 | nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
383 | acc = 1. * correct / (B * mask_patch)
384 | nce = nce / (-1. * B * mask_patch)
385 |
386 | # visualize the masked area, for probing test only, set show_mask = False for any training/inference.
387 | if show_mask == False:
388 | return acc, nce
389 | else:
390 | if B > 1:
391 | raise Exception('Currently only support single spectrogram probing test.')
392 |
393 | self.mask_correct = torch.nn.Parameter(torch.arange(0, mask_patch), requires_grad=False)
394 |
395 | pred = input.clone() # [B, 512, 256]
396 | masked = input.clone()
397 |
398 | for i in range(B):
399 | result = [float(t) * 99 for t in torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct)]
400 | pred[i, mask_index[i], :] = torch.tensor(result).reshape(mask_patch, 1).expand(mask_patch, 256)
401 | masked[i, mask_index[i], :] = 99.0
402 |
403 | # print(total)
404 | # print(self.softmax(total))
405 | # print(torch.argmax(self.softmax(total), dim=0))
406 | # print(self.mask_correct)
407 | # print(torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct))
408 | # print([float(t)*99 for t in torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct)])
409 |
410 | fold = torch.nn.Fold(output_size=([self.input_fdim, self.input_tdim]), kernel_size=(self.fshape, self.tshape), stride=(self.fstride, self.tstride))
411 | pred = fold(pred.transpose(1, 2))
412 | masked = fold(masked.transpose(1, 2))
413 |
414 | return pred, masked
415 |
416 | # # masked patch pretraining with generative objective
417 | def mpg(self, input, mask_patch, cluster):
418 | B = input.shape[0]
419 | x = self.v.patch_embed(input)
420 | input = self.unfold(input).transpose(1, 2)
421 |
422 | # size 12(batch_size) * 100(#mask_patch), index of masked patches
423 | mask_index = torch.empty((B, mask_patch), device=x.device, requires_grad=False).long()
424 | # size 12(batch_size) * 512(sequence_len) * 768(hidden_dim)
425 | mask_dense = torch.ones([x.shape[0], x.shape[1], x.shape[2]], device=x.device)
426 | for i in range(B):
427 | # randomly generate #mask_patch mask indexes without duplicate
428 | if cluster == True:
429 | # use this if you are masking e.g. 16*16 patches
430 | mask_index[i] = self.gen_maskid_patch(self.num_patches, mask_patch)
431 | else:
432 | # use this if you are masking frame, i.e., 128*2 patches
433 | mask_index[i] = self.gen_maskid_frame(self.num_patches, mask_patch)
434 | mask_dense[i, mask_index[i], :] = 0
435 |
436 | mask_tokens = self.mask_embed.expand(B, x.shape[1], -1)
437 |
438 | # follow BEIT paper, mask with learnable masking embedding, but no performance diff observed compared with masking with 0s.
439 | x = x * mask_dense + (1-mask_dense) * mask_tokens
440 |
441 | # go through the Transformer layers
442 | cls_tokens = self.v.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
443 | dist_token = self.v.dist_token.expand(B, -1, -1)
444 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
445 | x = x + self.v.pos_embed
446 | x = self.v.pos_drop(x)
447 | for blk in self.v.blocks:
448 | x = blk(x)
449 | x = self.v.norm(x)
450 |
451 | pred = torch.empty((B, mask_patch, self.fshape * self.tshape), device=x.device).float() # e.g. size 12*100*256
452 | target = torch.empty((B, mask_patch, self.fshape * self.tshape), device=x.device).float() # e.g. size 12*100*256
453 |
454 | for i in range(B):
455 | # +2 for indexes because cls and dis token
456 | pred[i] = self.gpredlayer(x[i, mask_index[i] + self.cls_token_num, :])
457 | target[i] = input[i, mask_index[i], :]
458 |
459 | # calculate the MSE loss
460 | mse = torch.mean((pred - target) ** 2)
461 |
462 | return mse
463 |
464 | def patch_mix(self, image, target):
465 | if self.mix_beta > 0:
466 | lam = np.random.beta(self.mix_beta, self.mix_beta)
467 | else:
468 | lam = 1
469 |
470 | batch_size, num_patch, dim = image.size()
471 | device = image.device
472 |
473 | index = torch.randperm(batch_size).to(device)
474 |
475 | num_mask = int(num_patch * (1. - lam))
476 | mask = torch.randperm(num_patch)[:num_mask].to(device)
477 |
478 | image[:, mask, :] = image[index][:, mask, :]
479 | lam = 1 - (num_mask / num_patch)
480 |
481 | y_a, y_b = target, target[index]
482 | return image, y_a, y_b, lam, index
483 |
484 | def forward(self, x, y=None, patch_mix=False):
485 | cluster = True
486 | mask_patch = 400
487 |
488 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
489 | # x = x.unsqueeze(1) # (12, 1, 1024, 128)
490 | x = x.transpose(2, 3) # (12, 1, 128, 1024)
491 |
492 | # finetuning (ft), use the mean of all token (patch) output as clip-level representation.
493 | # this is default for SSAST fine-tuning as during pretraining, supervision signal is given to each token, not the [cls] token
494 | if self.task == 'ft_avgtok':
495 | return self.finetuningavgtok(x, y, patch_mix)
496 | # alternatively, use the [cls] token output as clip-level representation.
497 | elif self.task == 'ft_cls':
498 | return self.finetuningcls(x)
499 | # pretraining, masked patch classification (discriminative objective)
500 | elif self.task == 'pretrain_mpc':
501 | return self.mpc(x, mask_patch=mask_patch, cluster=cluster)
502 | # pretraining, masked patch reconstruction (generative objective)
503 | elif self.task == 'pretrain_mpg':
504 | return self.mpg(x, mask_patch=mask_patch, cluster=cluster)
505 | elif self.task == 'visualize_mask':
506 | return self.mpc(x, mask_patch=mask_patch, cluster=cluster, show_mask=True)
507 | else:
508 | raise Exception('Task unrecognized.')
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | tqdm
3 | librosa
4 | pandas
5 | nlpaug
6 | cmapy
7 | audiomentations
8 | timm==0.4.5
9 | wget
10 | cmake==3.18.4
11 | MulticoreTSNE
12 | seaborn
13 | numpy
--------------------------------------------------------------------------------
/scripts/icbhi_ce.sh:
--------------------------------------------------------------------------------
1 |
2 | MODEL="ast"
3 | SEED="1"
4 |
5 | for s in $SEED
6 | do
7 | for m in $MODEL
8 | do
9 | TAG="bs8_lr5e-5_ep50_seed${s}"
10 | CUDA_VISIBLE_DEVICES=0 python main.py --tag $TAG \
11 | --dataset icbhi \
12 | --seed $s \
13 | --class_split lungsound \
14 | --n_cls 4 \
15 | --epochs 50 \
16 | --batch_size 8 \
17 | --optimizer adam \
18 | --learning_rate 5e-5 \
19 | --weight_decay 1e-6 \
20 | --weighted_loss \
21 | --cosine \
22 | --model $m \
23 | --test_fold official \
24 | --pad_types repeat \
25 | --resz 1 \
26 | --n_mels 128 \
27 | --ma_update \
28 | --ma_beta 0.5 \
29 | --from_sl_official \
30 | --audioset_pretrained \
31 | --method ce
32 |
33 | # only for evaluation, add the following arguments
34 | # --eval \
35 | # --pretrained \
36 | # --pretrained_ckpt ./save/icbhi_ast_ce_bs8_lr5e-5_ep50_seed1/best.pth
37 |
38 | done
39 | done
40 |
--------------------------------------------------------------------------------
/scripts/icbhi_patchmix_ce.sh:
--------------------------------------------------------------------------------
1 |
2 | MODEL="ast"
3 | SEED="1"
4 |
5 | for s in $SEED
6 | do
7 | for m in $MODEL
8 | do
9 | TAG="bs8_lr5e-5_ep50_seed${s}"
10 | CUDA_VISIBLE_DEVICES=0 python main.py --tag $TAG \
11 | --dataset icbhi \
12 | --seed $s \
13 | --class_split lungsound \
14 | --n_cls 4 \
15 | --epochs 50 \
16 | --batch_size 8 \
17 | --optimizer adam \
18 | --learning_rate 5e-5 \
19 | --weight_decay 1e-6 \
20 | --cosine \
21 | --model $m \
22 | --test_fold official \
23 | --pad_types repeat \
24 | --resz 1 \
25 | --n_mels 128 \
26 | --ma_update \
27 | --ma_beta 0.5 \
28 | --from_sl_official \
29 | --audioset_pretrained \
30 | --method patchmix \
31 | --mix_beta 1.0
32 |
33 | # only for evaluation, add the following arguments
34 | # --eval \
35 | # --pretrained \
36 | # --pretrained_ckpt ./save/icbhi_ast_patchmix_bs8_lr5e-5_ep50_seed1/best.pth
37 |
38 | done
39 | done
40 |
41 |
--------------------------------------------------------------------------------
/scripts/icbhi_patchmix_cl.sh:
--------------------------------------------------------------------------------
1 |
2 | MODEL="ast"
3 | SEED="1"
4 |
5 | for s in $SEED
6 | do
7 | for m in $MODEL
8 | do
9 | TAG="bs8_lr5e-5_ep50_seed${s}_best_param"
10 | CUDA_VISIBLE_DEVICES=0 python main.py --tag $TAG \
11 | --dataset icbhi \
12 | --seed $s \
13 | --class_split lungsound \
14 | --n_cls 4 \
15 | --epochs 50 \
16 | --batch_size 8 \
17 | --optimizer adam \
18 | --learning_rate 5e-5 \
19 | --weight_decay 1e-6 \
20 | --cosine \
21 | --model $m \
22 | --test_fold official \
23 | --pad_types repeat \
24 | --resz 1 \
25 | --n_mels 128 \
26 | --ma_update \
27 | --ma_beta 0.5 \
28 | --from_sl_official \
29 | --audioset_pretrained \
30 | --method patchmix_cl \
31 | --temperature 0.06 \
32 | --proj_dim 768 \
33 | --alpha 1.0 \
34 | --mix_beta 1.0
35 |
36 | # only for evaluation, add the following arguments
37 | # --eval \
38 | # --pretrained \
39 | # --pretrained_ckpt ./save/icbhi_ast_patchmix_cl_bs8_lr5e-5_ep50_seed1_best_param/best.pth
40 |
41 | done
42 | done
43 |
44 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raymin0223/patch-mix_contrastive_learning/836b09fea1b70eb29fe0b25afa481286b56f5104/util/__init__.py
--------------------------------------------------------------------------------
/util/augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import nlpaug.augmenter.audio as naa
4 |
5 | import torch
6 | from torchvision.utils import _log_api_usage_once
7 | from torchvision.transforms import transforms
8 | from torchaudio import transforms as T
9 | from .time_warping import sparse_image_warp
10 |
11 | __all__ = ['augment_raw_audio', 'SpecAugment']
12 |
13 |
14 | def augment_raw_audio(sample, sample_rate, args):
15 | """
16 | Raw audio data augmentation technique
17 | you can utilize any library code
18 | 1) nlpaug
19 | 2) audiomentations
20 | 3) librosa
21 | """
22 |
23 | """ 1) nlpaug """
24 | augment_list = [
25 | # naa.CropAug(sampling_rate=sample_rate)
26 | naa.NoiseAug(), # apply noise injection operation
27 | naa.SpeedAug(), # apply speed adjustment operation
28 | naa.LoudnessAug(factor=(0.5, 2)), # apply adjusting loudness operation
29 | naa.VtlpAug(sampling_rate=sample_rate, zone=(0.0, 1.0)), # apply vocal tract length perturbation (VTLP) operation
30 | naa.PitchAug(sampling_rate=sample_rate, factor=(-1,3)) # apply pitch adjustment operation
31 | ]
32 |
33 | # randomly sample augmentation
34 | aug_idx = random.randint(0, len(augment_list)-1)
35 | sample = augment_list[aug_idx].augment(sample)
36 |
37 | # apply all augmentations
38 | # for aug_idx in range(len(augment_list)):
39 | # sample = augment_list[aug_idx].augment(sample)
40 |
41 | """ 2) audiomentations """
42 | # import audiomentations
43 | # from audiomentations import AddGaussianSNR, TimeStretch, PitchShift, Shift
44 |
45 | # # when using audiomentations library (not DEBUG yet)
46 | # audio_transforms = audiomentations.Compose([
47 | # AddGaussianSNR(min_snr_in_db=5, max_snr_in_db=40.0, p=0.5),
48 | # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
49 | # PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
50 | # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5),
51 | # ])
52 |
53 | # sample = audio_transforms(samples=sample, sample_rate=sample_rate)
54 |
55 | """ 3) librosa """
56 | # import librosa
57 |
58 | # def _noise(data):
59 | # noise_amp = 0.035 * np.random.uniform() * np.amax(data)
60 | # data = data + noise_amp * np.random.normal(size=data.shape[0])
61 | # return data
62 |
63 | # def _stretch(data, rate=0.8):
64 | # return librosa.effects.time_stretch(data, rate)
65 |
66 | # def _shift(data):
67 | # shift_range = int(np.random.uniform(low=-5, high=5) * 1000)
68 | # return np.roll(data, shift_range)
69 |
70 | # def _pitch(data, sampling_rate, pitch_factor=0.7):
71 | # return librosa.effects.pitch_shift(data, sampling_rate, pitch_factor)
72 |
73 | # sample = _noise(sample)
74 | # sample = _stretch(sample)
75 | # sample = _shift(sample)
76 | # sample = _pitch(sample, sample_rate)
77 |
78 | if type(sample) == list:
79 | return sample[0]
80 | else:
81 | return sample
82 |
83 |
84 | # Use this Class when you load dataset with librosa
85 | class SpecAugment(torch.nn.Module):
86 | '''
87 | Unofficial Implementation of SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
88 | Paper: https://arxiv.org/pdf/1904.08779.pdf
89 | Ref. github: https://github.com/pyyush/SpecAugment/blob/219fc6e9ed4838fe9700295700040b1da283c536/augment.py#L10
90 |
91 | Augmentation Parameters for policies
92 | -----------------------------------------
93 | Policy | W | F | m_F | T | p | m_T
94 | -----------------------------------------
95 | None | 0 | 0 | - | 0 | - | -
96 | -----------------------------------------
97 | LB | 80 | 27 | 1 | 100 | 1.0 | 1
98 | -----------------------------------------
99 | LD | 80 | 27 | 2 | 100 | 1.0 | 2
100 | -----------------------------------------
101 | SM | 40 | 15 | 2 | 70 | 0.2 | 2
102 | -----------------------------------------
103 | SS | 40 | 27 | 2 | 70 | 0.2 | 2
104 | -----------------------------------------
105 |
106 | LB : LibriSpeech basic
107 | LD : LibriSpeech double
108 | SM : Switchboard mild
109 | SS : Switchboard strong
110 | W : Time Warp parameter
111 | F : Frequency Mask parameter
112 | m_F : Number of Frequency masks
113 | T : Time Mask parameter
114 | p : Parameter for calculating upper bound for time mask
115 | m_T : Number of time masks
116 | '''
117 | #def __init__(self, policy, zero_mean_normalized=False):
118 | def __init__(self, args):
119 | super().__init__()
120 | _log_api_usage_once(self)
121 |
122 | self.policy = args.specaug_policy
123 | self.mask = args.specaug_mask
124 |
125 | # Policy Specific Parameters
126 | if self.policy == 'LB':
127 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 80, 27, 1, 100, 1.0, 1
128 | elif self.policy == 'LD':
129 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 80, 27, 2, 100, 1.0, 2
130 | elif self.policy == 'SM':
131 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 40, 15, 2, 70, 0.2, 2
132 | elif self.policy == 'SS':
133 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 40, 27, 2, 70, 0.2, 2
134 | elif self.policy == 'icbhi_sup':
135 | # following https://github.com/ilyassmoummad/scl_icbhi2017
136 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 0, 20, 2, 50, 1.0, 2
137 | elif self.policy == 'icbhi_ast_sup':
138 | self.W, self.F, self.m_F, self.T, self.p, self.m_T = 0, 48, 2, 160, 1.0, 2
139 |
140 | # mimic SpecAugment by using torchaudio.transforms
141 | # self.spec_aug = torch.nn.Sequential(
142 | # T.TimeStretch(0.8, fixed_rate=True),
143 | # T.FrequencyMasking(freq_mask_param=self.F),
144 | # T.TimeMasking(time_mask_param=self.T)
145 | # )
146 |
147 | def time_warp(self):
148 | """ Tensorflow version """
149 | # v, tau = self.mel_spectrogram.shape[1], self.mel_spectrogram.shape[2]
150 |
151 | # horiz_line_thru_ctr = self.mel_spectrogram[0][v//2]
152 |
153 | # random_pt = horiz_line_thru_ctr[random.randrange(self.W, tau - self.W)] # random point along the horizontal/time axis
154 | # w = np.random.uniform((-self.W), self.W) # distance
155 |
156 | # src_points = [[[v//2, random_pt[0]]]] # Source Points
157 | # dest_points = [[[v//2, random_pt[0] + w]]] # Destination Points
158 | # self.mel_spectrogram, _ = sparse_image_warp(self.mel_spectrogram, src_points, dest_points, num_boundary_points=2)
159 | # self.mel_spectrogram = self.mel_spectrogram.numpy()
160 |
161 | """ Pytorch version """
162 | # refer to https://github.com/zcaceres/spec_augment/blob/master/SpecAugment.ipynb
163 | num_rows = self.mel_spectrogram.shape[2]
164 | spec_len = self.mel_spectrogram.shape[1]
165 | device = self.mel_spectrogram.device
166 |
167 | # adapted from https://github.com/DemisEom/SpecAugment/
168 | pt = (num_rows - 2 * self.W) * torch.rand([1], dtype=torch.float) + self.W # random point along the time axis
169 | src_ctr_pt_freq = torch.arange(0, spec_len // 2) # control points on freq-axis
170 | src_ctr_pt_time = torch.ones_like(src_ctr_pt_freq) * pt # control points on time-axis
171 | src_ctr_pts = torch.stack((src_ctr_pt_freq, src_ctr_pt_time), dim=-1)
172 | src_ctr_pts = src_ctr_pts.float().to(device)
173 |
174 | # Destination
175 | w = 2 * self.W * torch.rand([1], dtype=torch.float) - self.W # distance
176 | dest_ctr_pt_freq = src_ctr_pt_freq
177 | dest_ctr_pt_time = src_ctr_pt_time + w
178 | dest_ctr_pts = torch.stack((dest_ctr_pt_freq, dest_ctr_pt_time), dim=-1)
179 | dest_ctr_pts = dest_ctr_pts.float().to(device)
180 |
181 | # warp
182 | source_control_point_locations = torch.unsqueeze(src_ctr_pts, 0) # (1, v//2, 2)
183 | dest_control_point_locations = torch.unsqueeze(dest_ctr_pts, 0) # (1, v//2, 2)
184 | warped_spectro, dense_flows = sparse_image_warp(self.mel_spectrogram, source_control_point_locations, dest_control_point_locations)
185 |
186 | return warped_spectro.squeeze(3)
187 |
188 | def freq_mask(self):
189 | if self.mask == 'mean':
190 | # maksing to mean value
191 | mask_value = self.mel_spectrogram.mean()
192 | elif self.mask == 'zero':
193 | # maksing to zero value
194 | mask_value = 0.
195 |
196 | v = self.mel_spectrogram.shape[1] # no. of mel bins
197 |
198 | # apply m_F frequency masks to the mel spectrogram
199 | for i in range(self.m_F):
200 | f = int(np.random.uniform(0, self.F)) # [0, F)
201 | f0 = random.randint(0, v - f) # [0, v - f)
202 | self.mel_spectrogram[:, f0:f0 + f, :] = mask_value
203 |
204 | return self.mel_spectrogram
205 |
206 | def time_mask(self):
207 | if self.mask == 'mean':
208 | # maksing to mean value
209 | mask_value = self.mel_spectrogram.mean()
210 | elif self.mask == 'zero':
211 | # maksing to zero value
212 | mask_value = 0.
213 |
214 | tau = self.mel_spectrogram.shape[2] # time frames
215 |
216 | # apply m_T time masks to the mel spectrogram
217 | for i in range(self.m_T):
218 | t = int(np.random.uniform(0, self.T)) # [0, T)
219 | t0 = random.randint(0, tau - t) # [0, tau - t)
220 | self.mel_spectrogram[:, :, t0:t0 + t] = mask_value
221 |
222 | return self.mel_spectrogram
223 |
224 | def forward(self, img):
225 | """
226 | Args:
227 | img (Tensor): Mel-spectrogram to be specaugmented.
228 | Returns:
229 | Tensor: Time-warped, time masked and freq masked image.
230 | """
231 | # self.mel_spectrogram = img # np.array [time, freq, channel]
232 | self.mel_spectrogram = img # torch.tensor [channel, time, freq]
233 | self.mel_spectrogram = self.mel_spectrogram.transpose(2, 1) # torch.tensor [channel, freq, time]
234 |
235 | if self.p >= torch.randn(1):
236 | if self.W:
237 | try:
238 | # input shape of time_warp should be [sample_size, time, freq]
239 | # assume that channel == 1 and augment each "one" sample
240 | self.mel_spectrogram= self.time_warp()
241 | except Exception as e:
242 | # torch.linalg.solve: (Batch element 0): The solver failed because the input matrix is singular.
243 | # print(e)
244 | pass
245 |
246 | self.mel_spectrogram = self.freq_mask()
247 | self.mel_spectrogram = self.time_mask()
248 |
249 | return self.mel_spectrogram.transpose(2, 1)
250 |
251 | def __repr__(self) -> str:
252 | return f"{self.__class__.__name__}()"
--------------------------------------------------------------------------------
/util/icbhi_dataset.py:
--------------------------------------------------------------------------------
1 | from curses import meta
2 | import os
3 | import cv2
4 | import pickle
5 | import random
6 | import numpy as np
7 | import pandas as pd
8 | from tqdm import tqdm
9 |
10 | import librosa
11 | import torch
12 | from torch.utils.data import Dataset
13 | from copy import deepcopy
14 | from PIL import Image
15 |
16 | from .icbhi_util import get_annotations, save_image, generate_fbank, get_individual_cycles_librosa, split_pad_sample, generate_mel_spectrogram, concat_augmentation
17 | from .icbhi_util import get_individual_cycles_torchaudio, cut_pad_sample_torchaudio
18 | from .augmentation import augment_raw_audio
19 |
20 |
21 | class ICBHIDataset(Dataset):
22 | def __init__(self, train_flag, transform, args, print_flag=True, mean_std=False):
23 | data_folder = os.path.join(args.data_folder, 'icbhi_dataset/audio_test_data')
24 | folds_file = os.path.join(args.data_folder, 'icbhi_dataset/patient_list_foldwise.txt')
25 | official_folds_file = os.path.join(args.data_folder, 'icbhi_dataset/official_split.txt')
26 | test_fold = args.test_fold
27 |
28 | self.data_folder = data_folder
29 | self.train_flag = train_flag
30 | self.split = 'train' if train_flag else 'test'
31 | self.transform = transform
32 | self.args = args
33 | self.mean_std = mean_std
34 |
35 | # parameters for spectrograms
36 | self.sample_rate = args.sample_rate
37 | self.desired_length = args.desired_length
38 | self.pad_types = args.pad_types
39 | self.nfft = args.nfft
40 | self.hop = self.nfft // 2
41 | self.n_mels = args.n_mels
42 | self.f_min = 50
43 | self.f_max = 2000
44 | self.dump_images = False
45 |
46 | # ==========================================================================
47 | """ get ICBHI dataset meta information """
48 | # store stethoscope device information for each file or patient
49 | self.file_to_device = {}
50 | self.device_to_id = {'Meditron': 0, 'LittC2SE': 1, 'Litt3200': 2, 'AKGC417L': 3}
51 | self.device_id_to_patient = {0: [], 1: [], 2: [], 3: []}
52 |
53 | filenames = os.listdir(data_folder)
54 | filenames =set([f.strip().split('.')[0] for f in filenames if '.wav' in f or '.txt' in f])
55 | for f in filenames:
56 | f += '.wav'
57 | # get the total number of devices from original dataset (icbhi dataset has 4 stethoscope devices)
58 | device = f.strip().split('_')[-1].split('.')[0]
59 | # if device not in self.device_to_id:
60 | # self.device_to_id[device] = device_id
61 | # self.device_id_to_patient[device_id] = []
62 | # device_id += 1
63 |
64 | # get the device information for each wav file
65 | self.file_to_device[f.strip().split('.')[0]] = self.device_to_id[device]
66 |
67 | pat_id = f.strip().split('_')[0]
68 | if pat_id not in self.device_id_to_patient[self.device_to_id[device]]:
69 | self.device_id_to_patient[self.device_to_id[device]].append(pat_id)
70 |
71 | # store all metadata (age, sex, adult_BMI, child_weight, child_height, device_index)
72 | self.file_to_metadata = {}
73 | meta_file = pd.read_csv(os.path.join(args.data_folder, 'icbhi_dataset/metadata.txt'), names=['age', 'sex', 'adult_BMI', 'child_weight', 'child_height', 'chest_location'], delimiter= '\t')
74 | meta_file['chest_location'].replace({'Tc':0, 'Al':1, 'Ar':2, 'Pl':3, 'Pr':4, 'Ll':5, 'Lr':6}, inplace=True)
75 | for f in filenames:
76 | pat_idx = int(f.strip().split('_')[0])
77 | info = list(meta_file.loc[pat_idx])
78 | info[1] = 0 if info[1] == 'M' else 1
79 |
80 | info = np.array(info)
81 | for idx in np.argwhere(np.isnan(info)):
82 | info[idx] = -1
83 |
84 | self.file_to_metadata[f] = torch.tensor(np.append(info, self.file_to_device[f.strip()]))
85 | # ==========================================================================
86 |
87 | # ==========================================================================
88 | """ train-test split based on train_flag and test_fold """
89 | if test_fold in ['0', '1', '2', '3', '4']: # from RespireNet, 80-20% split
90 | patient_dict = {}
91 | all_patients = open(folds_file).read().splitlines()
92 | for line in all_patients:
93 | idx, fold = line.strip().split(' ')
94 | if train_flag and int(fold) != int(test_fold):
95 | patient_dict[idx] = fold
96 | elif train_flag == False and int(fold) == int(test_fold):
97 | patient_dict[idx] = fold
98 |
99 | if print_flag:
100 | print('*' * 20)
101 | print('Train and test 80-20% split with test_fold {}'.format(test_fold))
102 | print('Patience number in {} dataset: {}'.format(self.split, len(patient_dict)))
103 | else:
104 | """
105 | args.test_fold == 'official', 60-40% split
106 | two patient dataset contain both train and test samples
107 | """
108 | patient_dict = {}
109 | all_fpath = open(official_folds_file).read().splitlines()
110 | for line in all_fpath:
111 | fpath, fold = line.strip().split('\t')
112 | if train_flag and fold == 'train':
113 | # idx = fpath.strip().split('_')[0]
114 | patient_dict[fpath] = fold
115 | elif not train_flag and fold == 'test':
116 | # idx = fpath.strip().split('_')[0]
117 | patient_dict[fpath] = fold
118 |
119 | if print_flag:
120 | print('*' * 20)
121 | print('Train and test 60-40% split with test_fold {}'.format(test_fold))
122 | print('File number in {} dataset: {}'.format(self.split, len(patient_dict)))
123 | # ==========================================================================
124 |
125 | # dict {filename: annotations}, annotation is for breathing cycle
126 | annotation_dict = get_annotations(args, data_folder)
127 |
128 | self.filenames = []
129 | for f in filenames:
130 | # for 'official' test_fold, two patient dataset contain both train and test samples
131 | idx = f.split('_')[0] if test_fold in ['0', '1', '2', '3', '4'] else f
132 | if args.stetho_id >= 0: # extract specific device dataset
133 | if idx in patient_dict and self.file_to_device[f] == args.stetho_id:
134 | self.filenames.append(f)
135 | else: # use all dataset
136 | if idx in patient_dict:
137 | self.filenames.append(f)
138 |
139 | self.audio_data = [] # each sample is a tuple with (audio_data, label, filename)
140 | self.metadata = [] # (age, sex, adult_BMI, child_weight, child_height, device_idx)
141 | self.labels = []
142 |
143 | if print_flag:
144 | print('*' * 20)
145 | print("Extracting individual breathing cycles..")
146 |
147 | self.cycle_list = []
148 | self.filename_to_label = {}
149 | self.classwise_cycle_list = [[] for _ in range(args.n_cls)]
150 |
151 | # ==========================================================================
152 | """ extract individual cycles by librosa or torchaudio """
153 | for idx, filename in enumerate(self.filenames):
154 | # you can use self.filename_to_label to get statistics of original sample labels (will not be used on other function)
155 | self.filename_to_label[filename] = []
156 |
157 | # "RespireNet" version: get original cycles 6,898 by librosa
158 | # sample_data = get_individual_cycles_librosa(args, annotation_dict[filename], data_folder, filename, args.sample_rate, args.n_cls, args.butterworth_filter)
159 |
160 | # "SCL" version: get original cycles 6,898 by torchaudio and cut_pad samples
161 | sample_data = get_individual_cycles_torchaudio(args, annotation_dict[filename], self.file_to_metadata[filename], data_folder, filename, args.sample_rate, args.n_cls)
162 |
163 | # cycles_with_labels: [(audio_chunk, label, metadata), (...)]
164 | cycles_with_labels = [(data[0], data[1], self.file_to_metadata[filename]) for data in sample_data]
165 |
166 | self.cycle_list.extend(cycles_with_labels)
167 | for d in cycles_with_labels:
168 | # {filename: [label for cycle 1, ...]}
169 | self.filename_to_label[filename].append(d[1])
170 | self.classwise_cycle_list[d[1]].append(d)
171 |
172 | # concatenation based augmentation scheme from "RespireNet" paper..
173 | # TODO: how to decide the meta information of generated cycles
174 | # if train_flag and args.concat_aug_scale and args.class_split == 'lungsound' and args.n_cls == 4:
175 | # self.classwise_cycle_list, self.cycle_list = concat_augmentation(self.classwise_cycle_list, self.cycle_list, scale=args.concat_aug_scale)
176 |
177 | for sample in self.cycle_list:
178 | self.metadata.append(sample[2])
179 |
180 | # "RespireNet" version: split and pad each cycle to the desired length (cycle numbers can be more than 6,898)
181 | # output = split_pad_sample(sample, args.desired_length, args.sample_rate, types=args.pad_types)
182 | # self.audio_data.extend(output)
183 |
184 | # "SCL" version
185 | self.audio_data.append(sample)
186 | # ==========================================================================
187 |
188 | self.class_nums = np.zeros(args.n_cls)
189 | for sample in self.audio_data:
190 | self.class_nums[sample[1]] += 1
191 | self.labels.append(sample[1])
192 | self.class_ratio = self.class_nums / sum(self.class_nums) * 100
193 |
194 | if print_flag:
195 | print('[Preprocessed {} dataset information]'.format(self.split))
196 | print('total number of audio data: {}'.format(len(self.audio_data)))
197 | for i, (n, p) in enumerate(zip(self.class_nums, self.class_ratio)):
198 | print('Class {} {:<9}: {:<4} ({:.1f}%)'.format(i, '('+args.cls_list[i]+')', int(n), p))
199 |
200 | # ==========================================================================
201 | """ convert mel-spectrogram """
202 | self.audio_images = []
203 | for index in range(len(self.audio_data)):
204 | audio, label = self.audio_data[index][0], self.audio_data[index][1]
205 |
206 | audio_image = []
207 | # self.aug_times = 1 + 5 * self.args.augment_times # original + five naa augmentations * augment_times (optional)
208 | for aug_idx in range(self.args.raw_augment+1):
209 | if aug_idx > 0:
210 | if self.train_flag and not mean_std:
211 | audio = augment_raw_audio(audio, self.sample_rate, self.args)
212 |
213 | # "RespireNet" version: pad incase smaller than desired length
214 | # audio = split_pad_sample([audio, 0,0], self.desired_length, self.sample_rate, types=self.pad_types)[0][0]
215 |
216 | # "SCL" version: cut longer sample or pad sample
217 | audio = cut_pad_sample_torchaudio(torch.tensor(audio), args)
218 | else:
219 | audio_image.append(None)
220 | continue
221 |
222 | image = generate_fbank(audio, self.sample_rate, n_mels=self.n_mels)
223 | # image = generate_mel_spectrogram(audio.squeeze(0).numpy(), self.sample_rate, n_mels=self.n_mels, f_max=self.f_max, nfft=self.nfft, hop=self.hop, args=self.args) # image [n_mels, 251, 1]
224 |
225 | # blank region clipping from "RespireNet" paper..
226 | if self.args.blank_region_clip:
227 | image_copy = deepcopy(generate_fbank(audio, self.sample_rate, n_mels=self.n_mels))
228 | # image_copy = deepcopy(generate_mel_spectrogram(audio.squeeze(0).numpy(), self.sample_rate, n_mels=self.n_mels, f_max=self.f_max, nfft=self.nfft, hop=self.hop, args=self.args)) # image [n_mels, 251, 1]
229 |
230 | image_copy[image_copy < 10] = 0
231 | for row in range(image_copy.shape[0]):
232 | black_percent = len(np.where(image_copy[row,:] == 0)[0]) / len(image_copy[row,:])
233 | # if there is row that is filled by more than 20% regions, stop and remember that `row`
234 | if black_percent < 0.80:
235 | break
236 |
237 | # delete black percent
238 | if row + 1 < image.shape[0]:
239 | image = image[row+1:,:,:]
240 | image = cv2.resize(image, (image.shape[1], self.n_mels), interpolation=cv2.INTER_LINEAR)
241 | image = image[..., np.newaxis]
242 |
243 | audio_image.append(image)
244 | self.audio_images.append((audio_image, label))
245 |
246 | if self.dump_images:
247 | save_image(audio_image, './')
248 | self.dump_images = False
249 |
250 | self.h, self.w, _ = self.audio_images[0][0][0].shape
251 | # ==========================================================================
252 |
253 | def __getitem__(self, index):
254 | audio_images, label, metadata = self.audio_images[index][0], self.audio_images[index][1], self.metadata[index]
255 |
256 | if self.args.raw_augment and self.train_flag and not self.mean_std:
257 | aug_idx = random.randint(0, self.args.raw_augment)
258 | audio_image = audio_images[aug_idx]
259 | else:
260 | audio_image = audio_images[0]
261 |
262 | if self.transform is not None:
263 | audio_image = self.transform(audio_image)
264 |
265 | return audio_image, label, metadata
266 |
267 | def __len__(self):
268 | return len(self.audio_data)
--------------------------------------------------------------------------------
/util/icbhi_util.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import os
3 | import math
4 | import random
5 | from tkinter import W
6 | import pandas as pd
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import cv2
11 | import cmapy
12 | import librosa
13 | import torch
14 | import torchaudio
15 | from torchaudio import transforms as T
16 | from scipy.signal import butter, lfilter
17 |
18 | from .augmentation import augment_raw_audio
19 |
20 | __all__ = ['get_annotations', 'save_image', 'get_mean_and_std', 'get_individual_cycles_librosa', 'get_individual_cycles_torchaudio', 'split_pad_sample', 'generate_mel_spectrogram', 'generate_fbank', 'concat_augmentation', 'get_score']
21 |
22 |
23 | # ==========================================================================
24 | """ ICBHI dataset information """
25 | def _extract_lungsound_annotation(file_name, data_folder):
26 | tokens = file_name.strip().split('_')
27 | recording_info = pd.DataFrame(data = [tokens], columns = ['Patient Number', 'Recording index', 'Chest location','Acquisition mode','Recording equipment'])
28 | recording_annotations = pd.read_csv(os.path.join(data_folder, file_name + '.txt'), names = ['Start', 'End', 'Crackles', 'Wheezes'], delimiter= '\t')
29 |
30 | return recording_info, recording_annotations
31 |
32 |
33 | def get_annotations(args, data_folder):
34 | if args.class_split == 'lungsound' or args.class_split in ['lungsound_meta', 'meta']:
35 | filenames = [f.strip().split('.')[0] for f in os.listdir(data_folder) if '.txt' in f]
36 |
37 | annotation_dict = {}
38 | for f in filenames:
39 | info, ann = _extract_lungsound_annotation(f, data_folder)
40 | annotation_dict[f] = ann
41 |
42 | elif args.class_split == 'diagnosis':
43 | filenames = [f.strip().split('.')[0] for f in os.listdir(data_folder) if '.txt' in f]
44 | tmp = pd.read_csv(os.path.join(args.data_folder, 'icbhi_dataset/patient_diagnosis.txt'), names=['Disease'], delimiter='\t')
45 |
46 | annotation_dict = {}
47 | for f in filenames:
48 | info, ann = _extract_lungsound_annotation(f, data_folder)
49 | ann.drop(['Crackles', 'Wheezes'], axis=1, inplace=True)
50 |
51 | disease = tmp.loc[int(f.strip().split('_')[0]), 'Disease']
52 | ann['Disease'] = disease
53 |
54 | annotation_dict[f] = ann
55 |
56 | return annotation_dict
57 |
58 |
59 | def save_image(image, fpath):
60 | save_dir = os.path.join(fpath, 'image.jpg')
61 | cv2.imwrite(save_dir, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
62 |
63 |
64 | def get_mean_and_std(dataset):
65 | """ Compute the mean and std value of mel-spectrogram """
66 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8)
67 |
68 | cnt = 0
69 | fst_moment = torch.zeros(1)
70 | snd_moment = torch.zeros(1)
71 | for inputs, _, _ in dataloader:
72 | b, c, h, w = inputs.shape
73 | nb_pixels = b * h * w
74 |
75 | fst_moment += torch.sum(inputs, dim=[0,2,3])
76 | snd_moment += torch.sum(inputs**2, dim=[0,2,3])
77 | cnt += nb_pixels
78 |
79 | mean = fst_moment / cnt
80 | std = torch.sqrt(snd_moment/cnt - mean**2)
81 |
82 | return mean, std
83 | # ==========================================================================
84 |
85 |
86 | # ==========================================================================
87 | """ data preprocessing """
88 | def _butter_bandpass(lowcut, highcut, fs, order=5):
89 | nyq = 0.5 * fs
90 | low = lowcut / nyq
91 | high = highcut / nyq
92 | b, a = butter(order, [low, high], btype='band')
93 |
94 | return b, a
95 |
96 |
97 | def _butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
98 | b, a = _butter_bandpass(lowcut, highcut, fs, order=order)
99 | y = lfilter(b, a, data)
100 |
101 | return y
102 |
103 |
104 | def _slice_data_librosa(start, end, data, sample_rate):
105 | """
106 | RespireNet paper..
107 | sample_rate denotes how many sample points for one second
108 | """
109 | max_ind = len(data)
110 | start_ind = min(int(start * sample_rate), max_ind)
111 | end_ind = min(int(end * sample_rate), max_ind)
112 |
113 | return data[start_ind: end_ind]
114 |
115 |
116 | def _get_lungsound_label(crackle, wheeze, n_cls):
117 | if n_cls == 4:
118 | if crackle == 0 and wheeze == 0:
119 | return 0
120 | elif crackle == 1 and wheeze == 0:
121 | return 1
122 | elif crackle == 0 and wheeze == 1:
123 | return 2
124 | elif crackle == 1 and wheeze == 1:
125 | return 3
126 |
127 | elif n_cls == 2:
128 | if crackle == 0 and wheeze == 0:
129 | return 0
130 | else:
131 | return 1
132 |
133 |
134 | def _get_diagnosis_label(disease, n_cls):
135 | if n_cls == 3:
136 | if disease in ['COPD', 'Bronchiectasis', 'Asthma']:
137 | return 1
138 | elif disease in ['URTI', 'LRTI', 'Pneumonia', 'Bronchiolitis']:
139 | return 2
140 | else:
141 | return 0
142 |
143 | elif n_cls == 2:
144 | if disease == 'Healthy':
145 | return 0
146 | else:
147 | return 1
148 |
149 |
150 | def get_individual_cycles_librosa(args, recording_annotations, data_folder, filename, sample_rate, n_cls, butterworth_filter=None):
151 | """
152 | RespireNet paper..
153 | Used to split each individual sound file into separate sound clips containing one respiratory cycle each
154 | output: [(audio_chunk:np.array, label:int), (...)]
155 | """
156 | sample_data = []
157 |
158 | # load file with specified sample rate (also converts to mono)
159 | data, rate = librosa.load(os.path.join(data_folder, filename+'.wav'), sr=sample_rate)
160 |
161 | if butterworth_filter:
162 | # butter bandpass filter
163 | data = _butter_bandpass_filter(lowcut=200, highcut=1800, fs=sample_rate, order=butterworth_filter)
164 |
165 | for idx in recording_annotations.index:
166 | row = recording_annotations.loc[idx]
167 |
168 | start = row['Start'] # time (second)
169 | end = row['End'] # time (second)
170 | audio_chunk = _slice_data_librosa(start, end, data, rate)
171 |
172 | if args.class_split == 'lungsound':
173 | crackles = row['Crackles']
174 | wheezes = row['Wheezes']
175 | sample_data.append((audio_chunk, _get_lungsound_label(crackles, wheezes, n_cls)))
176 | elif args.class_split == 'diagnosis':
177 | disease = row['Disease']
178 | sample_data.append((audio_chunk, _get_diagnosis_label(disease, n_cls)))
179 |
180 | return sample_data
181 |
182 |
183 | def _slice_data_torchaudio(start, end, data, sample_rate):
184 | """
185 | SCL paper..
186 | sample_rate denotes how many sample points for one second
187 | """
188 | max_ind = data.shape[1]
189 | start_ind = min(int(start * sample_rate), max_ind)
190 | end_ind = min(int(end * sample_rate), max_ind)
191 |
192 | return data[:, start_ind: end_ind]
193 |
194 |
195 | def cut_pad_sample_torchaudio(data, args):
196 | fade_samples_ratio = 16
197 | fade_samples = int(args.sample_rate / fade_samples_ratio)
198 | fade_out = T.Fade(fade_in_len=0, fade_out_len=fade_samples, fade_shape='linear')
199 | target_duration = args.desired_length * args.sample_rate
200 |
201 | if data.shape[-1] > target_duration:
202 | data = data[..., :target_duration]
203 | else:
204 | if args.pad_types == 'zero':
205 | tmp = torch.zeros(1, target_duration, dtype=torch.float32)
206 | diff = target_duration - data.shape[-1]
207 | tmp[..., diff//2:data.shape[-1]+diff//2] = data
208 | data = tmp
209 | elif args.pad_types == 'repeat':
210 | ratio = math.ceil(target_duration / data.shape[-1])
211 | data = data.repeat(1, ratio)
212 | data = data[..., :target_duration]
213 | data = fade_out(data)
214 |
215 | return data
216 |
217 | def get_individual_cycles_torchaudio(args, recording_annotations, metadata, data_folder, filename, sample_rate, n_cls):
218 | """
219 | SCL paper..
220 | used to split each individual sound file into separate sound clips containing one respiratory cycle each
221 | output: [(audio_chunk:np.array, label:int), (...)]
222 | """
223 | sample_data = []
224 | fpath = os.path.join(data_folder, filename+'.wav')
225 |
226 | sr = librosa.get_samplerate(fpath)
227 | data, _ = torchaudio.load(fpath)
228 |
229 | if sr != sample_rate:
230 | resample = T.Resample(sr, sample_rate)
231 | data = resample(data)
232 |
233 | fade_samples_ratio = 16
234 | fade_samples = int(sample_rate / fade_samples_ratio)
235 |
236 | fade = T.Fade(fade_in_len=fade_samples, fade_out_len=fade_samples, fade_shape='linear')
237 |
238 | data = fade(data)
239 | for idx in recording_annotations.index:
240 | row = recording_annotations.loc[idx]
241 |
242 | start = row['Start'] # time (second)
243 | end = row['End'] # time (second)
244 | audio_chunk = _slice_data_torchaudio(start, end, data, sample_rate)
245 |
246 | if args.class_split == 'lungsound':
247 | crackles = row['Crackles']
248 | wheezes = row['Wheezes']
249 | sample_data.append((audio_chunk, _get_lungsound_label(crackles, wheezes, n_cls)))
250 | elif args.class_split == 'diagnosis':
251 | disease = row['Disease']
252 | sample_data.append((audio_chunk, _get_diagnosis_label(disease, n_cls)))
253 |
254 | padded_sample_data = []
255 | for data, label in sample_data:
256 | data = cut_pad_sample_torchaudio(data, args)
257 | padded_sample_data.append((data, label))
258 |
259 | return padded_sample_data
260 |
261 |
262 | def _zero_padding(source, output_length):
263 | copy = np.zeros(output_length, dtype=np.float32)
264 | src_length = len(source)
265 |
266 | frac = src_length / output_length
267 | if frac < 0.5:
268 | # tile forward sounds to fill empty space
269 | cursor = 0
270 | while(cursor + src_length) < output_length:
271 | copy[cursor:(cursor + src_length)] = source[:]
272 | cursor += src_length
273 | else:
274 | # [src_length:] part will be zeros
275 | copy[:src_length] = source[:]
276 |
277 | return copy
278 |
279 |
280 | def _equally_slice_pad_sample(sample, desired_length, sample_rate):
281 | """
282 | pad_type == 0: zero-padding
283 | if sample length > desired_length,
284 | all equally sliced samples with samples_per_slice number are zero-padded or recursively duplicated
285 | """
286 | output_length = int(desired_length * sample_rate) # desired_length is second
287 | soundclip = sample[0].copy()
288 | n_samples = len(soundclip)
289 |
290 | total_length = n_samples / sample_rate # length of cycle in seconds
291 | n_slices = int(math.ceil(total_length / desired_length)) # get the minimum number of slices needed
292 | samples_per_slice = n_samples // n_slices
293 |
294 | output = [] # holds the resultant slices
295 | src_start = 0 # staring index of the samples to copy from the sample buffer
296 | for i in range(n_slices):
297 | src_end = min(src_start + samples_per_slice, n_samples)
298 | length = src_end - src_start
299 |
300 | copy = _zero_padding(soundclip[src_start:src_end], output_length)
301 | output.append((copy, sample[1], sample[2]))
302 | src_start += length
303 |
304 | return output
305 |
306 |
307 | def _duplicate_padding(sample, source, output_length, sample_rate, types):
308 | # pad_type == 1 or 2
309 | copy = np.zeros(output_length, dtype=np.float32)
310 | src_length = len(source)
311 | left = output_length - src_length # amount to be padded
312 |
313 | if types == 'repeat':
314 | aug = sample
315 | else:
316 | aug = augment_raw_audio(sample, sample_rate)
317 |
318 | while len(aug) < left:
319 | aug = np.concatenate([aug, aug])
320 |
321 | prob = random.random()
322 | if prob < 0.5:
323 | # pad the back part of original sample
324 | copy[left:] = source
325 | copy[:left] = aug[len(aug)-left:]
326 | else:
327 | # pad the front part of original sample
328 | copy[:src_length] = source[:]
329 | copy[src_length:] = aug[:left]
330 |
331 | return copy
332 |
333 |
334 | def split_pad_sample(sample, desired_length, sample_rate, types='repeat'):
335 | """
336 | if the audio sample length > desired_length, then split and pad samples
337 | else simply pad samples according to pad_types
338 | * types 'zero' : simply pad by zeros (zero-padding)
339 | * types 'repeat' : pad with duplicate on both sides (half-n-half)
340 | * types 'aug' : pad with augmented sample on both sides (half-n-half)
341 | """
342 | if types == 'zero':
343 | return _equally_slice_pad_sample(sample, desired_length, sample_rate)
344 |
345 | output_length = int(desired_length * sample_rate)
346 | soundclip = sample[0].copy()
347 | n_samples = len(soundclip)
348 |
349 | output = []
350 | if n_samples > output_length:
351 | """
352 | if sample length > desired_length, slice samples with desired_length then just use them,
353 | and the last sample is padded according to the padding types
354 | """
355 | # frames[j] = x[j * hop_length : j * hop_length + frame_length]
356 | frames = librosa.util.frame(soundclip, frame_length=output_length, hop_length=output_length//2, axis=0)
357 | for i in range(frames.shape[0]):
358 | output.append((frames[i], sample[1], sample[2]))
359 |
360 | # get the last sample
361 | last_id = frames.shape[0] * (output_length//2)
362 | last_sample = soundclip[last_id:]
363 |
364 | padded = _duplicate_padding(soundclip, last_sample, output_length, sample_rate, types)
365 | output.append((padded, sample[1], sample[2]))
366 | else: # only pad
367 | padded = _duplicate_padding(soundclip, soundclip, output_length, sample_rate, types)
368 | output.append((padded, sample[1], sample[2]))
369 |
370 | return output
371 |
372 |
373 | def generate_mel_spectrogram(audio, sample_rate, n_mels=64, f_min=50, f_max=2000, nfft=1024, hop=512, args=None):
374 | """
375 | use librosa library and convert mel-spectrogram to have 3 channels
376 | """
377 | S = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max, n_fft=nfft, hop_length=hop)
378 | # convert scale to dB from magnitude
379 | S = librosa.power_to_db(S, ref=np.max)
380 | S = (S - S.min()) / (S.max() - S.min())
381 | # S *= 255
382 |
383 | if args.blank_region_clip:
384 | S = cv2.flip(S, 0) # up-down flip
385 | img = S.reshape(S.shape[0], S.shape[1], 1)
386 |
387 | return img
388 |
389 |
390 | def generate_fbank(audio, sample_rate, n_mels=128):
391 | """
392 | use torchaudio library to convert mel fbank for AST model
393 | """
394 | assert sample_rate == 16000, 'input audio sampling rate must be 16kHz'
395 | fbank = torchaudio.compliance.kaldi.fbank(audio, htk_compat=True, sample_frequency=sample_rate, use_energy=False, window_type='hanning', num_mel_bins=n_mels, dither=0.0, frame_shift=10)
396 |
397 | mean, std = -4.2677393, 4.5689974
398 | fbank = (fbank - mean) / (std * 2) # mean / std
399 | fbank = fbank.unsqueeze(-1).numpy()
400 | return fbank
401 |
402 |
403 | def concat_augmentation(classwise_cycle_list, cycle_list, scale=1.):
404 | """ From "RespireNet" paper..
405 | """
406 |
407 | def _get_random_cycles(classwise_cycle_list, idx1, idx2):
408 | i = random.randint(0, len(classwise_cycle_list[idx1])-1)
409 | j = random.randint(0, len(classwise_cycle_list[idx2])-1)
410 |
411 | sample_i = classwise_cycle_list[idx1][i]
412 | sample_j = classwise_cycle_list[idx2][j]
413 |
414 | return sample_i, sample_j
415 |
416 | print('*' * 20)
417 | # augment normal
418 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[0]))
419 | print('# of concatenation-based augmentation for normal class is {}'.format(aug_nums))
420 |
421 | for _ in range(aug_nums):
422 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 0)
423 | new_sample = np.concatenate([sample_i[0], sample_j[0]])
424 | # cycle_list: [(audio_chunk, label, filename, pad_times), (...)]
425 | cycle_list.append((new_sample, 0, sample_i[2]+'-'+sample_j[2])) # sample_i[2] denotes filename
426 |
427 | # augment crackle
428 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[1]))
429 | print('# of concatenation-based augmentation for crackle class is {}'.format(aug_nums))
430 |
431 | for _ in range(aug_nums):
432 | aug_prob = random.random()
433 | if aug_prob < 0.6:
434 | # crackle_i + crackle_j
435 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 1, 1)
436 | elif aug_prob >= 0.6 and aug_prob < 0.8:
437 | # crackle_i + normal_j
438 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 1, 0)
439 | else:
440 | # normal_i + crackle_j
441 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 1)
442 |
443 | new_sample = np.concatenate([sample_i[0], sample_j[0]])
444 | cycle_list.append((new_sample, 1, sample_i[2]+'-'+sample_j[2]))
445 |
446 | # augment wheeze
447 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[2]))
448 | print('# of concatenation-based augmentation for wheeze class is {}'.format(aug_nums))
449 |
450 | for _ in range(aug_nums):
451 | aug_prob = random.random()
452 | if aug_prob < 0.6:
453 | # wheeze_i + wheeze_j
454 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 2, 2)
455 | elif aug_prob >= 0.6 and aug_prob < 0.8:
456 | # wheeze_i + normal_j
457 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 2, 0)
458 | else:
459 | # normal_i + wheeze_j
460 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 2)
461 |
462 | new_sample = np.concatenate([sample_i[0], sample_j[0]])
463 | cycle_list.append((new_sample, 2, sample_i[2]+'-'+sample_j[2]))
464 |
465 | # augment both
466 | aug_nums = int(scale*len(classwise_cycle_list[0]) - len(classwise_cycle_list[3]))
467 | print('# of concatenation-based augmentation for both class is {}'.format(aug_nums))
468 |
469 | for _ in range(aug_nums):
470 | aug_prob = random.random()
471 | if aug_prob < 0.5:
472 | # both_i + both_j
473 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 3, 3)
474 | elif aug_prob >= 0.5 and aug_prob < 0.7:
475 | # crackle_i + wheeze_j
476 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 1, 2)
477 | elif aug_prob >=0.7 and aug_prob < 0.8:
478 | # wheeze_i + crackle_j
479 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 2, 1)
480 | elif aug_prob >=0.8 and aug_prob < 0.9:
481 | # both_i + normal_j
482 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 3, 0)
483 | else:
484 | # normal_i + both_j
485 | sample_i, sample_j = _get_random_cycles(classwise_cycle_list, 0, 3)
486 |
487 | new_sample = np.concatenate([sample_i[0], sample_j[0]])
488 | cycle_list.append((new_sample, 3, sample_i[2]+'-'+sample_j[2]))
489 |
490 | return classwise_cycle_list, cycle_list
491 | # ==========================================================================
492 |
493 |
494 | # ==========================================================================
495 | """ evaluation metric """
496 | def get_score(hits, counts, pflag=False):
497 | # normal accuracy
498 | sp = hits[0] / (counts[0] + 1e-10) * 100
499 | # abnormal accuracy
500 | se = sum(hits[1:]) / (sum(counts[1:]) + 1e-10) * 100
501 | sc = (sp + se) / 2.0
502 |
503 | if pflag:
504 | # print("************* Metrics ******************")
505 | print("S_p: {}, S_e: {}, Score: {}".format(sp, se, sc))
506 |
507 | return sp, se, sc
508 | # ==========================================================================
509 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 |
10 |
11 | """ train util """
12 | def adjust_learning_rate(args, optimizer, epoch):
13 | lr = args.learning_rate
14 | if args.cosine:
15 | eta_min = lr * (args.lr_decay_rate ** 3)
16 | lr = eta_min + (lr - eta_min) * (
17 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
18 | else:
19 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
20 | if steps > 0:
21 | lr = lr * (args.lr_decay_rate ** steps)
22 |
23 | for i, param_group in enumerate(optimizer.param_groups):
24 | param_group['lr'] = lr
25 |
26 |
27 | def adjust_lr_wd(args, optimizer, epoch):
28 | lr = args.learning_rate
29 | if args.cosine:
30 | eta_min = lr * (args.lr_decay_rate ** 3)
31 | lr = eta_min + (lr - eta_min) * (
32 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
33 | else:
34 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
35 | if steps > 0:
36 | lr = lr * (args.lr_decay_rate ** steps)
37 |
38 | wd = args.weight_decay
39 | if args.wd_scheduler:
40 | wd_min = args.weight_decay_end
41 | wd = wd_min + (wd - wd_min) * (
42 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
43 |
44 | for i, param_group in enumerate(optimizer.param_groups):
45 | param_group['lr'] = lr
46 | if i == 0: # in case of DINO and ViT, only wd for regularized params
47 | param_group['weight_decay'] = wd
48 |
49 |
50 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
51 | if args.warm and epoch <= args.warm_epochs:
52 | p = (batch_id + (epoch - 1) * total_batches) / \
53 | (args.warm_epochs * total_batches)
54 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
55 |
56 | for param_group in optimizer.param_groups:
57 | param_group['lr'] = lr
58 |
59 |
60 | def set_optimizer(args, optim_params=None, model=None):
61 | if model is not None:
62 | optim_params = model.parameters()
63 |
64 | if args.optimizer == 'sgd':
65 | optimizer = optim.SGD(optim_params,
66 | lr=args.learning_rate,
67 | momentum=args.momentum,
68 | weight_decay=args.weight_decay)
69 | elif args.optimizer == 'adam':
70 | optimizer = optim.Adam(optim_params,
71 | lr=args.learning_rate,
72 | weight_decay=args.weight_decay)
73 | else:
74 | raise NotImplemented
75 |
76 | return optimizer
77 |
78 |
79 | class MA:
80 | def __init__(self, beta):
81 | super().__init__()
82 | self.beta = beta
83 |
84 | def update_average(self, old, new):
85 | if old is None:
86 | return new
87 | return old * self.beta + (1 - self.beta) * new
88 |
89 |
90 | def update_moving_average(ma_beta, current_model, ma_ckpt):
91 | ma_updater = MA(ma_beta)
92 | new_state_dict = {}
93 | for (k1, current_params), (k2, ma_params) in zip(current_model.state_dict().items(), ma_ckpt.items()):
94 | assert k1 == k2
95 | old_weight, up_weight = ma_params.data, current_params.data
96 | new_state_dict[k1] = ma_updater.update_average(old_weight, up_weight)
97 |
98 | current_model.load_state_dict(new_state_dict)
99 | return current_model
100 |
101 |
102 | """ eval util """
103 | class AverageMeter(object):
104 | """ Computes and stores the average and current value """
105 | def __init__(self):
106 | self.reset()
107 |
108 | def reset(self):
109 | self.val = 0
110 | self.avg = 0
111 | self.sum = 0
112 | self.count = 0
113 |
114 | def update(self, val, n=1):
115 | self.val = val
116 | self.sum += val * n
117 | self.count += n
118 | self.avg = self.sum / self.count
119 |
120 |
121 | def accuracy(output, target, topk=(1,)):
122 | """Computes the accuracy over the k top predictions for the specified values of k"""
123 | with torch.no_grad():
124 | n_cls = output.shape[1]
125 | valid_topk = [k for k in topk if k <= n_cls]
126 |
127 | maxk = max(valid_topk)
128 | bsz = target.size(0)
129 |
130 | _, pred = output.topk(maxk, 1, True, True)
131 | pred = pred.t()
132 | correct = pred.eq(target.view(1, -1).expand_as(pred))
133 |
134 | res = []
135 | for k in topk:
136 | if k in valid_topk:
137 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
138 | res.append(correct_k.mul_(100.0 / bsz))
139 | else: res.append(torch.tensor([0.]))
140 |
141 | return res, bsz
142 |
143 |
144 | def save_model(model, optimizer, args, epoch, save_file, classifier):
145 | print('==> Saving...')
146 | state = {
147 | 'args': args,
148 | 'model': model.state_dict(),
149 | 'optimizer': optimizer.state_dict(),
150 | 'epoch': epoch,
151 | 'classifier': classifier.state_dict()
152 | }
153 |
154 | torch.save(state, save_file)
155 | del state
156 |
157 |
158 | def update_json(exp_name, acc, path='./save/results.json'):
159 | acc = [round(a, 2) for a in acc]
160 | if not os.path.exists(path):
161 | with open(path, 'w') as f:
162 | json.dump({}, f)
163 |
164 | with open(path, 'r', encoding="UTF-8") as f:
165 | result_dict = json.load(f)
166 | result_dict[exp_name] = acc
167 |
168 | with open(path, 'w') as f:
169 | json.dump(result_dict, f)
170 |
171 | print('best Score: {} (sp, se, sc)'.format(acc))
172 | print('results updated to %s' % path)
--------------------------------------------------------------------------------
/util/time_warping.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is from https://github.com/zcaceres/spec_augment
3 | """
4 | import torch
5 |
6 | __all__ = ['sparse_image_warp']
7 |
8 |
9 | def sparse_image_warp(img_tensor,
10 | source_control_point_locations,
11 | dest_control_point_locations,
12 | interpolation_order=2,
13 | regularization_weight=0.0,
14 | num_boundaries_points=0):
15 | device = img_tensor.device
16 | control_point_flows = (dest_control_point_locations - source_control_point_locations)
17 |
18 | # clamp_boundaries = num_boundary_points > 0
19 | # boundary_points_per_edge = num_boundary_points - 1
20 | batch_size, image_height, image_width = img_tensor.shape
21 | flattened_grid_locations = get_flat_grid_locations(image_height, image_width, device)
22 |
23 | # IGNORED FOR OUR BASIC VERSION...
24 | # flattened_grid_locations = constant_op.constant(
25 | # _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
26 |
27 | # if clamp_boundaries:
28 | # (dest_control_point_locations,
29 | # control_point_flows) = _add_zero_flow_controls_at_boundary(
30 | # dest_control_point_locations, control_point_flows, image_height,
31 | # image_width, boundary_points_per_edge)
32 |
33 | flattened_flows = interpolate_spline(
34 | dest_control_point_locations,
35 | control_point_flows,
36 | flattened_grid_locations,
37 | interpolation_order,
38 | regularization_weight)
39 |
40 | dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width)
41 |
42 | warped_image = dense_image_warp(img_tensor, dense_flows)
43 |
44 | return warped_image, dense_flows
45 |
46 |
47 | def get_grid_locations(image_height, image_width, device):
48 | y_range = torch.linspace(0, image_height - 1, image_height, device=device)
49 | x_range = torch.linspace(0, image_width - 1, image_width, device=device)
50 | y_grid, x_grid = torch.meshgrid(y_range, x_range)
51 | return torch.stack((y_grid, x_grid), -1)
52 |
53 |
54 | def flatten_grid_locations(grid_locations, image_height, image_width):
55 | return torch.reshape(grid_locations, [image_height * image_width, 2])
56 |
57 |
58 | def get_flat_grid_locations(image_height, image_width, device):
59 | y_range = torch.linspace(0, image_height - 1, image_height, device=device)
60 | x_range = torch.linspace(0, image_width - 1, image_width, device=device)
61 | y_grid, x_grid = torch.meshgrid(y_range, x_range)
62 | return torch.stack((y_grid, x_grid), -1).reshape([image_height * image_width, 2])
63 |
64 |
65 | def create_dense_flows(flattened_flows, batch_size, image_height, image_width):
66 | # possibly .view
67 | return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2])
68 |
69 |
70 | def interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0,):
71 | # First, fit the spline to the observed data.
72 | w, v = solve_interpolation(train_points, train_values, order, regularization_weight)
73 | # Then, evaluate the spline at the query locations.
74 | query_values = apply_interpolation(query_points, train_points, w, v, order)
75 |
76 | return query_values
77 |
78 |
79 | def solve_interpolation(train_points, train_values, order, regularization_weight, eps=1e-7):
80 | device = train_points.device
81 | b, n, d = train_points.shape
82 | k = train_values.shape[-1]
83 |
84 | # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
85 | # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
86 | # To account for python style guidelines we use
87 | # matrix_a for A and matrix_b for B.
88 |
89 | c = train_points
90 | f = train_values.float()
91 |
92 | matrix_a = phi(cross_squared_distance_matrix(c,c), order).unsqueeze(0) # [b, n, n]
93 | # if regularization_weight > 0:
94 | # batch_identity_matrix = array_ops.expand_dims(
95 | # linalg_ops.eye(n, dtype=c.dtype), 0)
96 | # matrix_a += regularization_weight * batch_identity_matrix
97 |
98 | # Append ones to the feature values for the bias term in the linear model.
99 | ones = torch.ones(n, dtype=train_points.dtype, device=device).view([-1, n, 1])
100 | matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1]
101 |
102 | # [b, n + d + 1, n]
103 | left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)
104 |
105 | num_b_cols = matrix_b.shape[2] # d + 1
106 |
107 | # In Tensorflow, zeros are used here. Pytorch solve fails with zeros for some reason we don't understand.
108 | # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
109 | # lhs_zeros = torch.zeros((b, num_b_cols, num_b_cols), device=device)
110 | lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) * eps
111 | right_block = torch.cat((matrix_b, lhs_zeros),
112 | 1) # [b, n + d + 1, d + 1]
113 | lhs = torch.cat((left_block, right_block),
114 | 2) # [b, n + d + 1, n + d + 1]
115 |
116 | rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype, device=device).float()
117 | rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k]
118 |
119 | # Then, solve the linear system and unpack the results.
120 | X = torch.linalg.solve(lhs, rhs)
121 | w = X[:, :n, :]
122 | v = X[:, n:, :]
123 | return w, v
124 |
125 |
126 | def cross_squared_distance_matrix(x, y):
127 | """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
128 | Computes the pairwise distances between rows of x and rows of y
129 | Args:
130 | x: [batch_size, n, d] float `Tensor`
131 | y: [batch_size, m, d] float `Tensor`
132 | Returns:
133 | squared_dists: [batch_size, n, m] float `Tensor`, where
134 | squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
135 | """
136 | x_norm_squared = torch.sum(torch.mul(x, x))
137 | y_norm_squared = torch.sum(torch.mul(y, y))
138 |
139 | x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0,1))
140 |
141 | # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
142 | squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared
143 |
144 | return squared_dists.float()
145 |
146 |
147 | def phi(r, order):
148 | """Coordinate-wise nonlinearity used to define the order of the interpolation.
149 | See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
150 | Args:
151 | r: input op
152 | order: interpolation order
153 | Returns:
154 | phi_k evaluated coordinate-wise on r, for k = r
155 | """
156 | EPSILON=torch.tensor(1e-10, device=r.device)
157 | # using EPSILON prevents log(0), sqrt0), etc.
158 | # sqrt(0) is well-defined, but its gradient is not
159 | if order == 1:
160 | r = torch.max(r, EPSILON)
161 | r = torch.sqrt(r)
162 | return r
163 | elif order == 2:
164 | return 0.5 * r * torch.log(torch.max(r, EPSILON))
165 | elif order == 4:
166 | return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON))
167 | elif order % 2 == 0:
168 | r = torch.max(r, EPSILON)
169 | return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r)
170 | else:
171 | r = torch.max(r, EPSILON)
172 | return torch.pow(r, 0.5 * order)
173 |
174 |
175 | def apply_interpolation(query_points, train_points, w, v, order):
176 | """Apply polyharmonic interpolation model to data.
177 | Given coefficients w and v for the interpolation model, we evaluate
178 | interpolated function values at query_points.
179 | Args:
180 | query_points: `[b, m, d]` x values to evaluate the interpolation at
181 | train_points: `[b, n, d]` x values that act as the interpolation centers
182 | ( the c variables in the wikipedia article)
183 | w: `[b, n, k]` weights on each interpolation center
184 | v: `[b, d, k]` weights on each input dimension
185 | order: order of the interpolation
186 | Returns:
187 | Polyharmonic interpolation evaluated at points defined in query_points.
188 | """
189 | query_points = query_points.unsqueeze(0)
190 | # First, compute the contribution from the rbf term.
191 | pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float())
192 | phi_pairwise_dists = phi(pairwise_dists, order)
193 |
194 | rbf_term = torch.matmul(phi_pairwise_dists, w)
195 |
196 | # Then, compute the contribution from the linear term.
197 | # Pad query_points with ones, for the bias term in the linear model.
198 | ones = torch.ones_like(query_points[..., :1])
199 | query_points_pad = torch.cat((
200 | query_points,
201 | ones
202 | ), 2).float()
203 | linear_term = torch.matmul(query_points_pad, v)
204 |
205 | return rbf_term + linear_term
206 |
207 |
208 | def dense_image_warp(image, flow):
209 | """Image warping using per-pixel flow vectors.
210 | Apply a non-linear warp to the image, where the warp is specified by a dense
211 | flow field of offset vectors that define the correspondences of pixel values
212 | in the output image back to locations in the source image. Specifically, the
213 | pixel value at output[b, j, i, c] is
214 | images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].
215 | The locations specified by this formula do not necessarily map to an int
216 | index. Therefore, the pixel value is obtained by bilinear
217 | interpolation of the 4 nearest pixels around
218 | (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
219 | of the image, we use the nearest pixel values at the image boundary.
220 | Args:
221 | image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
222 | flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
223 | name: A name for the operation (optional).
224 | Note that image and flow can be of type tf.half, tf.float32, or tf.float64,
225 | and do not necessarily have to be the same type.
226 | Returns:
227 | A 4-D float `Tensor` with shape`[batch, height, width, channels]`
228 | and same type as input image.
229 | Raises:
230 | ValueError: if height < 2 or width < 2 or the inputs have the wrong number
231 | of dimensions.
232 | """
233 | image = image.unsqueeze(3) # add a single channel dimension to image tensor
234 | batch_size, height, width, channels = image.shape
235 | device = image.device
236 |
237 | # The flow is defined on the image grid. Turn the flow into a list of query
238 | # points in the grid space.
239 | grid_x, grid_y = torch.meshgrid(
240 | torch.arange(width, device=device), torch.arange(height, device=device))
241 |
242 | stacked_grid = torch.stack((grid_y, grid_x), dim=2).float()
243 |
244 | batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2)
245 |
246 | query_points_on_grid = batched_grid - flow
247 | query_points_flattened = torch.reshape(query_points_on_grid,
248 | [batch_size, height * width, 2])
249 | # Compute values at the query points, then reshape the result back to the
250 | # image grid.
251 | interpolated = interpolate_bilinear(image, query_points_flattened)
252 | interpolated = torch.reshape(interpolated,
253 | [batch_size, height, width, channels])
254 | return interpolated
255 |
256 |
257 | def interpolate_bilinear(grid,
258 | query_points,
259 | name='interpolate_bilinear',
260 | indexing='ij'):
261 | """Similar to Matlab's interp2 function.
262 | Finds values for query points on a grid using bilinear interpolation.
263 | Args:
264 | grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
265 | query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`.
266 | name: a name for the operation (optional).
267 | indexing: whether the query points are specified as row and column (ij),
268 | or Cartesian coordinates (xy).
269 | Returns:
270 | values: a 3-D `Tensor` with shape `[batch, N, channels]`
271 | Raises:
272 | ValueError: if the indexing mode is invalid, or if the shape of the inputs
273 | invalid.
274 | """
275 | if indexing != 'ij' and indexing != 'xy':
276 | raise ValueError('Indexing mode must be \'ij\' or \'xy\'')
277 |
278 |
279 | shape = grid.shape
280 | if len(shape) != 4:
281 | msg = 'Grid must be 4 dimensional. Received size: '
282 | raise ValueError(msg + str(grid.shape))
283 |
284 | batch_size, height, width, channels = grid.shape
285 |
286 | shape = [batch_size, height, width, channels]
287 | query_type = query_points.dtype
288 | grid_type = grid.dtype
289 | grid_device = grid.device
290 |
291 | num_queries = query_points.shape[1]
292 |
293 | alphas = []
294 | floors = []
295 | ceils = []
296 | index_order = [0, 1] if indexing == 'ij' else [1, 0]
297 | unstacked_query_points = query_points.unbind(2)
298 |
299 | for dim in index_order:
300 | queries = unstacked_query_points[dim]
301 |
302 | size_in_indexing_dimension = shape[dim + 1]
303 |
304 | # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
305 | # is still a valid index into the grid.
306 | max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type, device=grid_device)
307 | min_floor = torch.tensor(0.0, dtype=query_type, device=grid_device)
308 | maxx = torch.max(min_floor, torch.floor(queries))
309 | floor = torch.min(maxx, max_floor)
310 | int_floor = floor.long()
311 | floors.append(int_floor)
312 | ceil = int_floor + 1
313 | ceils.append(ceil)
314 |
315 | # alpha has the same type as the grid, as we will directly use alpha
316 | # when taking linear combinations of pixel values from the image.
317 |
318 |
319 | alpha = (queries - floor).clone().detach().type(grid_type)
320 | min_alpha = torch.tensor(0.0, dtype=grid_type, device=grid_device)
321 | max_alpha = torch.tensor(1.0, dtype=grid_type, device=grid_device)
322 | alpha = torch.min(torch.max(min_alpha, alpha), max_alpha)
323 |
324 | # Expand alpha to [b, n, 1] so we can use broadcasting
325 | # (since the alpha values don't depend on the channel).
326 | alpha = torch.unsqueeze(alpha, 2)
327 | alphas.append(alpha)
328 |
329 | flattened_grid = torch.reshape(
330 | grid, [batch_size * height * width, channels])
331 | batch_offsets = torch.reshape(
332 | torch.arange(batch_size, device=grid_device) * height * width, [batch_size, 1])
333 |
334 | # This wraps array_ops.gather. We reshape the image data such that the
335 | # batch, y, and x coordinates are pulled into the first dimension.
336 | # Then we gather. Finally, we reshape the output back. It's possible this
337 | # code would be made simpler by using array_ops.gather_nd.
338 | def gather(y_coords, x_coords, name):
339 | linear_coordinates = batch_offsets + y_coords * width + x_coords
340 | gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates)
341 | return torch.reshape(gathered_values,
342 | [batch_size, num_queries, channels])
343 |
344 | # grab the pixel values in the 4 corners around each query point
345 | top_left = gather(floors[0], floors[1], 'top_left')
346 | top_right = gather(floors[0], ceils[1], 'top_right')
347 | bottom_left = gather(ceils[0], floors[1], 'bottom_left')
348 | bottom_right = gather(ceils[0], ceils[1], 'bottom_right')
349 |
350 | interp_top = alphas[1] * (top_right - top_left) + top_left
351 | interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
352 | interp = alphas[0] * (interp_bottom - interp_top) + interp_top
353 |
354 | return interp
--------------------------------------------------------------------------------