├── LICENSE ├── README.md ├── docs └── framework.png ├── finetune ├── __pycache__ │ └── finetune_dataset.cpython-310.pyc ├── data_split │ ├── chapman │ │ ├── chapman_test.csv │ │ ├── chapman_train.csv │ │ └── chapman_val.csv │ ├── icbeb │ │ ├── icbeb_test.csv │ │ ├── icbeb_train.csv │ │ └── icbeb_val.csv │ └── ptbxl │ │ ├── form │ │ ├── ptbxl_form_test.csv │ │ ├── ptbxl_form_train.csv │ │ └── ptbxl_form_val.csv │ │ ├── rhythm │ │ ├── ptbxl_rhythm_test.csv │ │ ├── ptbxl_rhythm_train.csv │ │ └── ptbxl_rhythm_val.csv │ │ ├── sub_class │ │ ├── ptbxl_sub_class_test.csv │ │ ├── ptbxl_sub_class_train.csv │ │ └── ptbxl_sub_class_val.csv │ │ └── super_class │ │ ├── ptbxl_super_class_test.csv │ │ ├── ptbxl_super_class_train.csv │ │ └── ptbxl_super_class_val.csv ├── finetune_dataset.py ├── main_single.py ├── models │ ├── resnet1d.py │ └── vit1d.py ├── preprocess.ipynb └── sub_script │ ├── chapman │ └── sub_chapman.sh │ ├── icbeb │ └── sub_icbeb.sh │ ├── ptbxl │ ├── sub_ptbxl.sh │ ├── sub_ptbxl_form.sh │ ├── sub_ptbxl_rhythm.sh │ ├── sub_ptbxl_sub_class.sh │ └── sub_ptbxl_super_class.sh │ └── run_all_linear.sh ├── pretrain ├── config.yaml ├── launch.sh ├── main.py └── preprocess.ipynb ├── utils ├── __pycache__ │ ├── resnet1d.cpython-310.pyc │ ├── utils_builder.cpython-310.pyc │ ├── utils_dataset.cpython-310.pyc │ ├── utils_loss.cpython-310.pyc │ ├── utils_trainer.cpython-310.pyc │ ├── vit1d.cpython-310.pyc │ └── zeroshot_val.cpython-310.pyc ├── resnet1d.py ├── utils_builder.py ├── utils_dataset.py ├── utils_loss.py ├── utils_optimizer.py ├── utils_trainer.py ├── vit1d.py └── zeroshot_val.py └── zeroshot ├── CKEPE_prompt.json ├── __pycache__ └── zeroshot_dataset.cpython-39.pyc ├── test_zeroshot.py ├── zeroshot.sh └── zeroshot_config.yaml /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Che Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MERL 2 | [Zero-Shot ECG Classification with Multimodal Learning and Test-time Clinical Knowledge Enhancement](https://arxiv.org/abs/2403.06659), ICML 2024. 3 | 4 | ![framework](docs/framework.png) 5 | 6 | ### Installation 7 | To clone this repository: 8 | ``` 9 | git clone https://github.com/cheliu-computation/MERL.git 10 | ``` 11 | 12 | ### Dataset downloading 13 | Datasets we used are as follows: 14 | - **MIMIC-IV-ECG**: We downloaded the [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/) dataset as the ECG signals and paired ECG reports. 15 | 16 | - **PTB-XL**: We downloaded the [PTB-XL](https://physionet.org/content/ptb-xl/1.0.3/) dataset which consisting four subsets, Superclass, Subclass, Form, Rhythm. 17 | 18 | - **CPSC2018**: We downloaded the [CPSC2018](http://2018.icbeb.org/Challenge.html) dataset which consisting three training sets. 19 | 20 | - **CSN(Chapman-Shaoxing-Ningbo)**: We downloaded the [CSN](https://physionet.org/content/ecg-arrhythmia/1.0.0/) dataset. 21 | 22 | 23 | ### Data Preprocessing 24 | We preprocessed pretraining datasets and split the dataset into train/val set using the code in `pretrain/preprocess.ipynb`.\ 25 | We preprocessed downstream datasets and split the dataset into train/val/test set using the code in `finetune/preprocess.ipynb`.\ 26 | We also provide the train/val/test split csv file in `finetune/data_split` 27 | 28 | ### Pre-training 29 | 30 | We pre-trained MERL on MIMIC-IV-ECG using this command: 31 | 32 | ``` 33 | bash MERL/pretrain/launch.sh 34 | ``` 35 | 36 | Pre-trained models can be found [here](https://drive.google.com/drive/folders/13wb4DppUciMn-Y_qC2JRWTbZdz3xX0w2?usp=drive_link).\ 37 | We uploaded the pretrained models with resenet and vit.\ 38 | xxx_ckpt.pth is the whole pretrained model for zeroshot classification.\ 39 | xxx_encoder.pth is the ecg encoder only for linear probing. 40 | 41 | ### Downstream tasks 42 | We evlauate the performance of MERL on three scenarios: zero-shot classification, linear probing, and domain transferring. 43 | 44 | #### zero-shot classification 45 | We evaluate linear classification performance of our model using this command: 46 | ``` 47 | cd MERL/zeroshot 48 | bash zeroshot.sh 49 | ``` 50 | We also release the CKEPE prompt in `zeroshot/CKEPE_prompt.json`.\ 51 | Due to the copyright, we are unable to release the original SCP-code database, but you can find all information in: [https://www.iso.org/standard/84664.html](https://www.iso.org/standard/84664.html). 52 | 53 | #### linear probing 54 | We provide bash script for evaluating linear probing performance of MERL: 55 | ``` 56 | cd MERL/finetune/sub_script 57 | bash run_all_linear.sh 58 | ``` 59 | You can use `--dataset` to set specific dataset for finetuning. Here, 3 datsets are available: chexpert, rsna and covidx. 60 | You can use `--ratio` to set the fraction of training data for finetuning. 61 | 62 | #### domain transferring 63 | For domain trasnfering scenario, you do not reimplement any new experiments. You can only compute the metric across the overlapped categories. 64 | 65 | ### Reference 66 | If you found our work useful in your research, please consider citing our works(s) at: 67 | ```bash 68 | @inproceedings{liuzero, 69 | title={Zero-Shot ECG Classification with Multimodal Learning and Test-time Clinical Knowledge Enhancement}, 70 | author={Liu, Che and Wan, Zhongwei and Ouyang, Cheng and Shah, Anand and Bai, Wenjia and Arcucci, Rossella}, 71 | booktitle={Forty-first International Conference on Machine Learning} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /docs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/docs/framework.png -------------------------------------------------------------------------------- /finetune/__pycache__/finetune_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/finetune/__pycache__/finetune_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/data_split/icbeb/icbeb_val.csv: -------------------------------------------------------------------------------- 1 | patient_id,ecg_id,filename,validation,age,sex,scp_codes,AFIB,VPC,NORM,1AVB,CRBBB,STE,PAC,CLBBB,STD 2 | 4176,4176,A2600,False,32.0,0,{},0,0,1,0,0,0,0,0,0 3 | 3553,3553,A1103,False,73.0,0,{},0,0,0,0,0,0,0,0,1 4 | 5631,5631,A1535,False,83.0,1,{},0,0,0,1,0,0,0,0,0 5 | 4144,4144,A6135,False,74.0,0,{},1,0,0,0,0,0,0,0,0 6 | 1265,1265,A2305,False,81.0,1,{},0,0,0,1,0,0,0,0,0 7 | 6301,6301,A5381,False,64.0,0,{},0,1,0,0,0,0,0,0,0 8 | 2005,2005,A1803,False,84.0,1,{},1,0,0,0,0,0,0,0,0 9 | 4732,4732,A4350,False,79.0,1,{},1,0,0,0,0,0,0,0,0 10 | 1391,1391,A1699,False,34.0,0,{},0,0,1,0,0,0,0,0,0 11 | 5713,5713,A2934,False,76.0,1,{},0,1,0,0,0,0,0,0,0 12 | 5567,5567,A0107,False,26.0,0,{},0,0,1,0,0,0,0,0,0 13 | 3574,3574,A0205,False,79.0,1,{},1,0,0,0,0,0,0,0,0 14 | 3435,3435,A4205,False,69.0,1,{},1,0,0,0,0,0,0,0,0 15 | 5688,5688,A3575,False,21.0,1,{},0,0,0,0,1,0,0,0,0 16 | 4389,4389,A0498,False,72.0,1,{},1,0,0,0,0,0,0,0,0 17 | 3382,3382,A4726,False,31.0,0,{},0,0,0,0,0,0,0,0,1 18 | 3829,3829,A3819,False,74.0,1,{},0,0,0,0,1,0,0,0,0 19 | 877,877,A0701,False,74.0,0,{},1,0,0,0,0,0,0,0,0 20 | 4014,4014,A5679,False,52.0,1,{},1,0,0,0,0,0,0,0,0 21 | 4991,4991,A0481,False,39.0,0,{},0,0,1,0,0,0,0,0,0 22 | 3983,3983,A5186,False,71.0,0,{},0,0,0,1,0,0,0,0,0 23 | 1851,1851,A4170,False,26.0,0,{},0,0,0,0,1,0,0,0,0 24 | 708,708,A4615,False,70.0,0,{},1,0,0,0,1,0,0,0,0 25 | 932,932,A1911,False,59.0,1,{},0,0,0,0,1,0,0,0,0 26 | 4882,4882,A2634,False,67.0,1,{},0,1,0,0,0,0,0,0,0 27 | 5541,5541,A0178,False,70.0,0,{},0,0,0,0,1,0,0,0,0 28 | 4780,4780,A0584,False,82.0,0,{},1,0,0,0,0,0,0,0,0 29 | 4973,4973,A5725,False,29.0,1,{},1,0,0,0,0,0,0,0,0 30 | 2413,2413,A0561,False,65.0,1,{},0,0,0,0,1,0,0,0,0 31 | 4805,4805,A5467,False,75.0,1,{},1,0,0,0,0,0,0,0,0 32 | 5477,5477,A1559,False,62.0,1,{},0,1,0,0,1,0,0,0,0 33 | 149,149,A2603,False,77.0,0,{},0,0,0,1,0,0,0,0,0 34 | 840,840,A2712,False,55.0,0,{},1,0,0,0,0,0,0,0,0 35 | 2185,2185,A0690,False,23.0,0,{},0,0,1,0,0,0,0,0,0 36 | 5862,5862,A4858,False,62.0,0,{},0,0,0,1,0,0,0,0,0 37 | 832,832,A3426,False,84.0,0,{},0,0,0,0,1,0,0,0,0 38 | 1323,1323,A3801,False,40.0,1,{},0,0,0,0,0,1,0,0,0 39 | 132,132,A2962,False,68.0,0,{},1,0,0,0,0,0,0,0,0 40 | 3336,3336,A4776,False,77.0,1,{},0,0,0,1,0,0,0,0,0 41 | 5127,5127,A2997,False,21.0,0,{},0,0,0,0,0,0,0,0,1 42 | 5118,5118,A6795,False,49.0,0,{},0,0,0,0,0,0,0,0,1 43 | 5308,5308,A6769,False,59.0,0,{},0,0,0,0,0,0,0,0,1 44 | 3422,3422,A0731,False,50.0,1,{},1,0,0,0,0,0,0,0,0 45 | 1485,1485,A0946,False,84.0,1,{},0,0,0,1,0,0,0,0,0 46 | 6279,6279,A5362,False,71.0,0,{},1,0,0,0,1,0,0,0,0 47 | 5852,5852,A4269,False,87.0,1,{},0,0,0,1,0,0,0,0,0 48 | 6656,6656,A3385,False,39.0,1,{},0,0,0,0,0,0,0,0,1 49 | 1868,1868,A4302,False,35.0,1,{},0,0,0,0,0,0,0,0,1 50 | 4462,4462,A3528,False,65.0,0,{},0,1,0,0,0,0,0,0,0 51 | 4276,4276,A2033,False,84.0,1,{},0,0,0,1,0,0,0,0,0 52 | 1141,1141,A1199,False,69.0,1,{},1,0,0,0,0,0,0,0,0 53 | 6856,6856,A1322,False,70.0,0,{},0,1,0,0,0,0,0,0,0 54 | 1124,1124,A3494,False,56.0,1,{},0,1,0,0,0,0,0,0,0 55 | 5359,5359,A3125,False,82.0,1,{},0,0,0,1,0,0,0,0,0 56 | 1219,1219,A3576,False,68.0,1,{},0,0,0,0,1,0,0,0,0 57 | 6804,6804,A4148,False,65.0,0,{},0,0,0,1,0,0,0,0,0 58 | 6132,6132,A0896,False,63.0,0,{},0,0,0,0,1,0,0,0,0 59 | 408,408,A5674,False,54.0,0,{},1,0,0,0,0,0,0,0,0 60 | 1723,1723,A0881,False,56.0,0,{},0,1,0,0,1,0,0,0,0 61 | 4708,4708,A3112,False,78.0,0,{},1,0,0,0,0,0,0,0,0 62 | 597,597,A4901,False,69.0,0,{},0,1,0,0,0,0,0,0,0 63 | 6706,6706,A3254,False,91.0,1,{},1,0,0,0,0,0,0,0,0 64 | 4821,4821,A2632,False,70.0,1,{},0,0,0,0,1,0,0,0,0 65 | 6287,6287,A5217,False,53.0,1,{},0,0,0,1,0,0,0,0,0 66 | 5261,5261,A6759,False,29.0,1,{},0,0,0,0,1,0,0,0,0 67 | 1235,1235,A1101,False,90.0,1,{},0,0,0,0,1,0,0,0,0 68 | 5345,5345,A2165,False,59.0,1,{},0,0,1,0,0,0,0,0,0 69 | 1905,1905,A5541,False,74.0,1,{},0,0,0,0,1,0,0,0,0 70 | 6541,6541,A5618,False,83.0,0,{},0,0,0,0,0,0,1,0,0 71 | 1610,1610,A0932,False,49.0,1,{},1,0,0,0,1,0,0,0,0 72 | 3660,3660,A6620,False,83.0,0,{},0,0,0,1,0,0,0,0,0 73 | 4808,4808,A4186,False,84.0,1,{},0,0,0,1,0,0,0,0,0 74 | 5059,5059,A3123,False,54.0,1,{},0,0,0,0,0,0,1,0,0 75 | 579,579,A3639,False,68.0,1,{},0,1,0,0,0,0,0,0,0 76 | 1337,1337,A6457,False,63.0,0,{},0,0,0,0,0,0,1,0,0 77 | 1637,1637,A1582,False,25.0,0,{},0,0,1,0,0,0,0,0,0 78 | 2659,2659,A6398,False,76.0,0,{},1,0,0,0,0,0,0,0,0 79 | 5715,5715,A0223,False,67.0,1,{},0,0,0,1,0,0,0,0,0 80 | 2395,2395,A6276,False,60.0,0,{},0,0,0,0,0,0,1,0,0 81 | 1775,1775,A2930,False,18.0,1,{},0,0,0,0,1,0,0,0,0 82 | 4778,4778,A6819,False,48.0,0,{},0,0,0,0,0,0,0,0,1 83 | 3494,3494,A5282,False,56.0,0,{},0,1,0,0,1,0,0,0,0 84 | 2468,2468,A4899,False,78.0,1,{},0,0,0,1,0,0,0,0,0 85 | 6325,6325,A0821,False,82.0,0,{},0,0,0,0,0,0,1,0,0 86 | 4096,4096,A5739,False,44.0,1,{},0,0,0,0,1,0,0,0,0 87 | 6492,6492,A5167,False,27.0,0,{},0,0,0,0,1,0,0,0,0 88 | 3758,3758,A2108,False,65.0,1,{},0,0,0,0,1,0,0,0,0 89 | 6851,6851,A5430,False,63.0,1,{},0,0,0,1,0,0,0,0,0 90 | 5078,5078,A1071,False,84.0,1,{},0,0,0,0,1,0,0,0,0 91 | 3233,3233,A0612,False,82.0,1,{},0,0,0,1,0,0,0,0,0 92 | 4345,4345,A0022,False,27.0,0,{},0,0,0,0,1,0,0,0,0 93 | 4596,4596,A3476,False,91.0,1,{},1,0,0,0,0,0,0,0,0 94 | 5392,5392,A1334,False,57.0,0,{},0,1,0,0,0,0,0,0,0 95 | 3641,3641,A5210,False,86.0,1,{},0,1,0,0,0,0,0,0,0 96 | 6784,6784,A3162,False,51.0,0,{},0,0,0,0,1,0,0,0,0 97 | 2461,2461,A5267,False,75.0,0,{},1,0,0,0,0,0,0,0,0 98 | 2357,2357,A4588,False,69.0,1,{},0,0,0,0,0,0,0,0,1 99 | 1949,1949,A5353,False,65.0,1,{},1,0,0,0,0,0,0,0,0 100 | 1788,1788,A5562,False,62.0,0,{},0,0,0,0,1,0,0,0,0 101 | 2586,2586,A3258,False,26.0,0,{},0,0,0,0,0,0,0,0,1 102 | 4735,4735,A4616,False,74.0,1,{},1,0,0,0,1,0,0,0,0 103 | 6323,6323,A5187,False,77.0,1,{},1,0,0,0,0,0,0,0,0 104 | 5244,5244,A3015,False,67.0,1,{},0,0,0,0,0,0,0,0,1 105 | 68,68,A3152,False,52.0,1,{},0,0,0,0,0,0,0,0,1 106 | 1532,1532,A1982,False,89.0,1,{},0,0,0,1,0,0,0,0,0 107 | 6840,6840,A3367,False,75.0,1,{},1,0,0,0,0,0,0,0,0 108 | 4557,4557,A2800,False,62.0,1,{},0,0,0,0,1,0,0,0,0 109 | 5383,5383,A5229,False,66.0,1,{},0,0,1,0,0,0,0,0,0 110 | 975,975,A1452,False,77.0,0,{},0,0,0,0,1,0,1,0,0 111 | 5525,5525,A3048,False,66.0,1,{},1,0,0,0,0,0,0,0,0 112 | 1493,1493,A3924,False,44.0,0,{},0,0,1,0,0,0,0,0,0 113 | 6499,6499,A3457,False,39.0,0,{},0,0,0,0,0,0,0,0,1 114 | 3902,3902,A1294,False,35.0,1,{},0,1,0,0,0,0,0,0,0 115 | 2883,2883,A1316,False,85.0,0,{},0,0,0,0,0,0,0,0,1 116 | 1913,1913,A4736,False,79.0,0,{},0,0,0,0,0,0,0,1,0 117 | 3137,3137,A1888,False,29.0,0,{},0,0,0,0,0,0,0,0,1 118 | 2009,2009,A4887,False,72.0,0,{},0,0,0,0,0,0,0,0,1 119 | 1814,1814,A0653,False,15.0,0,{},0,0,1,0,0,0,0,0,0 120 | 44,44,A4507,False,69.0,0,{},0,0,0,1,0,0,0,0,0 121 | 6448,6448,A2995,False,62.0,1,{},0,0,0,0,1,0,0,0,0 122 | 4784,4784,A6806,False,55.0,1,{},0,0,1,0,0,0,0,0,0 123 | 4111,4111,A4695,False,74.0,1,{},0,0,0,1,0,0,0,0,0 124 | 2561,2561,A6636,False,79.0,0,{},0,0,0,0,1,0,0,0,0 125 | 4602,4602,A4036,False,64.0,0,{},0,0,1,0,0,0,0,0,0 126 | 3369,3369,A1588,False,84.0,1,{},0,0,0,0,0,1,0,0,0 127 | 664,664,A5247,False,64.0,1,{},0,0,0,0,1,0,0,0,0 128 | 4770,4770,A2869,False,58.0,1,{},0,0,0,0,0,0,0,0,1 129 | 2168,2168,A0382,False,84.0,1,{},0,0,0,0,1,0,0,0,0 130 | 4568,4568,A2296,False,87.0,0,{},0,0,0,0,1,0,1,0,0 131 | 5009,5009,A3418,False,70.0,0,{},1,0,0,0,0,0,0,0,0 132 | 3895,3895,A3500,False,26.0,0,{},0,0,1,0,0,0,0,0,0 133 | 1825,1825,A0409,False,84.0,1,{},0,0,0,1,0,0,0,0,0 134 | 4174,4174,A5463,False,35.0,1,{},0,0,0,0,0,0,1,0,0 135 | 5932,5932,A6616,False,28.0,0,{},0,0,0,0,0,0,0,0,1 136 | 5197,5197,A2250,False,77.0,0,{},0,0,0,0,1,0,0,0,0 137 | 1956,1956,A4195,False,68.0,1,{},0,0,0,0,0,0,1,0,0 138 | 5610,5610,A6269,False,45.0,0,{},0,0,1,0,0,0,0,0,0 139 | 2585,2585,A2497,False,18.0,1,{},0,0,0,0,1,0,0,0,0 140 | 55,55,A1873,False,45.0,1,{},0,0,0,0,1,0,0,0,0 141 | 3676,3676,A1489,False,56.0,1,{},0,0,1,0,0,0,0,0,0 142 | 6028,6028,A3223,False,80.0,1,{},1,0,0,0,0,0,0,0,0 143 | 4149,4149,A6027,False,52.0,0,{},0,0,0,0,0,0,0,0,1 144 | 5296,5296,A6500,False,58.0,0,{},0,1,0,0,0,0,0,0,0 145 | 5882,5882,A6488,False,51.0,0,{},1,0,0,0,0,0,0,0,0 146 | 5438,5438,A4674,False,67.0,0,{},0,0,0,0,0,0,0,0,1 147 | 1869,1869,A2228,False,25.0,0,{},0,1,0,0,0,0,0,0,0 148 | 2455,2455,A1864,False,12.0,0,{},0,0,1,0,0,0,0,0,0 149 | 4876,4876,A1683,False,62.0,0,{},0,0,0,0,1,0,0,0,0 150 | 3520,3520,A1649,False,65.0,0,{},0,0,1,0,0,0,0,0,0 151 | 146,146,A3185,False,71.0,1,{},0,0,0,0,0,0,1,0,0 152 | 4148,4148,A6781,False,64.0,1,{},1,0,0,0,0,0,0,0,0 153 | 1156,1156,A0793,False,24.0,0,{},0,0,0,0,1,0,0,0,0 154 | 1270,1270,A2779,False,66.0,1,{},0,0,0,0,1,0,0,0,0 155 | 1561,1561,A0287,False,18.0,0,{},0,0,1,0,0,0,0,0,0 156 | 2953,2953,A2440,False,60.0,0,{},0,0,0,0,0,0,1,0,0 157 | 1791,1791,A5122,False,62.0,0,{},0,0,0,0,0,0,0,0,1 158 | 6764,6764,A3643,False,32.0,0,{},0,0,1,0,0,0,0,0,0 159 | 681,681,A4140,False,53.0,1,{},0,1,0,0,0,0,0,0,0 160 | 6184,6184,A1733,False,46.0,1,{},0,0,0,0,1,0,0,0,0 161 | 1438,1438,A0226,False,80.0,0,{},0,0,0,0,1,0,0,0,0 162 | 4813,4813,A3907,False,72.0,0,{},0,0,0,0,1,0,0,0,0 163 | 5874,5874,A0665,False,58.0,1,{},0,0,1,0,0,0,0,0,0 164 | 2114,2114,A2011,False,80.0,1,{},1,1,0,0,0,0,0,0,0 165 | 3969,3969,A2523,False,27.0,1,{},0,0,0,0,0,1,0,0,0 166 | 4988,4988,A2267,False,37.0,0,{},0,0,0,0,0,0,1,0,1 167 | 156,156,A5615,False,17.0,0,{},0,1,0,0,0,0,0,0,0 168 | 4673,4673,A0440,False,59.0,0,{},0,1,0,0,0,0,0,0,0 169 | 5720,5720,A5624,False,67.0,1,{},1,0,0,0,0,0,0,0,0 170 | 4822,4822,A0027,False,55.0,0,{},0,0,0,0,1,0,0,0,0 171 | 5868,5868,A6700,False,71.0,0,{},0,0,0,0,1,0,0,0,0 172 | 5420,5420,A5886,False,81.0,1,{},0,0,0,1,0,0,0,0,0 173 | 2493,2493,A1521,False,80.0,0,{},0,0,0,1,0,0,0,0,0 174 | 5391,5391,A6107,False,73.0,0,{},0,0,0,0,1,0,0,0,0 175 | 5312,5312,A3585,False,80.0,0,{},0,0,0,0,0,0,1,0,0 176 | 5775,5775,A3762,False,51.0,1,{},0,0,0,0,0,0,1,0,0 177 | 2053,2053,A2276,False,36.0,1,{},0,0,0,0,1,0,0,0,0 178 | 496,496,A6538,False,79.0,1,{},0,0,0,1,0,0,0,0,0 179 | 759,759,A6022,False,60.0,0,{},1,0,0,0,0,0,0,0,0 180 | 4785,4785,A0579,False,49.0,0,{},0,0,1,0,0,0,0,0,0 181 | 3584,3584,A6472,False,38.0,0,{},0,0,0,0,0,0,0,0,1 182 | 2511,2511,A3556,False,60.0,0,{},0,0,0,0,0,0,0,1,0 183 | 3447,3447,A5692,False,61.0,1,{},0,0,0,0,1,0,0,0,0 184 | 1988,1988,A1347,False,86.0,1,{},0,0,0,0,0,0,1,0,0 185 | 878,878,A2913,False,60.0,0,{},0,0,0,0,0,0,0,0,1 186 | 3945,3945,A2333,False,41.0,0,{},0,0,0,0,0,0,0,0,1 187 | 6696,6696,A1764,False,73.0,0,{},1,0,0,0,0,0,0,0,0 188 | 2238,2238,A0284,False,71.0,1,{},0,0,0,1,0,0,0,0,0 189 | 3421,3421,A3650,False,31.0,0,{},0,0,1,0,0,0,0,0,0 190 | 3069,3069,A0015,False,41.0,0,{},0,0,0,0,1,0,0,0,0 191 | 5948,5948,A6444,False,70.0,0,{},1,0,0,0,0,0,0,0,0 192 | 6391,6391,A6798,False,82.0,0,{},1,0,0,0,1,0,0,0,0 193 | 6382,6382,A1689,False,74.0,0,{},0,0,0,1,1,0,0,0,0 194 | 6716,6716,A0078,False,64.0,1,{},0,0,0,0,1,0,0,0,0 195 | 916,916,A1463,False,27.0,0,{},0,0,1,0,0,0,0,0,0 196 | 4954,4954,A4554,False,66.0,1,{},1,0,0,0,0,0,0,0,0 197 | 1548,1548,A1875,False,84.0,1,{},1,0,0,0,0,0,0,0,0 198 | 6625,6625,A5326,False,56.0,0,{},0,0,0,0,1,0,0,0,0 199 | 4083,4083,A0863,False,69.0,0,{},0,1,0,0,0,0,0,0,0 200 | 5205,5205,A5212,False,48.0,1,{},0,0,0,0,1,0,0,0,0 201 | 6072,6072,A4874,False,80.0,0,{},1,0,0,0,0,0,0,0,0 202 | 1572,1572,A4981,False,75.0,1,{},1,0,0,0,0,0,0,0,0 203 | 4121,4121,A4008,False,71.0,1,{},0,0,0,0,1,0,0,0,0 204 | 6314,6314,A0391,False,71.0,1,{},0,0,0,1,0,0,0,0,0 205 | 824,824,A1211,False,29.0,0,{},0,0,0,0,0,0,1,0,0 206 | 2786,2786,A6241,False,52.0,1,{},0,0,0,0,1,0,0,0,0 207 | 4620,4620,A0117,False,70.0,1,{},1,0,0,0,0,0,0,0,0 208 | 2322,2322,A1393,False,76.0,1,{},0,1,0,0,0,0,0,0,0 209 | 3769,3769,A3947,False,42.0,0,{},0,0,1,0,0,0,0,0,0 210 | 3183,3183,A1460,False,51.0,0,{},0,0,0,0,0,0,0,0,1 211 | 3524,3524,A2666,False,89.0,1,{},0,0,0,1,0,0,0,0,0 212 | 5673,5673,A3147,False,84.0,1,{},1,0,0,0,0,0,0,0,0 213 | 4484,4484,A1124,False,88.0,0,{},1,0,0,0,0,0,0,0,0 214 | 3415,3415,A1238,False,77.0,1,{},1,0,0,0,0,0,0,0,0 215 | 6604,6604,A0123,False,83.0,0,{},0,1,0,0,0,0,0,0,0 216 | 4486,4486,A4046,False,59.0,0,{},0,0,0,0,0,0,0,0,1 217 | 2131,2131,A6723,False,47.0,0,{},0,0,0,0,0,0,1,0,0 218 | 3811,3811,A1025,False,85.0,1,{},1,0,0,0,0,0,0,0,0 219 | 4368,4368,A5612,False,46.0,1,{},0,0,0,0,1,0,0,0,0 220 | 1168,1168,A2218,False,49.0,0,{},0,1,0,0,0,0,0,0,0 221 | 6504,6504,A3897,False,86.0,1,{},0,0,0,0,0,0,1,0,0 222 | 2688,2688,A5716,False,65.0,0,{},0,0,1,0,0,0,0,0,0 223 | 2081,2081,A4354,False,69.0,0,{},1,0,0,0,0,0,0,0,0 224 | 2022,2022,A2368,False,67.0,1,{},0,0,0,0,1,0,0,0,0 225 | 6188,6188,A6595,False,37.0,0,{},0,0,0,0,0,0,0,0,1 226 | 3047,3047,A5032,False,72.0,1,{},0,0,0,0,0,0,1,0,0 227 | 5094,5094,A3358,False,26.0,0,{},0,0,1,0,0,0,0,0,0 228 | 3399,3399,A4758,False,33.0,0,{},0,0,0,0,0,0,0,0,1 229 | 2826,2826,A6230,False,95.0,0,{},0,0,0,0,0,0,0,0,1 230 | 2902,2902,A1274,False,66.0,1,{},0,0,0,0,0,0,1,0,0 231 | 2881,2881,A0443,False,62.0,0,{},0,0,0,0,1,0,0,0,0 232 | 826,826,A5194,False,52.0,1,{},0,0,1,0,0,0,0,0,0 233 | 2326,2326,A1633,False,70.0,1,{},0,0,0,0,0,0,0,0,1 234 | 1129,1129,A2887,False,64.0,1,{},1,0,0,0,0,0,0,0,0 235 | 4319,4319,A1027,False,76.0,1,{},1,0,0,0,1,0,0,0,0 236 | 30,30,A1001,False,41.0,1,{},0,1,0,0,0,0,0,0,0 237 | 5804,5804,A1734,False,65.0,0,{},0,0,0,0,0,0,0,0,1 238 | 1533,1533,A1031,False,67.0,0,{},0,0,0,1,0,0,0,0,0 239 | 1225,1225,A5702,False,84.0,1,{},0,0,0,0,0,0,0,1,0 240 | 329,329,A2476,False,50.0,1,{},0,1,0,0,0,0,0,0,0 241 | 1069,1069,A5772,False,76.0,1,{},1,0,0,0,0,0,0,0,0 242 | 4295,4295,A3686,False,75.0,1,{},0,0,0,0,1,0,0,0,0 243 | 1250,1250,A3938,False,64.0,1,{},0,0,0,1,0,0,0,0,0 244 | 3756,3756,A5568,False,82.0,0,{},0,0,0,1,0,0,0,0,0 245 | 1399,1399,A2140,False,22.0,0,{},0,0,0,0,1,0,0,0,0 246 | 1568,1568,A3176,False,79.0,0,{},0,1,0,0,0,0,0,0,0 247 | 3437,3437,A6345,False,76.0,1,{},0,1,0,0,0,0,0,0,0 248 | 2121,2121,A0996,False,59.0,1,{},0,0,0,0,1,0,0,0,0 249 | 5198,5198,A5404,False,79.0,1,{},1,0,0,0,0,0,0,0,0 250 | 3504,3504,A0749,False,47.0,0,{},0,0,0,0,0,0,1,0,0 251 | 3042,3042,A0398,False,51.0,1,{},0,0,0,0,0,0,0,1,0 252 | 2387,2387,A6579,False,40.0,0,{},0,0,0,1,0,0,0,0,0 253 | 5666,5666,A4727,False,66.0,1,{},1,0,0,0,0,0,0,0,0 254 | 6723,6723,A1480,False,50.0,1,{},0,0,1,0,0,0,0,0,0 255 | 1093,1093,A6157,False,36.0,1,{},0,0,0,0,1,0,0,0,0 256 | 6666,6666,A6735,False,13.0,1,{},0,0,1,0,0,0,0,0,0 257 | 1750,1750,A6625,False,61.0,0,{},0,0,0,0,0,0,0,0,1 258 | 6271,6271,A2474,False,58.0,1,{},0,0,0,0,0,1,0,0,0 259 | 3001,3001,A2543,False,23.0,0,{},0,0,1,0,0,0,0,0,0 260 | 2037,2037,A3063,False,90.0,1,{},0,0,0,0,1,1,0,0,0 261 | 3724,3724,A0571,False,73.0,1,{},0,0,0,0,1,0,0,0,0 262 | 1855,1855,A2064,False,69.0,1,{},1,0,0,0,0,0,0,0,0 263 | 359,359,A4954,False,44.0,0,{},0,0,0,0,1,0,0,0,0 264 | 1625,1625,A3730,False,64.0,1,{},0,0,0,0,1,0,0,0,0 265 | 892,892,A5222,False,41.0,0,{},0,0,0,0,1,0,0,0,1 266 | 1087,1087,A1812,False,35.0,1,{},0,0,0,0,1,0,0,0,0 267 | 6670,6670,A2514,False,63.0,0,{},0,1,0,1,0,0,0,0,0 268 | 1017,1017,A2063,False,45.0,0,{},0,0,0,0,0,0,0,0,1 269 | 1297,1297,A1947,False,32.0,0,{},0,1,0,0,0,0,0,0,0 270 | 3621,3621,A1713,False,90.0,0,{},1,0,0,0,0,0,0,0,0 271 | 4933,4933,A4224,False,42.0,1,{},0,0,0,0,1,0,0,0,0 272 | 1029,1029,A5403,False,64.0,0,{},0,0,0,0,0,0,0,0,1 273 | 6321,6321,A2411,False,53.0,1,{},0,0,0,1,0,0,0,0,0 274 | 6570,6570,A3253,False,67.0,1,{},0,0,0,0,1,0,0,0,0 275 | 3256,3256,A2570,False,6.0,0,{},0,0,1,0,0,0,0,0,0 276 | 2920,2920,A6017,False,25.0,0,{},0,0,0,0,0,0,0,0,1 277 | 4794,4794,A3931,False,83.0,0,{},0,0,0,0,1,0,0,0,0 278 | 1277,1277,A0760,False,73.0,0,{},0,0,0,0,0,0,0,1,0 279 | 1883,1883,A4655,False,79.0,1,{},0,0,0,1,0,0,0,0,0 280 | 5800,5800,A6156,False,54.0,0,{},0,0,0,0,0,0,0,0,1 281 | 511,511,A0044,False,84.0,1,{},0,0,0,0,0,0,0,1,0 282 | 107,107,A6168,False,62.0,1,{},0,0,0,0,0,0,0,0,1 283 | 1780,1780,A1968,False,74.0,0,{},1,0,0,0,0,0,0,0,0 284 | 5256,5256,A6400,False,46.0,0,{},0,0,0,0,0,0,0,0,1 285 | 4550,4550,A0972,False,68.0,1,{},0,0,0,0,0,0,1,0,0 286 | 215,215,A6275,False,12.0,1,{},0,0,0,0,0,1,0,0,0 287 | 3467,3467,A1760,False,85.0,0,{},0,0,0,1,0,0,0,0,0 288 | 3210,3210,A3182,False,78.0,0,{},1,0,0,0,0,0,0,0,0 289 | 4237,4237,A3447,False,75.0,1,{},0,0,0,0,1,0,0,0,0 290 | 6401,6401,A5433,False,84.0,0,{},1,0,0,0,0,0,0,0,0 291 | 345,345,A4026,False,27.0,0,{},0,0,0,0,0,0,0,0,1 292 | 4716,4716,A2671,False,62.0,1,{},0,0,0,0,0,0,1,0,0 293 | 1281,1281,A2383,False,44.0,0,{},0,0,1,0,0,0,0,0,0 294 | 4688,4688,A0837,False,50.0,0,{},0,1,0,0,0,0,0,0,0 295 | 6628,6628,A3424,False,10.0,1,{},0,0,1,0,0,0,0,0,0 296 | 2727,2727,A5005,False,73.0,1,{},0,0,0,1,0,0,0,0,0 297 | 3942,3942,A0556,False,84.0,1,{},0,0,0,0,0,0,0,1,0 298 | 1692,1692,A0268,False,64.0,1,{},0,0,0,0,1,0,0,0,0 299 | 3955,3955,A5926,False,34.0,0,{},0,0,0,0,0,0,1,0,0 300 | 5810,5810,A5397,False,65.0,1,{},1,0,0,0,0,0,0,0,0 301 | 2318,2318,A1054,False,78.0,1,{},1,0,0,0,0,0,0,0,0 302 | 5089,5089,A0173,False,34.0,0,{},0,0,1,0,0,0,0,0,0 303 | 3320,3320,A1940,False,88.0,1,{},0,0,0,0,0,0,1,0,0 304 | 6648,6648,A0567,False,30.0,0,{},0,0,1,0,0,0,0,0,0 305 | 1781,1781,A3091,False,75.0,1,{},0,0,0,0,1,0,0,0,0 306 | 3331,3331,A4775,False,78.0,1,{},0,0,0,0,1,0,0,0,0 307 | 4065,4065,A6415,False,47.0,0,{},0,0,1,0,0,0,0,0,0 308 | 5056,5056,A0906,False,68.0,1,{},0,0,0,0,1,0,0,0,0 309 | 4329,4329,A1068,False,58.0,0,{},0,1,0,0,0,0,0,0,0 310 | 1471,1471,A2098,False,66.0,0,{},0,0,0,0,1,0,0,0,0 311 | 6181,6181,A0026,False,74.0,1,{},1,0,0,0,0,0,0,0,0 312 | 1842,1842,A1589,False,79.0,1,{},1,0,0,0,0,0,0,0,0 313 | 937,937,A6196,False,13.0,1,{},0,0,1,0,0,0,0,0,0 314 | 1390,1390,A3750,False,64.0,1,{},0,0,0,1,0,0,0,0,0 315 | 5784,5784,A2589,False,80.0,1,{},0,0,0,0,0,0,1,0,0 316 | 1464,1464,A5480,False,47.0,1,{},0,0,0,0,1,0,0,0,0 317 | 6326,6326,A1139,False,62.0,1,{},0,0,0,0,1,0,0,0,0 318 | 2836,2836,A4994,False,54.0,0,{},0,0,1,0,0,0,0,0,0 319 | 6817,6817,A1593,False,75.0,0,{},0,0,0,0,0,0,0,0,1 320 | 4790,4790,A2840,False,70.0,1,{},0,0,0,0,1,0,0,0,0 321 | 6870,6870,A1974,False,23.0,1,{},0,0,0,1,0,0,0,0,0 322 | 809,809,A1414,False,70.0,1,{},0,0,0,1,0,0,0,0,0 323 | 6497,6497,A2272,False,43.0,0,{},0,0,0,0,1,0,0,0,0 324 | 298,298,A3914,False,36.0,1,{},0,1,0,0,0,0,0,0,0 325 | 1645,1645,A1278,False,62.0,0,{},0,0,0,0,0,0,1,0,1 326 | 1953,1953,A1106,False,23.0,0,{},0,0,1,0,0,0,0,0,0 327 | 479,479,A0885,False,80.0,0,{},0,0,0,0,1,0,0,0,0 328 | 1666,1666,A3919,False,64.0,1,{},0,0,0,0,0,0,0,0,1 329 | 3818,3818,A2273,False,49.0,0,{},0,0,1,0,0,0,0,0,0 330 | 2584,2584,A1007,False,73.0,0,{},1,0,0,0,0,0,0,0,0 331 | 3026,3026,A2768,False,28.0,0,{},0,0,0,0,0,0,0,0,1 332 | 5082,5082,A0271,False,92.0,0,{},1,0,0,0,0,0,0,0,0 333 | 385,385,A2268,False,30.0,0,{},0,0,0,0,0,0,1,0,0 334 | 3374,3374,A3175,False,43.0,0,{},0,0,1,0,0,0,0,0,0 335 | 4352,4352,A2885,False,93.0,1,{},1,0,0,0,0,0,0,0,0 336 | 6486,6486,A0248,False,83.0,0,{},0,0,0,1,0,0,0,0,0 337 | 4907,4907,A4345,False,84.0,0,{},1,0,0,0,0,0,0,0,0 338 | 3339,3339,A6274,False,29.0,1,{},0,0,1,0,0,0,0,0,0 339 | 3645,3645,A0096,False,45.0,1,{},0,0,0,0,1,0,0,0,0 340 | 5321,5321,A1806,False,44.0,0,{},0,0,1,0,0,0,0,0,0 341 | 3580,3580,A3920,False,49.0,0,{},0,0,1,0,0,0,0,0,0 342 | 1873,1873,A2410,False,60.0,1,{},0,0,0,1,0,0,0,0,0 343 | 6743,6743,A0585,False,37.0,1,{},0,0,0,1,0,0,0,0,0 344 | 1739,1739,A0446,False,87.0,0,{},1,0,0,0,0,0,0,0,0 345 | 367,367,A1892,False,43.0,1,{},0,0,1,0,0,0,0,0,0 346 | 6160,6160,A1122,False,29.0,0,{},0,0,1,0,0,0,0,0,0 347 | 3231,3231,A1470,False,36.0,1,{},0,0,1,0,0,0,0,0,0 348 | 115,115,A6224,False,77.0,1,{},0,1,0,0,0,0,0,0,0 349 | 6416,6416,A2850,False,57.0,1,{},1,0,0,0,0,0,0,0,0 350 | 1173,1173,A0685,False,90.0,1,{},0,0,0,1,0,0,0,0,1 351 | 6423,6423,A1657,False,89.0,0,{},0,0,0,0,0,1,0,0,0 352 | 2700,2700,A3277,False,80.0,1,{},1,0,0,0,0,0,0,0,0 353 | 5520,5520,A0170,False,22.0,0,{},0,0,1,0,0,0,0,0,0 354 | 2656,2656,A0093,False,49.0,1,{},0,0,0,0,0,1,0,0,0 355 | 742,742,A0759,False,40.0,1,{},0,0,0,0,1,0,0,0,0 356 | 3302,3302,A6596,False,47.0,1,{},0,0,0,0,1,0,0,0,0 357 | 6528,6528,A5225,False,56.0,1,{},1,0,0,0,0,0,0,0,0 358 | 1140,1140,A0419,False,48.0,1,{},0,0,1,0,0,0,0,0,0 359 | 2154,2154,A6179,False,77.0,1,{},0,0,0,0,1,0,0,0,0 360 | 1252,1252,A2073,False,51.0,1,{},0,1,0,0,0,0,0,0,0 361 | 4959,4959,A6452,False,89.0,0,{},0,0,0,1,0,0,0,0,0 362 | 1757,1757,A4344,False,65.0,1,{},1,0,0,0,0,0,0,1,0 363 | 4320,4320,A5098,False,78.0,1,{},1,0,0,0,0,0,0,0,0 364 | 3048,3048,A1203,False,85.0,1,{},0,0,0,0,1,0,0,0,0 365 | 3502,3502,A5155,False,76.0,0,{},0,0,1,0,0,0,0,0,0 366 | 638,638,A5972,False,87.0,0,{},0,0,0,0,1,0,1,0,0 367 | 5961,5961,A3326,False,55.0,1,{},0,0,0,0,1,0,0,0,0 368 | 418,418,A5779,False,20.0,1,{},0,0,1,0,0,0,0,0,0 369 | 6467,6467,A6775,False,37.0,0,{},0,0,0,0,0,0,0,0,1 370 | 6345,6345,A6697,False,13.0,0,{},0,0,1,0,0,0,0,0,0 371 | 785,785,A0488,False,47.0,1,{},0,0,0,0,0,0,0,0,1 372 | 774,774,A6237,False,78.0,0,{},0,1,0,0,1,0,0,0,0 373 | 2981,2981,A4853,False,58.0,0,{},0,0,1,0,0,0,0,0,0 374 | 1387,1387,A5125,False,60.0,0,{},0,0,1,0,0,0,0,0,0 375 | 3813,3813,A3684,False,85.0,1,{},0,0,0,1,0,0,0,0,0 376 | 2507,2507,A3999,False,59.0,0,{},0,0,0,0,1,0,0,0,0 377 | 5844,5844,A5598,False,66.0,0,{},0,0,0,1,0,0,0,0,0 378 | 6018,6018,A5123,False,54.0,1,{},0,0,0,0,1,0,0,0,0 379 | 3991,3991,A6097,False,60.0,1,{},1,0,0,0,0,0,0,0,0 380 | 4118,4118,A0955,False,40.0,0,{},0,0,1,0,0,0,0,0,0 381 | 1157,1157,A5346,False,78.0,0,{},0,0,0,0,1,0,0,0,0 382 | 2739,2739,A1333,False,81.0,1,{},0,0,0,0,1,0,0,0,0 383 | 2658,2658,A4562,False,74.0,0,{},0,0,0,0,0,0,1,0,0 384 | 6424,6424,A0858,False,31.0,0,{},0,0,1,0,0,0,0,0,0 385 | 2791,2791,A5993,False,76.0,1,{},0,0,0,0,0,0,1,0,0 386 | 5908,5908,A2896,False,69.0,0,{},0,0,0,0,0,0,0,0,1 387 | 5411,5411,A3661,False,57.0,1,{},0,0,0,0,0,0,1,0,0 388 | 6702,6702,A2735,False,22.0,1,{},0,0,0,0,0,0,1,0,0 389 | 3651,3651,A5669,False,68.0,1,{},0,0,0,0,1,0,0,0,0 390 | 781,781,A2585,False,60.0,0,{},0,0,0,0,1,0,0,0,0 391 | 5422,5422,A4701,False,73.0,1,{},0,0,0,0,1,0,0,0,0 392 | 5930,5930,A1046,False,50.0,0,{},0,0,0,0,0,0,0,0,1 393 | 2006,2006,A3178,False,83.0,1,{},0,0,0,1,0,0,0,0,0 394 | 6407,6407,A4424,False,46.0,1,{},0,0,0,0,0,1,0,0,0 395 | 4548,4548,A4672,False,71.0,0,{},0,0,0,0,0,0,1,0,0 396 | 3357,3357,A3619,False,43.0,0,{},0,0,1,0,0,0,0,0,0 397 | 5025,5025,A2802,False,68.0,0,{},0,0,0,0,0,0,0,0,1 398 | 1771,1771,A2025,False,33.0,1,{},0,0,0,0,1,0,0,0,0 399 | 4364,4364,A2944,False,32.0,0,{},0,0,1,0,0,0,0,0,0 400 | 2308,2308,A6522,False,37.0,0,{},0,0,0,0,0,0,0,0,1 401 | 268,268,A4374,False,51.0,1,{},0,0,1,0,0,0,0,0,0 402 | 1543,1543,A1300,False,89.0,0,{},0,0,0,0,0,0,0,1,0 403 | 2502,2502,A4093,False,18.0,0,{},0,0,1,0,0,0,0,0,0 404 | 1343,1343,A6042,False,48.0,0,{},0,0,1,0,0,0,0,0,0 405 | 4977,4977,A4399,False,84.0,1,{},0,0,1,0,0,0,0,0,0 406 | 5158,5158,A1021,False,60.0,0,{},0,1,0,0,0,0,0,0,0 407 | 4317,4317,A4096,False,39.0,0,{},0,0,0,0,1,0,0,0,0 408 | 4473,4473,A3793,False,46.0,1,{},0,0,0,0,1,0,0,0,0 409 | 2758,2758,A5075,False,64.0,1,{},0,0,0,0,1,0,0,0,0 410 | 3028,3028,A0062,False,66.0,1,{},0,0,0,0,0,0,1,0,0 411 | 3596,3596,A0993,False,93.0,1,{},1,0,0,0,0,0,0,0,0 412 | 1952,1952,A3832,False,59.0,1,{},0,0,1,0,0,0,0,0,0 413 | 6458,6458,A6204,False,74.0,1,{},0,0,0,0,0,0,1,0,0 414 | 6194,6194,A4108,False,76.0,1,{},1,0,0,0,0,0,0,0,0 415 | 1612,1612,A4061,False,24.0,0,{},0,0,0,0,0,0,0,0,1 416 | 2401,2401,A3723,False,53.0,0,{},0,0,1,0,0,0,0,0,0 417 | 321,321,A1511,False,55.0,0,{},0,0,0,0,0,0,1,0,0 418 | 2372,2372,A0030,False,46.0,1,{},0,0,1,0,0,0,0,0,0 419 | 3563,3563,A1740,False,60.0,0,{},1,0,0,0,1,0,0,0,0 420 | 6417,6417,A3945,False,61.0,0,{},0,1,0,0,0,0,0,0,0 421 | 2266,2266,A1252,False,81.0,1,{},1,0,0,0,0,0,0,0,0 422 | 1186,1186,A1925,False,86.0,1,{},0,0,0,1,0,0,0,0,0 423 | 2503,2503,A0878,False,69.0,0,{},0,0,0,0,1,0,0,0,0 424 | 6013,6013,A3511,False,41.0,0,{},0,0,0,0,0,0,0,0,1 425 | 3708,3708,A5987,False,75.0,1,{},0,0,0,1,0,0,0,0,0 426 | 6157,6157,A6585,False,50.0,1,{},0,0,1,0,0,0,0,0,0 427 | 5372,5372,A1273,False,65.0,1,{},0,0,0,0,0,0,0,1,0 428 | 5619,5619,A5947,False,68.0,1,{},0,0,0,1,0,0,0,0,0 429 | 3443,3443,A2729,False,61.0,0,{},0,1,0,0,0,0,0,0,0 430 | 4796,4796,A2304,False,66.0,1,{},0,0,0,0,0,1,0,0,0 431 | 964,964,A5431,False,79.0,0,{},0,0,0,0,0,1,0,0,0 432 | 2852,2852,A2091,False,79.0,0,{},0,0,0,0,0,0,1,0,0 433 | 6525,6525,A0049,False,85.0,0,{},0,0,0,0,0,0,1,0,0 434 | 2790,2790,A1886,False,68.0,1,{},0,0,0,0,1,0,0,0,0 435 | 1549,1549,A2468,False,66.0,0,{},0,1,0,0,0,0,0,0,0 436 | 3680,3680,A0633,False,89.0,0,{},0,0,0,0,1,0,0,0,0 437 | 1581,1581,A5360,False,68.0,1,{},0,0,0,1,0,0,0,0,0 438 | 5738,5738,A4781,False,54.0,1,{},0,0,1,0,0,0,0,0,0 439 | 6543,6543,A5083,False,69.0,1,{},1,0,0,0,0,0,0,0,0 440 | 6046,6046,A2956,False,15.0,0,{},0,0,1,0,0,0,0,0,0 441 | 4245,4245,A1437,False,76.0,0,{},0,0,1,0,0,0,0,0,0 442 | 3891,3891,A5179,False,78.0,0,{},0,0,0,0,0,0,1,0,0 443 | 6290,6290,A6197,False,50.0,0,{},0,1,0,0,0,0,0,0,0 444 | 3909,3909,A6166,False,65.0,1,{},0,0,0,0,0,0,0,1,0 445 | 1334,1334,A1144,False,38.0,1,{},0,0,1,0,0,0,0,0,0 446 | 1746,1746,A2510,False,52.0,0,{},0,0,0,1,0,0,0,0,0 447 | 4414,4414,A5441,False,54.0,0,{},0,1,0,0,1,0,0,0,0 448 | 5958,5958,A2008,False,29.0,0,{},0,0,0,0,1,0,0,0,0 449 | 3911,3911,A1117,False,80.0,0,{},0,0,0,0,1,0,0,0,0 450 | 4817,4817,A2157,False,81.0,1,{},1,0,0,0,0,0,0,0,0 451 | 4996,4996,A4504,False,12.0,0,{},0,0,1,0,0,0,0,0,0 452 | 6698,6698,A2418,False,33.0,0,{},0,1,0,0,0,0,0,0,0 453 | 4580,4580,A3481,False,48.0,1,{},0,1,0,0,0,0,0,0,0 454 | 2049,2049,A3694,False,81.0,1,{},0,0,0,0,0,0,1,0,0 455 | 349,349,A5314,False,15.0,1,{},0,0,0,0,0,1,0,0,0 456 | 2887,2887,A3351,False,34.0,0,{},0,0,1,0,0,0,0,0,0 457 | 6454,6454,A4673,False,92.0,1,{},0,0,0,0,0,0,1,1,0 458 | 137,137,A1650,False,90.0,1,{},0,0,0,1,0,0,0,0,0 459 | 5055,5055,A4382,False,47.0,0,{},0,0,0,0,1,0,0,0,0 460 | 5634,5634,A5088,False,23.0,0,{},0,0,0,0,1,0,0,0,0 461 | 5156,5156,A1163,False,55.0,1,{},1,0,0,0,0,0,0,0,0 462 | 1598,1598,A4799,False,74.0,0,{},0,0,0,0,0,0,0,0,1 463 | 1971,1971,A0017,False,60.0,0,{},1,0,0,0,0,0,0,0,0 464 | 680,680,A0539,False,65.0,0,{},0,0,0,1,0,0,0,0,0 465 | 1443,1443,A2569,False,33.0,0,{},0,0,1,0,0,0,0,0,0 466 | 1571,1571,A0028,False,58.0,0,{},0,0,0,0,1,0,0,0,0 467 | 1595,1595,A4129,False,72.0,0,{},1,0,0,0,0,0,0,0,0 468 | 1889,1889,A6635,False,50.0,0,{},0,0,0,0,1,0,0,0,0 469 | 5483,5483,A1900,False,27.0,0,{},0,1,0,0,0,0,0,0,0 470 | 701,701,A5865,False,32.0,0,{},0,1,0,0,0,0,0,0,0 471 | 2831,2831,A2902,False,69.0,1,{},1,0,0,0,0,0,0,0,0 472 | 1736,1736,A5890,False,73.0,0,{},0,0,0,0,0,0,1,0,0 473 | 419,419,A3427,False,49.0,1,{},0,0,0,0,0,0,0,0,1 474 | 1605,1605,A2434,False,80.0,1,{},0,0,0,0,0,0,1,0,0 475 | 6164,6164,A5213,False,81.0,1,{},0,0,0,0,1,0,0,0,0 476 | 3556,3556,A3755,False,59.0,1,{},1,0,0,0,0,0,0,0,0 477 | 1467,1467,A4714,False,30.0,0,{},0,0,1,0,0,0,0,0,0 478 | 2471,2471,A2086,False,25.0,1,{},0,0,1,0,0,0,0,0,0 479 | 3448,3448,A3339,False,83.0,0,{},1,0,0,0,0,0,0,0,0 480 | 4370,4370,A3134,False,64.0,0,{},0,0,1,0,0,0,0,0,0 481 | 779,779,A2701,False,35.0,1,{},0,0,0,0,1,0,0,0,0 482 | 874,874,A1091,False,61.0,0,{},0,0,0,0,0,0,1,0,0 483 | 967,967,A1708,False,52.0,0,{},0,0,0,0,0,0,0,0,1 484 | 4872,4872,A3212,False,79.0,1,{},1,0,0,0,0,0,0,0,0 485 | 5694,5694,A2188,False,44.0,1,{},0,1,0,0,0,0,0,0,0 486 | 3587,3587,A2380,False,65.0,0,{},0,0,1,0,0,0,0,0,0 487 | 1335,1335,A1457,False,72.0,1,{},0,0,0,1,0,0,0,0,0 488 | 2101,2101,A4502,False,79.0,0,{},1,0,0,0,0,0,0,0,0 489 | 3297,3297,A5007,False,43.0,1,{},0,1,0,0,0,0,0,0,0 490 | 3578,3578,A1006,False,46.0,1,{},0,0,0,1,0,0,0,0,0 491 | 4107,4107,A1167,False,68.0,1,{},0,0,0,0,1,0,0,0,0 492 | 3834,3834,A4408,False,83.0,1,{},0,1,0,0,0,0,0,1,0 493 | 1105,1105,A5234,False,61.0,1,{},0,0,0,0,1,0,0,0,0 494 | 490,490,A2128,False,33.0,0,{},0,1,0,0,0,0,0,0,0 495 | 1773,1773,A5449,False,40.0,0,{},0,0,1,0,0,0,0,0,0 496 | 3907,3907,A3181,False,67.0,0,{},0,0,0,0,1,0,0,0,0 497 | 920,920,A6514,False,79.0,0,{},0,0,0,0,1,0,0,0,0 498 | 2043,2043,A4333,False,52.0,1,{},0,0,0,0,1,0,0,0,0 499 | 1575,1575,A3183,False,83.0,1,{},1,0,0,0,0,0,0,0,0 500 | 6853,6853,A2971,False,65.0,0,{},0,0,1,0,0,0,0,0,0 501 | 4949,4949,A1137,False,29.0,0,{},0,0,1,0,0,0,0,0,0 502 | 5860,5860,A1753,False,71.0,1,{},1,0,0,0,0,0,0,0,0 503 | 1938,1938,A3843,False,71.0,0,{},0,0,0,0,1,0,0,0,0 504 | 4539,4539,A6813,False,60.0,1,{},0,0,0,0,0,0,0,1,0 505 | 929,929,A3859,False,67.0,0,{},0,0,0,0,1,0,0,0,0 506 | 4626,4626,A1545,False,74.0,1,{},0,0,0,1,0,0,0,0,0 507 | 6601,6601,A3228,False,35.0,0,{},0,0,1,0,0,0,0,0,0 508 | 5112,5112,A6108,False,61.0,1,{},0,0,0,0,1,0,0,0,0 509 | 2260,2260,A0857,False,61.0,1,{},1,0,0,0,0,0,0,1,0 510 | 5174,5174,A4393,False,67.0,0,{},0,0,0,0,1,0,0,0,0 511 | 2156,2156,A2406,False,74.0,1,{},1,0,0,0,0,0,0,0,0 512 | 5512,5512,A5001,False,31.0,0,{},0,0,0,0,0,0,0,0,1 513 | 3,3,A0346,False,60.0,1,{},0,0,0,0,0,0,1,0,0 514 | 3281,3281,A3665,False,24.0,0,{},0,0,0,0,0,0,0,0,1 515 | 1990,1990,A5057,False,55.0,1,{},0,0,0,0,1,0,0,0,0 516 | 5176,5176,A5251,False,59.0,0,{},0,1,0,0,0,0,0,0,1 517 | 361,361,A1478,False,46.0,1,{},0,0,0,0,1,0,0,0,0 518 | 728,728,A6656,False,82.0,1,{},0,0,0,0,1,0,0,0,0 519 | 3787,3787,A6653,False,43.0,0,{},0,1,0,0,0,0,0,0,0 520 | 6431,6431,A6002,False,72.0,0,{},0,1,0,0,0,0,0,0,0 521 | 3276,3276,A3275,False,71.0,1,{},0,0,0,0,1,0,0,0,0 522 | 6075,6075,A6315,False,74.0,1,{},0,0,0,0,0,0,1,0,0 523 | 5348,5348,A3626,False,82.0,1,{},1,0,0,0,1,0,0,0,0 524 | 880,880,A6083,False,59.0,1,{},0,0,0,0,1,0,0,0,0 525 | 5557,5557,A1250,False,47.0,0,{},0,0,0,0,0,0,1,0,0 526 | 4493,4493,A6142,False,37.0,0,{},0,1,0,0,0,0,0,0,0 527 | 5735,5735,A2563,False,74.0,0,{},1,0,0,0,0,0,0,0,0 528 | 2100,2100,A2457,False,81.0,0,{},0,0,0,0,0,0,0,0,1 529 | 2110,2110,A5584,False,77.0,1,{},1,0,0,0,0,0,0,0,0 530 | 721,721,A1196,False,77.0,1,{},0,0,0,0,1,0,0,0,0 531 | 3100,3100,A0848,False,90.0,0,{},1,0,0,0,1,0,0,0,0 532 | 6730,6730,A3647,False,76.0,0,{},0,0,0,0,0,0,1,0,0 533 | 3739,3739,A4380,False,55.0,1,{},0,0,0,1,0,0,0,0,0 534 | 5827,5827,A3543,False,76.0,1,{},1,0,0,0,0,0,0,0,0 535 | 5532,5532,A3106,False,69.0,0,{},0,0,0,0,0,0,1,0,0 536 | 6233,6233,A5795,False,61.0,0,{},1,0,0,0,1,0,0,0,0 537 | 3540,3540,A6063,False,74.0,1,{},0,0,0,0,0,0,1,0,0 538 | 4915,4915,A4536,False,75.0,1,{},0,1,0,0,1,0,0,0,0 539 | 6672,6672,A5215,False,87.0,1,{},0,0,0,0,1,0,1,0,0 540 | 3852,3852,A1226,False,29.0,1,{},0,0,0,0,0,0,1,0,0 541 | 1,1,A1488,False,77.0,1,{},0,0,0,1,0,0,0,0,0 542 | 3082,3082,A5309,False,84.0,0,{},0,0,0,0,1,0,1,0,0 543 | 2054,2054,A5688,False,30.0,1,{},0,0,0,1,0,0,0,0,0 544 | 2309,2309,A3601,False,72.0,0,{},0,0,1,0,0,0,0,0,0 545 | 2878,2878,A6709,False,54.0,1,{},0,0,1,0,0,0,0,0,0 546 | 2666,2666,A0827,False,28.0,1,{},0,0,1,0,0,0,0,0,0 547 | 3023,3023,A2079,False,68.0,1,{},0,0,0,0,1,0,0,0,1 548 | 5530,5530,A4767,False,63.0,0,{},0,0,0,0,0,0,0,0,1 549 | 12,12,A2891,False,16.0,0,{},0,0,1,0,0,0,0,0,0 550 | 3039,3039,A1544,False,59.0,0,{},0,1,0,0,1,0,0,0,0 551 | 891,891,A1100,False,82.0,1,{},0,0,0,0,1,0,0,0,0 552 | 625,625,A5514,False,68.0,0,{},0,0,0,1,0,0,0,0,0 553 | -------------------------------------------------------------------------------- /finetune/finetune_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.model_selection import train_test_split 6 | import os 7 | import wfdb 8 | from scipy.io import loadmat 9 | 10 | ''' 11 | In this code: 12 | PTB-XL has four subset: superclass, subclass, form, rhythm 13 | ICBEB is CPSC2018 dataset mentioned in the original paper 14 | Chapman is the CSN dataset from the original paper 15 | ''' 16 | 17 | class ECGDataset(Dataset): 18 | def __init__(self, data_path, csv_file, mode='train', dataset_name='ptbxl', backbone='resnet18'): 19 | """ 20 | Args: 21 | data_path (string): Path to store raw data. 22 | csv_file (string): Path to the .csv file with labels and data path. 23 | mode (string): ptbxl/icbeb/chapman. 24 | """ 25 | self.dataset_name = dataset_name 26 | 27 | if self.dataset_name == 'ptbxl': 28 | self.labels_name = list(csv_file.columns[6:]) 29 | self.num_classes = len(self.labels_name) 30 | 31 | self.data_path = data_path 32 | self.ecg_path = csv_file['filename_hr'] 33 | # in ptbxl, the column 0-5 is other meta data, the column 6-end is the label 34 | self.labels = csv_file.iloc[:, 6:].values 35 | 36 | elif self.dataset_name == 'icbeb': 37 | self.labels_name = list(csv_file.columns[7:]) 38 | self.num_classes = len(self.labels_name) 39 | 40 | self.data_path = data_path 41 | self.ecg_path = csv_file['ecg_id'].astype(str) 42 | # in icbeb, the column 0-6 is other meta data, the column 7-end is the label 43 | self.labels = csv_file.iloc[:, 7:].values 44 | 45 | elif self.dataset_name == 'chapman': 46 | self.labels_name = list(csv_file.columns[3:]) 47 | self.num_classes = len(self.labels_name) 48 | 49 | self.data_path = data_path 50 | self.ecg_path = csv_file['ecg_path'].astype(str) 51 | # in icbeb, the column 0-6 is other meta data, the column 7-end is the label 52 | self.labels = csv_file.iloc[:, 3:].values 53 | 54 | else: 55 | raise ValueError("dataset_type should be either 'ptbxl' or 'icbeb' or 'chapman") 56 | 57 | def __len__(self): 58 | return self.labels.shape[0] 59 | 60 | def __getitem__(self, idx): 61 | if self.dataset_name == 'ptbxl': 62 | ecg_path = os.path.join(self.data_path, self.ecg_path[idx]) 63 | # the wfdb format file include ecg and other meta data 64 | # the first element is the ecg data 65 | ecg = wfdb.rdsamp(ecg_path)[0] 66 | # the raw ecg shape is (5000, 12) 67 | # transform to (12, 5000) 68 | ecg = ecg.T 69 | 70 | ecg = ecg[:, :5000] 71 | # normalzie to 0-1 72 | ecg = (ecg - np.min(ecg))/(np.max(ecg) - np.min(ecg) + 1e-8) 73 | 74 | ecg = torch.from_numpy(ecg).float() 75 | target = self.labels[idx] 76 | target = torch.from_numpy(target).float() 77 | 78 | elif self.dataset_name == 'icbeb': 79 | ecg_path = os.path.join(self.data_path, self.ecg_path[idx]) 80 | # icbeb has dat file, which is the raw ecg data 81 | ecg = wfdb.rdsamp(ecg_path) 82 | # the raw ecg shape is (n, 12), n is different for each sample 83 | # transform to (12, n) 84 | ecg = ecg[0].T 85 | # icbeb has different length of ecg, so we need to preprocess it to the same length 86 | # we only keep the first 2500 points as METS did 87 | ecg = ecg[:, :2500] 88 | 89 | # padding to 5000 to match the pre-trained data length 90 | ecg = np.pad(ecg, ((0, 0), (0, 2500)), 'constant', constant_values=0) 91 | ecg = ecg[:, :5000] 92 | 93 | # normalzie to 0-1 94 | ecg = (ecg - np.min(ecg))/(np.max(ecg) - np.min(ecg) + 1e-8) 95 | 96 | ecg = torch.from_numpy(ecg).float() 97 | target = self.labels[idx] 98 | target = torch.from_numpy(target).float() 99 | 100 | elif self.dataset_name == 'chapman': 101 | # chapman ecg_path has / at the start, so we need to remove it 102 | ecg_path = os.path.join(self.data_path, self.ecg_path[idx][1:]) 103 | # raw data is (12, 5000), do not need to transform 104 | ecg = loadmat(ecg_path)['val'] 105 | ecg = ecg.astype(np.float32) 106 | 107 | ecg = ecg[:, :5000] 108 | 109 | # normalzie to 0-1 110 | ecg = (ecg - np.min(ecg))/(np.max(ecg) - np.min(ecg) + 1e-8) 111 | 112 | ecg = torch.from_numpy(ecg).float() 113 | target = self.labels[idx] 114 | target = torch.from_numpy(target).float() 115 | 116 | # switch AVL and AVF 117 | # In MIMIC-ECG, the lead order is I, II, III, aVR, aVF, aVL, V1, V2, V3, V4, V5, V6 118 | # In downstream datasets, the lead order is I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6 119 | ecg[[4, 5]] = ecg[[5, 4]] 120 | 121 | 122 | 123 | return ecg, target 124 | 125 | def getdataset(data_path, csv_path, mode='train', dataset_name='ptbxl', ratio=100, backbone='resnet18'): 126 | ratio = int(ratio) 127 | 128 | if dataset_name == 'ptbxl': 129 | csv = pd.read_csv(csv_path) 130 | if mode == 'train' and ratio != 100: 131 | csv, _ = train_test_split(csv, train_size=(ratio/100), random_state=42) 132 | elif dataset_name == 'icbeb': 133 | csv = pd.read_csv(csv_path) 134 | if mode == 'train' and ratio != 100: 135 | csv, _ = train_test_split(csv, train_size=(ratio/100), random_state=42) 136 | elif dataset_name == 'chapman': 137 | csv = pd.read_csv(csv_path) 138 | if mode == 'train' and ratio != 100: 139 | csv, _ = train_test_split(csv, train_size=(ratio/100), random_state=42) 140 | else: 141 | raise ValueError("dataset_name should be either 'ptbxl' or 'icbeb' or 'chapman!") 142 | 143 | csv.reset_index(drop=True, inplace=True) 144 | 145 | dataset = ECGDataset(data_path, csv, mode=mode, dataset_name=dataset_name,backbone=backbone) 146 | 147 | return dataset -------------------------------------------------------------------------------- /finetune/main_single.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import json 4 | import math 5 | import os 6 | import random 7 | import signal 8 | import subprocess 9 | import sys 10 | import pandas as pd 11 | import time 12 | import numpy as np 13 | from tqdm import tqdm 14 | from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, roc_auc_score, precision_recall_curve 15 | from torch.cuda.amp import autocast as autocast 16 | from torch.cuda.amp import GradScaler as GradScaler 17 | from torch import nn, optim 18 | import torch 19 | import torch.nn.functional as F 20 | from matplotlib import pyplot as plt 21 | from finetune_dataset import getdataset 22 | from models.resnet1d import ResNet18, ResNet34, ResNet50, ResNet101 23 | from models.vit1d import vit_base, vit_small, vit_tiny, vit_middle 24 | 25 | parser = argparse.ArgumentParser(description='MERL Finetuning') 26 | parser.add_argument('--dataset', default='ptbxl_super_class', 27 | type=str, help='dataset name') 28 | parser.add_argument('--ratio', default='100', 29 | type=int, help='training data ratio') 30 | parser.add_argument('--workers', default=8, type=int, metavar='N', 31 | help='number of data loader workers') 32 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--batch-size', default=256, type=int, metavar='N', 35 | help='mini-batch size') 36 | parser.add_argument('--test-batch-size', default=256, type=int, metavar='N', 37 | help='mini-batch size') 38 | parser.add_argument('--learning-rate', default=0.3, type=float, metavar='LR', 39 | help='base learning rate for weights') 40 | parser.add_argument('--weight-decay', default=1e-4, type=float, metavar='W', 41 | help='weight decay') 42 | parser.add_argument('--pretrain_path', default='your_pretrained_encoder.pth', type=str, 43 | help='path to pretrain weight directory') 44 | parser.add_argument('--checkpoint-dir', default='./checkpoint_finetune/', type=Path, 45 | metavar='DIR', help='path to checkpoint directory') 46 | parser.add_argument('--backbone', default='resnet18', type=str, metavar='B', 47 | help='backbone name') 48 | parser.add_argument('--num_leads', default=12, type=int, metavar='B', 49 | help='number of leads') 50 | parser.add_argument('--name', default='LinearProbing', type=str, metavar='B', 51 | help='exp name') 52 | 53 | def main(): 54 | args = parser.parse_args() 55 | args.ngpus_per_node = torch.cuda.device_count() 56 | batch_size = int(args.batch_size) 57 | test_batch_size = int(args.test_batch_size) 58 | args.checkpoint_dir.mkdir(parents=True, exist_ok=True) 59 | torch.cuda.empty_cache() 60 | device_id = torch.cuda.device_count() 61 | torch.manual_seed(42) 62 | random.seed(0) 63 | np.random.seed(0) 64 | torch.backends.cudnn.benchmark = True 65 | print(f'this task use {args.dataset} dataset') 66 | 67 | data_split_path = 'your_path/MERL/finetune/data_split' 68 | data_meta_path = 'your_path/downstream' 69 | 70 | if 'ptbxl' in args.dataset: 71 | # set the path where you store the ptbxl dataset 72 | data_path = f'{data_meta_path}/ptbxl' 73 | data_split_path = os.path.join(data_split_path, f'ptbxl/{args.dataset[6:]}') 74 | 75 | train_csv_path = f'{args.dataset}_train.csv' 76 | train_csv_path = os.path.join(data_split_path, train_csv_path) 77 | val_csv_path = f'{args.dataset}_val.csv' 78 | val_csv_path = os.path.join(data_split_path, val_csv_path) 79 | test_csv_path = f'{args.dataset}_test.csv' 80 | test_csv_path = os.path.join(data_split_path, test_csv_path) 81 | 82 | train_dataset = getdataset(data_path, train_csv_path, mode='train', dataset_name='ptbxl', ratio=args.ratio, 83 | backbone=args.backbone) 84 | val_dataset = getdataset(data_path, val_csv_path, mode='val', dataset_name='ptbxl', 85 | backbone=args.backbone) 86 | test_dataset = getdataset(data_path, test_csv_path, mode='test', dataset_name='ptbxl', 87 | backbone=args.backbone) 88 | 89 | args.labels_name = train_dataset.labels_name 90 | num_classes = train_dataset.num_classes 91 | 92 | elif args.dataset == 'CPSC2018': 93 | # set the path where you store the CPSC2018 dataset, the CPSC2018 dataset folder should be icbeb2018/records500/... 94 | data_path = f'{data_meta_path}/icbeb2018/records500' 95 | data_split_path = os.path.join(data_split_path, args.dataset) 96 | 97 | train_csv_path = f'{args.dataset}_train.csv' 98 | train_csv_path = os.path.join(data_split_path, train_csv_path) 99 | val_csv_path = f'{args.dataset}_val.csv' 100 | val_csv_path = os.path.join(data_split_path, val_csv_path) 101 | test_csv_path = f'{args.dataset}_test.csv' 102 | test_csv_path = os.path.join(data_split_path, test_csv_path) 103 | 104 | train_dataset = getdataset(data_path, train_csv_path, mode='train', dataset_name='icbeb', ratio=args.ratio, 105 | backbone=args.backbone) 106 | val_dataset = getdataset(data_path, val_csv_path, mode='val', dataset_name='icbeb', 107 | backbone=args.backbone) 108 | test_dataset = getdataset(data_path, test_csv_path, mode='test', dataset_name='icbeb', 109 | backbone=args.backbone) 110 | 111 | args.labels_name = train_dataset.labels_name 112 | num_classes = train_dataset.num_classes 113 | 114 | elif args.dataset == 'CSN': 115 | # set the path where you store the CSN dataset, the CSN dataset folder should be chapman/... 116 | data_path = f'{data_meta_path}/downstream/' 117 | data_split_path = os.path.join(data_split_path, args.dataset) 118 | 119 | train_csv_path = f'{args.dataset}_train.csv' 120 | train_csv_path = os.path.join(data_split_path, train_csv_path) 121 | val_csv_path = f'{args.dataset}_val.csv' 122 | val_csv_path = os.path.join(data_split_path, val_csv_path) 123 | test_csv_path = f'{args.dataset}_test.csv' 124 | test_csv_path = os.path.join(data_split_path, test_csv_path) 125 | 126 | train_dataset = getdataset(data_path, train_csv_path, mode='train', dataset_name='chapman', ratio=args.ratio, 127 | backbone=args.backbone) 128 | val_dataset = getdataset(data_path, val_csv_path, mode='val', dataset_name='chapman', 129 | backbone=args.backbone) 130 | test_dataset = getdataset(data_path, test_csv_path, mode='test', dataset_name='chapman', 131 | backbone=args.backbone) 132 | 133 | args.labels_name = train_dataset.labels_name 134 | num_classes = train_dataset.num_classes 135 | 136 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 137 | num_workers=args.workers, pin_memory=True) 138 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=test_batch_size, shuffle=False, 139 | num_workers=args.workers, pin_memory=True) 140 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, 141 | num_workers=args.workers, pin_memory=True) 142 | 143 | ckpt_path = args.pretrain_path 144 | ckpt = torch.load(ckpt_path, map_location='cpu') 145 | 146 | if 'resnet' in args.backbone: 147 | if args.backbone == 'resnet18': 148 | model = ResNet18(num_classes=num_classes) 149 | elif args.backbone == 'resnet50': 150 | model = ResNet50(num_classes=num_classes) 151 | elif args.backbone == 'resnet101': 152 | model = ResNet101(num_classes=num_classes) 153 | 154 | model.load_state_dict(ckpt, strict=False) 155 | print(f'load pretrained model from {args.pretrain_path}, the backbone is {args.backbone}, using {args.num_leads} leads') 156 | if 'linear' in args.name: 157 | for param in model.parameters(): 158 | param.requires_grad = False 159 | print(f'freeze backbone for {args.name} with {args.backbone}') 160 | 161 | for param in model.linear.parameters(): 162 | param.requires_grad = True 163 | 164 | if 'vit' in args.backbone: 165 | if args.backbone == 'vit_tiny': 166 | model = vit_tiny(num_classes=num_classes, num_leads=args.num_leads) 167 | elif args.backbone == 'vit_small': 168 | model = vit_small(num_classes=num_classes, num_leads=args.num_leads) 169 | elif args.backbone == 'vit_middle': 170 | model = vit_middle(num_classes=num_classes, num_leads=args.num_leads) 171 | elif args.backbone == 'vit_base': 172 | model = vit_base(num_classes=num_classes, num_leads=args.num_leads) 173 | 174 | model.load_state_dict(ckpt, strict=False) 175 | print(f'load pretrained model from {args.pretrain_path}, the backbone is {args.backbone}, using {args.num_leads} leads') 176 | if 'linear' in args.name: 177 | for param in model.parameters(): 178 | param.requires_grad = False 179 | print(f'freeze backbone for {args.name} with {args.backbone}') 180 | 181 | model.reset_head(num_classes=num_classes) 182 | model.head.weight.requires_grad = True 183 | model.head.bias.requires_grad = True 184 | 185 | 186 | model = model.to('cuda') 187 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 188 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 189 | milestones=[40], 190 | gamma=0.1, 191 | last_epoch=-1) 192 | criterion = nn.BCEWithLogitsLoss() 193 | 194 | # automatically resume from checkpoint if it exists 195 | if (args.checkpoint_dir / (args.backbone+'-checkpoint-'+'B-'+str(batch_size)+args.dataset+'.pth')).is_file(): 196 | ckpt = torch.load(args.checkpoint_dir / (args.backbone+'-checkpoint-'+'B-'+str(batch_size)+args.dataset+'R-'+str(args.ratio)+'.pth'), 197 | map_location='cpu') 198 | start_epoch = ckpt['epoch'] 199 | model.load_state_dict(ckpt['model']) 200 | optimizer.load_state_dict(ckpt['optimizer']) 201 | else: 202 | os.makedirs(args.checkpoint_dir, exist_ok=True) 203 | start_epoch = 0 204 | 205 | global_step = 0 206 | 207 | log = { 208 | 'epoch': [], 209 | 'val_acc': [], 210 | 'val_f1': [], 211 | 'val_precision': [], 212 | 'val_recall': [], 213 | 'val_auc': [], 214 | 'test_acc': [], 215 | 'test_f1': [], 216 | 'test_precision': [], 217 | 'test_recall': [], 218 | 'test_auc': [] 219 | } 220 | class_log = { 221 | 'val_log': [], 222 | 'test_log': [] 223 | } 224 | 225 | scaler = GradScaler() 226 | for epoch in tqdm(range(start_epoch, args.epochs)): 227 | model.train() 228 | for step, (ecg, target) in tqdm(enumerate(train_loader, start=epoch * len(train_loader))): 229 | optimizer.zero_grad() 230 | with autocast(): 231 | output = model(ecg.to('cuda')) 232 | loss = criterion(output, target.to('cuda')) 233 | 234 | scaler.scale(loss).backward() 235 | scaler.step(optimizer) 236 | scaler.update() 237 | scheduler.step() 238 | 239 | val_acc, val_f1, val_precision, val_recall, val_auc, val_metric_class = infer(model, val_loader, args) 240 | test_acc, test_f1, test_precision, test_recall, test_auc, test_metric_class = infer(model, test_loader, args) 241 | 242 | log['epoch'].append(epoch) 243 | log['val_acc'].append(val_acc) 244 | log['val_f1'].append(val_f1) 245 | log['val_precision'].append(val_precision) 246 | log['val_recall'].append(val_recall) 247 | log['val_auc'].append(val_auc) 248 | log['test_acc'].append(test_acc) 249 | log['test_f1'].append(test_f1) 250 | log['test_precision'].append(test_precision) 251 | log['test_recall'].append(test_recall) 252 | log['test_auc'].append(test_auc) 253 | 254 | class_log['val_log'].append(val_metric_class) 255 | class_log['test_log'].append(test_metric_class) 256 | 257 | scheduler.step() 258 | 259 | csv = pd.DataFrame(log) 260 | csv.columns = ['epoch', 'val_acc', 261 | 'val_f1', 'val_precision', 262 | 'val_recall', 'val_auc', 263 | 'test_acc', 264 | 'test_f1', 'test_precision', 265 | 'test_recall', 'test_auc'] 266 | 267 | val_class_csv = pd.concat(class_log['val_log'], axis=0) 268 | test_class_csv = pd.concat(class_log['test_log'], axis=0) 269 | val_class_csv.to_csv(f'{args.checkpoint_dir}/'+args.name+'-'+args.backbone+'-B-'+str(batch_size)+args.dataset+'R-'+str(args.ratio)+'-val-class.csv', index=False) 270 | test_class_csv.to_csv(f'{args.checkpoint_dir}/'+args.name+'-'+args.backbone+'-B-'+str(batch_size)+args.dataset+'R-'+str(args.ratio)+'-test-class.csv', index=False) 271 | 272 | csv.to_csv(f'{args.checkpoint_dir}/'+args.name+'-'+args.backbone+'-B-'+str(batch_size)+args.dataset+'R-'+str(args.ratio)+'.csv', index=False) 273 | 274 | print(f'max val acc: {max(log["val_acc"])}\n \ 275 | max val f1: {max(log["val_f1"])}\n \ 276 | max val precision: {max(log["val_precision"])}\n \ 277 | max val recall: {max(log["val_recall"])}\n \ 278 | max val auc: {max(log["val_auc"])}\n \ 279 | max test acc: {max(log["test_acc"])}\n \ 280 | max test f1: {max(log["test_f1"])}\n \ 281 | max test precision: {max(log["test_precision"])}\n \ 282 | max test recall: {max(log["test_recall"])}\n \ 283 | max test auc: {max(log["test_auc"])}\n') 284 | # plot each metric in one subplot 285 | plt.figure(figsize=(10, 10)) 286 | plt.subplot(1, 3, 1) 287 | plt.plot(log['epoch'], log['val_acc'], label='val_acc') 288 | plt.plot(log['epoch'], log['test_acc'], label='test_acc') 289 | plt.legend() 290 | plt.subplot(1, 2, 2) 291 | plt.plot(log['epoch'], log['val_f1'], label='val_f1') 292 | plt.plot(log['epoch'], log['test_f1'], label='test_f1') 293 | plt.legend() 294 | plt.subplot(2, 2, 3) 295 | # since we donot compute precision and recall in there. so this figure is not useful. 296 | # plt.plot(log['epoch'], log['val_precision'], label='val_precision') 297 | # plt.plot(log['epoch'], log['test_precision'], label='test_precision') 298 | # plt.plot(log['epoch'], log['val_ecall'], label='val_recall') 299 | # plt.plot(log['epoch'], log['test_recall'], label='test_recall') 300 | # plt.legend() 301 | plt.subplot(1, 3, 3) 302 | plt.plot(log['epoch'], log['val_auc'], label='val_auc') 303 | plt.plot(log['epoch'], log['test_auc'], label='test_auc') 304 | plt.legend() 305 | plt.savefig(f'{args.checkpoint_dir}/'+args.name+'-'+args.backbone+'-B-'+str(batch_size)+args.dataset+'R-'+str(args.ratio)+'.png') 306 | plt.close() 307 | 308 | @torch.no_grad() 309 | def infer(model, loader, args): 310 | # evaluate 311 | 312 | model.eval() 313 | 314 | y_pred = [] 315 | 316 | y_true = [] 317 | 318 | for step, (ecg, target) in tqdm(enumerate(loader)): 319 | 320 | input_label_list = target.to('cuda') 321 | 322 | predictions = model(ecg.to('cuda')) 323 | y_true.append(input_label_list.cpu().detach().numpy()) 324 | 325 | for index, val in enumerate(predictions): 326 | y_pred.append(val.cpu().detach().numpy().reshape(1, -1)) 327 | 328 | y_true = np.concatenate(y_true, axis=0) 329 | y_pred = np.concatenate(y_pred, axis=0) 330 | auc = roc_auc_score(y_true, y_pred, average='macro') 331 | 332 | max_f1s = [] 333 | accs = [] 334 | 335 | for i in range(y_pred.shape[1]): 336 | gt_np = y_true[:, i] 337 | pred_np = y_pred[:, i] 338 | precision, recall, thresholds = precision_recall_curve(gt_np, pred_np) 339 | numerator = 2 * recall * precision 340 | denom = recall + precision 341 | f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0)) 342 | max_f1 = np.max(f1_scores) 343 | max_f1_thresh = thresholds[np.argmax(f1_scores)] 344 | max_f1s.append(max_f1) 345 | accs.append(accuracy_score(gt_np, pred_np>max_f1_thresh)) 346 | 347 | 348 | max_f1s = [i*100 for i in max_f1s] 349 | accs = [i*100 for i in accs] 350 | f1 = np.array(max_f1s).mean() 351 | acc = np.array(accs).mean() 352 | 353 | # we donot compute precision and recall in there. 354 | precision, recall = 0, 0 355 | 356 | class_name = args.labels_name 357 | 358 | metric_dict = {element: [] for element in class_name} 359 | 360 | for i in range(len(list(metric_dict.keys()))): 361 | key = list(metric_dict.keys())[i] 362 | metric_dict[key].append(roc_auc_score(y_true[:, i], y_pred[:, i])) 363 | metric_class = pd.DataFrame(metric_dict) 364 | 365 | return acc, f1, precision, recall, auc, metric_class 366 | 367 | 368 | if __name__ == '__main__': 369 | main() 370 | -------------------------------------------------------------------------------- /finetune/models/resnet1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class BasicBlock(nn.Module): 5 | expansion = 1 6 | 7 | def __init__(self, in_channels, out_channels, stride=1): 8 | super(BasicBlock, self).__init__() 9 | 10 | # 1x1 Convolution 11 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm1d(out_channels) 13 | 14 | # 3x3 Convolution 15 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm1d(out_channels) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_channels != self.expansion * out_channels: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv1d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm1d(self.expansion * out_channels) 23 | ) 24 | 25 | def forward(self, x): 26 | out = torch.relu(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | out += self.shortcut(x) 29 | out = torch.relu(out) 30 | return out 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, in_channels, out_channels, stride=1): 36 | super(Bottleneck, self).__init__() 37 | 38 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm1d(out_channels) 40 | 41 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn2 = nn.BatchNorm1d(out_channels) 43 | 44 | self.conv3 = nn.Conv1d(out_channels, self.expansion * out_channels, kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm1d(self.expansion * out_channels) 46 | 47 | self.shortcut = nn.Sequential() 48 | if stride != 1 or in_channels != self.expansion * out_channels: 49 | self.shortcut = nn.Sequential( 50 | nn.Conv1d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False), 51 | nn.BatchNorm1d(self.expansion * out_channels) 52 | ) 53 | 54 | def forward(self, x): 55 | out = torch.relu(self.bn1(self.conv1(x))) 56 | out = torch.relu(self.bn2(self.conv2(out))) 57 | out = self.bn3(self.conv3(out)) 58 | out += self.shortcut(x) 59 | out = torch.relu(out) 60 | return out 61 | 62 | 63 | class ResNet(nn.Module): 64 | def __init__(self, block, num_blocks, num_classes=10): 65 | super(ResNet, self).__init__() 66 | self.in_channels = 64 67 | 68 | self.conv1 = nn.Conv1d(12, 64, kernel_size=7, stride=2, padding=3, bias=False) 69 | self.bn1 = nn.BatchNorm1d(64) 70 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 71 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 72 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 73 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 74 | self.linear = nn.Linear(512 * block.expansion, num_classes) 75 | 76 | self.avgpool = nn.AdaptiveAvgPool1d((1)) 77 | 78 | def _make_layer(self, block, out_channels, num_blocks, stride): 79 | strides = [stride] + [1] * (num_blocks - 1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_channels, out_channels, stride)) 83 | self.in_channels = out_channels * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = torch.relu(self.bn1(self.conv1(x))) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | out = self.avgpool(out) 93 | out = out.view(out.size(0), -1) 94 | out = self.linear(out) 95 | return out 96 | 97 | def ResNet18(num_classes): 98 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes) 99 | 100 | def ResNet34(num_classes): 101 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes) 102 | 103 | def ResNet50(num_classes): 104 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes) 105 | 106 | def ResNet101(num_classes): 107 | # Here we would usually use a different block type, Bottleneck, which contains three convolution layers 108 | # For simplicity, we're going to use BasicBlock here 109 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes) 110 | 111 | def ResNet152(num_classes): 112 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes) -------------------------------------------------------------------------------- /finetune/models/vit1d.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch Implementation of Vision Transformer 3 | ("An Image is Worth 16X16 Words: Transformers for Image Recognition at Scale") 4 | 5 | Reference 6 | - Paper: https://arxiv.org/abs/2010.11929 7 | - Code: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | from einops import rearrange 12 | from einops.layers.torch import Rearrange 13 | 14 | 15 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py#L137 16 | class DropPath(nn.Module): 17 | ''' 18 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 19 | ''' 20 | def __init__(self, drop_prob: float, scale_by_keep: bool = True): 21 | super(DropPath, self).__init__() 22 | self.drop_prob = drop_prob 23 | self.scale_by_keep = scale_by_keep 24 | 25 | def forward(self, x): 26 | if self.drop_prob <= 0. or not self.training: 27 | return x 28 | keep_prob = 1 - self.drop_prob 29 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 30 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 31 | if keep_prob > 0.0 and self.scale_by_keep: 32 | random_tensor.div_(keep_prob) 33 | return x * random_tensor 34 | 35 | 36 | class PreNorm(nn.Module): 37 | def __init__(self, 38 | dim: int, 39 | fn: nn.Module): 40 | super().__init__() 41 | self.norm = nn.LayerNorm(dim) 42 | self.fn = fn 43 | 44 | def forward(self, x, **kwargs): 45 | return self.fn(self.norm(x), **kwargs) 46 | 47 | 48 | class FeedForward(nn.Module): 49 | """ 50 | MLP Module with GELU activation fn + dropout. 51 | """ 52 | def __init__(self, 53 | input_dim: int, 54 | output_dim: int, 55 | hidden_dim: int, 56 | drop_out_rate=0.): 57 | super().__init__() 58 | self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim), 59 | nn.GELU(), 60 | nn.Dropout(drop_out_rate), 61 | nn.Linear(hidden_dim, output_dim), 62 | nn.Dropout(drop_out_rate)) 63 | 64 | def forward(self, x): 65 | return self.net(x) 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, 70 | input_dim: int, 71 | output_dim: int, 72 | heads: int = 8, 73 | dim_head: int = 64, 74 | qkv_bias: bool = True, 75 | drop_out_rate: float = 0., 76 | attn_drop_out_rate: float = 0.): 77 | super().__init__() 78 | inner_dim = dim_head * heads 79 | project_out = not (heads == 1 and dim_head == input_dim) 80 | 81 | self.heads = heads 82 | self.scale = dim_head ** -0.5 83 | 84 | self.attend = nn.Softmax(dim=-1) 85 | self.dropout = nn.Dropout(attn_drop_out_rate) 86 | self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias) 87 | 88 | if project_out: 89 | self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim), 90 | nn.Dropout(drop_out_rate)) 91 | else: 92 | self.to_out = nn.Identity() 93 | 94 | def forward(self, x): 95 | qkv = self.to_qkv(x).chunk(3, dim=-1) 96 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 97 | 98 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 99 | 100 | attn = self.attend(dots) 101 | attn = self.dropout(attn) 102 | out = torch.matmul(attn, v) 103 | out = rearrange(out, 'b h n d -> b n (h d)') 104 | out = self.to_out(out) 105 | return out 106 | 107 | 108 | class TransformerBlock(nn.Module): 109 | def __init__(self, 110 | input_dim: int, 111 | output_dim: int, 112 | hidden_dim: int, 113 | heads: int = 8, 114 | dim_head: int = 32, 115 | qkv_bias: bool = True, 116 | drop_out_rate: float = 0., 117 | attn_drop_out_rate: float = 0., 118 | drop_path_rate: float = 0.): 119 | super().__init__() 120 | attn = Attention(input_dim=input_dim, 121 | output_dim=output_dim, 122 | heads=heads, 123 | dim_head=dim_head, 124 | qkv_bias=qkv_bias, 125 | drop_out_rate=drop_out_rate, 126 | attn_drop_out_rate=attn_drop_out_rate) 127 | self.attn = PreNorm(dim=input_dim, 128 | fn=attn) 129 | self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() 130 | 131 | ff = FeedForward(input_dim=output_dim, 132 | output_dim=output_dim, 133 | hidden_dim=hidden_dim, 134 | drop_out_rate=drop_out_rate) 135 | self.ff = PreNorm(dim=output_dim, 136 | fn=ff) 137 | self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() 138 | 139 | def forward(self, x): 140 | x = self.droppath1(self.attn(x)) + x 141 | x = self.droppath2(self.ff(x)) + x 142 | return x 143 | 144 | 145 | class ViT(nn.Module): 146 | def __init__(self, 147 | num_leads: int, 148 | seq_len: int, 149 | patch_size: int, 150 | width: int = 768, 151 | depth: int = 12, 152 | mlp_dim: int = 3072, 153 | heads: int = 12, 154 | dim_head: int = 64, 155 | qkv_bias: bool = True, 156 | drop_out_rate: float = 0., 157 | attn_drop_out_rate: float = 0., 158 | drop_path_rate: float = 0., 159 | **kwargs): 160 | super().__init__() 161 | assert seq_len % patch_size == 0, 'The sequence length must be divisible by the patch size.' 162 | num_patches = seq_len // patch_size 163 | 164 | # conv patch start 165 | self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size, stride=patch_size, bias=False) 166 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, width)) 167 | 168 | self.dropout = nn.Dropout(drop_out_rate) 169 | 170 | 171 | self.depth = depth 172 | self.width = width 173 | drop_path_rate_list = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 174 | for i in range(depth): 175 | block = TransformerBlock(input_dim=width, 176 | output_dim=width, 177 | hidden_dim=mlp_dim, 178 | heads=heads, 179 | dim_head=dim_head, 180 | qkv_bias=qkv_bias, 181 | drop_out_rate=drop_out_rate, 182 | attn_drop_out_rate=attn_drop_out_rate, 183 | drop_path_rate=drop_path_rate_list[i]) 184 | self.add_module(f'block{i}', block) 185 | 186 | self.norm = nn.LayerNorm(width) 187 | self.head = nn.Identity() 188 | 189 | def forward_encoding(self, series): 190 | 191 | # for conv patch 192 | x = self.to_patch_embedding(series) 193 | x = rearrange(x, 'b c n -> b n c') 194 | x = x + self.pos_embedding 195 | 196 | # transformer blocks 197 | x = self.dropout(x) 198 | for i in range(self.depth): 199 | x = getattr(self, f'block{i}')(x) 200 | 201 | x = torch.mean(x, dim=1) # global average pooling 202 | 203 | return self.norm(x) 204 | 205 | def forward(self, series): 206 | x = self.forward_encoding(series) 207 | x = self.head(x) 208 | return x 209 | 210 | def reset_head(self, num_classes=1): 211 | del self.head 212 | self.head = nn.Linear(self.width, num_classes) 213 | 214 | 215 | def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 216 | model_args = dict(num_leads=num_leads, 217 | num_classes=num_classes, 218 | seq_len=seq_len, 219 | patch_size=patch_size, 220 | width=192, 221 | depth=12, 222 | heads=3, 223 | mlp_dim=768, 224 | **kwargs) 225 | return ViT(**model_args) 226 | 227 | 228 | def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 229 | model_args = dict(num_leads=num_leads, 230 | num_classes=num_classes, 231 | seq_len=seq_len, 232 | patch_size=patch_size, 233 | width=384, 234 | depth=12, 235 | heads=6, 236 | mlp_dim=1536, 237 | **kwargs) 238 | return ViT(**model_args) 239 | 240 | 241 | def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 242 | model_args = dict(num_leads=num_leads, 243 | num_classes=num_classes, 244 | seq_len=seq_len, 245 | patch_size=patch_size, 246 | width=512, 247 | depth=12, 248 | heads=8, 249 | mlp_dim=2048, 250 | **kwargs) 251 | return ViT(**model_args) 252 | 253 | 254 | def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 255 | model_args = dict(num_leads=num_leads, 256 | num_classes=num_classes, 257 | seq_len=seq_len, 258 | patch_size=patch_size, 259 | width=768, 260 | depth=12, 261 | heads=12, 262 | mlp_dim=3072, 263 | **kwargs) 264 | return ViT(**model_args) 265 | -------------------------------------------------------------------------------- /finetune/preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# This notebook is for preprocessing PTBXL, CPSC2018, and CSN datasets for finetuning tasks." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "import wfdb\n", 19 | "import os\n", 20 | "import ast\n", 21 | "from matplotlib import pyplot as plt\n", 22 | "import seaborn as sns\n", 23 | "from pprint import pprint\n", 24 | "from tqdm import tqdm\n", 25 | "from scipy.ndimage import zoom\n", 26 | "from scipy.io import loadmat\n", 27 | "from sklearn.model_selection import train_test_split" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# set the split file path to store your processed csv file\n", 37 | "split_path = ''\n", 38 | "# set the meta path for the raw ecg you download\n", 39 | "meta_path = ''" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "# Preprocessing PTB-XL dataset" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "'''\n", 56 | "Since PTB-XL provide the offical split, we will use the offical split for the finetune dataset.\n", 57 | "The offical preprocess code is shown in the orignal paper: https://www.nature.com/articles/s41597-020-0495-6\n", 58 | "We also list the preprocessed csv file in MERL/finetune/data_split/ptbxl\n", 59 | "'''" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "# Preprocessing CPSC2018 Dataset" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "'''\n", 76 | "This dataset provide raw file in .mat format.\n", 77 | "We first convert the .mat file to .hea and .dat file using the wfdb package.\n", 78 | "Then we downsample the data to 100Hz and 500Hz.\n", 79 | "All information of this dataset can be found in: http://2018.icbeb.org/Challenge.html\n", 80 | "'''\n", 81 | "\n", 82 | "# here is your original data folder, you should download the data from the website\n", 83 | "ori_data_folder = os.path.join(meta_path, 'icbeb2018')\n", 84 | "\n", 85 | "# here is the output folder to store the preprocessed data\n", 86 | "output_folder = os.path.join(meta_path, 'icbeb2018')\n", 87 | "output_datafolder_100 = output_folder+ '/records100/'\n", 88 | "output_datafolder_500 = output_folder+ '/records500/'\n", 89 | "if not os.path.exists(output_folder):\n", 90 | " os.makedirs(output_folder)\n", 91 | "else:\n", 92 | " print('The folder already exists')\n", 93 | "if not os.path.exists(output_datafolder_100):\n", 94 | " os.makedirs(output_datafolder_100)\n", 95 | "else:\n", 96 | " print('The folder already exists')\n", 97 | "if not os.path.exists(output_datafolder_500):\n", 98 | " os.makedirs(output_datafolder_500)\n", 99 | "else:\n", 100 | " print('The folder already exists')\n", 101 | "\n", 102 | "# function to store 12 leads ECG data as wfdb format\n", 103 | "def store_as_wfdb(signame, data, sigfolder, fs):\n", 104 | " channel_itos=['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']\n", 105 | " wfdb.wrsamp(signame,\n", 106 | " fs=fs,\n", 107 | " sig_name=channel_itos, \n", 108 | " p_signal=data,\n", 109 | " units=['mV']*len(channel_itos),\n", 110 | " fmt = ['16']*len(channel_itos), \n", 111 | " write_dir=sigfolder) \n", 112 | "\n", 113 | "# load the reference csv file\n", 114 | "reference_path = os.path.join(output_folder, 'REFERENCE.csv')\n", 115 | "df_reference = pd.read_csv(reference_path)\n", 116 | "\n", 117 | "# define the label dictionary\n", 118 | "# label_dict = {1:'NORM', 2:'AFIB', 3:'1AVB', 4:'CLBBB', 5:'CRBBB', 6:'PAC', 7:'VPC', 8:'STD_', 9:'STE_'}\n", 119 | "label_dict = {1:'NORM', 2:'AFIB', 3:'1AVB', 4:'CLBBB', 5:'CRBBB', 6:'PAC', 7:'VPC', 8:'STD', 9:'STE'}\n", 120 | "\n", 121 | "data = {'ecg_id':[], 'filename':[], 'validation':[], 'age':[], 'sex':[], 'scp_codes':[]}\n", 122 | "\n", 123 | "# read all .mat files from the folder then convert to .hea and .dat files\n", 124 | "ecg_counter = 0\n", 125 | "for folder in ['all_data']:\n", 126 | " filenames = os.listdir(os.path.join(ori_data_folder, folder))\n", 127 | " for filename in tqdm(filenames):\n", 128 | " if filename.split('.')[1] == 'mat':\n", 129 | " ecg_counter += 1\n", 130 | " name = filename.split('.')[0]\n", 131 | "\n", 132 | " sex, age, sig = loadmat(ori_data_folder + '/' + folder + '/' + filename)['ECG'][0][0]\n", 133 | " data['ecg_id'].append(ecg_counter)\n", 134 | " data['filename'].append(name)\n", 135 | " data['validation'].append(False)\n", 136 | " data['age'].append(age[0][0])\n", 137 | " data['sex'].append(1 if sex[0] == 'Male' else 0)\n", 138 | " labels = df_reference[df_reference.Recording == name][['First_label' ,'Second_label' ,'Third_label']].values.flatten()\n", 139 | " labels = labels[~np.isnan(labels)].astype(int)\n", 140 | " data['scp_codes'].append({label_dict[key]:1 for key in labels})\n", 141 | "\n", 142 | " # # resample to 500 hz data\n", 143 | " # store_as_wfdb(str(ecg_counter), sig.T, output_datafolder_500, 500)\n", 144 | " # # resample to 100 hz data\n", 145 | " # down_sig = np.array([zoom(channel, .2) for channel in sig])\n", 146 | " # store_as_wfdb(str(ecg_counter), down_sig.T, output_datafolder_100, 100)\n", 147 | "\n", 148 | "df = pd.DataFrame(data)\n", 149 | "df['patient_id'] = df.ecg_id\n", 150 | "# df = stratisfy_df(df, 'strat_fold')\n", 151 | "# df.to_csv(output_folder+'icbeb_database.csv')" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# make the patient_id column the first column\n", 161 | "cols = list(df.columns)\n", 162 | "cols = [cols[-1]] + cols[:-1]\n", 163 | "switched_df = df[cols]" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# Extract all unique labels from the 'scp_codes' column\n", 173 | "# all_labels = set()\n", 174 | "# for item in switched_df['scp_codes']:\n", 175 | "# all_labels.update(item.keys())\n", 176 | "\n", 177 | "all_labels = ['AFIB', 'VPC', 'NORM', '1AVB', 'CRBBB', 'STE', 'PAC', 'CLBBB', 'STD']\n", 178 | "\n", 179 | "\n", 180 | "# # Create new columns for each label\n", 181 | "for label in all_labels:\n", 182 | " switched_df[label] = switched_df['scp_codes'].apply(lambda x: x.get(label, 0))\n", 183 | "\n", 184 | "cols = list(switched_df.columns)\n", 185 | "print(cols)\n", 186 | "# cols[-1] = 'STD'\n", 187 | "# cols[-4] = 'STE'\n", 188 | "# # replace columns name\n", 189 | "# switched_df.columns = cols\n" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "# split train test val\n", 199 | "train_df, test_df = train_test_split(switched_df, test_size=0.2, random_state=42)\n", 200 | "train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)\n", 201 | "\n", 202 | "print(f'train_df shape: {train_df.shape}')\n", 203 | "print(f'val_df shape: {val_df.shape}')\n", 204 | "print(f'test_df shape: {test_df.shape}')\n", 205 | "\n", 206 | "# save the csv files\n", 207 | "# train_df.to_csv(split_path+'icbeb_train.csv', index=False)\n", 208 | "# val_df.to_csv(split_path+'icbeb_val.csv', index=False)\n", 209 | "# test_df.to_csv(split_path+'icbeb_test.csv', index=False)\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "# Preprocessing CSN Dataset" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "'''\n", 226 | "For all details of the dataset, please refer to: https://physionet.org/content/ecg-arrhythmia/1.0.0/\n", 227 | "'''\n", 228 | "\n", 229 | "your_path = meta_path\n", 230 | "\n", 231 | "data_path = f'{your_path}chapman/WFDBRecords'\n", 232 | "folders = os.listdir(data_path)\n", 233 | "num_folders = len(folders)\n", 234 | "folders = sorted(folders)\n", 235 | "folders = [os.path.join(data_path, f) for f in folders]\n", 236 | "folders = [f for f in folders if os.path.isdir(f)]\n", 237 | "\n", 238 | "dict_with_empty_lists = {f\"{i:02d}\": [] for i in range(1, 47)}\n", 239 | "for i, folder in enumerate(folders):\n", 240 | " subfolders = os.listdir(folder)\n", 241 | " subfolders = sorted(subfolders)\n", 242 | " subfolders = [os.path.join(folder, f) for f in subfolders]\n", 243 | " subfolders = [f for f in subfolders if os.path.isdir(f)]\n", 244 | " dict_with_empty_lists[f\"{i+1:02d}\"] = subfolders\n", 245 | "\n", 246 | "\n", 247 | "# place this '/raid/cl522/ecg-text/downstream' with your own path\n", 248 | "for key in dict_with_empty_lists.keys():\n", 249 | " dict_with_empty_lists[key] = [x.replace(f'{your_path}', '') for x in dict_with_empty_lists[key]]\n", 250 | "\n", 251 | "def read_header_file(file_path):\n", 252 | " with open(file_path, 'r') as file:\n", 253 | " lines = file.readlines()\n", 254 | " header_info = [line.strip() for line in lines]\n", 255 | " return header_info\n", 256 | "\n", 257 | "df = {'ecg_path': [], \n", 258 | " 'age': [], \n", 259 | " 'diagnose': []}\n", 260 | "\n", 261 | "ref = pd.read_csv(f'{your_path}chapman/ConditionNames_SNOMED-CT.csv')\n", 262 | "ref['Snomed_CT'] = ref['Snomed_CT'].astype(str)\n", 263 | "\n", 264 | "# count the number of mat file in each folder\n", 265 | "total_files = 0\n", 266 | "for key in tqdm(dict_with_empty_lists.keys()):\n", 267 | " for folder in dict_with_empty_lists[key]:\n", 268 | " files = os.listdir(f'{your_path}'+folder)\n", 269 | " mat_files = [f for f in files if f.endswith('.mat')]\n", 270 | " hea_files = [f for f in files if f.endswith('.hea')]\n", 271 | " \n", 272 | " mat_files_path = [os.path.join(f'{your_path}', folder, f) for f in mat_files]\n", 273 | " hea_files_path = [os.path.join(f'{your_path}', folder, f) for f in hea_files]\n", 274 | " mat_files_path = sorted(mat_files_path)\n", 275 | " hea_files_path = sorted(hea_files_path)\n", 276 | "\n", 277 | " for file, hea_file in zip(mat_files_path, hea_files_path):\n", 278 | " mat = loadmat(file)\n", 279 | " ecg = mat['val']\n", 280 | " hea = read_header_file(hea_file)\n", 281 | " \n", 282 | " df['ecg_path'].append(file)\n", 283 | " df['age'].append(hea[0].split()[1])\n", 284 | " \n", 285 | " try:\n", 286 | " diagnose_str = []\n", 287 | " Dx_idx = [i for i, s in enumerate(hea) if 'Dx' in s]\n", 288 | " diagnose_code = hea[Dx_idx[0]].split()[1]\n", 289 | " diagnose_code = diagnose_code.split(',')\n", 290 | " for i in range(len(diagnose_code)):\n", 291 | " diagnose = ref[ref['Snomed_CT'] == diagnose_code[i]]['Acronym Name']\n", 292 | " diagnose = diagnose.values[0]\n", 293 | " diagnose_str.append(diagnose)\n", 294 | " diagnose_str = ','.join(diagnose_str)\n", 295 | " df['diagnose'].append(diagnose_str)\n", 296 | " except:\n", 297 | " df['diagnose'].append('Unknown')\n", 298 | "\n" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "new_df = pd.DataFrame(df)\n", 308 | "new_df = new_df[new_df['diagnose'] != 'Unknown']\n", 309 | "new_df.reset_index(inplace=True, drop=True)\n", 310 | "\n", 311 | "unique_labels = []\n", 312 | "for labels in new_df['diagnose']:\n", 313 | " labels = labels.split(',')\n", 314 | " unique_labels.extend(labels)\n", 315 | "\n", 316 | "unique_labels = list(set(unique_labels))\n", 317 | "# Create new columns for each unique label\n", 318 | "for label in unique_labels:\n", 319 | " new_df[label] = new_df['diagnose'].apply(lambda x: 1 if label in x else 0)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "# count the number of sample for each label\n", 329 | "label_count = {}\n", 330 | "for label in unique_labels:\n", 331 | " label_count[label] = new_df[label].sum()\n", 332 | "# sort the label_count dictionary\n", 333 | "label_count = dict(sorted(label_count.items(), key=lambda item: item[1], reverse=True))\n", 334 | "# drop the label with less than 10 samples\n", 335 | "for key in list(label_count.keys()):\n", 336 | " if label_count[key] < 10:\n", 337 | " del label_count[key]\n", 338 | "# drop the columns not in label_count\n", 339 | "for key in list(new_df.columns):\n", 340 | " if key not in label_count.keys():\n", 341 | " new_df.drop(key, axis=1, inplace=True)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "# split train test val\n", 351 | "train_df, test_df = train_test_split(new_df, test_size=0.2, random_state=42)\n", 352 | "train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)\n", 353 | "train_df.reset_index(inplace=True, drop=True)\n", 354 | "val_df.reset_index(inplace=True, drop=True)\n", 355 | "test_df.reset_index(inplace=True, drop=True)\n", 356 | "\n", 357 | "print(f'train_df shape: {train_df.shape}')\n", 358 | "print(f'val_df shape: {val_df.shape}')\n", 359 | "print(f'test_df shape: {test_df.shape}')\n", 360 | "\n", 361 | "# save the csv files\n", 362 | "# train_df.to_csv(f'{split_path}chapman/'+'chapman_train.csv', index=False)\n", 363 | "# val_df.to_csv(f'{split_path}chapman/'+'chapman_val.csv', index=False)\n", 364 | "# test_df.to_csv(f'{split_path}chapman/'+'chapman_test.csv', index=False)" 365 | ] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "medvlp", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.9.19" 385 | }, 386 | "orig_nbformat": 4 387 | }, 388 | "nbformat": 4, 389 | "nbformat_minor": 2 390 | } 391 | -------------------------------------------------------------------------------- /finetune/sub_script/chapman/sub_chapman.sh: -------------------------------------------------------------------------------- 1 | task_name=$1 2 | backbone=$2 3 | pretrain_path=$3 4 | ckpt_dir="/home/cl522/github_repo/ETP/finetune/ckpt/chapman/$task_name" 5 | 6 | python main_single.py \ 7 | --checkpoint-dir $ckpt_dir \ 8 | --batch-size 16 \ 9 | --dataset chapman \ 10 | --pretrain_path $pretrain_path \ 11 | --ratio 1 \ 12 | --learning-rate 0.001 \ 13 | --backbone $backbone \ 14 | --epochs 100 \ 15 | --name $task_name 16 | 17 | python main_single.py \ 18 | --checkpoint-dir $ckpt_dir \ 19 | --batch-size 16 \ 20 | --dataset chapman \ 21 | --pretrain_path $pretrain_path \ 22 | --ratio 10 \ 23 | --learning-rate 0.001 \ 24 | --backbone $backbone \ 25 | --epochs 100 \ 26 | --name $task_name 27 | 28 | python main_single.py \ 29 | --checkpoint-dir $ckpt_dir \ 30 | --batch-size 16 \ 31 | --dataset chapman \ 32 | --pretrain_path $pretrain_path \ 33 | --ratio 100 \ 34 | --learning-rate 0.001 \ 35 | --backbone $backbone \ 36 | --epochs 100 \ 37 | --name $task_name -------------------------------------------------------------------------------- /finetune/sub_script/icbeb/sub_icbeb.sh: -------------------------------------------------------------------------------- 1 | task_name=$1 2 | backbone=$2 3 | pretrain_path=$3 4 | ckpt_dir="/home/cl522/github_repo/ETP/finetune/ckpt/icbeb/$task_name" 5 | 6 | python main_single.py \ 7 | --checkpoint-dir $ckpt_dir \ 8 | --batch-size 16 \ 9 | --dataset icbeb \ 10 | --pretrain_path $pretrain_path \ 11 | --ratio 1 \ 12 | --learning-rate 0.001 \ 13 | --backbone $backbone \ 14 | --epochs 100 \ 15 | --name $task_name 16 | 17 | python main_single.py \ 18 | --checkpoint-dir $ckpt_dir \ 19 | --batch-size 16 \ 20 | --dataset icbeb \ 21 | --pretrain_path $pretrain_path \ 22 | --ratio 10 \ 23 | --learning-rate 0.001 \ 24 | --backbone $backbone \ 25 | --epochs 100 \ 26 | --name $task_name 27 | 28 | python main_single.py \ 29 | --checkpoint-dir $ckpt_dir \ 30 | --batch-size 16 \ 31 | --dataset icbeb \ 32 | --pretrain_path $pretrain_path \ 33 | --ratio 100 \ 34 | --learning-rate 0.001 \ 35 | --backbone $backbone \ 36 | --epochs 100 \ 37 | --name $task_name -------------------------------------------------------------------------------- /finetune/sub_script/ptbxl/sub_ptbxl.sh: -------------------------------------------------------------------------------- 1 | task_name=$1 2 | backbone=$2 3 | pretrain_path=$3 4 | 5 | bash sub_ptbxl_form.sh $task_name $backbone $pretrain_path 6 | bash sub_ptbxl_rhythm.sh $task_name $backbone $pretrain_path 7 | bash sub_ptbxl_super_class.sh $task_name $backbone $pretrain_path 8 | bash sub_ptbxl_sub_class.sh $task_name $backbone $pretrain_path -------------------------------------------------------------------------------- /finetune/sub_script/ptbxl/sub_ptbxl_form.sh: -------------------------------------------------------------------------------- 1 | task_name=$1 2 | backbone=$2 3 | pretrain_path=$3 4 | ckpt_dir="/home/cl522/github_repo/ETP/finetune/ckpt/ptbxl_form/$task_name" 5 | 6 | python main_single.py \ 7 | --checkpoint-dir $ckpt_dir \ 8 | --batch-size 16 \ 9 | --ratio 1 \ 10 | --dataset ptbxl_form \ 11 | --pretrain_path $pretrain_path \ 12 | --learning-rate 0.001 \ 13 | --backbone $backbone \ 14 | --epochs 100 \ 15 | --name $task_name 16 | 17 | python main_single.py \ 18 | --checkpoint-dir $ckpt_dir \ 19 | --batch-size 16 \ 20 | --ratio 10 \ 21 | --dataset ptbxl_form \ 22 | --pretrain_path $pretrain_path \ 23 | --learning-rate 0.001 \ 24 | --backbone $backbone \ 25 | --epochs 100 \ 26 | --name $task_name 27 | 28 | python main_single.py \ 29 | --checkpoint-dir $ckpt_dir \ 30 | --batch-size 16 \ 31 | --ratio 100 \ 32 | --dataset ptbxl_form \ 33 | --pretrain_path $pretrain_path \ 34 | --learning-rate 0.001 \ 35 | --backbone $backbone \ 36 | --epochs 100 \ 37 | --name $task_name -------------------------------------------------------------------------------- /finetune/sub_script/ptbxl/sub_ptbxl_rhythm.sh: -------------------------------------------------------------------------------- 1 | task_name=$1 2 | backbone=$2 3 | pretrain_path=$3 4 | ckpt_dir="/home/cl522/github_repo/ETP/finetune/ckpt/ptbxl_rhythm/$task_name" 5 | 6 | python main_single.py \ 7 | --checkpoint-dir $ckpt_dir \ 8 | --batch-size 16 \ 9 | --ratio 1 \ 10 | --dataset ptbxl_rhythm \ 11 | --pretrain_path $pretrain_path \ 12 | --learning-rate 0.001 \ 13 | --backbone $backbone \ 14 | --epochs 100 \ 15 | --name $task_name 16 | 17 | python main_single.py \ 18 | --checkpoint-dir $ckpt_dir \ 19 | --batch-size 16 \ 20 | --ratio 10 \ 21 | --dataset ptbxl_rhythm \ 22 | --pretrain_path $pretrain_path \ 23 | --learning-rate 0.001 \ 24 | --backbone $backbone \ 25 | --epochs 100 \ 26 | --name $task_name 27 | 28 | python main_single.py \ 29 | --checkpoint-dir $ckpt_dir \ 30 | --batch-size 16 \ 31 | --ratio 100 \ 32 | --dataset ptbxl_rhythm \ 33 | --pretrain_path $pretrain_path \ 34 | --learning-rate 0.001 \ 35 | --backbone $backbone \ 36 | --epochs 100 \ 37 | --name $task_name -------------------------------------------------------------------------------- /finetune/sub_script/ptbxl/sub_ptbxl_sub_class.sh: -------------------------------------------------------------------------------- 1 | task_name=$1 2 | backbone=$2 3 | pretrain_path=$3 4 | ckpt_dir="/home/cl522/github_repo/ETP/finetune/ckpt/ptbxl_sub_class/$task_name" 5 | 6 | python main_single.py \ 7 | --checkpoint-dir $ckpt_dir \ 8 | --batch-size 16 \ 9 | --ratio 1 \ 10 | --dataset ptbxl_sub_class \ 11 | --pretrain_path $pretrain_path \ 12 | --learning-rate 0.001 \ 13 | --backbone $backbone \ 14 | --epochs 100 \ 15 | --name $task_name 16 | 17 | python main_single.py \ 18 | --checkpoint-dir $ckpt_dir \ 19 | --batch-size 16 \ 20 | --ratio 10 \ 21 | --dataset ptbxl_sub_class \ 22 | --pretrain_path $pretrain_path \ 23 | --learning-rate 0.001 \ 24 | --backbone $backbone \ 25 | --epochs 100 \ 26 | --name $task_name 27 | 28 | python main_single.py \ 29 | --checkpoint-dir $ckpt_dir \ 30 | --batch-size 16 \ 31 | --ratio 100 \ 32 | --dataset ptbxl_sub_class \ 33 | --pretrain_path $pretrain_path \ 34 | --learning-rate 0.001 \ 35 | --backbone $backbone \ 36 | --epochs 100 \ 37 | --name $task_name -------------------------------------------------------------------------------- /finetune/sub_script/ptbxl/sub_ptbxl_super_class.sh: -------------------------------------------------------------------------------- 1 | task_name=$1 2 | backbone=$2 3 | pretrain_path=$3 4 | ckpt_dir="/home/cl522/github_repo/ETP/finetune/ckpt/ptbxl_super_class/$task_name" 5 | 6 | python main_single.py \ 7 | --checkpoint-dir $ckpt_dir \ 8 | --batch-size 16 \ 9 | --ratio 1 \ 10 | --dataset ptbxl_super_class \ 11 | --pretrain_path $pretrain_path \ 12 | --learning-rate 0.001 \ 13 | --backbone $backbone \ 14 | --epochs 100 \ 15 | --name $task_name 16 | 17 | python main_single.py \ 18 | --checkpoint-dir $ckpt_dir \ 19 | --batch-size 16 \ 20 | --ratio 10 \ 21 | --dataset ptbxl_super_class \ 22 | --pretrain_path $pretrain_path \ 23 | --learning-rate 0.001 \ 24 | --backbone $backbone \ 25 | --epochs 100 \ 26 | --name $task_name 27 | 28 | python main_single.py \ 29 | --checkpoint-dir $ckpt_dir \ 30 | --batch-size 16 \ 31 | --ratio 100 \ 32 | --dataset ptbxl_super_class \ 33 | --pretrain_path $pretrain_path \ 34 | --learning-rate 0.001 \ 35 | --backbone $backbone \ 36 | --epochs 100 \ 37 | --name $task_name -------------------------------------------------------------------------------- /finetune/sub_script/run_all_linear.sh: -------------------------------------------------------------------------------- 1 | taskname='your_taskname' 2 | backbone='resnet18' 3 | pretrain_path='your_pretrained_encoder.pth' 4 | 5 | cd icbeb 6 | bash sub_icbeb.sh $taskname $backbone $pretrain_path 7 | 8 | cd .. 9 | cd chapman 10 | bash sub_chapman.sh $taskname $backbone $pretrain_path 11 | 12 | cd .. 13 | cd ptbxl 14 | bash sub_ptbxl.sh $taskname $backbone $pretrain_path 15 | -------------------------------------------------------------------------------- /pretrain/config.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | # ecg_model: resnet18 3 | ecg_model: vit_tiny 4 | num_leads: 12 5 | ### this part does not control builder/trainer 6 | text_model: ncbi/MedCPT-Query-Encoder 7 | free_layers: 6 # set 12 to freeze all layer in bert 8 | feature_dim: 768 9 | 10 | projection_head: 11 | mlp_hidden_size: 256 12 | projection_size: 256 13 | ### 14 | 15 | dataset: 16 | dataset_name: 'mimic' 17 | data_path: 'your_path/' # add your image file path here 18 | 19 | # params for trainer 20 | trainer: 21 | batch_size: 128 22 | val_batch_size: 512 23 | checkpoint_interval: 50 24 | max_epochs: 20 25 | num_workers: 8 26 | 27 | optimizer: 28 | params: 29 | lr: 1.0e-3 30 | weight_decay: 1.0e-8 31 | 32 | # params for zeroshot eval 33 | zeroshot: 34 | prompt_type: 'CKEPE' 35 | prompt_dict: 'your_path/MERL/zeroshot/CKEPE_prompt.json' 36 | batch_size: 256 37 | num_workers: 8 38 | meta_data_path: 'your_path/downstream' 39 | meta_split_path: 'your_path/MERL/finetune/data_split' 40 | 41 | val_sets: 42 | ### 43 | ptbxl_super_class: 44 | data_path: 'ptbxl' 45 | split_path: 'ptbxl/super_class/ptbxl_super_class_val.csv' 46 | ### 47 | ptbxl_sub_class: 48 | data_path: 'ptbxl' 49 | split_path: 'ptbxl/sub_class/ptbxl_sub_class_val.csv' 50 | ### 51 | ptbxl_form: 52 | data_path: 'ptbxl' 53 | split_path: 'ptbxl/form/ptbxl_form_val.csv' 54 | ### 55 | ptbxl_rhythm: 56 | data_path: 'ptbxl' 57 | split_path: 'ptbxl/rhythm/ptbxl_rhythm_val.csv' 58 | ### 59 | icbeb: 60 | data_path: 'icbeb2018/records500' 61 | split_path: 'icbeb/icbeb_val.csv' 62 | ### 63 | chapman: 64 | data_path: '' 65 | split_path: 'chapman/chapman_val.csv' 66 | 67 | # your model name 68 | wandb_name: 'vit_tiny_demo' -------------------------------------------------------------------------------- /pretrain/launch.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | 3 | wandb online 4 | cd /your_path/MERL/pretrain 5 | torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=101 --rdzv_endpoint=localhost:29502 main.py 6 | -------------------------------------------------------------------------------- /pretrain/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.nn.parallel import DistributedDataParallel as DDP 3 | import torch.multiprocessing as mp 4 | import torch.distributed as dist 5 | import tempfile 6 | import os 7 | from torch import optim 8 | import torch.nn as nn 9 | import pandas as pd 10 | import numpy as np 11 | import torch 12 | import yaml 13 | import sys 14 | sys.path.append("../utils") 15 | from utils_trainer import trainer_wBert 16 | import utils_dataset 17 | import utils_builder 18 | 19 | import wandb 20 | 21 | 22 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 23 | 24 | 25 | def ddp_main(): 26 | dist.init_process_group("nccl") 27 | torch.cuda.empty_cache() 28 | rank = dist.get_rank() 29 | 30 | print(f"Start running basic DDP example on rank {rank}.") 31 | device_id = rank % torch.cuda.device_count() 32 | 33 | # set up 34 | config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) 35 | 36 | if device_id == 0: 37 | run = wandb.init( 38 | # Set the project where this run will be logged 39 | project="MERL_ICML", 40 | name = config['wandb_name'], 41 | # Track hyperparameters and run metadata 42 | config={ 43 | "learning_rate": config['optimizer']['params']['lr'], 44 | "total_epochs": config['trainer']['max_epochs'], 45 | 'weight_decay': config['optimizer']['params']['weight_decay'], 46 | 'ecg_model': config['network']['ecg_model'], 47 | 'text_model': config['network']['text_model'], 48 | 'batch_size': config['trainer']['batch_size'], 49 | 'val_zeroshot': 'all_sets', 50 | 'prompt_type': config['zeroshot']['prompt_type'], 51 | } 52 | ) 53 | 54 | torch.manual_seed(42) 55 | random.seed(0) 56 | np.random.seed(0) 57 | # loading data path 58 | data_path = config['dataset']['data_path'] 59 | 60 | # define image-text dataset 61 | dataset = utils_dataset.ECG_TEXT_Dsataset( 62 | data_path=data_path, dataset_name=config['dataset']['dataset_name']) 63 | train_dataset = dataset.get_dataset(train_test='train') 64 | val_dataset = dataset.get_dataset(train_test='val') 65 | 66 | # building model part 67 | # -------------------- 68 | model = utils_builder.ECGCLIP(config['network']) 69 | 70 | ''' 71 | you can freeze bert from last layer to first layer. 72 | set num of layer in config.yaml 73 | default is freeze 9 layers 74 | ''' 75 | if config['network']['free_layers'] is not None: 76 | for layer_idx in range(int(config['network']['free_layers'])): 77 | for param in list(model.lm_model.encoder.layer[layer_idx].parameters()): 78 | param.requires_grad = False 79 | 80 | model = model.to(device_id) 81 | model = DDP(model, device_ids=[device_id], find_unused_parameters=True) 82 | 83 | # -------------------- 84 | 85 | # -------------------- 86 | optimizer = torch.optim.AdamW( 87 | model.parameters(), 88 | **config['optimizer']['params'], 89 | betas=(0.9, 0.999) 90 | ) 91 | 92 | # ---------xw----------- 93 | trainer = trainer_wBert(model=model, 94 | optimizer=optimizer, 95 | device=rank, 96 | model_name=config['wandb_name'], 97 | **config['trainer']) 98 | # -------------------- 99 | 100 | # -------------------- 101 | # I_T_P_trainer 102 | trainer.train_w_TextEmb(train_dataset, val_dataset, config['zeroshot']) 103 | 104 | 105 | ddp_main() 106 | -------------------------------------------------------------------------------- /pretrain/preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "import wfdb\n", 12 | "import os\n", 13 | "from sklearn.model_selection import train_test_split\n", 14 | "from matplotlib import pyplot as plt\n", 15 | "import seaborn as snss\n", 16 | "from pprint import pprint\n", 17 | "from tqdm import tqdm\n", 18 | "import sys\n", 19 | "sys.path.append(\"../finetune/\")\n", 20 | "sys.path.append(\"../utils\")" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# set your meta path of mimic-ecg\n", 30 | "meta_path = 'your_path/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0'\n", 31 | "report_csv = pd.read_csv(f'{meta_path}/machine_measurements.csv', low_memory=False)\n", 32 | "record_csv = pd.read_csv(f'{meta_path}/record_list.csv', low_memory=False)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "def process_report(row):\n", 42 | " # Select the relevant columns and filter out NaNs\n", 43 | " report = row[['report_0', 'report_1', 'report_2', 'report_3', 'report_4', \n", 44 | " 'report_5', 'report_6', 'report_7', 'report_8', 'report_9', \n", 45 | " 'report_10', 'report_11', 'report_12', 'report_13', 'report_14', \n", 46 | " 'report_15', 'report_16', 'report_17']].dropna()\n", 47 | " # Concatenate the report\n", 48 | " report = '. '.join(report)\n", 49 | " # Replace and preprocess text\n", 50 | " report = report.replace('EKG', 'ECG').replace('ekg', 'ecg')\n", 51 | " report = report.strip(' ***').strip('*** ').strip('***').strip('=-').strip('=')\n", 52 | " # Convert to lowercase\n", 53 | " report = report.lower()\n", 54 | "\n", 55 | " # concatenate the report if the report length is not 0\n", 56 | " total_report = ''\n", 57 | " if len(report.split()) != 0:\n", 58 | " total_report = report\n", 59 | " total_report = total_report.replace('\\n', ' ')\n", 60 | " total_report = total_report.replace('\\r', ' ')\n", 61 | " total_report = total_report.replace('\\t', ' ')\n", 62 | " total_report += '.'\n", 63 | " if len(report.split()) == 0:\n", 64 | " total_report = 'empty'\n", 65 | " # Calculate the length of the report in words\n", 66 | " return len(report.split()), total_report\n", 67 | "\n", 68 | "tqdm.pandas()\n", 69 | "report_csv['report_length'], report_csv['total_report'] = zip(*report_csv.progress_apply(process_report, axis=1))\n", 70 | "# Filter out reports with less than 4 words\n", 71 | "report_csv = report_csv[report_csv['report_length'] >= 4]\n", 72 | "\n", 73 | "# you should get 771693 here\n", 74 | "print(report_csv.shape)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "report_csv.reset_index(drop=True, inplace=True)\n", 84 | "record_csv = record_csv[record_csv['study_id'].isin(report_csv['study_id'])]\n", 85 | "record_csv.reset_index(drop=True, inplace=True)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "# build an empty numpy array to store the data, we use int16 to save the space\n", 95 | "temp_npy = np.zeros((len(record_csv), 12, 5000), dtype=np.int16)\n", 96 | "\n", 97 | "for p in tqdm(record_csv['path']):\n", 98 | " # read the data\n", 99 | " ecg_path = os.path.join(meta_path, p)\n", 100 | " record = wfdb.rdsamp(ecg_path)[0]\n", 101 | " record = record.T\n", 102 | " # replace the nan with the neighbor 5 value mean\n", 103 | " # detect nan in each lead\n", 104 | " if np.isnan(record).sum() == 0 and np.isinf(record).sum() == 0:\n", 105 | " # normalize to 0-1\n", 106 | " record = (record - record.min()) / (record.max() - record.min())\n", 107 | " # scale the data\n", 108 | " record *= 1000\n", 109 | " # convert to int16\n", 110 | " record = record.astype(np.int16)\n", 111 | " # store the data\n", 112 | " temp_npy[record_csv[record_csv['path'] == p].index[0]] = record[:, :5000]\n", 113 | "\n", 114 | " else:\n", 115 | " if np.isinf(record).sum() == 0:\n", 116 | " for i in range(record.shape[0]):\n", 117 | " nan_idx = np.where(np.isnan(record[:, i]))[0]\n", 118 | " for idx in nan_idx:\n", 119 | " record[idx, i] = np.mean(record[max(0, idx-6):min(idx+6, record.shape[0]), i])\n", 120 | " if np.isnan(record).sum() == 0:\n", 121 | " for i in range(record.shape[0]):\n", 122 | " inf_idx = np.where(np.isinf(record[:, i]))[0]\n", 123 | " for idx in inf_idx:\n", 124 | " record[idx, i] = np.mean(record[max(0, idx-6):min(idx+6, record.shape[0]), i])\n", 125 | "\n", 126 | " # normalize to 0-1\n", 127 | " record = (record - record.min()) / (record.max() - record.min())\n", 128 | " # scale the data\n", 129 | " record *= 1000\n", 130 | " # convert to int16\n", 131 | " record = record.astype(np.int16)\n", 132 | " # store the data\n", 133 | " temp_npy[record_csv[record_csv['path'] == p].index[0]] = record[:, :5000]" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "# split to train and val\n", 143 | "train_npy, train_csv, val_npy, val_csv = train_test_split(temp_npy, report_csv, test_size=0.02, random_state=42)\n", 144 | "\n", 145 | "train_csv.reset_index(drop=True, inplace=True)\n", 146 | "val_csv.reset_index(drop=True, inplace=True)\n", 147 | "\n", 148 | "# save to your path\n", 149 | "np.save(\"your_path_train.npy\", train_npy)\n", 150 | "np.save(\"your_path_val.npy\", val_npy)\n", 151 | "train_csv.to_csv(\"your_path_train.csv\", index=False)\n", 152 | "val_csv.to_csv(\"your_path_val.csv\", index=False)" 153 | ] 154 | } 155 | ], 156 | "metadata": { 157 | "kernelspec": { 158 | "display_name": "chen", 159 | "language": "python", 160 | "name": "python3" 161 | }, 162 | "language_info": { 163 | "codemirror_mode": { 164 | "name": "ipython", 165 | "version": 3 166 | }, 167 | "file_extension": ".py", 168 | "mimetype": "text/x-python", 169 | "name": "python", 170 | "nbconvert_exporter": "python", 171 | "pygments_lexer": "ipython3", 172 | "version": "3.9.19" 173 | } 174 | }, 175 | "nbformat": 4, 176 | "nbformat_minor": 2 177 | } 178 | -------------------------------------------------------------------------------- /utils/__pycache__/resnet1d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/utils/__pycache__/resnet1d.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/utils/__pycache__/utils_builder.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/utils/__pycache__/utils_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/utils/__pycache__/utils_loss.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/utils/__pycache__/utils_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vit1d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/utils/__pycache__/vit1d.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/zeroshot_val.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/utils/__pycache__/zeroshot_val.cpython-310.pyc -------------------------------------------------------------------------------- /utils/resnet1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class BasicBlock(nn.Module): 5 | expansion = 1 6 | 7 | def __init__(self, in_channels, out_channels, stride=1): 8 | super(BasicBlock, self).__init__() 9 | 10 | # 1x1 Convolution 11 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm1d(out_channels) 13 | 14 | # 3x3 Convolution 15 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm1d(out_channels) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_channels != self.expansion * out_channels: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv1d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm1d(self.expansion * out_channels) 23 | ) 24 | 25 | def forward(self, x): 26 | out = torch.relu(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | out += self.shortcut(x) 29 | out = torch.relu(out) 30 | return out 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, in_channels, out_channels, stride=1): 36 | super(Bottleneck, self).__init__() 37 | 38 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm1d(out_channels) 40 | 41 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn2 = nn.BatchNorm1d(out_channels) 43 | 44 | self.conv3 = nn.Conv1d(out_channels, self.expansion * out_channels, kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm1d(self.expansion * out_channels) 46 | 47 | self.shortcut = nn.Sequential() 48 | if stride != 1 or in_channels != self.expansion * out_channels: 49 | self.shortcut = nn.Sequential( 50 | nn.Conv1d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False), 51 | nn.BatchNorm1d(self.expansion * out_channels) 52 | ) 53 | 54 | def forward(self, x): 55 | out = torch.relu(self.bn1(self.conv1(x))) 56 | out = torch.relu(self.bn2(self.conv2(out))) 57 | out = self.bn3(self.conv3(out)) 58 | out += self.shortcut(x) 59 | out = torch.relu(out) 60 | return out 61 | 62 | 63 | class ResNet(nn.Module): 64 | def __init__(self, block, num_blocks, num_classes=10): 65 | super(ResNet, self).__init__() 66 | self.in_channels = 64 67 | 68 | self.conv1 = nn.Conv1d(12, 64, kernel_size=7, stride=2, padding=3, bias=False) 69 | self.bn1 = nn.BatchNorm1d(64) 70 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 71 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 72 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 73 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 74 | self.linear = nn.Linear(512 * block.expansion, num_classes) 75 | 76 | self.avgpool = nn.AdaptiveAvgPool1d((1)) 77 | 78 | def _make_layer(self, block, out_channels, num_blocks, stride): 79 | strides = [stride] + [1] * (num_blocks - 1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_channels, out_channels, stride)) 83 | self.in_channels = out_channels * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = torch.relu(self.bn1(self.conv1(x))) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | # out = self.avgpool(out) 93 | # out = out.view(out.size(0), -1) 94 | # out = self.linear(out) 95 | return out 96 | 97 | def ResNet18(): 98 | return ResNet(BasicBlock, [2, 2, 2, 2]) 99 | 100 | def ResNet34(): 101 | return ResNet(BasicBlock, [3, 4, 6, 3]) 102 | 103 | def ResNet50(): 104 | return ResNet(Bottleneck, [3, 4, 6, 3]) 105 | 106 | def ResNet101(): 107 | # Here we would usually use a different block type, Bottleneck, which contains three convolution layers 108 | # For simplicity, we're going to use BasicBlock here 109 | return ResNet(Bottleneck, [3, 4, 23, 3]) 110 | 111 | def ResNet152(): 112 | return ResNet(Bottleneck, [3, 8, 36, 3]) -------------------------------------------------------------------------------- /utils/utils_builder.py: -------------------------------------------------------------------------------- 1 | from cgi import test 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import torchvision 8 | import torch.nn.functional as F 9 | from torch.nn.functional import normalize 10 | from transformers import AutoModel, AutoTokenizer 11 | from resnet1d import ResNet18, ResNet34, ResNet50, ResNet101 12 | from vit1d import vit_base, vit_small, vit_tiny, vit_middle 13 | 14 | 15 | class AttentionPool2d(nn.Module): 16 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 17 | super().__init__() 18 | self.positional_embedding = nn.Parameter(torch.randn(1, spacial_dim + 1, embed_dim) / embed_dim) 19 | self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) 20 | 21 | self.mhsa = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) 22 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 23 | self.num_heads = num_heads 24 | 25 | def forward(self, x): 26 | x = x.permute(0, 2, 1) # convert X shape (B, C, L) to (B, L, C) 27 | 28 | self.cls_tokens = self.cls_token + self.positional_embedding[:, :1, :] 29 | self.cls_tokens = self.cls_tokens.expand(x.shape[0], -1, -1) 30 | x = torch.cat((self.cls_tokens, x), dim=1) 31 | x = x + self.positional_embedding[:, :, :].to(x.dtype) # (L+1)NC 32 | x, att_map = self.mhsa(x[:, :1, :], x, x, average_attn_weights=True) 33 | x = self.c_proj(x) 34 | return x.squeeze(0), att_map[:, :, 1:] 35 | 36 | class ECGCLIP(torch.nn.Module): 37 | def __init__(self, network_config): 38 | super(ECGCLIP, self).__init__() 39 | 40 | self.proj_hidden = network_config['projection_head']['mlp_hidden_size'] 41 | self.proj_out = network_config['projection_head']['projection_size'] 42 | 43 | # ecg signal encoder 44 | self.ecg_model = network_config['ecg_model'] 45 | self.num_leads = network_config['num_leads'] 46 | 47 | if 'resnet' in self.ecg_model: 48 | if self.ecg_model == 'resnet18': 49 | model = ResNet18() 50 | self.downconv = nn.Conv1d(in_channels=512, out_channels=self.proj_out, kernel_size=1) 51 | self.att_pool_head = AttentionPool2d(spacial_dim=313, 52 | embed_dim=self.proj_out, 53 | num_heads=4, 54 | output_dim=self.proj_out) 55 | elif self.ecg_model == 'resnet34': 56 | model = ResNet34() 57 | self.downconv = nn.Conv1d(in_channels=512, out_channels=self.proj_out, kernel_size=1) 58 | self.att_pool_head = AttentionPool2d(spacial_dim=313, 59 | embed_dim=self.proj_out, 60 | num_heads=4, 61 | output_dim=self.proj_out) 62 | elif self.ecg_model == 'resnet50': 63 | model = ResNet50() 64 | self.downconv = nn.Conv1d(in_channels=2048, out_channels=self.proj_out, kernel_size=1) 65 | self.att_pool_head = AttentionPool2d(spacial_dim=313, 66 | embed_dim=self.proj_out, 67 | num_heads=4, 68 | output_dim=self.proj_out) 69 | elif self.ecg_model == 'resnet101': 70 | model = ResNet101() 71 | self.downconv = nn.Conv1d(in_channels=2048, out_channels=self.proj_out, kernel_size=1) 72 | self.att_pool_head = AttentionPool2d(spacial_dim=313, 73 | embed_dim=self.proj_out, 74 | num_heads=4, 75 | output_dim=self.proj_out) 76 | 77 | self.linear1 = nn.Linear(self.proj_out, self.proj_out, bias=False) 78 | self.linear2 = nn.Linear(self.proj_out, self.proj_out, bias=False) 79 | 80 | if 'vit' in self.ecg_model: 81 | if self.ecg_model == 'vit_tiny': 82 | model = vit_tiny(num_leads=self.num_leads) 83 | elif self.ecg_model == 'vit_small': 84 | model = vit_small(num_leads=self.num_leads) 85 | elif self.ecg_model == 'vit_middle': 86 | model = vit_middle(num_leads=self.num_leads) 87 | elif self.ecg_model == 'vit_base': 88 | model = vit_base(num_leads=self.num_leads) 89 | self.proj_e_input = model.width 90 | self.proj_e = nn.Sequential( 91 | nn.Linear(self.proj_e_input, self.proj_hidden), 92 | nn.BatchNorm1d(self.proj_hidden), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(self.proj_hidden, self.proj_out), 95 | nn.BatchNorm1d(self.proj_out), 96 | ) 97 | self.linear1 = nn.Linear(self.proj_e_input, self.proj_out, bias=False) 98 | self.linear2 = nn.Linear(self.proj_e_input, self.proj_out, bias=False) 99 | 100 | 101 | self.ecg_encoder = model 102 | self.avgpool = nn.AdaptiveAvgPool1d(1) 103 | 104 | 105 | self.dropout1 = nn.Dropout(p=0.1) 106 | self.dropout2 = nn.Dropout(p=0.1) 107 | 108 | # text encoder 109 | url = network_config['text_model'] 110 | self.lm_model = AutoModel.from_pretrained( 111 | url, trust_remote_code=True, revision='main') 112 | self.tokenizer = AutoTokenizer.from_pretrained( 113 | url, trust_remote_code=True, revision='main') 114 | 115 | # text projector 116 | self.proj_t = nn.Sequential( 117 | nn.Linear(768, self.proj_hidden), 118 | nn.GELU(), 119 | nn.Linear(self.proj_hidden, self.proj_out), 120 | ) 121 | 122 | 123 | def _tokenize(self, text): 124 | tokenizer_output = self.tokenizer.batch_encode_plus(batch_text_or_text_pairs=text, 125 | add_special_tokens=True, 126 | truncation=True, 127 | max_length=256, 128 | padding='max_length', 129 | return_tensors='pt') 130 | 131 | return tokenizer_output 132 | 133 | @torch.no_grad() 134 | def ext_ecg_emb(self, ecg): 135 | 136 | if 'resnet' in self.ecg_model: 137 | ecg_emb = self.ecg_encoder(ecg) 138 | ecg_emb = self.downconv(ecg_emb) 139 | proj_ecg_emb, att_map = self.att_pool_head(ecg_emb) 140 | proj_ecg_emb = proj_ecg_emb.view(proj_ecg_emb.shape[0], -1) 141 | 142 | if 'vit' in self.ecg_model: 143 | ecg_emb = self.ecg_encoder(ecg) 144 | proj_ecg_emb = self.proj_e(ecg_emb) 145 | 146 | return proj_ecg_emb 147 | 148 | @torch.no_grad() 149 | def get_text_emb(self, input_ids, attention_mask): 150 | text_emb = self.lm_model(input_ids=input_ids, 151 | attention_mask=attention_mask).pooler_output 152 | return text_emb 153 | 154 | def forward(self, ecg, input_ids, attention_mask): 155 | ecg_emb = self.ecg_encoder(ecg) 156 | 157 | if 'resnet' in self.ecg_model: 158 | # attention pooling (only for resnet models) 159 | ecg_emb = self.downconv(ecg_emb) 160 | proj_ecg_emb, _ = self.att_pool_head(ecg_emb) 161 | proj_ecg_emb = proj_ecg_emb.view(proj_ecg_emb.shape[0], -1) 162 | 163 | ecg_emb = self.avgpool(ecg_emb).view(ecg_emb.shape[0], -1) 164 | ecg_emb1 = self.dropout1(self.linear1(ecg_emb)) 165 | ecg_emb2 = self.dropout2(self.linear2(ecg_emb)) 166 | 167 | if 'vit' in self.ecg_model: 168 | proj_ecg_emb = self.proj_e(ecg_emb) 169 | ecg_emb1 = self.dropout1(self.linear1(ecg_emb)) 170 | ecg_emb2 = self.dropout2(self.linear2(ecg_emb)) 171 | 172 | proj_ecg_emb = normalize(proj_ecg_emb, dim=-1) 173 | 174 | 175 | # get text feature 176 | # text feature extraction is independent of the type of ecg encoder 177 | text_emb = self.get_text_emb(input_ids, attention_mask) 178 | proj_text_emb = self.proj_t(text_emb.contiguous()) 179 | proj_text_emb = normalize(proj_text_emb, dim=-1) 180 | 181 | if self.training: 182 | return {'ecg_emb': [ecg_emb1, ecg_emb2], 183 | 'proj_ecg_emb': [proj_ecg_emb], 184 | 'proj_text_emb': [proj_text_emb]} 185 | else: 186 | return {'ecg_emb': [ecg_emb1, ecg_emb2], 187 | 'proj_ecg_emb': [proj_ecg_emb], 188 | 'proj_text_emb': [proj_text_emb]} 189 | -------------------------------------------------------------------------------- /utils/utils_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from torch.utils.data import Dataset, ConcatDataset 4 | import numpy as np 5 | from sklearn.model_selection import train_test_split 6 | from torchvision.transforms import transforms 7 | from PIL import Image 8 | import wfdb 9 | from tqdm import tqdm 10 | import os 11 | 12 | # these two datasets will read the raw ecg 13 | 14 | class Ori_MIMIC_E_T_Dataset(Dataset): 15 | def __init__(self, ecg_meta_path, transform=None, **args): 16 | self.ecg_meta_path = ecg_meta_path 17 | self.mode = args['train_test'] 18 | self.text_csv = args['text_csv'] 19 | self.record_csv = args['record_csv'] 20 | self.transform = transform 21 | 22 | def __len__(self): 23 | return (self.text_csv.shape[0]) 24 | 25 | def __getitem__(self, idx): 26 | if torch.is_tensor(idx): 27 | idx = idx.tolist() 28 | 29 | # get ecg 30 | study_id = self.text_csv['study_id'].iloc[idx] 31 | if study_id == self.record_csv['study_id'].iloc[idx]: 32 | path = self.record_csv['path'].iloc[idx] 33 | else: 34 | print('Error: study_id not match!') 35 | path = os.path.join(self.ecg_meta_path, path) 36 | ecg = wfdb.rdsamp(path)[0] 37 | ecg = ecg.T 38 | 39 | # check nan and inf 40 | if np.isinf(ecg).sum() == 0: 41 | for i in range(ecg.shape[0]): 42 | nan_idx = np.where(np.isnan(ecg[:, i]))[0] 43 | for idx in nan_idx: 44 | ecg[idx, i] = np.mean(ecg[max(0, idx-6):min(idx+6, ecg.shape[0]), i]) 45 | if np.isnan(ecg).sum() == 0: 46 | for i in range(ecg.shape[0]): 47 | inf_idx = np.where(np.isinf(ecg[:, i]))[0] 48 | for idx in inf_idx: 49 | ecg[idx, i] = np.mean(ecg[max(0, idx-6):min(idx+6, ecg.shape[0]), i]) 50 | 51 | # noramlize 52 | ecg = (ecg - np.min(ecg))/(np.max(ecg) - np.min(ecg) + 1e-8) 53 | 54 | # get raw text 55 | report = self.text_csv.iloc[idx][['report_0', 'report_1', 56 | 'report_2', 'report_3', 'report_4', 'report_5', 'report_6', 'report_7', 57 | 'report_8', 'report_9', 'report_10', 'report_11', 'report_12', 58 | 'report_13', 'report_14', 'report_15', 'report_16', 'report_17']] 59 | # only keep not NaN 60 | report = report[~report.isna()] 61 | # concat the report 62 | report = '. '.join(report) 63 | # preprocessing on raw text 64 | report = report.replace('EKG', 'ECG') 65 | report = report.replace('ekg', 'ecg') 66 | report = report.strip('*** ') 67 | report = report.strip(' ***') 68 | report = report.strip('***') 69 | report = report.strip('=-') 70 | report = report.strip('=') 71 | # convert to all lower case 72 | report = report.lower() 73 | 74 | sample = {'ecg': ecg, 'raw_text': report} 75 | 76 | if self.transform: 77 | if self.mode == 'train': 78 | sample['ecg'] = self.transform(sample['ecg']) 79 | sample['ecg'] = torch.squeeze(sample['ecg'], dim=0) 80 | else: 81 | sample['ecg'] = self.transform(sample['ecg']) 82 | sample['ecg'] = torch.squeeze(sample['ecg'], dim=0) 83 | return sample 84 | 85 | 86 | class Ori_ECG_TEXT_Dsataset: 87 | 88 | def __init__(self, ecg_path, csv_path, dataset_name='mimic'): 89 | # if you use this dataset, please replace ecg_path from config.yaml to the 'your path/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0' 90 | self.ecg_path = ecg_path 91 | self.csv_path = csv_path 92 | self.dataset_name = dataset_name 93 | self.csv = pd.read_csv(self.csv_path, low_memory=False) 94 | self.record_csv = pd.read_csv(os.path.join(self.ecg_path, 'record_list.csv'), low_memory=False) 95 | 96 | # sort and reset index by study_id 97 | self.csv = self.csv.sort_values(by=['study_id']) 98 | self.csv.reset_index(inplace=True, drop=True) 99 | self.record_csv = self.record_csv.sort_values(by=['study_id']) 100 | self.record_csv.reset_index(inplace=True, drop=True) 101 | 102 | # split train and val 103 | self.train_csv, self.val_csv, self.train_record_csv, self.val_record_csv = \ 104 | train_test_split(self.csv, self.record_csv, test_size=0.02, random_state=42) 105 | # sort and reset index by study_id 106 | self.train_csv = self.train_csv.sort_values(by=['study_id']) 107 | self.val_csv = self.val_csv.sort_values(by=['study_id']) 108 | self.train_csv.reset_index(inplace=True, drop=True) 109 | self.val_csv.reset_index(inplace=True, drop=True) 110 | 111 | self.train_record_csv = self.train_record_csv.sort_values(by=['study_id']) 112 | self.val_record_csv = self.val_record_csv.sort_values(by=['study_id']) 113 | self.train_record_csv.reset_index(inplace=True, drop=True) 114 | self.val_record_csv.reset_index(inplace=True, drop=True) 115 | 116 | print(f'train size: {self.train_csv.shape[0]}') 117 | print(f'val size: {self.val_csv.shape[0]}') 118 | 119 | def get_dataset(self, train_test, T=None): 120 | 121 | if train_test == 'train': 122 | print('Apply Train-stage Transform!') 123 | 124 | Transforms = transforms.Compose([ 125 | transforms.ToTensor(), 126 | ]) 127 | else: 128 | print('Apply Val-stage Transform!') 129 | 130 | Transforms = transforms.Compose([ 131 | transforms.ToTensor(), 132 | ]) 133 | 134 | 135 | if self.dataset_name == 'mimic': 136 | 137 | if train_test == 'train': 138 | misc_args = {'train_test': train_test, 139 | 'text_csv': self.train_csv, 140 | 'record_csv': self.train_record_csv} 141 | else: 142 | misc_args = {'train_test': train_test, 143 | 'text_csv': self.val_csv, 144 | 'record_csv': self.val_record_csv} 145 | 146 | 147 | dataset = Ori_MIMIC_E_T_Dataset(ecg_data=self.ecg_path, 148 | transform=Transforms, 149 | **misc_args) 150 | print(f'{train_test} dataset length: ', len(dataset)) 151 | 152 | return dataset 153 | 154 | 155 | # these two datasets will read the ecg from preprocessed npy file 156 | # we suggest to use these two datasets for accelerating the IO speed 157 | 158 | 159 | class MIMIC_E_T_Dataset(Dataset): 160 | def __init__(self, ecg_meta_path, transform=None, **args): 161 | self.ecg_meta_path = ecg_meta_path 162 | self.mode = args['train_test'] 163 | if self.mode == 'train': 164 | self.ecg_data = os.path.join(ecg_meta_path, 'mimic_ecg_train.npy') 165 | self.ecg_data = np.load(self.ecg_data, 'r') 166 | 167 | else: 168 | self.ecg_data = os.path.join(ecg_meta_path, 'mimic_ecg_val.npy') 169 | self.ecg_data = np.load(self.ecg_data, 'r') 170 | 171 | 172 | self.text_csv = args['text_csv'] 173 | 174 | self.transform = transform 175 | 176 | def __len__(self): 177 | return (self.text_csv.shape[0]) 178 | 179 | def __getitem__(self, idx): 180 | if torch.is_tensor(idx): 181 | idx = idx.tolist() 182 | 183 | # we have to divide 1000 to get the real value 184 | ecg = self.ecg_data[idx]/1000 185 | # ecg = (ecg - np.min(ecg))/(np.max(ecg) - np.min(ecg) + 1e-8) 186 | 187 | 188 | # get raw text 189 | report = self.text_csv.iloc[idx]['total_report'] 190 | 191 | sample = {'ecg': ecg, 'raw_text': report} 192 | 193 | if self.transform: 194 | if self.mode == 'train': 195 | sample['ecg'] = self.transform(sample['ecg']) 196 | sample['ecg'] = torch.squeeze(sample['ecg'], dim=0) 197 | else: 198 | sample['ecg'] = self.transform(sample['ecg']) 199 | sample['ecg'] = torch.squeeze(sample['ecg'], dim=0) 200 | return sample 201 | 202 | 203 | class ECG_TEXT_Dsataset: 204 | 205 | def __init__(self, data_path, dataset_name='mimic'): 206 | self.data_path = data_path 207 | self.dataset_name = dataset_name 208 | 209 | print(f'Load {dataset_name} dataset!') 210 | self.train_csv = pd.read_csv(os.path.join(self.data_path, 'train.csv'), low_memory=False) 211 | self.val_csv = pd.read_csv(os.path.join(self.data_path, 'val.csv'), low_memory=False) 212 | 213 | print(f'train size: {self.train_csv.shape[0]}') 214 | print(f'val size: {self.val_csv.shape[0]}') 215 | print(f'total size: {self.train_csv.shape[0] + self.val_csv.shape[0]}') 216 | 217 | def get_dataset(self, train_test, T=None): 218 | 219 | if train_test == 'train': 220 | print('Apply Train-stage Transform!') 221 | 222 | Transforms = transforms.Compose([ 223 | transforms.ToTensor(), 224 | ]) 225 | else: 226 | print('Apply Val-stage Transform!') 227 | 228 | Transforms = transforms.Compose([ 229 | transforms.ToTensor(), 230 | ]) 231 | 232 | 233 | if self.dataset_name == 'mimic': 234 | 235 | if train_test == 'train': 236 | misc_args = {'train_test': train_test, 237 | 'text_csv': self.train_csv, 238 | } 239 | else: 240 | misc_args = {'train_test': train_test, 241 | 'text_csv': self.val_csv, 242 | } 243 | 244 | 245 | dataset = MIMIC_E_T_Dataset(ecg_meta_path=self.data_path, 246 | transform=Transforms, 247 | **misc_args) 248 | print(f'{train_test} dataset length: ', len(dataset)) 249 | 250 | return dataset 251 | -------------------------------------------------------------------------------- /utils/utils_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Type 3 | import torch 4 | import torch.nn.functional as F 5 | import pandas as pd 6 | from torch.cuda.amp import autocast as autocast 7 | from torch.cuda.amp import GradScaler as GradScaler 8 | from tqdm import tqdm 9 | 10 | def precision_at_k(output: torch.Tensor, target: torch.Tensor, top_k=(1,)): 11 | ''' Compute the accuracy over the k top predictions for the specified values of k''' 12 | with torch.no_grad(): 13 | maxk = max(top_k) 14 | batch_size = target.size(0) 15 | 16 | _, pred = output.topk(maxk, 1, True, True) 17 | pred = pred.t() 18 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 19 | 20 | res = [] 21 | for k in top_k: 22 | correct_k = correct[:k].contiguous( 23 | ).view(-1).float().sum(0, keepdim=True) 24 | res.append(correct_k.mul_(100.0 / batch_size)) 25 | return res 26 | 27 | def clip_loss(x, y, temperature=0.07, device='cuda'): 28 | x = F.normalize(x, dim=-1) 29 | y = F.normalize(y, dim=-1) 30 | 31 | sim = torch.einsum('i d, j d -> i j', x, y) * 1 / temperature 32 | 33 | labels = torch.arange(x.shape[0]).to(device) 34 | 35 | loss_t = F.cross_entropy(sim, labels) 36 | loss_i = F.cross_entropy(sim.T, labels) 37 | 38 | i2t_acc1, i2t_acc5 = precision_at_k( 39 | sim, labels, top_k=(1, 5)) 40 | t2i_acc1, t2i_acc5 = precision_at_k( 41 | sim.T, labels, top_k=(1, 5)) 42 | acc1 = (i2t_acc1 + t2i_acc1) / 2. 43 | acc5 = (i2t_acc5 + t2i_acc5) / 2. 44 | 45 | return (loss_t + loss_i), acc1, acc5 -------------------------------------------------------------------------------- /utils/utils_optimizer.py: -------------------------------------------------------------------------------- 1 | from torch import optim as optim 2 | import torch 3 | 4 | 5 | class LARS(optim.Optimizer): 6 | def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, 7 | weight_decay_filter=False, lars_adaptation_filter=False): 8 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 9 | eta=eta, weight_decay_filter=weight_decay_filter, 10 | lars_adaptation_filter=lars_adaptation_filter) 11 | super().__init__(params, defaults) 12 | 13 | def exclude_bias_and_norm(self, p): 14 | return p.ndim == 1 15 | 16 | @torch.no_grad() 17 | def step(self): 18 | for g in self.param_groups: 19 | for p in g['params']: 20 | dp = p.grad 21 | 22 | if dp is None: 23 | continue 24 | 25 | if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): 26 | dp = dp.add(p, alpha=g['weight_decay']) 27 | 28 | if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): 29 | param_norm = torch.norm(p) 30 | update_norm = torch.norm(dp) 31 | one = torch.ones_like(param_norm) 32 | q = torch.where(param_norm > 0., 33 | torch.where(update_norm > 0, 34 | (g['eta'] * param_norm / update_norm), one), one) 35 | dp = dp.mul(q) 36 | 37 | param_state = self.state[p] 38 | if 'mu' not in param_state: 39 | param_state['mu'] = torch.zeros_like(p) 40 | mu = param_state['mu'] 41 | mu.mul_(g['momentum']).add_(dp) 42 | 43 | p.add_(mu, alpha=-g['lr']) 44 | -------------------------------------------------------------------------------- /utils/utils_trainer.py: -------------------------------------------------------------------------------- 1 | # package import 2 | # import wandb 3 | import os 4 | from typing import Type 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data.dataloader import DataLoader 9 | from torch.cuda.amp import autocast as autocast 10 | from torch.cuda.amp import GradScaler as GradScaler 11 | from torch.utils.data.distributed import DistributedSampler 12 | from torch import distributed as torch_dist 13 | import torch.distributed as dist 14 | 15 | from tqdm import tqdm 16 | import numpy as np 17 | import pandas as pd 18 | from matplotlib import pyplot as plt 19 | import yaml as yaml 20 | 21 | from utils_loss import clip_loss 22 | from zeroshot_val import zeroshot_eval 23 | 24 | import wandb 25 | 26 | class trainer_wBert: 27 | def __init__(self, model, 28 | optimizer, device, model_name, **args): 29 | self.model = model 30 | self.optimizer = optimizer 31 | self.device = device 32 | self.model_name = model_name 33 | self.train_batch_size = args['batch_size'] 34 | self.max_epochs = args['max_epochs'] 35 | self.num_workers = args['num_workers'] 36 | self.checkpoint_interval = args['checkpoint_interval'] 37 | self.val_batch_size = args['val_batch_size'] 38 | 39 | # traing process 40 | def train_w_TextEmb(self, train_dataset, val_dataset, args_zeroshot_eval): 41 | 42 | train_loader = DataLoader(train_dataset, batch_size=self.train_batch_size, 43 | num_workers=self.num_workers, 44 | drop_last=True, shuffle=False, 45 | sampler=DistributedSampler(train_dataset)) 46 | 47 | val_loader = DataLoader(val_dataset, batch_size=self.val_batch_size, 48 | num_workers=self.num_workers, 49 | drop_last=True, shuffle=False, 50 | sampler=DistributedSampler(val_dataset)) 51 | 52 | 53 | model_checkpoints_folder = os.path.join('../checkpoints/') 54 | if self.device == 0: 55 | if not os.path.exists(model_checkpoints_folder): 56 | print('create directory "{}" for save checkpoint!'.format( 57 | model_checkpoints_folder)) 58 | print('---------------------------') 59 | os.makedirs(model_checkpoints_folder) 60 | else: 61 | print('directory "{}" existing for save checkpoint!'.format( 62 | model_checkpoints_folder)) 63 | 64 | # automatically resume from checkpoint if it exists 65 | print('#########################################') 66 | print('Be patient..., checking checkpoint now...') 67 | if os.path.exists(model_checkpoints_folder + self.model_name+'_checkpoint.pth'): 68 | ckpt = torch.load(model_checkpoints_folder + self.model_name+'_checkpoint.pth', 69 | map_location='cpu') 70 | start_epoch = ckpt['epoch'] 71 | self.model.load_state_dict(ckpt['model_state_dict']) 72 | self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) 73 | print('continue training successful!') 74 | else: 75 | start_epoch = 0 76 | print('Start training from 0 epoch') 77 | 78 | print('#########################################') 79 | print('training start!') 80 | 81 | # scheduler 82 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 83 | self.optimizer, 84 | T_0=5000, 85 | T_mult=1, 86 | eta_min=1e-8, 87 | ) 88 | niter = 1 89 | 90 | skip_scheduler = False 91 | scaler = GradScaler() 92 | 93 | f1_total = [] 94 | acc_total = [] 95 | auc_total = [] 96 | 97 | zeroshot_csv = pd.DataFrame() 98 | best_auc = 0 99 | 100 | for epoch_counter in tqdm(range(start_epoch, self.max_epochs+1)): 101 | 102 | epoch_loss = 0 103 | epoch_acc1 = [] 104 | epoch_acc5 = [] 105 | self.model.train() 106 | for data in tqdm(train_loader): 107 | self.model.train() 108 | # get raw text 109 | report = data['raw_text'] 110 | 111 | # get ecg 112 | ecg = data['ecg'].to(torch.float32).to( 113 | self.device).contiguous() 114 | 115 | self.optimizer.zero_grad() 116 | 117 | with autocast(): 118 | report_tokenize_output = self.model.module._tokenize(report) 119 | 120 | input_ids = report_tokenize_output.input_ids.to( 121 | self.device).contiguous() 122 | attention_mask = report_tokenize_output.attention_mask.to( 123 | self.device).contiguous() 124 | 125 | output_dict = self.model(ecg, input_ids, attention_mask) 126 | ecg_emb, proj_ecg_emb, proj_text_emb = output_dict['ecg_emb'],\ 127 | output_dict['proj_ecg_emb'],\ 128 | output_dict['proj_text_emb'] 129 | 130 | 131 | world_size = torch_dist.get_world_size() 132 | with torch.no_grad(): 133 | agg_proj_img_emb = [torch.zeros_like(proj_ecg_emb[0]) for _ in range(world_size)] 134 | agg_proj_text_emb = [torch.zeros_like(proj_text_emb[0]) for _ in range(world_size)] 135 | 136 | dist.all_gather(agg_proj_img_emb, proj_ecg_emb[0]) 137 | dist.all_gather(agg_proj_text_emb, proj_text_emb[0]) 138 | 139 | agg_proj_ecg_emb1 = [torch.zeros_like(ecg_emb[0]) for _ in range(world_size)] 140 | agg_proj_ecg_emb2 = [torch.zeros_like(ecg_emb[1]) for _ in range(world_size)] 141 | dist.all_gather(agg_proj_ecg_emb1, ecg_emb[0]) 142 | dist.all_gather(agg_proj_ecg_emb2, ecg_emb[1]) 143 | # get current rank 144 | rank = torch_dist.get_rank() 145 | 146 | agg_proj_img_emb[rank] = proj_ecg_emb[0] 147 | agg_proj_text_emb[rank] = proj_text_emb[0] 148 | 149 | agg_proj_ecg_emb1[rank] = ecg_emb[0] 150 | agg_proj_ecg_emb2[rank] = ecg_emb[1] 151 | 152 | agg_proj_img_emb = torch.cat(agg_proj_img_emb, dim=0) 153 | agg_proj_text_emb = torch.cat(agg_proj_text_emb, dim=0) 154 | 155 | agg_proj_ecg_emb1 = torch.cat(agg_proj_ecg_emb1, dim=0) 156 | agg_proj_ecg_emb2 = torch.cat(agg_proj_ecg_emb2, dim=0) 157 | 158 | cma_loss, acc1, acc5 = clip_loss(agg_proj_img_emb, agg_proj_text_emb, device=self.device) 159 | uma_loss, _, _ = clip_loss(agg_proj_ecg_emb1, agg_proj_ecg_emb2, device=self.device) 160 | loss = cma_loss + uma_loss 161 | 162 | if self.device == 0: 163 | print(f'loss is {loss.item()}, acc1 is {acc1.item()}, acc5 is {acc5.item()}, cma_loss is {cma_loss.item()}, uma_loss is {uma_loss.item()}') 164 | 165 | wandb.log({ 166 | 'train_step_uma_loss': uma_loss.item(), 167 | 'train_step_cma_loss': cma_loss.item(), 168 | 'train_step_total_loss': loss.item(), 169 | 'train_step_acc1': acc1.item(), 170 | 'train_step_acc5': acc5.item()} 171 | ) 172 | 173 | # accumalate loss for logging 174 | epoch_loss += loss.item() 175 | epoch_acc1.append(acc1.item()) 176 | epoch_acc5.append(acc5.item()) 177 | 178 | scaler.scale(loss).backward() 179 | scaler.step(self.optimizer) 180 | scaler.update() 181 | 182 | if not skip_scheduler: 183 | scheduler.step() 184 | niter += 1 185 | 186 | # eval stage 187 | val_log = self.val(val_loader) 188 | 189 | if self.device == 0: 190 | # average train metric 191 | epoch_acc1 = np.array(epoch_acc1).mean() 192 | epoch_acc5 = np.array(epoch_acc5).mean() 193 | 194 | epoch_iter = (len(train_dataset)//self.train_batch_size) 195 | print(f'{epoch_counter} epoch loss is {epoch_loss/epoch_iter},\ 196 | acc1 is {epoch_acc1}, acc5 is {epoch_acc5}') 197 | 198 | # log train and val epoch metric 199 | wandb.log({ 200 | 'train_epoch_loss': epoch_loss/epoch_iter, 201 | 'train_epoch_acc1': epoch_acc1, 202 | 'train_epoch_acc5': epoch_acc5, 203 | 'val_cma_loss': val_log['val_cma_loss'], 204 | 'val_uma_loss': val_log['val_uma_loss'], 205 | 'val_epoch_loss': val_log['val_loss'], 206 | 'val_epoch_acc1': val_log['val_acc1'], 207 | 'val_epoch_acc5': val_log['val_acc5']} 208 | ) 209 | 210 | 211 | # zero-shot eval 212 | avg_f1, avg_acc, avg_auc = 0, 0, 0 213 | for set_name in args_zeroshot_eval['val_sets'].keys(): 214 | 215 | f1, acc, auc, _, _, _, res_dict = \ 216 | zeroshot_eval(model=self.model, 217 | set_name=set_name, 218 | device=self.device, 219 | args_zeroshot_eval=args_zeroshot_eval) 220 | 221 | avg_f1 += f1 222 | avg_acc += acc 223 | avg_auc += auc 224 | 225 | # log each val set zeroshot performance 226 | wandb.log({ 227 | f'{set_name}_f1': f1, 228 | f'{set_name}_acc': acc, 229 | f'{set_name}_AUROC': auc 230 | } 231 | ) 232 | 233 | avg_f1 = avg_f1/len(args_zeroshot_eval['val_sets'].keys()) 234 | avg_acc = avg_acc/len(args_zeroshot_eval['val_sets'].keys()) 235 | avg_auc = avg_auc/len(args_zeroshot_eval['val_sets'].keys()) 236 | wandb.log({ 237 | 'avg_f1': avg_f1, 238 | 'avg_acc': avg_acc, 239 | 'avg_auc': avg_auc 240 | } 241 | ) 242 | 243 | f1_total.append(f1) 244 | acc_total.append(acc) 245 | auc_total.append(auc) 246 | 247 | best_metric = avg_auc 248 | if best_metric > best_auc: 249 | best_auc = best_metric 250 | torch.save(self.model.module.state_dict(), 251 | model_checkpoints_folder + self.model_name+f'_bestZeroShotAll_ckpt.pth') 252 | torch.save(self.model.module.ecg_encoder.state_dict(), 253 | model_checkpoints_folder + self.model_name+f'_bestZeroShotAll_encoder.pth') 254 | 255 | if epoch_counter % self.checkpoint_interval == 0: 256 | self.save_checkpoints(epoch_counter, model_checkpoints_folder + self.model_name + f'_{epoch_counter}_ckpt.pth') 257 | 258 | if self.checkpoint_interval != 1: 259 | # save final ecg_encoder 260 | torch.save(self.model.module.ecg_encoder.state_dict(), 261 | model_checkpoints_folder + self.model_name + '_final_encoder.pth') 262 | # save final total model 263 | torch.save(self.model.module.state_dict(), 264 | model_checkpoints_folder + self.model_name + '_final_total.pth') 265 | 266 | def val(self, loader): 267 | print('start validation') 268 | self.model.eval() 269 | val_cma_loss = 0 270 | val_uma_loss = 0 271 | val_loss = 0 272 | val_epoch_acc1 = [] 273 | val_epoch_acc5 = [] 274 | 275 | for data in tqdm(loader): 276 | # get raw text 277 | report = data['raw_text'] 278 | # get ecg 279 | ecg = data['ecg'].to(torch.float32).to( 280 | self.device).contiguous() 281 | 282 | report_tokenize_output = self.model.module._tokenize(report) 283 | 284 | input_ids = report_tokenize_output.input_ids.to( 285 | self.device).contiguous() 286 | attention_mask = report_tokenize_output.attention_mask.to( 287 | self.device).contiguous() 288 | 289 | with torch.no_grad(): 290 | output_dict = self.model(ecg, input_ids, attention_mask) 291 | ecg_emb, proj_ecg_emb, proj_text_emb = output_dict['ecg_emb'],\ 292 | output_dict['proj_ecg_emb'],\ 293 | output_dict['proj_text_emb'] 294 | 295 | 296 | world_size = torch_dist.get_world_size() 297 | with torch.no_grad(): 298 | agg_proj_img_emb = [torch.zeros_like(proj_ecg_emb[0]) for _ in range(world_size)] 299 | agg_proj_text_emb = [torch.zeros_like(proj_text_emb[0]) for _ in range(world_size)] 300 | 301 | dist.all_gather(agg_proj_img_emb, proj_ecg_emb[0]) 302 | dist.all_gather(agg_proj_text_emb, proj_text_emb[0]) 303 | 304 | agg_proj_ecg_emb1 = [torch.zeros_like(ecg_emb[0]) for _ in range(world_size)] 305 | agg_proj_ecg_emb2 = [torch.zeros_like(ecg_emb[1]) for _ in range(world_size)] 306 | dist.all_gather(agg_proj_ecg_emb1, ecg_emb[0]) 307 | dist.all_gather(agg_proj_ecg_emb2, ecg_emb[1]) 308 | # get current rank 309 | rank = torch_dist.get_rank() 310 | 311 | agg_proj_img_emb[rank] = proj_ecg_emb[0] 312 | agg_proj_text_emb[rank] = proj_text_emb[0] 313 | 314 | agg_proj_ecg_emb1[rank] = ecg_emb[0] 315 | agg_proj_ecg_emb2[rank] = ecg_emb[1] 316 | 317 | agg_proj_img_emb = torch.cat(agg_proj_img_emb, dim=0) 318 | agg_proj_text_emb = torch.cat(agg_proj_text_emb, dim=0) 319 | 320 | agg_proj_ecg_emb1 = torch.cat(agg_proj_ecg_emb1, dim=0) 321 | agg_proj_ecg_emb2 = torch.cat(agg_proj_ecg_emb2, dim=0) 322 | 323 | cma_loss, acc1, acc5 = clip_loss(agg_proj_img_emb, agg_proj_text_emb, device=self.device) 324 | uma_loss, _, _ = clip_loss(agg_proj_ecg_emb1, agg_proj_ecg_emb2, device=self.device) 325 | loss = cma_loss + uma_loss 326 | 327 | # accumalate loss for logging 328 | val_cma_loss += cma_loss.item() 329 | val_uma_loss += uma_loss.item() 330 | val_loss += loss.item() 331 | val_epoch_acc1.append(acc1.item()) 332 | val_epoch_acc5.append(acc5.item()) 333 | 334 | if self.device == 0: 335 | val_cma_loss = val_cma_loss/len(val_epoch_acc1) 336 | val_uma_loss = val_uma_loss/len(val_epoch_acc1) 337 | val_loss = val_loss/len(val_epoch_acc1) 338 | val_epoch_acc1 = np.array(val_epoch_acc1).mean() 339 | val_epoch_acc5 = np.array(val_epoch_acc5).mean() 340 | 341 | val_log = {'val_loss': val_loss, 342 | 'val_cma_loss': val_cma_loss, 343 | 'val_uma_loss': val_uma_loss, 344 | 'val_acc1': val_epoch_acc1, 345 | 'val_acc5': val_epoch_acc5} 346 | return val_log 347 | else: 348 | return None 349 | 350 | def save_checkpoints(self, epoch, PATH): 351 | 352 | torch.save({ 353 | 'epoch': epoch, 354 | 'model_state_dict': self.model.state_dict(), 355 | 'optimizer_state_dict': self.optimizer.state_dict()}, 356 | PATH) 357 | -------------------------------------------------------------------------------- /utils/vit1d.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch Implementation of Vision Transformer 3 | ("An Image is Worth 16X16 Words: Transformers for Image Recognition at Scale") 4 | 5 | Reference 6 | - Paper: https://arxiv.org/abs/2010.11929 7 | - Code: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | from einops import rearrange 12 | from einops.layers.torch import Rearrange 13 | 14 | 15 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py#L137 16 | class DropPath(nn.Module): 17 | ''' 18 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 19 | ''' 20 | def __init__(self, drop_prob: float, scale_by_keep: bool = True): 21 | super(DropPath, self).__init__() 22 | self.drop_prob = drop_prob 23 | self.scale_by_keep = scale_by_keep 24 | 25 | def forward(self, x): 26 | if self.drop_prob <= 0. or not self.training: 27 | return x 28 | keep_prob = 1 - self.drop_prob 29 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 30 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 31 | if keep_prob > 0.0 and self.scale_by_keep: 32 | random_tensor.div_(keep_prob) 33 | return x * random_tensor 34 | 35 | 36 | class PreNorm(nn.Module): 37 | def __init__(self, 38 | dim: int, 39 | fn: nn.Module): 40 | super().__init__() 41 | self.norm = nn.LayerNorm(dim) 42 | self.fn = fn 43 | 44 | def forward(self, x, **kwargs): 45 | return self.fn(self.norm(x), **kwargs) 46 | 47 | 48 | class FeedForward(nn.Module): 49 | """ 50 | MLP Module with GELU activation fn + dropout. 51 | """ 52 | def __init__(self, 53 | input_dim: int, 54 | output_dim: int, 55 | hidden_dim: int, 56 | drop_out_rate=0.): 57 | super().__init__() 58 | self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim), 59 | nn.GELU(), 60 | nn.Dropout(drop_out_rate), 61 | nn.Linear(hidden_dim, output_dim), 62 | nn.Dropout(drop_out_rate)) 63 | 64 | def forward(self, x): 65 | return self.net(x) 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, 70 | input_dim: int, 71 | output_dim: int, 72 | heads: int = 8, 73 | dim_head: int = 64, 74 | qkv_bias: bool = True, 75 | drop_out_rate: float = 0., 76 | attn_drop_out_rate: float = 0.): 77 | super().__init__() 78 | inner_dim = dim_head * heads 79 | project_out = not (heads == 1 and dim_head == input_dim) 80 | 81 | self.heads = heads 82 | self.scale = dim_head ** -0.5 83 | 84 | self.attend = nn.Softmax(dim=-1) 85 | self.dropout = nn.Dropout(attn_drop_out_rate) 86 | self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias) 87 | 88 | if project_out: 89 | self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim), 90 | nn.Dropout(drop_out_rate)) 91 | else: 92 | self.to_out = nn.Identity() 93 | 94 | def forward(self, x): 95 | qkv = self.to_qkv(x).chunk(3, dim=-1) 96 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 97 | 98 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 99 | 100 | attn = self.attend(dots) 101 | attn = self.dropout(attn) 102 | out = torch.matmul(attn, v) 103 | out = rearrange(out, 'b h n d -> b n (h d)') 104 | out = self.to_out(out) 105 | return out 106 | 107 | 108 | class TransformerBlock(nn.Module): 109 | def __init__(self, 110 | input_dim: int, 111 | output_dim: int, 112 | hidden_dim: int, 113 | heads: int = 8, 114 | dim_head: int = 32, 115 | qkv_bias: bool = True, 116 | drop_out_rate: float = 0., 117 | attn_drop_out_rate: float = 0., 118 | drop_path_rate: float = 0.): 119 | super().__init__() 120 | attn = Attention(input_dim=input_dim, 121 | output_dim=output_dim, 122 | heads=heads, 123 | dim_head=dim_head, 124 | qkv_bias=qkv_bias, 125 | drop_out_rate=drop_out_rate, 126 | attn_drop_out_rate=attn_drop_out_rate) 127 | self.attn = PreNorm(dim=input_dim, 128 | fn=attn) 129 | self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() 130 | 131 | ff = FeedForward(input_dim=output_dim, 132 | output_dim=output_dim, 133 | hidden_dim=hidden_dim, 134 | drop_out_rate=drop_out_rate) 135 | self.ff = PreNorm(dim=output_dim, 136 | fn=ff) 137 | self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() 138 | 139 | def forward(self, x): 140 | x = self.droppath1(self.attn(x)) + x 141 | x = self.droppath2(self.ff(x)) + x 142 | return x 143 | 144 | 145 | class ViT(nn.Module): 146 | def __init__(self, 147 | num_leads: int, 148 | seq_len: int, 149 | patch_size: int, 150 | width: int = 768, 151 | depth: int = 12, 152 | mlp_dim: int = 3072, 153 | heads: int = 12, 154 | dim_head: int = 64, 155 | qkv_bias: bool = True, 156 | drop_out_rate: float = 0., 157 | attn_drop_out_rate: float = 0., 158 | drop_path_rate: float = 0., 159 | **kwargs): 160 | super().__init__() 161 | assert seq_len % patch_size == 0, 'The sequence length must be divisible by the patch size.' 162 | num_patches = seq_len // patch_size 163 | 164 | # conv patch start 165 | self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size, stride=patch_size, bias=False) 166 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, width)) 167 | 168 | self.dropout = nn.Dropout(drop_out_rate) 169 | 170 | 171 | self.depth = depth 172 | self.width = width 173 | drop_path_rate_list = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 174 | for i in range(depth): 175 | block = TransformerBlock(input_dim=width, 176 | output_dim=width, 177 | hidden_dim=mlp_dim, 178 | heads=heads, 179 | dim_head=dim_head, 180 | qkv_bias=qkv_bias, 181 | drop_out_rate=drop_out_rate, 182 | attn_drop_out_rate=attn_drop_out_rate, 183 | drop_path_rate=drop_path_rate_list[i]) 184 | self.add_module(f'block{i}', block) 185 | 186 | self.norm = nn.LayerNorm(width) 187 | self.head = nn.Identity() 188 | 189 | def forward_encoding(self, series): 190 | 191 | # for conv patch 192 | x = self.to_patch_embedding(series) 193 | x = rearrange(x, 'b c n -> b n c') 194 | x = x + self.pos_embedding 195 | 196 | # transformer blocks 197 | x = self.dropout(x) 198 | for i in range(self.depth): 199 | x = getattr(self, f'block{i}')(x) 200 | 201 | x = torch.mean(x, dim=1) # global average pooling 202 | 203 | return self.norm(x) 204 | 205 | def forward(self, series): 206 | x = self.forward_encoding(series) 207 | x = self.head(x) 208 | return x 209 | 210 | def reset_head(self, num_classes=1): 211 | del self.head 212 | self.head = nn.Linear(self.width, num_classes) 213 | 214 | 215 | def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 216 | model_args = dict(num_leads=num_leads, 217 | num_classes=num_classes, 218 | seq_len=seq_len, 219 | patch_size=patch_size, 220 | width=192, 221 | depth=12, 222 | heads=3, 223 | mlp_dim=768, 224 | **kwargs) 225 | return ViT(**model_args) 226 | 227 | 228 | def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 229 | model_args = dict(num_leads=num_leads, 230 | num_classes=num_classes, 231 | seq_len=seq_len, 232 | patch_size=patch_size, 233 | width=384, 234 | depth=12, 235 | heads=6, 236 | mlp_dim=1536, 237 | **kwargs) 238 | return ViT(**model_args) 239 | 240 | 241 | def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 242 | model_args = dict(num_leads=num_leads, 243 | num_classes=num_classes, 244 | seq_len=seq_len, 245 | patch_size=patch_size, 246 | width=512, 247 | depth=12, 248 | heads=8, 249 | mlp_dim=2048, 250 | **kwargs) 251 | return ViT(**model_args) 252 | 253 | 254 | def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs): 255 | model_args = dict(num_leads=num_leads, 256 | num_classes=num_classes, 257 | seq_len=seq_len, 258 | patch_size=patch_size, 259 | width=768, 260 | depth=12, 261 | heads=12, 262 | mlp_dim=3072, 263 | **kwargs) 264 | return ViT(**model_args) 265 | -------------------------------------------------------------------------------- /utils/zeroshot_val.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data.dataloader import DataLoader 4 | from torch.cuda.amp import autocast as autocast 5 | from torch.cuda.amp import GradScaler as GradScaler 6 | import os 7 | from tqdm import tqdm 8 | import numpy as np 9 | import pandas as pd 10 | from matplotlib import pyplot as plt 11 | from sklearn.metrics import roc_auc_score,precision_recall_curve,accuracy_score, f1_score 12 | import yaml as yaml 13 | import sys 14 | sys.path.append("../finetune/") 15 | 16 | from finetune_dataset import getdataset as get_zero_dataset 17 | 18 | def compute_AUCs(gt, pred, n_class): 19 | """Computes Area Under the Curve (AUC) from prediction scores. 20 | Args: 21 | gt: Pytorch tensor on GPU, shape = [n_samples, n_classes] 22 | true binary labels. 23 | pred: Pytorch tensor on GPU, shape = [n_samples, n_classes] 24 | can either be probability estimates of the positive class, 25 | confidence values, or binary decisions. 26 | Returns: 27 | List of AUROCs of all classes. 28 | """ 29 | AUROCs = [] 30 | gt_np = gt 31 | pred_np = pred 32 | for i in range(n_class): 33 | AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i], average='macro', multi_class='ovo')) 34 | return AUROCs 35 | 36 | def get_class_emd(model, class_name, device='cuda'): 37 | model.eval() 38 | with torch.no_grad(): # to(device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")) 39 | zeroshot_weights = [] 40 | # compute embedding through model for each class 41 | for texts in tqdm(class_name): 42 | texts = texts.lower() 43 | texts = [texts] # convert to list 44 | texts = model._tokenize(texts) # tokenize 45 | class_embeddings = model.get_text_emb(texts.input_ids.to(device=device) 46 | , texts.attention_mask.to(device=device) 47 | ) # embed with text encoder 48 | class_embeddings = model.proj_t(class_embeddings) # embed with text encoder 49 | 50 | # normalize class_embeddings 51 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 52 | # average over templates 53 | class_embedding = class_embeddings.mean(dim=0) 54 | # norm over new averaged templates 55 | class_embedding /= class_embedding.norm() 56 | zeroshot_weights.append(class_embedding) 57 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1) 58 | return zeroshot_weights 59 | 60 | def get_ecg_emd(model, loader, zeroshot_weights, device='cuda', softmax_eval=True): 61 | y_pred = [] 62 | model.eval() 63 | with torch.no_grad(): 64 | for i, (ecg, target) in enumerate(tqdm(loader)): 65 | ecg = ecg.to(device=device) 66 | # predict 67 | ecg_emb = model.ext_ecg_emb(ecg) 68 | ecg_emb /= ecg_emb.norm(dim=-1, keepdim=True) 69 | 70 | # obtain logits (cos similarity) 71 | logits = ecg_emb @ zeroshot_weights 72 | logits = torch.squeeze(logits, 0) # (N, num_classes) 73 | if softmax_eval is False: 74 | norm_logits = (logits - logits.mean()) / (logits.std()) 75 | logits = torch.sigmoid(norm_logits) 76 | 77 | y_pred.append(logits.cpu().data.numpy()) 78 | 79 | y_pred = np.concatenate(y_pred, axis=0) 80 | return np.array(y_pred) 81 | 82 | def zeroshot_eval(model, set_name, device='cuda', args_zeroshot_eval=None): 83 | assert args_zeroshot_eval is not None, "Please specify the test set!" 84 | 85 | set_name = set_name 86 | num_workers = args_zeroshot_eval['num_workers'] 87 | batch_size = args_zeroshot_eval['batch_size'] 88 | 89 | meta_data_path = args_zeroshot_eval['meta_data_path'] 90 | 91 | if 'val_sets' not in args_zeroshot_eval.keys(): 92 | data_path = args_zeroshot_eval['test_sets'][set_name]['data_path'] 93 | if 'val_sets' in args_zeroshot_eval.keys(): 94 | data_path = args_zeroshot_eval['val_sets'][set_name]['data_path'] 95 | 96 | data_path = os.path.join(meta_data_path, data_path) 97 | 98 | meta_split_path = args_zeroshot_eval['meta_split_path'] 99 | if 'val_sets' not in args_zeroshot_eval.keys(): 100 | split_path = args_zeroshot_eval['test_sets'][set_name]['split_path'] 101 | if 'val_sets' in args_zeroshot_eval.keys(): 102 | split_path = args_zeroshot_eval['val_sets'][set_name]['split_path'] 103 | split_path = os.path.join(meta_split_path, split_path) 104 | 105 | 106 | if 'ptbxl' in set_name: 107 | test_dataset = get_zero_dataset(data_path, split_path, mode='test', dataset_name='ptbxl') 108 | else: 109 | test_dataset = get_zero_dataset(data_path, split_path, mode='test', dataset_name=set_name) 110 | class_name = test_dataset.labels_name 111 | 112 | # open json as dict 113 | with open(args_zeroshot_eval['prompt_dict'], 'r') as f: 114 | prompt_dict = yaml.load(f, Loader=yaml.FullLoader) 115 | 116 | # get prompt for each class 117 | target_class = [prompt_dict[i] for i in class_name] 118 | 119 | print('***********************************') 120 | print('zeroshot classification set is {}'.format(set_name)) 121 | 122 | test_dataloader = DataLoader( 123 | test_dataset, 124 | batch_size=batch_size, 125 | num_workers=num_workers, 126 | pin_memory=True, 127 | sampler=None, 128 | shuffle=False, 129 | collate_fn=None, 130 | drop_last=False, 131 | ) 132 | 133 | # get the target array from testset 134 | gt = test_dataset.labels 135 | 136 | # get class embedding 137 | zeroshot_weights = get_class_emd(model.module, target_class, device=device) 138 | # get ecg prediction 139 | pred = get_ecg_emd(model.module, test_dataloader, 140 | zeroshot_weights, device=device, softmax_eval=True) 141 | 142 | AUROCs = compute_AUCs(gt, pred, len(target_class)) 143 | AUROCs = [i*100 for i in AUROCs] 144 | AUROC_avg = np.array(AUROCs).mean() 145 | 146 | max_f1s = [] 147 | accs = [] 148 | 149 | for i in range(len(target_class)): 150 | gt_np = gt[:, i] 151 | pred_np = pred[:, i] 152 | precision, recall, thresholds = precision_recall_curve(gt_np, pred_np) 153 | numerator = 2 * recall * precision 154 | denom = recall + precision 155 | f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0)) 156 | max_f1 = np.max(f1_scores) 157 | max_f1_thresh = thresholds[np.argmax(f1_scores)] 158 | max_f1s.append(max_f1) 159 | accs.append(accuracy_score(gt_np, pred_np>max_f1_thresh)) 160 | 161 | 162 | max_f1s = [i*100 for i in max_f1s] 163 | accs = [i*100 for i in accs] 164 | f1_avg = np.array(max_f1s).mean() 165 | acc_avg = np.array(accs).mean() 166 | 167 | res_dict = {'AUROC_avg': AUROC_avg, 168 | 'F1_avg': f1_avg, 169 | 'ACC_avg': acc_avg 170 | } 171 | for i in range(len(target_class)): 172 | res_dict.update({f'AUROC_{class_name[i]}': AUROCs[i], 173 | f'F1_{class_name[i]}': max_f1s[i], 174 | f'ACC_{class_name[i]}': accs[i] 175 | }) 176 | 177 | print('-----------------------------------') 178 | print('The average AUROC is {AUROC_avg:.4f}'.format(AUROC_avg=AUROC_avg)) 179 | for i in range(len(target_class)): 180 | print('The AUROC of {} is {}'.format(class_name[i], AUROCs[i])) 181 | 182 | print('-----------------------------------') 183 | print('The average f1 is {F1_avg:.4f}'.format(F1_avg=f1_avg)) 184 | for i in range(len(target_class)): 185 | print('The F1 of {} is {}'.format(class_name[i], max_f1s[i])) 186 | 187 | print('-----------------------------------') 188 | print('The average ACC is {ACC_avg:.4f}'.format(ACC_avg=acc_avg)) 189 | for i in range(len(target_class)): 190 | print('The ACC of {} is {}'.format(class_name[i], accs[i])) 191 | print('***********************************') 192 | 193 | return f1_avg, acc_avg, AUROC_avg, max_f1s, accs, AUROCs, res_dict 194 | -------------------------------------------------------------------------------- /zeroshot/CKEPE_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "NDT": "non-diagnostic T abnormalities, non-diagnostic T-waves, non-diagnostic T-wave changes, abnormal T-waves in leads I and aVL, abnormal T-waves in leads II, III and aVF.", 3 | "NST": "non-specific ST changes, non-specific ST segment changes, non-specific ST elevation, non-specific ST depression.", 4 | "DIG": "DIG, suggests digitalis-effect, Digitalis-induced changes, Digitalis-like effect, Digitalis-related changes.", 5 | "LNGQT": "long QT-interval, prolonged QT interval, congenital long QT syndrome, acquired long QT syndrome, Romano-Ward syndrome, Jervell and Lange-Nielsen syndrome.", 6 | "NORM": "normal ECG, normal electrocardiogram, Sinus rhythm, Normal sinus rhythm, Sinus tachycardia, Sinus bradycardia.", 7 | "IMI": "inferior myocardial infarction,inferolateral myocardial infarction,inferoposterolateral myocardial infarction,inferoposterior myocardial infarction,subendocardial injury in inferior leads,posterior wall MI,posterior-inferior MI,posteroinferolateral MI.", 8 | "ASMI": "anteroseptal myocardial infarction, anteroseptal wall infarction, anterolateral myocardial infarction, anterior wall infarction, subendocardial injury in anteroseptal leads.", 9 | "LVH": "left ventricular hypertrophy, left ventricular hypertrophy with strain pattern, asymmetric septal hypertrophy, concentric left ventricular hypertrophy, eccentric left ventricular hypertrophy.", 10 | "LAFB": "left anterior fascicular block, left anterior hemiblock, left bundle branch block (LAFB), first-degree LAFB, type I LAFB.", 11 | "ISC": "ischemic ST-T changes, non-specific ischemic ST-T changes, anterior ischemic ST-T changes, inferior ischemic ST-T changes, lateral ischemic ST-T changes, posterior ischemic ST-T changes.", 12 | "IRBBB": "incomplete right bundle branch block, incomplete right bundle branch block type I, incomplete right bundle branch block type II.", 13 | "1AVB": "first degree AV block, first-degree atrioventricular block.", 14 | "IVCD": "non-specific intraventricular conduction disturbance, non-specific intraventricular conduction defect, non-specific intraventricular block.", 15 | "ISCAL": "in anterolateral leads, anterolateral ST segment changes, anterolateral T-wave changes, anteroseptal myocardial infarction.", 16 | "CRBBB": "right bundle branch block, right bundle branch block (RBBB), complete right bundle branch block, incomplete right bundle branch block.", 17 | "CLBBB": "left bundle branch block, left bundle branch block (LBBB), left anterior fascicular block, left posterior fascicular block, bifascicular block.", 18 | "ILMI": "inferolateral myocardial infarction, inferoposterolateral myocardial infarction, inferoposterior myocardial infarction, subendocardial injury in inferolateral leads.", 19 | "LAO/LAE": "left atrial overload, left atrial enlargement, left atrial hypertrophy, left atrial dilatation.", 20 | "AMI": "anterior myocardial infarction, anterolateral myocardial infarction, anteroseptal myocardial infarction, anterobasal myocardial infarction.", 21 | "ALMI": "anterolateral myocardial infarction, anteroseptal myocardial infarction, anterolateral subendocardial injury, anterolateral subepicardial injury.", 22 | "ISCIN": "ischemic changes in inferior leads, infero-inferior ischemia, infero-lateral ischemia, subendocardial injury in inferior leads.", 23 | "HYP": "left ventricular hypertrophy, right ventricular hypertrophy, biventricular hypertrophy, ventricular hypertrophy, concentric left ventricular hypertrophy, eccentric left ventricular hypertropy, concentric remodeling of the left atrium/enlargement.", 24 | "CD": "Conduction Disturbance, Left Anterior Fascicular Block, Right Bundle Branch Block, Left Bundle Branch Block, First-degree atrioventricular block, Second-degree atrioventricular block (Mobitz type I and II), Third-degree atrioventricular block (complete heart block), Intra-atrial conduction delay.", 25 | "INJAS": "ischemic in anteroseptal leads, anteroseptal ischemia, anterior septal ischemia, anterolateral ischemia.", 26 | "VPC": "Ventricular premature complexes, Ventricular premature beats, Ventricular ectopic beats, Premature ventricular contractions.", 27 | "LMI": "lateral myocardial infarction, anterolateral myocardial infarction, anteroseptal and lateral myocardial infarction, posterolateral myocardial infarction.", 28 | "ISCIL": "ischemic in inferolateral leads, inferolateral ischemia, inferolateral ST-segment depression, inferolateral T-wave inversion.", 29 | "ISCI": "ischemic in inferior leads, ischemic in inferolateral leads, inferoposterior myocardial infarction, subendocardial injury in inferior leads.", 30 | "LPFB": "left posterior fascicular block, left posterior fascicular block (LPFB), left posterior hemiblock, left bundle branch block (LBBB) with left anterior fascicular conduction delay.", 31 | "LAFB/LPFB": "left anterior fascicular block, left posterior fascicular block, Left Anterior Hemiblock, Left Posterior Hemiblock.", 32 | "ISCAS": "in anteroseptal leads, anteroseptal ST segment elevation, anteroseptal T-wave inversion, anterolateral ST segment depression.", 33 | "ISCA": "ischemic in anterior leads, ischemic in anteroseptal leads, ischemic in anterolateral leads, anteroseptal myocardial infarction, anterior myocardial infarction.", 34 | "INJAL": "in anterolateral leads, anterolateral injury, anteroseptal injury, anterior wall myocardial infarction with lateral involvement.", 35 | "ISCLA": "in lateral leads, lateral leads, lateral wall myocardial infarction, inferolateral myocardial infarction.", 36 | "RVH": "right ventricle hypertrophy, right ventricular hypertrophy, cor pulmonale, idiopathic pulmonary hypertension-related right ventricular hypertrophy, chronic thromboembolic pulmonary hypertension-related right ventricular hypertrophy.", 37 | "ANEUR": "ST-T changes compatible with ventricular aneurysm, ventricular aneurysm, left ventricular aneurysm, right ventricular aneurysm.", 38 | "RAO/RAE": "right atrial overload, right atrial enlargement, right atrial dilatation.", 39 | "EL": "compatible with electrolyte abnormalities, electrolyte imbalance, electrolyte disturbances, hypokalemia, hyperkalemia, hypocalcemia, hypercalcemia.", 40 | "WPW": "Wolf-Parkinson-White syndrome, Wolff-Parkinson-White pattern, Pre-excitation syndrome.", 41 | "ILBBB": "incomplete left bundle branch block, left anterior fascicular block, left posterior fascicular block, bifascicular block.", 42 | "IPLMI": "inferoposterolateral myocardial infarction, inferoposterior myocardial infarction, inferolateral myocardial infarction.", 43 | "ISCAN": "in anterior leads, anterior ST-segment changes, anterior T-wave changes, anteroseptal myocardial infarction, anterolateral myocardial infarction.", 44 | "IPMI": "inferoposterior myocardial infarction, posterior myocardial infarction, inferoposterolateral myocardial infarction, subendocardial injury in inferior and lateral leads.", 45 | "SEHYP": "septal hypertrophy, left ventricular septal hypertrophy, right ventricular septal hypertrophy, apical septal hypertrophy, mid-septal hypertrophy.", 46 | "INJIN": "subendocardial injury in inferior leads, subendocardial injury in inferolateral leads, subendocardial injury in inferoposterolateral leads, subendocardial injury in inferoposterior leads.", 47 | "INJLA": "subendocardial injury in lateral leads, subendocardial injury in anterolateral leads, subendocardial injury in posterolateral leads.", 48 | "PMI": "posterior myocardial infarction, posterior wall myocardial infarction, posterolateral myocardial infarction, posteroseptal myocardial infarction, subendocardial injury in posterior leads.", 49 | "3AVB": "third degree AV block, complete heart block, high-grade AV block, third-degree atrioventricular block.", 50 | "INJIL": "in inferolateral leads, inferolateral injury, inferolateral myocardial infarction, subendocardial injury in inferolateral leads.", 51 | "2AVB": "second degree AV block, Mobitz type I second degree AV block, Mobitz type II second degree AV block.", 52 | "ABQRS": "abnormal QRS, abnormal QRS axis, abnormal QRS complex, abnormal QRS duration, abnormal QRS morphology, abnormal QRS transition zone.", 53 | "PVC": "ventricular premature complex (beat), ventricular premature contraction, ventricular ectopic beat, ventricular extrasystole.", 54 | "STD": "non-specific ST depression, non-specific ST segment depression, diffuse ST depression, widespread ST depression.", 55 | "VCLVH": "voltage criteria (QRS) for left ventricular hypertrophy, voltage criteria for left ventricular hypertrophy, QRS complex changes for left ventricular hypertrophy.", 56 | "QWAVE": "Q waves present, Q wave abnormalities, Q wave changes, anterior Q waves, posterior Q waves, lateral Q waves.", 57 | "LOWT": "low amplitude T-waves, low voltage T-waves, decreased amplitude T-waves.", 58 | "NT": "non-specific T-wave changes, non-specific T-waves, nonspecific ST-T changes, abnormal T-waves, abnormal ST segment and T waves.", 59 | "PAC": "atrial premature complex, atrial premature beat, atrial ectopic beat, supraventricular premature complex.", 60 | "LPR": "Prolonged PR interval, prolonged PR segment, first-degree atrioventricular block.", 61 | "INVT": "inverted T-waves, inverted T-wave changes, inverted T-wave pattern, inverted T-waves in leads V1-V3, inverted T-waves in leads II, III and aVL.", 62 | "LVOLT": "low QRS voltages in the frontal and horizontal leads, low QRS voltage, low QRS amplitude, reduced QRS amplitude in frontal and horizontal leads.", 63 | "HVOLT": "high QRS voltage, high QRS amplitude, increased QRS amplitude, increased R wave amplitude.", 64 | "TAB": "t-wave abnormality, T-wave changes, non-diagnostic T-waves, non-diagnostic T abnormalities.", 65 | "PRC(S)": "premature complex(es), premature ventricular contraction(s), premature atrial contraction(s), premature junctional complex(es)..", 66 | "SR": "Sinus Rhythm, Normal Sinus Rhythm, Sinus Node Dysfunction, Sinus Bradycardia, Sinus Tachycardia.", 67 | "AFIB": "Atrial Fibrillation, Paroxysmal Atrial Fibrillation, Persistent Atrial Fibrillation, Long-standing Persistent Atrial Fibrillation, Permanent Atrial Fibrillation.", 68 | "STACH": "sinus tachycardia, supraventricular tachycardia, atrial tachycardia, junctional tachycardia.", 69 | "SARRH": "sinus arrhythmia, sinus arrhythmias, sinus rhythm irregularity, wandering atrial pacemaker.", 70 | "SBRAD": "sinus bradycardia, sinus bradycardia with atrial fibrillation, sinus bradycardia with atrial flutter, sinus bradycardia with premature ventricular contractions.", 71 | "PACE": "normal functioning artificial pacemaker, normal functioning permanent artificial cardiac pacemaker, normal functioning temporary artificial cardiac pacemaker.", 72 | "SVARR": "supraventricular arrhythmia, supraventricular tachycardia, atrial fibrillation, atrial flutter, premature atrial contraction (PAC), junctional rhythm, sinus node reentry tachycardia.", 73 | "BIGU": "Based on the input, I generated the following subtypes and attributes for BIGU (Bigeminal pattern, unknown origin, SV or Ventricular):* Bigeminal pattern (unknown origin)* Bigeminal supraventricular rhythm* Bigeminal ventricular rhythm* Supraventricular bigeminus* Ventricular bigeminusLet me know if this meets your requirements!.", 74 | "AFLT": "atrial flutter, atrial flutter with variable block, atrial flutter with 2:1 block, atrial flutter with 3:1 block, typical atrial flutter, atypical atrial flutter.", 75 | "SVTAC": "supraventricular tachycardia, supraventricular tachyarrhythmia, atrial tachycardia, junctional tachycardia, sinus node reentry tachycardia.", 76 | "PSVT": "paroxysmal supraventricular tachycardia, paroxysmal supraventricular tachyarrhythmia, atrioventricular nodal reentrant tachycardia (AVNRT), atrioventricular reciprocating tachycardia (AVRT), junctional ectopic tachycardia, orthodromic atrioventricular reentrant tachycardia (oAVRT), antidromic AVRT.", 77 | "TRIGU": "trigeminal pattern (unknown origin), trigeminal pattern of unknown origin with supraventricular origin, trigeminal pattern of unknown origin with ventricular origin.", 78 | "2AVB1": "second degree AV block (Type 1), Mobitz type I second-degree atrioventricular block, Wenckebach phenomenon.", 79 | "2AVB2": "second degree AV block (Type two), Mobitz type II second-degree atrioventricular block, Second-degree atrioventricular block type II, Type II second-degree heart block.", 80 | "ABI": "atrial bigeminy, atrial extrasystole, supraventricular bigeminy.", 81 | "ALS": "Axis left shift, Left axis deviation, Leftward shift of the electrical axis.", 82 | "APB": "atrial premature beats, atrial premature complexes, atrial ectopic beats, supraventricular ectopic beats.", 83 | "AQW": "abnormal Q wave, anterior Q wave, posterior Q wave, lateral Q wave, septal Q wave.", 84 | "ARS": "Axis right shift, Right axis deviation, Rightward axis deviation, Rightward displacement of the mean electrical axis.", 85 | "AVB": "AV block, atrioventricular block, a-v block, second-degree AV block, third-degree AV block.", 86 | "CCR": "counter-clockwise rotation, counterclockwise axis deviation, left axis deviation.", 87 | "CR": "clockwise rotation, counterclockwise rotation.", 88 | "ERV": "Early repolarization of the ventricles, Early repolarization, Ventricular early repolarization, Ventricular premature beats with early repolarization.", 89 | "FQRS": "fQRS Wave, fragmented QRS complex, fragmented QRS waveforms, abnormal QRS morphology.", 90 | "IDC": "Interior differences conduction, Left bundle branch block, Right bundle branch block, Left anterior fascicular block, Left posterior fascicular block.", 91 | "IVB": "Intraventricular block, Complete intraventricular block, High-grade intraventricular block, Low-grade intraventricular block, Left bundle branch block (LBBB), Right bundle branch block (RBBB), Left anterior fascicular block (LAFB), Left posterior fascicular block (LPFB), Right anterior fascicular block (RAFB), Right posterior fascicular block (RPFb)..", 92 | "JEB": "junctional escape beat, junctional escape rhythm, junctional premature beat, junctional premature complexes.", 93 | "JPT": "junctional premature beat, junctional premature complexes, atrioventricular junctional rhythm, supraventricular junctional rhythm.", 94 | "LBBB": "left bundle branch block, left bundle branch block (LBBB), left anterior fascicular block, left posterior fascicular block, bifascicular block (LBBB + RBBB), trifascicular block (LBBB + RBBB + LPHB).", 95 | "LBBBB": "left back bundle branch block, left posterior fascicular block, left posterior hemiblock, left bundle branch block, complete left bundle branch block.", 96 | "LFBBB": "left front bundle branch block, left anterior fascicular block, left anterior hemiblock.", 97 | "LVQRSAL": "lower voltage QRS in all leads, low voltage QRS complex in all leads, low amplitude QRS complex in all leads.", 98 | "LVQRSCL": "lower voltage QRS in chest lead, low voltage QRS complex, reduced amplitude of QRS complex in chest leads.", 99 | "LVQRSLL": "lower voltage QRS in limb lead, low-voltage QRS complex in limb leads, reduced amplitude of QRS complex in limb leads.", 100 | "MI": "myocardial infarction, anterior myocardial infarction, anterolateral myocardial infarction, anteroseptal myocardial infarction, inferior myocardial infarction, inferolateral myocardial infarction, inferoposterolateral myocardial infarction, inferoposterior myocardial infarction, lateral myocardial infarction, posterior myocardial infarction, right ventricular myocardial infarction, subendocardial myocardial infarction.", 101 | "MIBW": "myocardial infarction in back wall, posterior myocardial infarction, posterior-inferior myocardial infarction, posterolateral myocardial infarction, subendocardial injury in posterior leads.", 102 | "MIFW": "Myocardial infarction in the front wall, Anteroseptal myocardial infarction, Anterior myocardial infarction, Anterolateral myocardial infarction.", 103 | "MILW": "Myocardial infarction in the lower wall, Inferior myocardial infarction, Posterior myocardial infarction, Lateral myocardial infarction, Subendocardial injury in inferior and lateral leads.", 104 | "MISW": "Myocardial infarction in the side wall, lateral myocardial infarction, posterolateral myocardial infarction, anterolateral myocardial infarction.", 105 | "PRIE": "PR interval extension, prolonged PR interval, first-degree atrioventricular block.", 106 | "PWC": "P wave Change, P-wave changes, P-wave abnormalities, P-wave elevation, P-wave depression.", 107 | "QTIE": "QT interval extension, prolonged QT interval, prolonged corrected QT interval, Torsade de Pointes, Long QT syndrome.", 108 | "RAH": "right atrial hypertrophy, right atrial enlargement, right atrial dilatation.", 109 | "RBBB": "right bundle branch block, complete right bundle branch block, incomplete right bundle branch block, RBBB pattern on ECG, RBBB with left anterior fascicular block.", 110 | "STDD": "ST drop down, ST segment depression, ST segment elevation.", 111 | "STE": "ST elevation - probable extension, ST elevation myocardial infarction, ST segment elevation in anterior leads, ST segment elevation in lateral leads, ST segment elevation in inferior leads.", 112 | "STTC": "ST-T Change, ischemic in anterior leads, ischemic in posterior leads, ischemic in inferior leads, ischemic in lateral leads, non-specifc ST changes.", 113 | "STTU": "ST tilt up, ST segment tilt up, ST segment elevation in leads II, III and aVL.", 114 | "TWC": "T wave Change, T-wave changes, non-diagnostic T-waves, abnormal T-waves, ST-T wave changes.", 115 | "TWO": "T wave opposite, T wave inversion, T wave flattening.", 116 | "UW": "U wave, U-waves, abnormal U waves, prominent U waves.", 117 | "VB": "ventricular bigeminy, ventricular extrasystoles, premature ventricular contractions.", 118 | "VEB": "ventricular escape beat, ventricular premature beat, ventricular ectopic beat.", 119 | "VFW": "ventricular fusion wave, ventricular fusion beats, ventricular fusion complexes.", 120 | "VPB": "ventricular premature beat, ventricular premature complexes, ventricular ectopic beats, ventricular extrasystoles.", 121 | "VPE": "ventricular preexcitation, Wolff-Parkinson-White syndrome, accessory pathway conduction, Kent bundle..", 122 | "VET": "ventricular escape trigeminy, ventricular escape beats, ventricular premature beats with a cycle length of 2:1.", 123 | "WAVN": "Wandering in the atrioventricular node, Atrioventricular nodal reentry tachycardia, Atrial fibrillation with variable ventricular response, Multifocal atrial tachycardia, Wandering pacemaker syndrome..", 124 | "SB": "Sinus Bradycardia, Sinus bradycardia, sinus bradycardic rhythm.", 125 | "ST": "Sinus Tachycardia, Sinus tachycardia, Supraventricular tachycardia.", 126 | "AF": "Atrial Flutter, Atrial Fibrillation, Paroxysmal Atrial Flutter, Persistent Atrial Flutter, Long-standing Persistent Atrial Flutter.", 127 | "SA": "Sinus Irregularity, Sinus arrhythmia, Sinus tachycardia, Sinus bradycardia.", 128 | "SVT": "Supraventricular Tachycardia, Supraventricular tachyarrhythmia, Atrial tachycardia, Junctional tachycardia, Atrioventricular nodal reentrant tachycardia (AVNRT), Atrioventricular reciprocating tachycardia (AVRT), Wolff-Parkinson-White syndrome.", 129 | "AT": "Atrial Tachycardia, Atrial Tachycardia with block, Atrial Tachycardia with flutter waves, Orthodromic Atrial Tachycardia, Antidromic Atrial Tachycardia.", 130 | "AVNRT": "AV Node Reentrant Tachycardia, AV nodal reentrant tachycardia, Supraventricular tachycardia, Narrow complex tachycardia.", 131 | "AVRT": "AV Reentrant Tachycardia, AV nodal reentrant tachycardia, atrioventricular reentrant tachycardia, supraventricular tachycardia with aberrancy.", 132 | "SAAWR": "Sinus Atrium to Atrial Wandering Rhythm, Sinus Atrial Wandering Rhythm, Atrial Wandering Rhythm, Sinus-Atrial-Wandering-Rhythm.." 133 | } -------------------------------------------------------------------------------- /zeroshot/__pycache__/zeroshot_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/MERL-ICML2024/38799b7deb1c9958f41badcfacc6a95192787cc4/zeroshot/__pycache__/zeroshot_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /zeroshot/test_zeroshot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import yaml as yaml 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import numpy as np 7 | 8 | import torch 9 | 10 | import sys 11 | sys.path.append("../utils") 12 | import utils_builder 13 | from zeroshot_val import zeroshot_eval 14 | 15 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 16 | 17 | device_id = 'cuda' 18 | 19 | config = yaml.load(open("zeroshot_config.yaml", "r"), Loader=yaml.FullLoader) 20 | 21 | torch.manual_seed(42) 22 | random.seed(0) 23 | np.random.seed(0) 24 | 25 | model = utils_builder.ECGCLIP(config['network']) 26 | ckpt = 'your_ckpt_path' 27 | ckpt = torch.load(f'{ckpt}', map_location='cpu') 28 | model.load_state_dict(ckpt) 29 | model = model.to(device_id) 30 | model = torch.nn.DataParallel(model) 31 | 32 | args_zeroshot_eval = config['zeroshot'] 33 | 34 | avg_f1, avg_acc, avg_auc = 0, 0, 0 35 | for set_name in args_zeroshot_eval['test_sets'].keys(): 36 | 37 | f1, acc, auc, _, _, _, res_dict = \ 38 | zeroshot_eval(model=model, 39 | set_name=set_name, 40 | device=device_id, 41 | args_zeroshot_eval=args_zeroshot_eval) 42 | 43 | avg_f1 += f1 44 | avg_acc += acc 45 | avg_auc += auc 46 | 47 | avg_f1 = avg_f1/len(args_zeroshot_eval['test_sets'].keys()) 48 | avg_acc = avg_acc/len(args_zeroshot_eval['test_sets'].keys()) 49 | avg_auc = avg_auc/len(args_zeroshot_eval['test_sets'].keys()) -------------------------------------------------------------------------------- /zeroshot/zeroshot.sh: -------------------------------------------------------------------------------- 1 | cd your_path/MERL/zeroshot 2 | python test_zeroshot.py 3 | -------------------------------------------------------------------------------- /zeroshot/zeroshot_config.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | # ecg_model: resnet18 3 | ecg_model: vit_tiny 4 | num_leads: 12 5 | ### this part does not control builder/trainer 6 | text_model: ncbi/MedCPT-Query-Encoder 7 | free_layers: 6 # set 12 to freeze all layer in bert 8 | feature_dim: 768 9 | 10 | projection_head: 11 | mlp_hidden_size: 256 12 | projection_size: 256 13 | ### 14 | 15 | dataset: 16 | dataset_name: 'mimic' 17 | data_path: 'your_path/' # add your image file path here 18 | 19 | # params for trainer 20 | trainer: 21 | batch_size: 1024 22 | val_batch_size: 512 23 | checkpoint_interval: 1 24 | max_epochs: 100 25 | num_workers: 8 26 | 27 | optimizer: 28 | params: 29 | lr: 1.0e-3 30 | weight_decay: 1.0e-8 31 | 32 | # params for zeroshot eval 33 | zeroshot: 34 | prompt_type: 'CKEPE' 35 | prompt_dict: 'your_path/MERL/zeroshot/CKEPE_prompt.json' 36 | batch_size: 256 37 | num_workers: 8 38 | meta_data_path: 'your_path/downstream' 39 | meta_split_path: 'your_path/MERL/finetune/data_split' 40 | 41 | test_sets: 42 | ### 43 | ptbxl_super_class: 44 | data_path: 'ptbxl' 45 | split_path: 'ptbxl/super_class/ptbxl_super_class_test.csv' 46 | ### 47 | ptbxl_sub_class: 48 | data_path: 'ptbxl' 49 | split_path: 'ptbxl/sub_class/ptbxl_sub_class_test.csv' 50 | ### 51 | ptbxl_form: 52 | data_path: 'ptbxl' 53 | split_path: 'ptbxl/form/ptbxl_form_test.csv' 54 | ### 55 | ptbxl_rhythm: 56 | data_path: 'ptbxl' 57 | split_path: 'ptbxl/rhythm/ptbxl_rhythm_test.csv' 58 | ### 59 | icbeb: 60 | data_path: 'icbeb2018/records500' 61 | split_path: 'icbeb/icbeb_test.csv' 62 | ### 63 | chapman: 64 | data_path: '' 65 | split_path: 'chapman/chapman_test.csv' 66 | 67 | # your model name 68 | wandb_name: 'None' 69 | --------------------------------------------------------------------------------