├── dmp ├── code │ ├── data │ │ ├── data_loaders_trans.py │ │ ├── __pycache__ │ │ │ ├── data_loaders.cpython-310.pyc │ │ │ └── transforms.cpython-310.pyc │ │ ├── data_loaders.py │ │ ├── preprocess.py │ │ ├── preprocess_amos.py │ │ └── transforms.py │ ├── chd_data │ │ └── splits │ │ │ ├── labeled_2p.txt │ │ │ ├── labeled_5p.txt │ │ │ ├── labeled_10p.txt │ │ │ ├── eval.txt │ │ │ ├── test.txt │ │ │ ├── labeled_20p.txt │ │ │ ├── labeled_30p.txt │ │ │ ├── labeled_50p.txt │ │ │ ├── unlabeled_50p.txt │ │ │ ├── unlabeled_30p.txt │ │ │ ├── unlabeled_20p.txt │ │ │ ├── unlabeled_10p.txt │ │ │ ├── unlabeled_5p.txt │ │ │ ├── unlabeled_2p.txt │ │ │ └── train.txt │ ├── synapse_data │ │ └── splits │ │ │ ├── labeled_10p.txt │ │ │ ├── eval.txt │ │ │ ├── labeled_20p.txt │ │ │ ├── test.txt │ │ │ ├── labeled_40p.txt │ │ │ ├── unlabeled_40p.txt │ │ │ ├── unlabeled_20p.txt │ │ │ ├── unlabeled_10p.txt │ │ │ └── train.txt │ ├── models │ │ ├── __pycache__ │ │ │ ├── unet.cpython-310.pyc │ │ │ ├── vnet.cpython-310.pyc │ │ │ ├── unet_ds.cpython-310.pyc │ │ │ └── vnet_dst.cpython-310.pyc │ │ ├── unet.py │ │ ├── unet_ds.py │ │ ├── vnet.py │ │ └── vnet_dst.py │ ├── evaluate.py │ ├── test.py │ ├── evaluate_Ntimes.py │ ├── train_cps.py │ ├── train_depl.py │ ├── train_crest.py │ ├── train_dhc.py │ └── train_cdifw.py ├── utils │ ├── __pycache__ │ │ ├── config.cpython-310.pyc │ │ ├── loss.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── __init__.cpython-310.pyc.23362804253232 │ ├── config.py │ ├── loss.py │ └── __init__.py ├── train3times_seeds_10p.sh ├── train3times_seeds_40p.sh └── train3times_seeds_20p.sh ├── LICENSE ├── README.md └── requirements.txt /dmp/code/data/data_loaders_trans.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/labeled_2p.txt: -------------------------------------------------------------------------------- 1 | ct_1011 2 | ct_1178 3 | -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/labeled_10p.txt: -------------------------------------------------------------------------------- 1 | 0023 2 | 0028 3 | -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/eval.txt: -------------------------------------------------------------------------------- 1 | 0006 2 | 0025 3 | 0026 4 | 0040 5 | -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/labeled_20p.txt: -------------------------------------------------------------------------------- 1 | 0002 2 | 0023 3 | 0034 4 | 0039 5 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/labeled_5p.txt: -------------------------------------------------------------------------------- 1 | ct_1035 2 | ct_1077 3 | ct_1139 4 | ct_1150 5 | -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/test.txt: -------------------------------------------------------------------------------- 1 | 0004 2 | 0007 3 | 0010 4 | 0033 5 | 0035 6 | 0036 7 | -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/labeled_40p.txt: -------------------------------------------------------------------------------- 1 | 0002 2 | 0009 3 | 0021 4 | 0023 5 | 0027 6 | 0031 7 | 0032 8 | 0037 9 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/labeled_10p.txt: -------------------------------------------------------------------------------- 1 | ct_1014 2 | ct_1015 3 | ct_1025 4 | ct_1043 5 | ct_1067 6 | ct_1070 7 | ct_1133 8 | ct_1146 9 | -------------------------------------------------------------------------------- /dmp/utils/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/utils/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/utils/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/utils/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/code/models/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/code/models/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/code/models/__pycache__/vnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/code/models/__pycache__/vnet.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/code/models/__pycache__/unet_ds.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/code/models/__pycache__/unet_ds.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/code/data/__pycache__/data_loaders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/code/data/__pycache__/data_loaders.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/code/data/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/code/data/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/code/models/__pycache__/vnet_dst.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/code/models/__pycache__/vnet_dst.cpython-310.pyc -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/unlabeled_40p.txt: -------------------------------------------------------------------------------- 1 | 0001 2 | 0003 3 | 0005 4 | 0008 5 | 0022 6 | 0024 7 | 0028 8 | 0029 9 | 0030 10 | 0034 11 | 0038 12 | 0039 13 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/eval.txt: -------------------------------------------------------------------------------- 1 | ct_1001 2 | ct_1008 3 | ct_1029 4 | ct_1054 5 | ct_1056 6 | ct_1059 7 | ct_1061 8 | ct_1091 9 | ct_1102 10 | ct_1138 11 | ct_1148 -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/test.txt: -------------------------------------------------------------------------------- 1 | ct_1002 2 | ct_1005 3 | ct_1018 4 | ct_1092 5 | ct_1098 6 | ct_1109 7 | ct_1113 8 | ct_1119 9 | ct_1129 10 | ct_1132 11 | ct_1170 12 | -------------------------------------------------------------------------------- /dmp/utils/__pycache__/__init__.cpython-310.pyc.23362804253232: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhang30/Double-Mix/HEAD/dmp/utils/__pycache__/__init__.cpython-310.pyc.23362804253232 -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/unlabeled_20p.txt: -------------------------------------------------------------------------------- 1 | 0001 2 | 0003 3 | 0005 4 | 0008 5 | 0009 6 | 0021 7 | 0022 8 | 0024 9 | 0027 10 | 0028 11 | 0029 12 | 0030 13 | 0031 14 | 0032 15 | 0037 16 | 0038 17 | -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/unlabeled_10p.txt: -------------------------------------------------------------------------------- 1 | 0001 2 | 0002 3 | 0003 4 | 0005 5 | 0008 6 | 0009 7 | 0021 8 | 0022 9 | 0024 10 | 0027 11 | 0029 12 | 0030 13 | 0031 14 | 0032 15 | 0034 16 | 0037 17 | 0038 18 | 0039 19 | -------------------------------------------------------------------------------- /dmp/code/synapse_data/splits/train.txt: -------------------------------------------------------------------------------- 1 | 0001 2 | 0002 3 | 0003 4 | 0005 5 | 0008 6 | 0009 7 | 0021 8 | 0022 9 | 0023 10 | 0024 11 | 0027 12 | 0028 13 | 0029 14 | 0030 15 | 0031 16 | 0032 17 | 0034 18 | 0037 19 | 0038 20 | 0039 21 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/labeled_20p.txt: -------------------------------------------------------------------------------- 1 | ct_1003 2 | ct_1023 3 | ct_1030 4 | ct_1032 5 | ct_1033 6 | ct_1036 7 | ct_1043 8 | ct_1050 9 | ct_1051 10 | ct_1080 11 | ct_1101 12 | ct_1116 13 | ct_1117 14 | ct_1120 15 | ct_1135 16 | ct_1144 17 | ct_1146 18 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/labeled_30p.txt: -------------------------------------------------------------------------------- 1 | ct_1007 2 | ct_1011 3 | ct_1016 4 | ct_1020 5 | ct_1022 6 | ct_1023 7 | ct_1024 8 | ct_1025 9 | ct_1030 10 | ct_1035 11 | ct_1046 12 | ct_1051 13 | ct_1067 14 | ct_1077 15 | ct_1083 16 | ct_1085 17 | ct_1103 18 | ct_1106 19 | ct_1110 20 | ct_1114 21 | ct_1122 22 | ct_1133 23 | ct_1135 24 | ct_1139 25 | ct_1141 26 | ct_1178 27 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/labeled_50p.txt: -------------------------------------------------------------------------------- 1 | ct_1003 2 | ct_1004 3 | ct_1013 4 | ct_1014 5 | ct_1020 6 | ct_1021 7 | ct_1023 8 | ct_1024 9 | ct_1028 10 | ct_1032 11 | ct_1033 12 | ct_1042 13 | ct_1044 14 | ct_1046 15 | ct_1047 16 | ct_1048 17 | ct_1050 18 | ct_1051 19 | ct_1052 20 | ct_1062 21 | ct_1063 22 | ct_1064 23 | ct_1067 24 | ct_1074 25 | ct_1075 26 | ct_1077 27 | ct_1083 28 | ct_1085 29 | ct_1088 30 | ct_1099 31 | ct_1103 32 | ct_1110 33 | ct_1117 34 | ct_1120 35 | ct_1122 36 | ct_1124 37 | ct_1127 38 | ct_1128 39 | ct_1133 40 | ct_1145 41 | ct_1147 42 | ct_1158 43 | ct_1161 44 | ct_1178 45 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/unlabeled_50p.txt: -------------------------------------------------------------------------------- 1 | ct_1007 2 | ct_1010 3 | ct_1011 4 | ct_1012 5 | ct_1015 6 | ct_1016 7 | ct_1017 8 | ct_1019 9 | ct_1022 10 | ct_1025 11 | ct_1030 12 | ct_1035 13 | ct_1036 14 | ct_1037 15 | ct_1039 16 | ct_1041 17 | ct_1043 18 | ct_1053 19 | ct_1060 20 | ct_1066 21 | ct_1070 22 | ct_1072 23 | ct_1078 24 | ct_1079 25 | ct_1080 26 | ct_1081 27 | ct_1101 28 | ct_1105 29 | ct_1106 30 | ct_1111 31 | ct_1112 32 | ct_1114 33 | ct_1116 34 | ct_1121 35 | ct_1125 36 | ct_1126 37 | ct_1135 38 | ct_1139 39 | ct_1140 40 | ct_1141 41 | ct_1143 42 | ct_1144 43 | ct_1146 44 | ct_1150 45 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/unlabeled_30p.txt: -------------------------------------------------------------------------------- 1 | ct_1003 2 | ct_1004 3 | ct_1010 4 | ct_1012 5 | ct_1013 6 | ct_1014 7 | ct_1015 8 | ct_1017 9 | ct_1019 10 | ct_1021 11 | ct_1028 12 | ct_1032 13 | ct_1033 14 | ct_1036 15 | ct_1037 16 | ct_1039 17 | ct_1041 18 | ct_1042 19 | ct_1043 20 | ct_1044 21 | ct_1047 22 | ct_1048 23 | ct_1050 24 | ct_1052 25 | ct_1053 26 | ct_1060 27 | ct_1062 28 | ct_1063 29 | ct_1064 30 | ct_1066 31 | ct_1070 32 | ct_1072 33 | ct_1074 34 | ct_1075 35 | ct_1078 36 | ct_1079 37 | ct_1080 38 | ct_1081 39 | ct_1088 40 | ct_1099 41 | ct_1101 42 | ct_1105 43 | ct_1111 44 | ct_1112 45 | ct_1116 46 | ct_1117 47 | ct_1120 48 | ct_1121 49 | ct_1124 50 | ct_1125 51 | ct_1126 52 | ct_1127 53 | ct_1128 54 | ct_1140 55 | ct_1143 56 | ct_1144 57 | ct_1145 58 | ct_1146 59 | ct_1147 60 | ct_1150 61 | ct_1158 62 | ct_1161 63 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/unlabeled_20p.txt: -------------------------------------------------------------------------------- 1 | ct_1004 2 | ct_1007 3 | ct_1010 4 | ct_1011 5 | ct_1012 6 | ct_1013 7 | ct_1014 8 | ct_1015 9 | ct_1016 10 | ct_1017 11 | ct_1019 12 | ct_1020 13 | ct_1021 14 | ct_1022 15 | ct_1024 16 | ct_1025 17 | ct_1028 18 | ct_1035 19 | ct_1037 20 | ct_1039 21 | ct_1041 22 | ct_1042 23 | ct_1044 24 | ct_1046 25 | ct_1047 26 | ct_1048 27 | ct_1052 28 | ct_1053 29 | ct_1060 30 | ct_1062 31 | ct_1063 32 | ct_1064 33 | ct_1066 34 | ct_1067 35 | ct_1070 36 | ct_1072 37 | ct_1074 38 | ct_1075 39 | ct_1077 40 | ct_1078 41 | ct_1079 42 | ct_1081 43 | ct_1083 44 | ct_1085 45 | ct_1088 46 | ct_1099 47 | ct_1103 48 | ct_1105 49 | ct_1106 50 | ct_1110 51 | ct_1111 52 | ct_1112 53 | ct_1114 54 | ct_1121 55 | ct_1122 56 | ct_1124 57 | ct_1125 58 | ct_1126 59 | ct_1127 60 | ct_1128 61 | ct_1133 62 | ct_1139 63 | ct_1140 64 | ct_1141 65 | ct_1143 66 | ct_1145 67 | ct_1147 68 | ct_1150 69 | ct_1158 70 | ct_1161 71 | ct_1178 72 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/unlabeled_10p.txt: -------------------------------------------------------------------------------- 1 | ct_1003 2 | ct_1004 3 | ct_1007 4 | ct_1010 5 | ct_1011 6 | ct_1012 7 | ct_1013 8 | ct_1016 9 | ct_1017 10 | ct_1019 11 | ct_1020 12 | ct_1021 13 | ct_1022 14 | ct_1023 15 | ct_1024 16 | ct_1028 17 | ct_1030 18 | ct_1032 19 | ct_1033 20 | ct_1035 21 | ct_1036 22 | ct_1037 23 | ct_1039 24 | ct_1041 25 | ct_1042 26 | ct_1044 27 | ct_1046 28 | ct_1047 29 | ct_1048 30 | ct_1050 31 | ct_1051 32 | ct_1052 33 | ct_1053 34 | ct_1060 35 | ct_1062 36 | ct_1063 37 | ct_1064 38 | ct_1066 39 | ct_1072 40 | ct_1074 41 | ct_1075 42 | ct_1077 43 | ct_1078 44 | ct_1079 45 | ct_1080 46 | ct_1081 47 | ct_1083 48 | ct_1085 49 | ct_1088 50 | ct_1099 51 | ct_1101 52 | ct_1103 53 | ct_1105 54 | ct_1106 55 | ct_1110 56 | ct_1111 57 | ct_1112 58 | ct_1114 59 | ct_1116 60 | ct_1117 61 | ct_1120 62 | ct_1121 63 | ct_1122 64 | ct_1124 65 | ct_1125 66 | ct_1126 67 | ct_1127 68 | ct_1128 69 | ct_1135 70 | ct_1139 71 | ct_1140 72 | ct_1141 73 | ct_1143 74 | ct_1144 75 | ct_1145 76 | ct_1147 77 | ct_1150 78 | ct_1158 79 | ct_1161 80 | ct_1178 81 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/unlabeled_5p.txt: -------------------------------------------------------------------------------- 1 | ct_1003 2 | ct_1004 3 | ct_1007 4 | ct_1010 5 | ct_1011 6 | ct_1012 7 | ct_1013 8 | ct_1014 9 | ct_1015 10 | ct_1016 11 | ct_1017 12 | ct_1019 13 | ct_1020 14 | ct_1021 15 | ct_1022 16 | ct_1023 17 | ct_1024 18 | ct_1025 19 | ct_1028 20 | ct_1030 21 | ct_1032 22 | ct_1033 23 | ct_1036 24 | ct_1037 25 | ct_1039 26 | ct_1041 27 | ct_1042 28 | ct_1043 29 | ct_1044 30 | ct_1046 31 | ct_1047 32 | ct_1048 33 | ct_1050 34 | ct_1051 35 | ct_1052 36 | ct_1053 37 | ct_1060 38 | ct_1062 39 | ct_1063 40 | ct_1064 41 | ct_1066 42 | ct_1067 43 | ct_1070 44 | ct_1072 45 | ct_1074 46 | ct_1075 47 | ct_1078 48 | ct_1079 49 | ct_1080 50 | ct_1081 51 | ct_1083 52 | ct_1085 53 | ct_1088 54 | ct_1099 55 | ct_1101 56 | ct_1103 57 | ct_1105 58 | ct_1106 59 | ct_1110 60 | ct_1111 61 | ct_1112 62 | ct_1114 63 | ct_1116 64 | ct_1117 65 | ct_1120 66 | ct_1121 67 | ct_1122 68 | ct_1124 69 | ct_1125 70 | ct_1126 71 | ct_1127 72 | ct_1128 73 | ct_1133 74 | ct_1135 75 | ct_1140 76 | ct_1141 77 | ct_1143 78 | ct_1144 79 | ct_1145 80 | ct_1146 81 | ct_1147 82 | ct_1158 83 | ct_1161 84 | ct_1178 85 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/unlabeled_2p.txt: -------------------------------------------------------------------------------- 1 | ct_1003 2 | ct_1004 3 | ct_1007 4 | ct_1010 5 | ct_1012 6 | ct_1013 7 | ct_1014 8 | ct_1015 9 | ct_1016 10 | ct_1017 11 | ct_1019 12 | ct_1020 13 | ct_1021 14 | ct_1022 15 | ct_1023 16 | ct_1024 17 | ct_1025 18 | ct_1028 19 | ct_1030 20 | ct_1032 21 | ct_1033 22 | ct_1035 23 | ct_1036 24 | ct_1037 25 | ct_1039 26 | ct_1041 27 | ct_1042 28 | ct_1043 29 | ct_1044 30 | ct_1046 31 | ct_1047 32 | ct_1048 33 | ct_1050 34 | ct_1051 35 | ct_1052 36 | ct_1053 37 | ct_1060 38 | ct_1062 39 | ct_1063 40 | ct_1064 41 | ct_1066 42 | ct_1067 43 | ct_1070 44 | ct_1072 45 | ct_1074 46 | ct_1075 47 | ct_1077 48 | ct_1078 49 | ct_1079 50 | ct_1080 51 | ct_1081 52 | ct_1083 53 | ct_1085 54 | ct_1088 55 | ct_1099 56 | ct_1101 57 | ct_1103 58 | ct_1105 59 | ct_1106 60 | ct_1110 61 | ct_1111 62 | ct_1112 63 | ct_1114 64 | ct_1116 65 | ct_1117 66 | ct_1120 67 | ct_1121 68 | ct_1122 69 | ct_1124 70 | ct_1125 71 | ct_1126 72 | ct_1127 73 | ct_1128 74 | ct_1133 75 | ct_1135 76 | ct_1139 77 | ct_1140 78 | ct_1141 79 | ct_1143 80 | ct_1144 81 | ct_1145 82 | ct_1146 83 | ct_1147 84 | ct_1150 85 | ct_1158 86 | ct_1161 87 | -------------------------------------------------------------------------------- /dmp/code/chd_data/splits/train.txt: -------------------------------------------------------------------------------- 1 | ct_1003 2 | ct_1004 3 | ct_1007 4 | ct_1010 5 | ct_1011 6 | ct_1012 7 | ct_1013 8 | ct_1014 9 | ct_1015 10 | ct_1016 11 | ct_1017 12 | ct_1019 13 | ct_1020 14 | ct_1021 15 | ct_1022 16 | ct_1023 17 | ct_1024 18 | ct_1025 19 | ct_1028 20 | ct_1030 21 | ct_1032 22 | ct_1033 23 | ct_1035 24 | ct_1036 25 | ct_1037 26 | ct_1039 27 | ct_1041 28 | ct_1042 29 | ct_1043 30 | ct_1044 31 | ct_1046 32 | ct_1047 33 | ct_1048 34 | ct_1050 35 | ct_1051 36 | ct_1052 37 | ct_1053 38 | ct_1060 39 | ct_1062 40 | ct_1063 41 | ct_1064 42 | ct_1066 43 | ct_1067 44 | ct_1070 45 | ct_1072 46 | ct_1074 47 | ct_1075 48 | ct_1077 49 | ct_1078 50 | ct_1079 51 | ct_1080 52 | ct_1081 53 | ct_1083 54 | ct_1085 55 | ct_1088 56 | ct_1099 57 | ct_1101 58 | ct_1103 59 | ct_1105 60 | ct_1106 61 | ct_1110 62 | ct_1111 63 | ct_1112 64 | ct_1114 65 | ct_1116 66 | ct_1117 67 | ct_1120 68 | ct_1121 69 | ct_1122 70 | ct_1124 71 | ct_1125 72 | ct_1126 73 | ct_1127 74 | ct_1128 75 | ct_1133 76 | ct_1135 77 | ct_1139 78 | ct_1140 79 | ct_1141 80 | ct_1143 81 | ct_1144 82 | ct_1145 83 | ct_1146 84 | ct_1147 85 | ct_1150 86 | ct_1158 87 | ct_1161 88 | ct_1178 89 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 morilab_lzhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dmp/train3times_seeds_10p.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | while getopts 'm:e:c:t:l:w:' OPT; do 5 | case $OPT in 6 | m) method=$OPTARG;; 7 | e) exp=$OPTARG;; 8 | c) cuda=$OPTARG;; 9 | t) task=$OPTARG;; 10 | l) lr=$OPTARG;; 11 | w) cps_w=$OPTARG;; 12 | esac 13 | done 14 | echo $method 15 | echo $cuda 16 | 17 | epoch=200 18 | echo $epoch 19 | 20 | labeled_data="labeled_10p" 21 | unlabeled_data="unlabeled_10p" 22 | folder="Task_"${task}"_10p/" 23 | cps="AB" 24 | 25 | echo $folder 26 | 27 | : <<'END_COMMENT' 28 | END_COMMENT 29 | 30 | # 31 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold1 --seed 0 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r 32 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold1 -g ${cuda} --cps ${cps} 33 | 34 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold2 --seed 1 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r 35 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold2 -g ${cuda} --cps ${cps} 36 | #python code/evaluate_Ntimes.py --exp ${folder}${method}${exp} --folds 2 --cps ${cps} 37 | # 38 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold3 --seed 666 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r 39 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold3 -g ${cuda} --cps ${cps} 40 | 41 | python code/evaluate_Ntimes.py --task ${task} --exp ${folder}${method}${exp} --cps ${cps} 42 | -------------------------------------------------------------------------------- /dmp/train3times_seeds_40p.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | while getopts 'm:e:c:t:l:w:' OPT; do 5 | case $OPT in 6 | m) method=$OPTARG;; 7 | e) exp=$OPTARG;; 8 | c) cuda=$OPTARG;; 9 | t) task=$OPTARG;; 10 | l) lr=$OPTARG;; 11 | w) cps_w=$OPTARG;; 12 | esac 13 | done 14 | echo $method 15 | echo $cuda 16 | 17 | epoch=300 18 | echo $epoch 19 | 20 | labeled_data="labeled_40p" 21 | unlabeled_data="unlabeled_40p" 22 | folder="Task_"${task}"_40p/" 23 | cps="AB" 24 | 25 | echo $folder 26 | 27 | 28 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold1 --seed 549 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r #--start_mix 50 29 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold1 -g ${cuda} --cps ${cps} 30 | 31 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold2 --seed 521 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r #--start_mix 50 32 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold2 -g ${cuda} --cps ${cps} 33 | #python code/evaluate_Ntimes.py --exp ${folder}${method}${exp} --folds 2 --cps ${cps} 34 | #END_COMMENT 35 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold3 --seed 999 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r #--start_mix 50 36 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold3 -g ${cuda} --cps ${cps} 37 | 38 | python code/evaluate_Ntimes.py --task ${task} --exp ${folder}${method}${exp} --folds 3 --cps ${cps} -------------------------------------------------------------------------------- /dmp/code/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from medpy import metric 5 | from tqdm import tqdm 6 | 7 | from utils import read_list, read_nifti 8 | from utils import config 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--exp', type=str, default="fully") 15 | args = parser.parse_args() 16 | 17 | test_cls = [i for i in range(1, config.num_cls)] 18 | values = np.zeros((len(test_cls), 2)) # dice and asd 19 | ids_list = read_list('test') 20 | for data_id in tqdm(ids_list): 21 | pred = read_nifti(os.path.join("./logs",args.exp, "predictions",f'{data_id}.nii.gz')) 22 | label = read_nifti(os.path.join(config.base_dir, 'labelsTr', f'label{data_id}.nii.gz')) 23 | 24 | dd, ww, hh = label.shape 25 | label = torch.FloatTensor(label).unsqueeze(0).unsqueeze(0) 26 | label = F.interpolate(label, size=(dd, ww//2, hh//2),mode='trilinear', align_corners=False) 27 | label = label.squeeze().numpy() 28 | 29 | for i in test_cls: 30 | 31 | pred_i = (pred == i) 32 | label_i = (label == i) 33 | if pred_i.sum() > 0 and label_i.sum() > 0: 34 | dice = metric.binary.dc(pred == i, label == i) * 100 35 | hd95 = metric.binary.hd95(pred == i, label == i) 36 | values[i - 1] += np.array([dice, hd95]) 37 | 38 | values /= len(ids_list) 39 | print("====== Dice ======") 40 | print(np.round(values[:,0],1)) 41 | print("====== HD ======") 42 | print(np.round(values[:,1],1)) 43 | print(np.mean(values, axis=0)[0], np.mean(values, axis=0)[1]) 44 | -------------------------------------------------------------------------------- /dmp/train3times_seeds_20p.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | while getopts 'm:e:c:t:l:w:' OPT; do 5 | case $OPT in 6 | m) method=$OPTARG;; 7 | e) exp=$OPTARG;; 8 | c) cuda=$OPTARG;; 9 | t) task=$OPTARG;; 10 | l) lr=$OPTARG;; 11 | w) cps_w=$OPTARG;; 12 | esac 13 | done 14 | echo $method 15 | echo $cuda 16 | 17 | epoch=300 18 | echo $epoch 19 | 20 | labeled_data="labeled_20p" 21 | unlabeled_data="unlabeled_20p" 22 | folder="Task_"${task}"_20p/" 23 | cps="AB" 24 | 25 | echo $folder 26 | 27 | : <<'END_COMMENT' 28 | END_COMMENT 29 | 30 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold1 --seed 0 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r #--start_mix 50 31 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold1 -g ${cuda} --cps ${cps} 32 | 33 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold2 --seed 1 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r #--start_mix 50 34 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold2 -g ${cuda} --cps ${cps} 35 | #python code/evaluate_Ntimes.py --exp ${folder}${method}${exp} --folds 2 --cps ${cps} 36 | 37 | python code/train_${method}.py --task ${task} --exp ${folder}${method}${exp}/fold3 --seed 666 -g ${cuda} --base_lr ${lr} -w ${cps_w} -ep ${epoch} -sl ${labeled_data} -su ${unlabeled_data} -r #--start_mix 50 38 | python code/test.py --task ${task} --exp ${folder}${method}${exp}/fold3 -g ${cuda} --cps ${cps} 39 | 40 | python code/evaluate_Ntimes.py --task ${task} --exp ${folder}${method}${exp} --folds 3 --cps ${cps} 41 | #python code/evaluate_Ntimes.py --task ${task} --exp ${folder}${method}${exp} --cps ${cps} 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Double-Mix Pseudo-Label Framework 2 | 3 | Here is the code for our proposed Double-mix pseudo-label framework: enhancing semi-supervised segmentation on category-imbalanced CT volumes. 4 | 5 | ## Data Preprocess 6 | To prepare the dataset, you can follow the work of [DHC](https://github.com/xmed-lab/DHC). 7 | 8 | You can also use the ``` dmp/code/data/preprocess_amos.py ``` to prepare the dataset. 9 | 10 | The splits are available at ``` {dataset}_data/split ```. 11 | 12 | ## Model Training 13 | Run 14 | ``` 15 | cd dmp 16 | bash train3times_seeds_20p.sh -c 0 -t synapse -m cdifw_dmp_ours -e 'test' -l 3e-2 -w 0.1 17 | ``` 18 | 19 | ### Training Data Percentage: 20 | 21 | The notation 20p represents training with 20% labeled data. You can modify this value to `train3times_seeds_40p`, `train3times_seeds_5p`, etc., to indicate training with 40%, 5%, and so on. 22 | 23 | ### Command-line Parameters: 24 | 25 | `-c`: Specifies which GPU to use for training. 26 | 27 | `-t`: Defines the task, which can be either synapse or amos. 28 | 29 | `-m`: Specifies the training method. The available methods include: 30 | 31 | *(i)* `cdifw_dmp_ours` (our proposed method) 32 | 33 | *(ii)*`cdifw` (ablation studies) 34 | 35 | `-e`: Defines the name of the current experiment. default: `'test'` 36 | 37 | `-l`: Sets the learning rate. In this experiment, it was set to `0.1` 38 | 39 | `-w`: Specifies the weight of the unsupervised loss. 40 | 41 | 42 | Have fun. 43 | 44 | ## Cite 45 | 46 | If this code is helpful for your study, welcome to cite our paper 47 | ``` 48 | @article{zhang2025double, 49 | title={Double-mix pseudo-label framework: enhancing semi-supervised segmentation on category-imbalanced CT volumes}, 50 | author={Zhang, Luyang and Hayashi, Yuichiro and Oda, Masahiro and Mori, Kensaku}, 51 | journal={International Journal of Computer Assisted Radiology and Surgery}, 52 | pages={1--12}, 53 | year={2025}, 54 | publisher={Springer} 55 | } 56 | ``` 57 | 58 | ## License 59 | 60 | This repository is released under MIT License. 61 | 62 | 63 | -------------------------------------------------------------------------------- /dmp/code/data/data_loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from torch.utils.data import Dataset 4 | 5 | from utils import read_list, read_data, softmax 6 | from utils.config import Config 7 | 8 | 9 | class Synapse_AMOS(Dataset): 10 | def __init__(self, split='train', repeat=None, transform=None, unlabeled=False, is_val=False, task="synapse", num_cls=1): 11 | self.ids_list = read_list(split, task=task) 12 | self.config = Config(task) 13 | self.split = split 14 | 15 | self.repeat = repeat 16 | self.task=task 17 | if self.repeat is None: 18 | self.repeat = len(self.ids_list) 19 | print('total {} datas'.format(self.repeat)) 20 | self.transform = transform 21 | self.unlabeled = unlabeled 22 | self.num_cls = num_cls 23 | self._weight = None 24 | self.is_val = is_val 25 | if self.is_val: 26 | self.data_list = {} 27 | for data_id in tqdm(self.ids_list): # <-- load data to memory 28 | image, label = read_data(data_id, task=task) 29 | self.data_list[data_id] = (image, label) 30 | 31 | 32 | def __len__(self): 33 | return self.repeat 34 | 35 | def _get_data(self, data_id): 36 | # [160, 384, 384] 37 | #print(data_id) 38 | if self.is_val: 39 | image, label = self.data_list[data_id] 40 | else: 41 | image, label = read_data(data_id, task=self.task,unlabeled = self.unlabeled) 42 | return data_id, image, label 43 | 44 | 45 | def __getitem__(self, index): 46 | index = index % len(self.ids_list) 47 | data_id = self.ids_list[index] 48 | #print(data_id) 49 | _, image, label = self._get_data(data_id) 50 | if self.unlabeled: # <-- for safety 51 | label[:] = 0 52 | label[label>self.config.num_cls-1] = 0 53 | # print("before",image.min(), image.max()) 54 | # image = (image - image.min()) / (image.max() - image.min()) 55 | if self.task == 'chd' or self.task == 'covid': 56 | min_val = np.percentile(image, 5) 57 | max_val = np.percentile(image, 95) 58 | image = image.clip(min=min_val, max=max_val) 59 | elif self.task == 'colon' : 60 | image = image.clip(min=-250, max=275) 61 | else: 62 | image = image.clip(min=-75, max=275) 63 | 64 | image = (image - image.min()) / (image.max() - image.min()) 65 | # image = (image - image.mean()) / (image.std() + 1e-8) 66 | # print("after",image.min(), image.max()) 67 | # print("ss",image.max()) 68 | # image = image.astype(np.float32) 69 | # label = label.astype(np.int8) 70 | 71 | # print(image.shape, label.shape) 72 | 73 | sample = {'image': image, 'label': label} 74 | 75 | # print(sample['image']) 76 | 77 | if self.transform: 78 | # if not self.unlabeled and not self.is_val: 79 | # sample = self.transform(sample, weights=self.transform.weights) 80 | # else: 81 | sample = self.transform(sample) 82 | 83 | return sample 84 | -------------------------------------------------------------------------------- /dmp/utils/config.py: -------------------------------------------------------------------------------- 1 | 2 | class Config: 3 | def __init__(self,task): 4 | self.task = task 5 | if task == "synapse": 6 | self.base_dir = '/homes/lzhang/data/ssl/MALBCV/Abdomen/RawData/Training/' 7 | self.save_dir = 'code/synapse_data' 8 | self.patch_size = (64, 128, 128) 9 | self.num_cls = 14 10 | self.num_channels = 1 11 | self.n_filters = 32 12 | self.early_stop_patience = 30 13 | 14 | 15 | if task == "synapse_re": 16 | self.base_dir = '/homes/lzhang/data/ssl/MALBCV/Abdomen/RawData/Training/' 17 | self.save_dir = 'code/synapse_data' 18 | self.patch_size = (64, 128, 128) 19 | self.num_cls = 14 20 | self.num_channels = 1 21 | self.n_filters = 32 22 | self.early_stop_patience = 50 23 | 24 | if task == 'word': 25 | 26 | self.base_dir = '/homes/lzhang/data/WORD-V0.1.0/ssl' 27 | self.save_dir = '/homes/lzhang/data/ssl/DHC/code/word_data_better' 28 | self.patch_size = (128, 128, 128) 29 | self.num_cls = 17 30 | self.num_channels = 1 31 | self.n_filters = 32 32 | self.early_stop_patience = 50 33 | 34 | if task == 'chd': 35 | 36 | self.base_dir = '/homes/lzhang/data/contrast/positional_cl-main/dataset/dataset/CHD/better/' 37 | self.save_dir = 'code/chd_data' 38 | self.patch_size = (64, 128, 128) 39 | self.num_cls = 8 40 | self.num_channels = 1 41 | self.n_filters = 32 42 | self.early_stop_patience = 50 43 | 44 | if task == 'amos': 45 | self.base_dir = '/homes/lzhang/data/ssl/DHC/code/data/Datasets/amos22' 46 | self.save_dir = 'code/amos_data' 47 | self.patch_size = (64, 128, 128) 48 | self.num_cls = 16 49 | self.num_channels = 1 50 | self.n_filters = 32 51 | self.early_stop_patience = 50 52 | 53 | if task == 'acc': 54 | self.base_dir = '/homes/lzhang/CT' 55 | self.save_dir = 'acc_data' 56 | self.patch_size = (512, 512, 512) 57 | self.num_cls = 8 58 | self.num_channels = 1 59 | self.n_filters = 32 60 | self.early_stop_patience = 50 61 | 62 | if task == 'acc_s': 63 | self.base_dir = '/homes/lzhang/CT' 64 | self.save_dir = '/homes/lzhang/data/ssl/dhc2/DHC/code/acc_t_data' 65 | self.patch_size = (64, 128, 128) 66 | self.num_cls = 8 67 | self.num_channels = 1 68 | self.n_filters = 32 69 | self.early_stop_patience = 50 70 | 71 | if task == 'covid': 72 | self.base_dir = '/homes/lzhang/data/ssl/COVID-19-20/COVID-19-20_v2' 73 | self.save_dir = '/homes/lzhang/data/ssl/dhc2/DHC/code/covid_data' 74 | self.patch_size = (96, 128, 128) 75 | self.num_cls = 2 76 | self.num_channels = 1 77 | self.n_filters = 32 78 | self.early_stop_patience = 50 79 | 80 | if task == 'colon': 81 | self.base_dir = '/homes/lzhang/data/colon/labeled/nii' 82 | self.save_dir = '/homes/lzhang/data/ssl/dhc2/DHC/code/colon_data' 83 | self.patch_size = (96, 128, 128) 84 | self.num_cls = 3 85 | self.num_channels = 1 86 | self.n_filters = 32 87 | self.early_stop_patience = 30 88 | 89 | if task == 'colon_u': 90 | self.base_dir = '/homes/lzhang/data/colon/imagesTr' 91 | self.save_dir = '/homes/lzhang/data/ssl/dhc2/DHC/code/colon_u_data' 92 | self.patch_size = (96, 128, 128) 93 | self.num_cls = 2 94 | self.num_channels = 1 95 | self.n_filters = 32 96 | self.early_stop_patience = 50 97 | 98 | 99 | -------------------------------------------------------------------------------- /dmp/code/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn.functional as F 7 | from utils import read_list, read_nifti 8 | from utils.config import Config 9 | config = Config('synapse') 10 | 11 | 12 | 13 | def write_txt(data, path): 14 | with open(path, 'w') as f: 15 | for val in data: 16 | f.writelines(val + '\n') 17 | 18 | 19 | def process_npy(): 20 | if not os.path.exists(os.path.join(config.save_dir, 'npy')): 21 | os.makedirs(os.path.join(config.save_dir, 'npy')) 22 | for tag in ['Tr']: 23 | img_ids = [] 24 | for path in tqdm(glob.glob(os.path.join(config.base_dir, f'images{tag}', '*.nii.gz'))): 25 | print(path) 26 | img_id = path.split('/')[-1].split('.')[0] 27 | print(img_id) 28 | img_ids.append(img_id) 29 | label_id= 'label'+img_id[3:] 30 | 31 | image_path = os.path.join(config.base_dir, f'images{tag}', f'{img_id}.nii.gz') 32 | label_path =os.path.join(config.base_dir, f'labels{tag}', f'{label_id}.nii.gz') 33 | 34 | 35 | resize_shape=(config.patch_size[0]+config.patch_size[0]//4, 36 | config.patch_size[1]+config.patch_size[1]//4, 37 | config.patch_size[2]+config.patch_size[2]//4) 38 | image = read_nifti(image_path) 39 | label = read_nifti(label_path) 40 | image = image.astype(np.float32) 41 | label = label.astype(np.int8) 42 | 43 | image = torch.FloatTensor(image).unsqueeze(0).unsqueeze(0) 44 | label = torch.FloatTensor(label).unsqueeze(0).unsqueeze(0) 45 | 46 | image = F.interpolate(image, size=resize_shape,mode='trilinear', align_corners=False) 47 | label = F.interpolate(label, size=resize_shape,mode='nearest') 48 | image = image.squeeze().numpy() 49 | label = label.squeeze().numpy() 50 | 51 | 52 | np.save( 53 | os.path.join(config.save_dir, 'npy', f'{img_id[3:]}_image.npy'), 54 | image 55 | ) 56 | np.save( 57 | os.path.join(config.save_dir, 'npy', f'{img_id[3:]}_label.npy'), 58 | label 59 | ) 60 | 61 | 62 | 63 | 64 | 65 | def process_split_fully(train_ratio=0.8): 66 | if not os.path.exists(os.path.join(config.save_dir, 'splits')): 67 | os.makedirs(os.path.join(config.save_dir, 'splits')) 68 | for tag in ['Tr']: 69 | img_ids = [] 70 | for path in tqdm(glob.glob(os.path.join(config.base_dir, f'images{tag}', '*.nii.gz'))): 71 | img_id = path.split('/')[-1].split('.')[0][3:] 72 | img_ids.append(img_id) 73 | 74 | if tag == 'Tr': 75 | img_ids = np.random.permutation(img_ids) 76 | split_idx = int(len(img_ids) * train_ratio) 77 | train_val_ids = img_ids[:split_idx] 78 | test_ids = sorted(img_ids[split_idx:]) 79 | 80 | split_idx = int(len(train_val_ids) * 5/6) 81 | train_ids = sorted(train_val_ids[:split_idx]) 82 | eval_ids = sorted(train_val_ids[split_idx:]) 83 | write_txt( 84 | train_ids, 85 | os.path.join(config.save_dir, 'splits/train.txt') 86 | ) 87 | write_txt( 88 | eval_ids, 89 | os.path.join(config.save_dir, 'splits/eval.txt') 90 | ) 91 | 92 | test_ids = sorted(test_ids) 93 | write_txt( 94 | test_ids, 95 | os.path.join(config.save_dir, 'splits/test.txt') 96 | ) 97 | 98 | 99 | def process_split_semi(split='train', labeled_ratio=20): 100 | ids_list = read_list(split, task="synapse") 101 | ids_list = np.random.permutation(ids_list) 102 | 103 | split_idx = int(len(ids_list) * labeled_ratio/100) 104 | labeled_ids = sorted(ids_list[:split_idx]) 105 | unlabeled_ids = sorted(ids_list[split_idx:]) 106 | 107 | 108 | write_txt( 109 | labeled_ids, 110 | os.path.join(config.save_dir, f'splits/labeled_{labeled_ratio}p.txt') 111 | ) 112 | write_txt( 113 | unlabeled_ids, 114 | os.path.join(config.save_dir, f'splits/unlabeled_{labeled_ratio}p.txt') 115 | ) 116 | 117 | 118 | if __name__ == '__main__': 119 | process_npy() 120 | process_split_fully() 121 | process_split_semi(labeled_ratio=10) 122 | process_split_semi(labeled_ratio=20) 123 | process_split_semi(labeled_ratio=40) 124 | -------------------------------------------------------------------------------- /dmp/code/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--task', type=str, default='synapse') 6 | parser.add_argument('--exp', type=str, default='fully') 7 | parser.add_argument('--split', type=str, default='test') 8 | parser.add_argument('--speed', type=int, default=0) 9 | parser.add_argument('-g', '--gpu', type=str, default='0') 10 | parser.add_argument('--cps', type=str, default=None) 11 | parser.add_argument('--fold', type=str, default=None) 12 | args = parser.parse_args() 13 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 14 | 15 | import torch 16 | from models.vnet import VNet, VNet4SSNet 17 | from models.vnet_dst import VNet_Decoupled 18 | from models.unet_ds import unet_3D_ds 19 | from models.unet import unet_3D 20 | from utils import test_all_case, read_list, maybe_mkdir, test_all_case_AB 21 | from utils.config import Config 22 | config = Config(args.task) 23 | 24 | if __name__ == '__main__': 25 | stride_dict = { 26 | 0: (32, 16), 27 | 1: (64, 16), 28 | 2: (128, 32), 29 | } 30 | stride = stride_dict[args.speed] 31 | 32 | snapshot_path = f'./logs/{args.exp}/' 33 | test_save_path = f'./logs/{args.exp}/predictions_{args.cps}/' 34 | maybe_mkdir(test_save_path) 35 | 36 | if "fully" in args.exp: 37 | model = VNet( 38 | n_channels=config.num_channels, 39 | n_classes=config.num_cls, 40 | n_filters=config.n_filters, 41 | normalization='batchnorm', 42 | has_dropout=False 43 | ).cuda() 44 | model.eval() 45 | args.cps = None 46 | 47 | 48 | elif "dst" in args.exp: 49 | model_A = VNet_Decoupled( 50 | n_channels=config.num_channels, 51 | n_classes=config.num_cls, 52 | n_filters=config.n_filters, 53 | normalization='batchnorm', 54 | has_dropout=False 55 | ).cuda() 56 | model_B = VNet_Decoupled( 57 | n_channels=config.num_channels, 58 | n_classes=config.num_cls, 59 | n_filters=config.n_filters, 60 | normalization='batchnorm', 61 | has_dropout=False 62 | ).cuda() 63 | model_A.eval() 64 | model_B.eval() 65 | 66 | elif "urpc" in args.exp: 67 | model = unet_3D_ds(n_classes=config.num_cls, in_channels=1).cuda() 68 | model.eval() 69 | args.cps = None 70 | # elif "acisis" in args.exp: 71 | # model = unet_3D(n_classes=config.num_cls, in_channels=1).cuda() 72 | # model.eval() 73 | # args.cps = None 74 | 75 | elif "uamt" in args.exp or "acisis" in args.exp: 76 | model = VNet( 77 | n_channels=config.num_channels, 78 | n_classes=config.num_cls, 79 | n_filters=config.n_filters, 80 | normalization='batchnorm', 81 | has_dropout=False 82 | ).cuda() 83 | model.eval() 84 | args.cps = None 85 | elif "ssnet" in args.exp: 86 | model = VNet4SSNet( 87 | n_channels=config.num_channels, 88 | n_classes=config.num_cls, 89 | n_filters=config.n_filters, 90 | normalization='batchnorm', 91 | has_dropout=False).cuda() 92 | model.eval() 93 | args.cps = None 94 | else: 95 | model_A = VNet( 96 | n_channels=config.num_channels, 97 | n_classes=config.num_cls, 98 | n_filters=config.n_filters, 99 | normalization='batchnorm', 100 | has_dropout=False 101 | ).cuda() 102 | model_B = VNet( 103 | n_channels=config.num_channels, 104 | n_classes=config.num_cls, 105 | n_filters=config.n_filters, 106 | normalization='batchnorm', 107 | has_dropout=False 108 | ).cuda() 109 | model_A.eval() 110 | model_B.eval() 111 | 112 | 113 | ckpt_path = os.path.join(snapshot_path, f'ckpts/best_model.pth') 114 | 115 | if args.task == 'colon': 116 | args.split = args.split+f"_{args.fold}" 117 | 118 | with torch.no_grad(): 119 | if args.cps == "AB": 120 | model_A.load_state_dict(torch.load(ckpt_path)["A"]) 121 | model_B.load_state_dict(torch.load(ckpt_path)["B"]) 122 | print(f'load checkpoint from {ckpt_path}') 123 | test_all_case_AB( 124 | model_A, model_B, 125 | read_list(args.split, task=args.task), 126 | task=args.task, 127 | num_classes=config.num_cls, 128 | patch_size=config.patch_size, 129 | stride_xy=stride[0], 130 | stride_z=stride[1], 131 | test_save_path=test_save_path 132 | ) 133 | else: 134 | if args.cps: 135 | model.load_state_dict(torch.load(ckpt_path)[args.cps]) 136 | else: # for full-supervision 137 | model.load_state_dict(torch.load(ckpt_path)) 138 | print(f'load checkpoint from {ckpt_path}') 139 | test_all_case( 140 | model, 141 | read_list(args.split, task=args.task), 142 | task=args.task, 143 | num_classes=config.num_cls, 144 | patch_size=config.patch_size, 145 | stride_xy=stride[0], 146 | stride_z=stride[1], 147 | test_save_path=test_save_path 148 | ) 149 | -------------------------------------------------------------------------------- /dmp/code/data/preprocess_amos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn.functional as F 7 | from utils import read_list, read_nifti 8 | from utils.config import Config 9 | 10 | 11 | 12 | class Config: 13 | def __init__(self,task): 14 | self.task = task 15 | if task == 'colon_u': 16 | self.base_dir = '/homes/lzhang/data/colon' 17 | self.save_dir = '/homes/lzhang/data/ssl/dhc2/DHC/code/colon_u_data' 18 | self.patch_size = (96, 128, 128) 19 | self.num_cls = 2 20 | self.num_channels = 1 21 | self.n_filters = 32 22 | self.early_stop_patience = 50 23 | 24 | config = Config('colon_u') 25 | 26 | print(config) 27 | 28 | 29 | def write_txt(data, path): 30 | with open(path, 'w') as f: 31 | for val in data: 32 | f.writelines(val + '\n') 33 | 34 | 35 | def process_npy(): 36 | if not os.path.exists(os.path.join(config.save_dir, 'npy')): 37 | os.makedirs(os.path.join(config.save_dir, 'npy')) 38 | for tag in ['Tr', 'Va']: 39 | img_ids = [] 40 | for path in tqdm(glob.glob(os.path.join(config.base_dir, f'images{tag}', '*.nii.gz'))): 41 | print(path) 42 | img_id = path.split('/')[-1].split('.')[0] 43 | print(img_id) 44 | img_ids.append(img_id) 45 | # label_id = 'label'+ img_id[3:] 46 | 47 | image_path = os.path.join(config.base_dir, f'images{tag}', f'{img_id}.nii.gz') 48 | if config.task == 'colon': 49 | label_path =os.path.join(config.base_dir, f'labels{tag}', f'{img_id}.seg.nii.gz') 50 | else: 51 | label_path =os.path.join(config.base_dir, f'labels{tag}', f'{img_id}.nii.gz') 52 | 53 | 54 | resize_shape=(config.patch_size[0]+config.patch_size[0]//4, 55 | config.patch_size[1]+config.patch_size[1]//4, 56 | config.patch_size[2]+config.patch_size[2]//4) 57 | 58 | image = read_nifti(image_path) 59 | print(label_path) 60 | try: 61 | label = read_nifti(label_path) 62 | 63 | islabel = True 64 | except: 65 | islabel = False 66 | image = image.astype(np.float32) 67 | if islabel: 68 | label = label.astype(np.int8) 69 | 70 | 71 | image = torch.FloatTensor(image).unsqueeze(0).unsqueeze(0) 72 | if islabel: 73 | label = torch.FloatTensor(label).unsqueeze(0).unsqueeze(0) 74 | 75 | image = F.interpolate(image, size=resize_shape,mode='trilinear', align_corners=False) 76 | if islabel: 77 | label = F.interpolate(label, size=resize_shape,mode='nearest') 78 | image = image.squeeze().numpy() 79 | if islabel: 80 | label = label.squeeze().numpy() 81 | 82 | 83 | np.save( 84 | os.path.join(config.save_dir, 'npy', f'{img_id}_image.npy'), 85 | image 86 | ) 87 | if islabel: 88 | np.save( 89 | os.path.join(config.save_dir, 'npy', f'{img_id}_label.npy'), 90 | label 91 | ) 92 | 93 | 94 | 95 | 96 | 97 | def process_split_fully(train_ratio=0.9): 98 | if not os.path.exists(os.path.join(config.save_dir, 'splits')): 99 | os.makedirs(os.path.join(config.save_dir, 'splits')) 100 | for tag in ['Tr', 'Va']: 101 | img_ids = [] 102 | for path in tqdm(glob.glob(os.path.join(config.base_dir, f'images{tag}', '*.nii.gz'))): 103 | img_id = path.split('/')[-1].split('.')[0] 104 | img_ids.append(img_id) 105 | print(img_ids) 106 | 107 | if tag == 'Tr': 108 | train_val_ids = np.random.permutation(img_ids) 109 | # split_idx = int(len(img_ids) * train_ratio) 110 | # train_val_ids = img_ids[:split_idx] 111 | # test_ids = sorted(img_ids[split_idx:]) 112 | 113 | # train_val_ids = [i for i in img_ids if i not in test_ids] 114 | split_idx = int(len(train_val_ids) * train_ratio) 115 | train_ids = sorted(train_val_ids[:split_idx]) 116 | eval_ids = sorted(train_val_ids[split_idx:]) 117 | write_txt( 118 | train_ids, 119 | os.path.join(config.save_dir, 'splits/train.txt') 120 | ) 121 | write_txt( 122 | eval_ids, 123 | os.path.join(config.save_dir, 'splits/eval.txt') 124 | ) 125 | 126 | else: 127 | test_ids = np.random.permutation(img_ids) 128 | test_ids = sorted(test_ids) 129 | write_txt( 130 | test_ids, 131 | os.path.join(config.save_dir, 'splits/test.txt') 132 | ) 133 | 134 | 135 | def process_split_semi(split='train', labeled_ratio=10): 136 | ids_list = read_list(split, task=config.task) 137 | ids_list = np.random.permutation(ids_list) 138 | 139 | split_idx = int(len(ids_list) * labeled_ratio/100) 140 | labeled_ids = sorted(ids_list[:split_idx]) 141 | unlabeled_ids = sorted(ids_list[split_idx:]) 142 | 143 | write_txt( 144 | labeled_ids, 145 | os.path.join(config.save_dir, f'splits/labeled_{labeled_ratio}p.txt') 146 | ) 147 | write_txt( 148 | unlabeled_ids, 149 | os.path.join(config.save_dir, f'splits/unlabeled_{labeled_ratio}p.txt') 150 | ) 151 | 152 | 153 | if __name__ == '__main__': 154 | #process_npy() 155 | process_split_fully() 156 | #process_split_semi(labeled_ratio=2) 157 | #process_split_semi(labeled_ratio=5) 158 | #process_split_semi(labeled_ratio=60) 159 | #process_split_semi(labeled_ratio=40) 160 | #process_split_semi(labeled_ratio=100) 161 | -------------------------------------------------------------------------------- /dmp/code/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | class CenterCrop(object): 7 | def __init__(self, output_size): 8 | self.output_size = output_size 9 | 10 | def __call__(self, sample): 11 | image = sample['image'] 12 | padding_flag = image.shape[0] <= self.output_size[0] or \ 13 | image.shape[1] <= self.output_size[1] or \ 14 | image.shape[2] <= self.output_size[2] 15 | 16 | # pad the sample if necessary 17 | if padding_flag: 18 | pw = max((self.output_size[0] - image.shape[0]) // 2 + 3, 0) 19 | ph = max((self.output_size[1] - image.shape[1]) // 2 + 3, 0) 20 | pd = max((self.output_size[2] - image.shape[2]) // 2 + 3, 0) 21 | 22 | w1, h1, d1 = None, None, None 23 | ret_dict = {} 24 | # resize_shape=(self.output_size[0]+self.output_size[0]//4, 25 | # self.output_size[1]+self.output_size[1]//4, 26 | # self.output_size[2]+self.output_size[2]//4) 27 | if w1 is None: 28 | (w, h, d) = image.shape 29 | w1 = int(round((w - self.output_size[0]) / 2.)) 30 | h1 = int(round((h - self.output_size[1]) / 2.)) 31 | d1 = int(round((d - self.output_size[2]) / 2.)) 32 | for key in sample.keys(): 33 | item = sample[key] 34 | # item = torch.FloatTensor(item).unsqueeze(0).unsqueeze(0) 35 | # print(item.shape) 36 | # if key == 'image': 37 | # item = F.interpolate(item, size=resize_shape,mode='trilinear', align_corners=False) 38 | # else: 39 | # item = F.interpolate(item, size=resize_shape, mode="nearest") 40 | # print(item.max()) 41 | # item = item.squeeze().numpy() 42 | if padding_flag: 43 | item = np.pad(item, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 44 | 45 | item = item[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 46 | ret_dict[key] = item 47 | 48 | return ret_dict 49 | 50 | class RandomCrop(object): 51 | ''' 52 | Crop randomly the image in a sample 53 | Args: 54 | output_size (int): Desired output size 55 | ''' 56 | def __init__(self, output_size): 57 | self.output_size = output_size 58 | 59 | def __call__(self, sample): 60 | image = sample['image'] 61 | padding_flag = image.shape[0] <= self.output_size[0] or image.shape[1] <= self.output_size[1] or image.shape[2] <= self.output_size[2] 62 | # pad the sample if necessary 63 | if padding_flag: 64 | pw = max((self.output_size[0] - image.shape[0]) // 2 + 3, 0) 65 | ph = max((self.output_size[1] - image.shape[1]) // 2 + 3, 0) 66 | pd = max((self.output_size[2] - image.shape[2]) // 2 + 3, 0) 67 | 68 | w1, h1, d1 = None, None, None 69 | ret_dict = {} 70 | # print(image.shape) 71 | # resize_shape=(self.output_size[0]+self.output_size[0]//4, 72 | # self.output_size[1]+self.output_size[1]//4, 73 | # self.output_size[2]+self.output_size[2]//4) 74 | if w1 is None: 75 | (w, h, d) = image.shape 76 | w1 = np.random.randint(0, w - self.output_size[0]) 77 | h1 = np.random.randint(0, h - self.output_size[1]) 78 | d1 = np.random.randint(0, d - self.output_size[2]) 79 | 80 | for key in sample.keys(): 81 | item = sample[key] 82 | # item = torch.FloatTensor(item).unsqueeze(0).unsqueeze(0) 83 | # print(item.shape) 84 | # if key == 'image': 85 | # item = F.interpolate(item, size=resize_shape,mode='trilinear', align_corners=False) 86 | # print("img",item.shape) 87 | # else: 88 | # item = F.interpolate(item, size=resize_shape, mode="nearest") 89 | # print("lbl",item.shape) 90 | # item = item.squeeze().numpy() 91 | if padding_flag: 92 | item = np.pad(item, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 93 | 94 | item = item[w1:w1+self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 95 | # print(item.shape) 96 | ret_dict[key] = item 97 | 98 | return ret_dict 99 | 100 | 101 | class RandomFlip_LR: 102 | def __init__(self, prob=0.5): 103 | self.prob = prob 104 | 105 | def _flip(self, img, prob): 106 | if prob[0] <= self.prob: 107 | img = np.flip(img,1).copy() 108 | return img 109 | 110 | def __call__(self, sample): 111 | prob = (np.random.uniform(0, 1), np.random.uniform(0, 1)) 112 | ret_dict = {} 113 | for key in sample.keys(): 114 | item = sample[key] 115 | self._flip(item, prob) 116 | ret_dict[key] = item 117 | return ret_dict 118 | 119 | class RandomFlip_UD: 120 | def __init__(self, prob=0.5): 121 | self.prob = prob 122 | 123 | def _flip(self, img, prob): 124 | if prob[1] <= self.prob: 125 | img = np.flip(img, 2).copy() 126 | return img 127 | 128 | def __call__(self, sample): 129 | prob = (np.random.uniform(0, 1), np.random.uniform(0, 1)) 130 | ret_dict = {} 131 | for key in sample.keys(): 132 | item = sample[key] 133 | self._flip(item, prob) 134 | ret_dict[key] = item 135 | return ret_dict 136 | 137 | 138 | class ToTensor(object): 139 | '''Convert ndarrays in sample to Tensors.''' 140 | def __call__(self, sample): 141 | ret_dict = {} 142 | for key in sample.keys(): 143 | item = sample[key] 144 | if key == 'image': 145 | # print(item.max()) 146 | ret_dict[key] = torch.from_numpy(item).unsqueeze(0).float() 147 | elif key == 'label': 148 | # item[item>config.num_cls-1]=0 149 | ret_dict[key] = torch.from_numpy(item).long() 150 | else: 151 | raise ValueError(key) 152 | # print(ret_dict['image'].shape) 153 | 154 | return ret_dict 155 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=2_kmp_llvm 6 | absl-py=2.1.0=pypi_0 7 | acvl-utils=0.2=pypi_0 8 | argparse=1.4.0=pypi_0 9 | asttokens=2.4.1=pypi_0 10 | batchgenerators=0.25=pypi_0 11 | batchgeneratorsv2=0.1.1=pypi_0 12 | blas=1.0=mkl 13 | bleach=6.1.0=pyhd8ed1ab_0 14 | bzip2=1.0.8=h5eee18b_5 15 | ca-certificates=2024.7.4=hbcca054_0 16 | certifi=2024.7.4=pypi_0 17 | charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | colorama=0.4.6=pyhd8ed1ab_0 19 | comm=0.2.2=pypi_0 20 | connected-components-3d=3.16.0=pypi_0 21 | contourpy=1.2.0=pypi_0 22 | cuda-cudart=11.8.89=0 23 | cuda-cupti=11.8.87=0 24 | cuda-libraries=11.8.0=0 25 | cuda-nvrtc=11.8.89=0 26 | cuda-nvtx=11.8.86=0 27 | cuda-runtime=11.8.0=0 28 | cycler=0.12.1=pypi_0 29 | debugpy=1.8.2=py311h4332511_0 30 | decorator=5.1.1=pypi_0 31 | dicom2nifti=2.4.11=pypi_0 32 | dynamic-network-architectures=0.3.1=pypi_0 33 | efficientnet-pytorch=0.7.1=pypi_0 34 | einops=0.8.0=pypi_0 35 | executing=2.0.1=pypi_0 36 | ffmpeg=4.3=hf484d3e_0 37 | fft-conv-pytorch=1.2.0=pypi_0 38 | filelock=3.9.0=pypi_0 39 | fonttools=4.49.0=pypi_0 40 | freetype=2.12.1=h4a9f257_0 41 | fsspec=2024.6.0=pypi_0 42 | future=1.0.0=pypi_0 43 | gmp=6.2.1=h295c915_3 44 | gmpy2=2.1.2=py311hc9b5ff0_0 45 | gnutls=3.6.15=he1e5248_0 46 | grpcio=1.62.1=pypi_0 47 | h5py=3.10.0=pypi_0 48 | huggingface-hub=0.24.6=pypi_0 49 | idna=3.4=py311h06a4308_0 50 | imagecodecs=2024.6.1=pypi_0 51 | imageio=2.34.0=pypi_0 52 | importlib-metadata=8.0.0=pypi_0 53 | importlib_metadata=8.0.0=hd8ed1ab_0 54 | intel-openmp=2023.1.0=hdb19cb5_46306 55 | ipykernel=6.28.0=py311h06a4308_0 56 | ipython=8.25.0=py311h06a4308_0 57 | jedi=0.19.1=pypi_0 58 | jinja2=3.1.2=pypi_0 59 | joblib=1.3.2=pypi_0 60 | jpeg=9e=h5eee18b_1 61 | jupyter-client=8.6.2=pypi_0 62 | jupyter_client=8.6.2=pyhd8ed1ab_0 63 | jupyter_core=5.7.2=py311h38be061_0 64 | kaggle=1.6.14=pyhd8ed1ab_0 65 | keyutils=1.6.1=h166bdaf_0 66 | kiwisolver=1.4.5=pypi_0 67 | krb5=1.21.3=h659f571_0 68 | lame=3.100=h7b6447c_0 69 | lazy-loader=0.3=pypi_0 70 | lcms2=2.12=h3be6417_0 71 | ld_impl_linux-64=2.38=h1181459_1 72 | lerc=3.0=h295c915_0 73 | libcublas=11.11.3.6=0 74 | libcufft=10.9.0.58=0 75 | libcufile=1.8.1.2=0 76 | libcurand=10.3.4.107=0 77 | libcusolver=11.4.1.48=0 78 | libcusparse=11.7.5.86=0 79 | libdeflate=1.17=h5eee18b_1 80 | libedit=3.1.20191231=he28a2e2_2 81 | libffi=3.4.4=h6a678d5_0 82 | libgcc-ng=13.2.0=h77fa898_7 83 | libiconv=1.16=h7f8727e_2 84 | libidn2=2.3.4=h5eee18b_0 85 | libjpeg-turbo=2.0.0=h9bf148f_0 86 | libnpp=11.8.0.86=0 87 | libnvjpeg=11.9.0.86=0 88 | libpng=1.6.39=h5eee18b_0 89 | libsodium=1.0.18=h36c2ea0_1 90 | libstdcxx-ng=13.2.0=hc0a3c3a_7 91 | libtasn1=4.19.0=h5eee18b_0 92 | libtiff=4.5.1=h6a678d5_0 93 | libunistring=0.9.10=h27cfd23_0 94 | libuuid=1.41.5=h5eee18b_0 95 | libwebp-base=1.3.2=h5eee18b_0 96 | linecache2=1.0.0=pypi_0 97 | llvm-openmp=14.0.6=h9e868ea_0 98 | lz4-c=1.9.4=h6a678d5_0 99 | markdown=3.6=pypi_0 100 | markupsafe=2.1.3=py311h5eee18b_0 101 | matplotlib=3.8.3=pypi_0 102 | matplotlib-inline=0.1.7=pypi_0 103 | medpy=0.4.0=pypi_0 104 | mkl=2023.1.0=h213fc3f_46344 105 | mkl-service=2.4.0=py311h5eee18b_1 106 | mkl_fft=1.3.8=py311h5eee18b_0 107 | mkl_random=1.2.4=py311hdb19cb5_0 108 | mpc=1.1.0=h10f8cd9_1 109 | mpfr=4.0.2=hb69a4c5_1 110 | mpmath=1.3.0=py311h06a4308_0 111 | munch=4.0.0=pypi_0 112 | ncurses=6.4=h6a678d5_0 113 | nest-asyncio=1.6.0=pypi_0 114 | nettle=3.7.3=hbbd107a_1 115 | networkx=3.2.1=pypi_0 116 | nibabel=5.2.1=pypi_0 117 | nnunetv2=2.5=pypi_0 118 | numpy=1.24.0=pypi_0 119 | numpy-base=1.26.4=py311hf175353_0 120 | nvidia-cublas-cu11=11.11.3.6=pypi_0 121 | nvidia-cuda-cupti-cu11=11.8.87=pypi_0 122 | nvidia-cuda-nvrtc-cu11=11.8.89=pypi_0 123 | nvidia-cuda-runtime-cu11=11.8.89=pypi_0 124 | nvidia-cudnn-cu11=8.7.0.84=pypi_0 125 | nvidia-cufft-cu11=10.9.0.58=pypi_0 126 | nvidia-curand-cu11=10.3.0.86=pypi_0 127 | nvidia-cusolver-cu11=11.4.1.48=pypi_0 128 | nvidia-cusparse-cu11=11.7.5.86=pypi_0 129 | nvidia-nccl-cu11=2.19.3=pypi_0 130 | nvidia-nvtx-cu11=11.8.86=pypi_0 131 | opencv-python=4.9.0.80=pypi_0 132 | openh264=2.1.1=h4ff587b_0 133 | openjpeg=2.4.0=h3ad879b_0 134 | openssl=3.3.1=h4ab18f5_1 135 | packaging=23.2=pypi_0 136 | pandas=2.2.2=pypi_0 137 | parso=0.8.4=pypi_0 138 | pexpect=4.9.0=pypi_0 139 | pillow=10.2.0=py311h5eee18b_0 140 | pip=23.3.1=py311h06a4308_0 141 | platformdirs=4.2.2=pypi_0 142 | pretrainedmodels=0.7.4=pypi_0 143 | prompt-toolkit=3.0.47=pypi_0 144 | prompt_toolkit=3.0.47=hd8ed1ab_0 145 | protobuf=4.25.3=pypi_0 146 | psutil=6.0.0=py311h331c9d8_0 147 | ptyprocess=0.7.0=pypi_0 148 | pure-eval=0.2.2=pypi_0 149 | pure_eval=0.2.2=pyhd8ed1ab_0 150 | pydensecrf=1.0=pypi_0 151 | pydicom=2.4.4=pypi_0 152 | pygments=2.18.0=pypi_0 153 | pyparsing=3.1.1=pypi_0 154 | python=3.11.8=h955ad1f_0 155 | python-dateutil=2.9.0.post0=pypi_0 156 | python-gdcm=3.0.23=pypi_0 157 | python-graphviz=0.20.3=pypi_0 158 | python-slugify=8.0.4=pyhd8ed1ab_0 159 | python_abi=3.11=2_cp311 160 | pytorch=2.2.1=py3.11_cuda11.8_cudnn8.7.0_0 161 | pytorch-crf=0.7.2=pypi_0 162 | pytorch-cuda=11.8=h7e8668a_5 163 | pytorch-mutex=1.0=cuda 164 | pytz=2024.1=pypi_0 165 | pyyaml=6.0.1=py311h5eee18b_0 166 | pyzmq=26.0.3=py311h08a0b41_0 167 | readline=8.2=h5eee18b_0 168 | requests=2.31.0=py311h06a4308_1 169 | safetensors=0.4.3=pypi_0 170 | scikit-image=0.22.0=pypi_0 171 | scikit-learn=1.4.1.post1=pypi_0 172 | scipy=1.12.0=pypi_0 173 | seaborn=0.13.2=pypi_0 174 | segmentation-models-pytorch=0.3.4=pypi_0 175 | setuptools=68.2.2=py311h06a4308_0 176 | simpleitk=2.3.1=pypi_0 177 | six=1.16.0=pyh6c4a22f_0 178 | sqlite=3.41.2=h5eee18b_0 179 | stack-data=0.6.2=pypi_0 180 | stack_data=0.6.2=pyhd8ed1ab_0 181 | sympy=1.12=py311h06a4308_0 182 | tbb=2021.8.0=hdb19cb5_0 183 | tensorboard=2.16.2=pypi_0 184 | tensorboard-data-server=0.7.2=pypi_0 185 | tensorboardx=2.6.2.2=pypi_0 186 | text-unidecode=1.3=pyhd8ed1ab_1 187 | threadpoolctl=3.3.0=pypi_0 188 | tifffile=2024.2.12=pypi_0 189 | timm=0.9.7=pypi_0 190 | tk=8.6.12=h1ccaba5_0 191 | torch=2.2.1+cu118=pypi_0 192 | torchaudio=2.2.1+cu118=pypi_0 193 | torchtriton=2.2.0=py311 194 | torchvision=0.17.1+cu118=pypi_0 195 | tornado=6.4.1=py311h331c9d8_0 196 | tqdm=4.66.2=pypi_0 197 | traceback2=1.4.0=pypi_0 198 | traitlets=5.14.3=pypi_0 199 | triton=2.2.0=pypi_0 200 | typing-extensions=4.8.0=pypi_0 201 | typing_extensions=4.9.0=py311h06a4308_1 202 | tzdata=2024.1=pypi_0 203 | unittest2=1.1.0=pypi_0 204 | urllib3=2.2.1=pypi_0 205 | wcwidth=0.2.13=pypi_0 206 | webencodings=0.5.1=pyhd8ed1ab_2 207 | werkzeug=3.0.1=pypi_0 208 | wheel=0.41.2=py311h06a4308_0 209 | xz=5.4.6=h5eee18b_0 210 | yacs=0.1.8=pypi_0 211 | yaml=0.2.5=h7b6447c_0 212 | zeromq=4.3.5=h75354e8_4 213 | zipp=3.19.2=pypi_0 214 | zlib=1.2.13=h5eee18b_0 215 | zstd=1.5.5=hc292b87_0 216 | -------------------------------------------------------------------------------- /dmp/code/evaluate_Ntimes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from medpy import metric 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--task', type=str, default='synapse') 9 | parser.add_argument('--exp', type=str, default="fully") 10 | parser.add_argument('--folds', type=int, default=3) 11 | parser.add_argument('--cps', type=str, default=None) 12 | args = parser.parse_args() 13 | import nibabel as nib 14 | 15 | from utils import read_list, read_nifti 16 | import torch 17 | import torch.nn.functional as F 18 | from utils.config import Config 19 | config = Config(args.task) 20 | 21 | if __name__ == '__main__': 22 | 23 | 24 | 25 | results_all_folds = [] 26 | 27 | txt_path = "./logs/"+args.exp+"/evaluation_res.txt" 28 | # print(txt_path) 29 | print("\n Evaluating...") 30 | fw = open(txt_path, 'w') 31 | for fold in range(1, args.folds+1): 32 | if args.task == "colon": 33 | ids_list = read_list(f'test_{fold}', task=args.task) 34 | else: 35 | ids_list = read_list('test', task=args.task) 36 | test_cls = [i for i in range(1, config.num_cls)] 37 | values = np.zeros((len(ids_list), len(test_cls), 2)) # dice and asd 38 | 39 | for idx, data_id in enumerate(tqdm(ids_list)): 40 | # if idx > 1: 41 | # break 42 | # print(os.path.join("./logs",args.exp, "fold"+str(fold), "predictions_"+args.cps,f'{data_id}.nii.gz')) 43 | pred = read_nifti(os.path.join("./logs",args.exp, "fold"+str(fold), "predictions_"+str(args.cps),f'{data_id}.nii.gz')) 44 | #pred = read_nifti(os.path.join("./logs",args.exp, "fold"+str(fold), "predictions_"+str(args.cps),f'{data_id}.nii.gz')) 45 | if args.task == "amos": 46 | label = read_nifti(os.path.join(config.base_dir, 'labelsVa', f'{data_id}.nii.gz')) 47 | image = read_nifti(os.path.join(config.base_dir, 'imagesVa', f'{data_id}.nii.gz')) 48 | elif args.task == "chd": 49 | image = read_nifti(os.path.join(config.base_dir, 'imagesTr', f'{data_id}.nii.gz')) 50 | label = read_nifti(os.path.join(config.base_dir, 'labelsTr', f'{data_id}.nii.gz')) 51 | 52 | 53 | elif args.task == "acc_s": 54 | image = read_nifti(os.path.join(config.base_dir, 'imagesTr', f'{data_id}.nii.gz')) 55 | label =read_nifti(os.path.join(config.base_dir, f'labelsTr', f'{data_id[:-3]}gt.nii.gz')) 56 | 57 | elif args.task == "covid": 58 | image = read_nifti(os.path.join(config.base_dir, 'imagesTr', f'{data_id}.nii.gz')) 59 | label =read_nifti(os.path.join(config.base_dir, f'labelsTr', f'{data_id[:-2]}seg.nii.gz')) 60 | elif args.task == "colon": 61 | print(f"data id: {data_id}") 62 | image = read_nifti(os.path.join(config.base_dir, 'imagesTr', f'{data_id}.nii.gz')) 63 | label =read_nifti(os.path.join(config.base_dir, f'labelsTr', f'{data_id}.seg.nii.gz')) 64 | 65 | else: 66 | label = read_nifti(os.path.join(config.base_dir, 'labelsTr', f'label{data_id}.nii.gz')) 67 | image = read_nifti(os.path.join(config.base_dir, 'img', f'img{data_id}.nii.gz')) 68 | label = label.astype(np.int8) 69 | dd, ww, hh = label.shape 70 | label = torch.FloatTensor(label).unsqueeze(0).unsqueeze(0) 71 | resize_shape=(config.patch_size[0]+config.patch_size[0]//4, 72 | config.patch_size[1]+config.patch_size[1]//4, 73 | config.patch_size[2]+config.patch_size[2]//4) 74 | label = F.interpolate(label, size=resize_shape,mode='nearest') 75 | label = label.squeeze().numpy() 76 | 77 | pred_t = nib.Nifti1Image(label, np.eye(4)) 78 | os.makedirs(os.path.join('./logs',f'label_{args.task}',args.exp), exist_ok=True) 79 | nib.save(pred_t, os.path.join('./logs',f'label_{args.task}',args.exp, f'{data_id}.nii.gz')) 80 | for i in test_cls: 81 | pred_i = (pred == i) 82 | label_i = (label == i) 83 | if pred_i.sum() > 0 and label_i.sum() > 0: 84 | dice = metric.binary.dc(pred == i, label == i) * 100 85 | hd95 = metric.binary.asd(pred == i, label == i) 86 | values[idx][i-1] = np.array([dice, hd95]) 87 | elif pred_i.sum() > 0 and label_i.sum() == 0: 88 | dice, hd95 = 0, 128 89 | elif pred_i.sum() == 0 and label_i.sum() > 0: 90 | dice, hd95 = 0, 128 91 | elif pred_i.sum() == 0 and label_i.sum() == 0: 92 | dice, hd95 = 1, 0 93 | 94 | values[idx][i-1] = np.array([dice, hd95]) 95 | #print(values) 96 | # print(values.shape) 97 | # values /= len(ids_list) 98 | values_mean_cases = np.mean(values, axis=0) 99 | results_all_folds.append(values) 100 | fw.write("Fold" + str(fold) + '\n') 101 | fw.write("------ Dice ------" + '\n') 102 | fw.write(str(np.round(values_mean_cases[:,0],1)) + '\n') 103 | fw.write("------ ASD ------" + '\n') 104 | fw.write(str(np.round(values_mean_cases[:,1],1)) + '\n') 105 | fw.write('Average Dice:'+str(np.mean(values_mean_cases, axis=0)[0]) + '\n') 106 | fw.write('Average ASD:'+str(np.mean(values_mean_cases, axis=0)[1]) + '\n') 107 | fw.write("=================================") 108 | print("Fold", fold) 109 | print("------ Dice ------") 110 | print(np.round(values_mean_cases[:,0],1)) 111 | print("------ ASD ------") 112 | print(np.round(values_mean_cases[:,1],1)) 113 | print(np.mean(values_mean_cases, axis=0)[0], np.mean(values_mean_cases, axis=0)[1]) 114 | 115 | #print(f'************{results_all_folds}********************') 116 | 117 | results_all_folds = np.array(results_all_folds) 118 | 119 | # print(results_all_folds.shape) 120 | 121 | fw.write('\n\n\n') 122 | fw.write('All folds' + '\n') 123 | 124 | results_folds_mean = results_all_folds.mean(0) 125 | 126 | for i in range(results_folds_mean.shape[0]): 127 | fw.write("="*5 + " Case-" + str(ids_list[i]) + '\n') 128 | fw.write('\tDice:'+str(np.round(results_folds_mean[i][:,0],2).tolist()) + '\n') 129 | fw.write('\t ASD:'+str(np.round(results_folds_mean[i][:,1],2).tolist()) + '\n') 130 | fw.write('\t'+'Average Dice:'+str(np.mean(results_folds_mean[i], axis=0)[0]) + '\n') 131 | fw.write('\t'+'Average ASD:'+str(np.mean(results_folds_mean[i], axis=0)[1]) + '\n') 132 | 133 | fw.write("=================================\n") 134 | fw.write('Final Dice of each class\n') 135 | fw.write(str([round(x,1) for x in results_folds_mean.mean(0)[:,0].tolist()]) + '\n') 136 | fw.write('Final ASD of each class\n') 137 | fw.write(str([round(x,1) for x in results_folds_mean.mean(0)[:,1].tolist()]) + '\n') 138 | print("=================================") 139 | print('Final Dice of each class') 140 | print(str([round(x,1) for x in results_folds_mean.mean(0)[:,0].tolist()])) 141 | print('Final ASD of each class') 142 | print(str([round(x,1) for x in results_folds_mean.mean(0)[:,1].tolist()])) 143 | std_dice = np.std(results_all_folds.mean(1).mean(1)[:,0]) 144 | std_hd = np.std(results_all_folds.mean(1).mean(1)[:,1]) 145 | 146 | fw.write('Final Avg Dice: '+str(round(results_folds_mean.mean(0).mean(0)[0], 2)) +'±' + str(round(std_dice,2)) + '\n') 147 | fw.write('Final Avg ASD: '+str(round(results_folds_mean.mean(0).mean(0)[1], 2)) +'±' + str(round(std_hd,2)) + '\n') 148 | 149 | print('Final Avg Dice: '+str(round(results_folds_mean.mean(0).mean(0)[0], 2)) +'±' + str(round(std_dice,2))) 150 | print('Final Avg ASD: '+str(round(results_folds_mean.mean(0).mean(0)[1], 2)) +'±' + str(round(std_hd,2))) 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /dmp/code/models/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # from networks.networks_other import init_weights 6 | from torch.nn import init 7 | 8 | def weights_init_normal(m): 9 | classname = m.__class__.__name__ 10 | #print(classname) 11 | if classname.find('Conv') != -1: 12 | init.normal(m.weight.data, 0.0, 0.02) 13 | elif classname.find('Linear') != -1: 14 | init.normal(m.weight.data, 0.0, 0.02) 15 | elif classname.find('BatchNorm') != -1: 16 | init.normal(m.weight.data, 1.0, 0.02) 17 | init.constant(m.bias.data, 0.0) 18 | 19 | 20 | def weights_init_xavier(m): 21 | classname = m.__class__.__name__ 22 | #print(classname) 23 | if classname.find('Conv') != -1: 24 | init.xavier_normal(m.weight.data, gain=1) 25 | elif classname.find('Linear') != -1: 26 | init.xavier_normal(m.weight.data, gain=1) 27 | elif classname.find('BatchNorm') != -1: 28 | init.normal(m.weight.data, 1.0, 0.02) 29 | init.constant(m.bias.data, 0.0) 30 | 31 | 32 | def weights_init_kaiming(m): 33 | classname = m.__class__.__name__ 34 | #print(classname) 35 | if classname.find('Conv') != -1: 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | elif classname.find('Linear') != -1: 38 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 39 | elif classname.find('BatchNorm') != -1: 40 | init.normal_(m.weight.data, 1.0, 0.02) 41 | init.constant_(m.bias.data, 0.0) 42 | 43 | 44 | def weights_init_orthogonal(m): 45 | classname = m.__class__.__name__ 46 | #print(classname) 47 | if classname.find('Conv') != -1: 48 | init.orthogonal(m.weight.data, gain=1) 49 | elif classname.find('Linear') != -1: 50 | init.orthogonal(m.weight.data, gain=1) 51 | elif classname.find('BatchNorm') != -1: 52 | init.normal(m.weight.data, 1.0, 0.02) 53 | init.constant(m.bias.data, 0.0) 54 | 55 | 56 | def init_weights(net, init_type='normal'): 57 | #print('initialization method [%s]' % init_type) 58 | if init_type == 'normal': 59 | net.apply(weights_init_normal) 60 | elif init_type == 'xavier': 61 | net.apply(weights_init_xavier) 62 | elif init_type == 'kaiming': 63 | net.apply(weights_init_kaiming) 64 | elif init_type == 'orthogonal': 65 | net.apply(weights_init_orthogonal) 66 | else: 67 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 68 | 69 | class UnetConv3(nn.Module): 70 | def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): 71 | super(UnetConv3, self).__init__() 72 | 73 | if is_batchnorm: 74 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 75 | nn.InstanceNorm3d(out_size), 76 | nn.ReLU(inplace=True),) 77 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 78 | nn.InstanceNorm3d(out_size), 79 | nn.ReLU(inplace=True),) 80 | else: 81 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 82 | nn.ReLU(inplace=True),) 83 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 84 | nn.ReLU(inplace=True),) 85 | 86 | # initialise the blocks 87 | for m in self.children(): 88 | init_weights(m, init_type='kaiming') 89 | 90 | def forward(self, inputs): 91 | outputs = self.conv1(inputs) 92 | outputs = self.conv2(outputs) 93 | return outputs 94 | 95 | 96 | 97 | class UnetUp3_CT(nn.Module): 98 | def __init__(self, in_size, out_size, is_batchnorm=True): 99 | super(UnetUp3_CT, self).__init__() 100 | self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 101 | self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') 102 | 103 | # initialise the blocks 104 | for m in self.children(): 105 | if m.__class__.__name__.find('UnetConv3') != -1: continue 106 | init_weights(m, init_type='kaiming') 107 | 108 | def forward(self, inputs1, inputs2): 109 | outputs2 = self.up(inputs2) 110 | offset = outputs2.size()[2] - inputs1.size()[2] 111 | padding = 2 * [offset // 2, offset // 2, 0] 112 | outputs1 = F.pad(inputs1, padding) 113 | return self.conv(torch.cat([outputs1, outputs2], 1)) 114 | 115 | 116 | class UnetDsv3(nn.Module): 117 | def __init__(self, in_size, out_size, scale_factor): 118 | super(UnetDsv3, self).__init__() 119 | self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0), 120 | nn.Upsample(scale_factor=scale_factor, mode='trilinear'), ) 121 | 122 | def forward(self, input): 123 | return self.dsv(input) 124 | 125 | class unet_3D(nn.Module): 126 | 127 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 128 | super(unet_3D, self).__init__() 129 | self.is_deconv = is_deconv 130 | self.in_channels = in_channels 131 | self.is_batchnorm = is_batchnorm 132 | self.feature_scale = feature_scale 133 | 134 | filters = [64, 128, 256, 512, 1024] 135 | filters = [int(x / self.feature_scale) for x in filters] 136 | 137 | # downsampling 138 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 139 | 3, 3, 3), padding_size=(1, 1, 1)) 140 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 141 | 142 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 143 | 3, 3, 3), padding_size=(1, 1, 1)) 144 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 145 | 146 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 147 | 3, 3, 3), padding_size=(1, 1, 1)) 148 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 149 | 150 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 151 | 3, 3, 3), padding_size=(1, 1, 1)) 152 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 153 | 154 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 155 | 3, 3, 3), padding_size=(1, 1, 1)) 156 | 157 | # upsampling 158 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 159 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 160 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 161 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 162 | 163 | # final conv (without any concat) 164 | self.final = nn.Conv3d(filters[0], n_classes, 1) 165 | 166 | self.dropout1 = nn.Dropout(p=0.3) 167 | self.dropout2 = nn.Dropout(p=0.3) 168 | 169 | # initialise weights 170 | for m in self.modules(): 171 | if isinstance(m, nn.Conv3d): 172 | init_weights(m, init_type='kaiming') 173 | elif isinstance(m, nn.BatchNorm3d): 174 | init_weights(m, init_type='kaiming') 175 | 176 | def forward(self, inputs): 177 | conv1 = self.conv1(inputs) 178 | maxpool1 = self.maxpool1(conv1) 179 | 180 | conv2 = self.conv2(maxpool1) 181 | maxpool2 = self.maxpool2(conv2) 182 | 183 | conv3 = self.conv3(maxpool2) 184 | maxpool3 = self.maxpool3(conv3) 185 | 186 | conv4 = self.conv4(maxpool3) 187 | maxpool4 = self.maxpool4(conv4) 188 | 189 | center = self.center(maxpool4) 190 | center = self.dropout1(center) 191 | up4 = self.up_concat4(conv4, center) 192 | up3 = self.up_concat3(conv3, up4) 193 | up2 = self.up_concat2(conv2, up3) 194 | up1 = self.up_concat1(conv1, up2) 195 | up1 = self.dropout2(up1) 196 | 197 | final = self.final(up1) 198 | 199 | return final 200 | 201 | @staticmethod 202 | def apply_argmax_softmax(pred): 203 | log_p = F.softmax(pred, dim=1) 204 | 205 | return log_p -------------------------------------------------------------------------------- /dmp/code/models/unet_ds.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # from networks.networks_other import init_weights 6 | from torch.nn import init 7 | 8 | def weights_init_normal(m): 9 | classname = m.__class__.__name__ 10 | #print(classname) 11 | if classname.find('Conv') != -1: 12 | init.normal(m.weight.data, 0.0, 0.02) 13 | elif classname.find('Linear') != -1: 14 | init.normal(m.weight.data, 0.0, 0.02) 15 | elif classname.find('BatchNorm') != -1: 16 | init.normal(m.weight.data, 1.0, 0.02) 17 | init.constant(m.bias.data, 0.0) 18 | 19 | 20 | def weights_init_xavier(m): 21 | classname = m.__class__.__name__ 22 | #print(classname) 23 | if classname.find('Conv') != -1: 24 | init.xavier_normal(m.weight.data, gain=1) 25 | elif classname.find('Linear') != -1: 26 | init.xavier_normal(m.weight.data, gain=1) 27 | elif classname.find('BatchNorm') != -1: 28 | init.normal(m.weight.data, 1.0, 0.02) 29 | init.constant(m.bias.data, 0.0) 30 | 31 | 32 | def weights_init_kaiming(m): 33 | classname = m.__class__.__name__ 34 | #print(classname) 35 | if classname.find('Conv') != -1: 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | elif classname.find('Linear') != -1: 38 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 39 | elif classname.find('BatchNorm') != -1: 40 | init.normal_(m.weight.data, 1.0, 0.02) 41 | init.constant_(m.bias.data, 0.0) 42 | 43 | 44 | def weights_init_orthogonal(m): 45 | classname = m.__class__.__name__ 46 | #print(classname) 47 | if classname.find('Conv') != -1: 48 | init.orthogonal(m.weight.data, gain=1) 49 | elif classname.find('Linear') != -1: 50 | init.orthogonal(m.weight.data, gain=1) 51 | elif classname.find('BatchNorm') != -1: 52 | init.normal(m.weight.data, 1.0, 0.02) 53 | init.constant(m.bias.data, 0.0) 54 | 55 | 56 | def init_weights(net, init_type='normal'): 57 | #print('initialization method [%s]' % init_type) 58 | if init_type == 'normal': 59 | net.apply(weights_init_normal) 60 | elif init_type == 'xavier': 61 | net.apply(weights_init_xavier) 62 | elif init_type == 'kaiming': 63 | net.apply(weights_init_kaiming) 64 | elif init_type == 'orthogonal': 65 | net.apply(weights_init_orthogonal) 66 | else: 67 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 68 | 69 | class UnetConv3(nn.Module): 70 | def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): 71 | super(UnetConv3, self).__init__() 72 | 73 | if is_batchnorm: 74 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 75 | nn.InstanceNorm3d(out_size), 76 | nn.ReLU(inplace=True),) 77 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 78 | nn.InstanceNorm3d(out_size), 79 | nn.ReLU(inplace=True),) 80 | else: 81 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 82 | nn.ReLU(inplace=True),) 83 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 84 | nn.ReLU(inplace=True),) 85 | 86 | # initialise the blocks 87 | for m in self.children(): 88 | init_weights(m, init_type='kaiming') 89 | 90 | def forward(self, inputs): 91 | outputs = self.conv1(inputs) 92 | outputs = self.conv2(outputs) 93 | return outputs 94 | 95 | 96 | 97 | class UnetUp3_CT(nn.Module): 98 | def __init__(self, in_size, out_size, is_batchnorm=True): 99 | super(UnetUp3_CT, self).__init__() 100 | self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 101 | self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') 102 | 103 | # initialise the blocks 104 | for m in self.children(): 105 | if m.__class__.__name__.find('UnetConv3') != -1: continue 106 | init_weights(m, init_type='kaiming') 107 | 108 | def forward(self, inputs1, inputs2): 109 | outputs2 = self.up(inputs2) 110 | offset = outputs2.size()[2] - inputs1.size()[2] 111 | padding = 2 * [offset // 2, offset // 2, 0] 112 | outputs1 = F.pad(inputs1, padding) 113 | return self.conv(torch.cat([outputs1, outputs2], 1)) 114 | 115 | 116 | class UnetDsv3(nn.Module): 117 | def __init__(self, in_size, out_size, scale_factor): 118 | super(UnetDsv3, self).__init__() 119 | self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0), 120 | nn.Upsample(scale_factor=scale_factor, mode='trilinear'), ) 121 | 122 | def forward(self, input): 123 | return self.dsv(input) 124 | 125 | class unet_3D_ds(nn.Module): 126 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 127 | super(unet_3D_ds, self).__init__() 128 | self.is_deconv = is_deconv 129 | self.in_channels = in_channels 130 | self.is_batchnorm = is_batchnorm 131 | self.feature_scale = feature_scale 132 | 133 | filters = [64, 128, 256, 512, 1024] 134 | filters = [int(x / self.feature_scale) for x in filters] 135 | 136 | # downsampling 137 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 138 | 3, 3, 3), padding_size=(1, 1, 1)) 139 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 140 | 141 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 142 | 3, 3, 3), padding_size=(1, 1, 1)) 143 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 144 | 145 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 146 | 3, 3, 3), padding_size=(1, 1, 1)) 147 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 148 | 149 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 150 | 3, 3, 3), padding_size=(1, 1, 1)) 151 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 152 | 153 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 154 | 3, 3, 3), padding_size=(1, 1, 1)) 155 | 156 | # upsampling 157 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 158 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 159 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 160 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 161 | 162 | # deep supervision 163 | self.dsv4 = UnetDsv3( 164 | in_size=filters[3], out_size=n_classes, scale_factor=8) 165 | self.dsv3 = UnetDsv3( 166 | in_size=filters[2], out_size=n_classes, scale_factor=4) 167 | self.dsv2 = UnetDsv3( 168 | in_size=filters[1], out_size=n_classes, scale_factor=2) 169 | self.dsv1 = nn.Conv3d( 170 | in_channels=filters[0], out_channels=n_classes, kernel_size=1) 171 | 172 | self.dropout1 = nn.Dropout3d(p=0.5) 173 | self.dropout2 = nn.Dropout3d(p=0.3) 174 | self.dropout3 = nn.Dropout3d(p=0.2) 175 | self.dropout4 = nn.Dropout3d(p=0.1) 176 | 177 | # initialise weights 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv3d): 180 | init_weights(m, init_type='kaiming') 181 | elif isinstance(m, nn.BatchNorm3d): 182 | init_weights(m, init_type='kaiming') 183 | 184 | def forward(self, inputs): 185 | # print(inputs.size()) 186 | conv1 = self.conv1(inputs) 187 | maxpool1 = self.maxpool1(conv1) 188 | 189 | conv2 = self.conv2(maxpool1) 190 | maxpool2 = self.maxpool2(conv2) 191 | 192 | conv3 = self.conv3(maxpool2) 193 | maxpool3 = self.maxpool3(conv3) 194 | 195 | conv4 = self.conv4(maxpool3) 196 | maxpool4 = self.maxpool4(conv4) 197 | 198 | center = self.center(maxpool4) 199 | 200 | up4 = self.up_concat4(conv4, center) 201 | up4 = self.dropout1(up4) 202 | 203 | up3 = self.up_concat3(conv3, up4) 204 | up3 = self.dropout2(up3) 205 | 206 | up2 = self.up_concat2(conv2, up3) 207 | up2 = self.dropout3(up2) 208 | 209 | up1 = self.up_concat1(conv1, up2) 210 | up1 = self.dropout4(up1) 211 | 212 | # Deep Supervision 213 | dsv4 = self.dsv4(up4) 214 | dsv3 = self.dsv3(up3) 215 | dsv2 = self.dsv2(up2) 216 | dsv1 = self.dsv1(up1) 217 | 218 | if not self.training: 219 | return dsv1 220 | 221 | return dsv1, dsv2, dsv3, dsv4 222 | 223 | @staticmethod 224 | def apply_argmax_softmax(pred): 225 | log_p = F.softmax(pred, dim=1) 226 | 227 | return log_p -------------------------------------------------------------------------------- /dmp/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | softmax_helper = lambda x: F.softmax(x, 1) 8 | 9 | 10 | def sum_tensor(inp, axes, keepdim=False): 11 | axes = np.unique(axes).astype(int) 12 | if keepdim: 13 | for ax in axes: 14 | inp = inp.sum(int(ax), keepdim=True) 15 | else: 16 | for ax in sorted(axes, reverse=True): 17 | inp = inp.sum(int(ax)) 18 | return inp 19 | 20 | 21 | def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): 22 | """ 23 | net_output must be (b, c, x, y(, z))) 24 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 25 | if mask is provided it must have shape (b, 1, x, y(, z))) 26 | :param net_output: 27 | :param gt: 28 | :param axes: can be (, ) = no summation 29 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 30 | :param square: if True then fp, tp and fn will be squared before summation 31 | :return: 32 | """ 33 | if axes is None: 34 | axes = tuple(range(2, len(net_output.size()))) 35 | 36 | shp_x = net_output.shape 37 | shp_y = gt.shape 38 | 39 | with torch.no_grad(): 40 | if len(shp_x) != len(shp_y): 41 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 42 | 43 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 44 | # if this is the case then gt is probably already a one hot encoding 45 | y_onehot = gt 46 | else: 47 | gt = gt.long() 48 | y_onehot = torch.zeros(shp_x) 49 | if net_output.device.type == "cuda": 50 | y_onehot = y_onehot.cuda(net_output.device.index) 51 | y_onehot.scatter_(1, gt, 1) 52 | # print(y_onehot.size()) 53 | 54 | tp = net_output * y_onehot 55 | fp = net_output * (1 - y_onehot) 56 | fn = (1 - net_output) * y_onehot 57 | tn = (1 - net_output) * (1 - y_onehot) 58 | 59 | if mask is not None: 60 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 61 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 62 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 63 | tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) 64 | 65 | if square: 66 | tp = tp ** 2 67 | fp = fp ** 2 68 | fn = fn ** 2 69 | tn = tn ** 2 70 | 71 | if len(axes) > 0: 72 | tp = sum_tensor(tp, axes, keepdim=False) 73 | fp = sum_tensor(fp, axes, keepdim=False) 74 | fn = sum_tensor(fn, axes, keepdim=False) 75 | tn = sum_tensor(tn, axes, keepdim=False) 76 | 77 | return tp, fp, fn, tn 78 | 79 | class SoftDiceLoss(nn.Module): 80 | def __init__(self, weight=None, apply_nonlin=None, batch_dice=True, do_bg=False, smooth=1.): 81 | """ 82 | """ 83 | super(SoftDiceLoss, self).__init__() 84 | if weight is not None: 85 | weight = torch.FloatTensor(weight).cuda() 86 | 87 | self.do_bg = do_bg 88 | self.batch_dice = batch_dice 89 | self.apply_nonlin = apply_nonlin 90 | self.smooth = smooth 91 | self.weight = weight 92 | 93 | def forward(self, x, y, loss_mask=None, is_training=True): 94 | shp_x = x.shape 95 | 96 | if self.batch_dice: 97 | axes = [0] + list(range(2, len(shp_x))) 98 | else: 99 | axes = list(range(2, len(shp_x))) 100 | 101 | if self.apply_nonlin is not None: 102 | x = self.apply_nonlin(x) 103 | 104 | tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) 105 | 106 | nominator = 2 * tp + self.smooth 107 | denominator = 2 * tp + fp + fn + self.smooth 108 | 109 | dc = nominator / (denominator + 1e-8) 110 | 111 | if not self.do_bg: 112 | if self.batch_dice: 113 | dc = dc[1:] 114 | else: 115 | dc = dc[:, 1:] 116 | 117 | if self.weight is not None: # <-- 118 | if not self.do_bg and self.batch_dice: 119 | dc *= self.weight[1:] 120 | else: 121 | raise NotImplementedError 122 | 123 | if not is_training: 124 | return dc 125 | else: 126 | return -dc.mean() 127 | 128 | 129 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 130 | """ 131 | this is just a compatibility layer because my target tensor is float and has an extra dimension 132 | """ 133 | def __init__(self, weight=None): 134 | if weight is not None: 135 | weight = torch.FloatTensor(weight).cuda() 136 | super().__init__(weight=weight) 137 | 138 | def forward(self, input, target): 139 | if len(target.shape) == len(input.shape): 140 | assert target.shape[1] == 1 141 | target = target[:, 0] 142 | return super().forward(input, target.long()) 143 | 144 | def update_weight(self, weight): 145 | self.weight = weight 146 | 147 | 148 | class DC_and_CE_loss(nn.Module): 149 | def __init__(self, w_dc=None, w_ce=None, aggregate="sum", weight_ce=1, weight_dice=1, 150 | log_dice=False, ignore_label=None): 151 | """ 152 | CAREFUL. Weights for CE and Dice do not need to sum to one. You can set whatever you want. 153 | :param aggregate: 154 | :param square_dice: 155 | :param weight_ce: 156 | :param weight_dice: 157 | """ 158 | super().__init__() 159 | 160 | ce_kwargs = {'weight': w_ce} 161 | if ignore_label is not None: 162 | ce_kwargs['reduction'] = 'none' 163 | 164 | self.log_dice = log_dice 165 | self.weight_dice = weight_dice 166 | self.weight_ce = weight_ce 167 | self.aggregate = aggregate 168 | self.ce = RobustCrossEntropyLoss(**ce_kwargs) 169 | 170 | self.ignore_label = ignore_label 171 | self.dc = SoftDiceLoss(weight=w_dc, apply_nonlin=softmax_helper) 172 | 173 | def forward(self, net_output, target): 174 | """ 175 | target must be b, c, x, y(, z) with c=1 176 | :param net_output: 177 | :param target: 178 | :return: 179 | """ 180 | if self.ignore_label is not None: 181 | assert target.shape[1] == 1, 'not implemented for one hot encoding' 182 | mask = target != self.ignore_label 183 | target[~mask] = 0 184 | mask = mask.float() 185 | else: 186 | mask = None 187 | 188 | dc_loss = self.dc(net_output, target, loss_mask=mask) if self.weight_dice != 0 else 0 189 | if self.log_dice: 190 | dc_loss = -torch.log(-dc_loss) 191 | 192 | ce_loss = self.ce(net_output, target[:, 0].long()) if self.weight_ce != 0 else 0 193 | if self.ignore_label is not None: 194 | ce_loss *= mask[:, 0] 195 | ce_loss = ce_loss.sum() / mask.sum() 196 | 197 | if self.aggregate == "sum": 198 | result = self.weight_ce * ce_loss + self.weight_dice * dc_loss 199 | else: 200 | raise NotImplementedError("nah son") # reserved for other stuff (later) 201 | return result 202 | 203 | def update_weight(self, weight): 204 | self.dc.weight = weight 205 | self.ce.weight = weight 206 | 207 | 208 | 209 | class WeightedCrossEntropyLoss(nn.CrossEntropyLoss): 210 | def __init__(self, weight=None): 211 | if weight is not None: 212 | weight = torch.FloatTensor(weight).cuda() 213 | super().__init__(weight=weight, reduction='none') 214 | 215 | def forward(self, input, target, weight_map=None): 216 | ''' 217 | - input: B, C, [WHD] 218 | - target: B, [WHD] / B, 1, [WHD] 219 | ''' 220 | b = input.shape[0] 221 | 222 | if len(target.shape) == len(input.shape): 223 | assert target.shape[1] == 1 224 | target = target[:, 0] 225 | 226 | # print("\n",input.size(), target.size()) 227 | loss = super().forward(input, target.long()) # B, [WHD] 228 | loss = loss.view(b, -1) 229 | 230 | if weight_map is not None: 231 | weight = weight_map.view(b, -1).detach() 232 | loss = loss * weight 233 | return torch.mean(loss) 234 | 235 | def update_weight(self, weight): 236 | self.weight = weight 237 | 238 | 239 | class ClassDependent_WeightedCrossEntropyLoss(nn.CrossEntropyLoss): 240 | def __init__(self, weight=None, reduction='none'): 241 | if weight is not None: 242 | weight = torch.FloatTensor(weight).cuda() 243 | super().__init__(weight=weight) 244 | self.reduction = reduction 245 | 246 | def forward(self, input, target, weight_map=None): 247 | ''' 248 | - input: B, C, [WHD] 249 | - target: B, [WHD] / B, 1, [WHD] 250 | ''' 251 | b, c = input.shape[0], input.shape[1] 252 | 253 | if len(target.shape) == len(input.shape): 254 | assert target.shape[1] == 1 255 | target = target[:, 0] 256 | 257 | if weight_map is not None: 258 | loss = super().forward(input*weight_map.detach(), target.long()) # B, [WHD] 259 | else: 260 | loss = super().forward(input, target.long()) 261 | 262 | loss = loss.view(b, -1) 263 | 264 | return torch.mean(loss) 265 | 266 | def update_weight(self, weight): 267 | self.weight = weight 268 | 269 | 270 | 271 | class Onehot_WeightedCrossEntropyLoss(nn.CrossEntropyLoss): 272 | def __init__(self, weight=None): 273 | if weight is not None: 274 | weight = torch.FloatTensor(weight).cuda() 275 | super().__init__(weight=weight, reduction='none') 276 | 277 | def forward(self, input, target, weight_map=None): 278 | ''' 279 | - input: B, C, [WHD] 280 | - target: B, [WHD] / B, 1, [WHD] 281 | ''' 282 | b = input.shape[0] 283 | 284 | # print("\n",input.size(), target.size()) 285 | loss = super().forward(input, target) # B, [WHD] 286 | loss = loss.view(b, -1) 287 | 288 | if weight_map is not None: 289 | weight = weight_map.view(b, -1).detach() 290 | loss = loss * weight 291 | return torch.mean(loss) 292 | 293 | def update_weight(self, weight): 294 | self.weight = weight 295 | -------------------------------------------------------------------------------- /dmp/code/train_cps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--task', type=str, default='synapse') 9 | parser.add_argument('--exp', type=str, default='cps') 10 | parser.add_argument('--seed', type=int, default=0) 11 | parser.add_argument('-sl', '--split_labeled', type=str, default='labeled_20p') 12 | parser.add_argument('-su', '--split_unlabeled', type=str, default='unlabeled_80p') 13 | parser.add_argument('-se', '--split_eval', type=str, default='eval') 14 | parser.add_argument('-m', '--mixed_precision', action='store_true', default=True) # <-- 15 | parser.add_argument('-ep', '--max_epoch', type=int, default=500) 16 | parser.add_argument('--cps_loss', type=str, default='wce') 17 | parser.add_argument('--sup_loss', type=str, default='w_ce+dice') 18 | parser.add_argument('--batch_size', type=int, default=2) 19 | parser.add_argument('--num_workers', type=int, default=2) 20 | parser.add_argument('--base_lr', type=float, default=0.001) 21 | parser.add_argument('-g', '--gpu', type=str, default='0') 22 | parser.add_argument('-w', '--cps_w', type=float, default=1) 23 | parser.add_argument('-r', '--cps_rampup', action='store_true', default=False) # <-- 24 | parser.add_argument('-cr', '--consistency_rampup', type=float, default=None) 25 | args = parser.parse_args() 26 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 27 | 28 | import numpy as np 29 | import torch 30 | from torch import nn 31 | import torch.optim as optim 32 | from torchvision import transforms 33 | from torch.utils.data import DataLoader 34 | from torch.utils.tensorboard import SummaryWriter 35 | from torch.cuda.amp import GradScaler, autocast 36 | 37 | from models.vnet import VNet 38 | from utils import maybe_mkdir, get_lr, fetch_data, seed_worker, poly_lr 39 | from utils.loss import DC_and_CE_loss, RobustCrossEntropyLoss, SoftDiceLoss 40 | from data.transforms import RandomCrop, CenterCrop, ToTensor, RandomFlip_LR, RandomFlip_UD 41 | from data.data_loaders import Synapse_AMOS 42 | from utils.config import Config 43 | config = Config(args.task) 44 | 45 | def sigmoid_rampup(current, rampup_length): 46 | '''Exponential rampup from https://arxiv.org/abs/1610.02242''' 47 | if rampup_length == 0: 48 | return 1.0 49 | else: 50 | current = np.clip(current, 0.0, rampup_length) 51 | phase = 1.0 - current / rampup_length 52 | return float(np.exp(-5.0 * phase * phase)) 53 | 54 | 55 | def get_current_consistency_weight(epoch): 56 | if args.cps_rampup: 57 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 58 | if args.consistency_rampup is None: 59 | args.consistency_rampup = args.max_epoch 60 | return args.cps_w * sigmoid_rampup(epoch, args.consistency_rampup) 61 | else: 62 | return args.cps_w 63 | 64 | 65 | def kaiming_normal_init_weight(model): 66 | for m in model.modules(): 67 | if isinstance(m, nn.Conv3d): 68 | torch.nn.init.kaiming_normal_(m.weight) 69 | elif isinstance(m, nn.BatchNorm3d): 70 | m.weight.data.fill_(1) 71 | m.bias.data.zero_() 72 | return model 73 | 74 | 75 | def xavier_normal_init_weight(model): 76 | for m in model.modules(): 77 | if isinstance(m, nn.Conv3d): 78 | torch.nn.init.xavier_normal_(m.weight) 79 | elif isinstance(m, nn.BatchNorm3d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | return model 83 | 84 | 85 | def make_loss_function(name, weight=None): 86 | if name == 'ce': 87 | return RobustCrossEntropyLoss() 88 | elif name == 'wce': 89 | return RobustCrossEntropyLoss(weight=weight) 90 | elif name == 'ce+dice': 91 | return DC_and_CE_loss() 92 | elif name == 'wce+dice': 93 | return DC_and_CE_loss(w_ce=weight) 94 | elif name == 'w_ce+dice': 95 | return DC_and_CE_loss(w_dc=weight, w_ce=weight) 96 | else: 97 | raise ValueError(name) 98 | 99 | 100 | def make_loader(split, fold = 0, dst_cls=Synapse_AMOS, repeat=None, is_training=True, unlabeled=False): 101 | if is_training: 102 | dst = dst_cls( 103 | task=args.task, 104 | split=split, 105 | repeat=repeat, 106 | unlabeled=unlabeled, 107 | num_cls=config.num_cls, 108 | transform=transforms.Compose([ 109 | RandomCrop(config.patch_size), 110 | RandomFlip_LR(), 111 | RandomFlip_UD(), 112 | ToTensor() 113 | ]) 114 | ) 115 | return DataLoader( 116 | dst, 117 | batch_size=args.batch_size, 118 | shuffle=True, 119 | num_workers=args.num_workers, 120 | pin_memory=True, 121 | worker_init_fn=seed_worker 122 | ) 123 | else: 124 | dst = dst_cls( 125 | task=args.task, 126 | split=split, 127 | is_val=True, 128 | num_cls=config.num_cls, 129 | transform=transforms.Compose([ 130 | CenterCrop(config.patch_size), 131 | ToTensor() 132 | ]) 133 | ) 134 | return DataLoader(dst, pin_memory=True) 135 | 136 | 137 | def make_model_all(): 138 | model = VNet( 139 | n_channels=config.num_channels, 140 | n_classes=config.num_cls, 141 | n_filters=config.n_filters, 142 | normalization='batchnorm', 143 | has_dropout=True 144 | ).cuda() 145 | 146 | optimizer = optim.SGD( 147 | model.parameters(), 148 | lr=args.base_lr, 149 | momentum=0.9, 150 | weight_decay=3e-5, 151 | nesterov=True 152 | ) 153 | return model, optimizer 154 | 155 | 156 | 157 | 158 | if __name__ == '__main__': 159 | import random 160 | SEED=args.seed 161 | random.seed(SEED) 162 | np.random.seed(SEED) 163 | torch.manual_seed(SEED) 164 | torch.cuda.manual_seed(SEED) 165 | torch.cuda.manual_seed_all(SEED) 166 | 167 | # make logger file 168 | snapshot_path = f'./logs/{args.exp}/' 169 | fold = str(args.exp[-1]) 170 | maybe_mkdir(snapshot_path) 171 | maybe_mkdir(os.path.join(snapshot_path, 'ckpts')) 172 | 173 | # make logger 174 | writer = SummaryWriter(os.path.join(snapshot_path, 'tensorboard')) 175 | logging.basicConfig( 176 | filename=os.path.join(snapshot_path, 'train.log'), 177 | level=logging.INFO, 178 | format='[%(asctime)s.%(msecs)03d] %(message)s', 179 | datefmt='%H:%M:%S' 180 | ) 181 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 182 | logging.info(str(args)) 183 | 184 | # make data loader 185 | if args.task == 'colon': 186 | #args.split_unlabeled = args.split_unlabeled+'_'+fold 187 | args.split_labeled = args.split_labeled+'_'+fold 188 | args.split_eval = args.split_eval+'_'+fold 189 | 190 | 191 | unlabeled_loader = make_loader(args.split_unlabeled, unlabeled=True) 192 | labeled_loader = make_loader(args.split_labeled, repeat=len(unlabeled_loader.dataset)) 193 | eval_loader = make_loader(args.split_eval, is_training=False) 194 | 195 | logging.info(f'{len(labeled_loader)} itertations per epoch (labeled)') 196 | logging.info(f'{len(unlabeled_loader)} itertations per epoch (unlabeled)') 197 | 198 | # make model, optimizer, and lr scheduler 199 | model_A, optimizer_A = make_model_all() 200 | model_B, optimizer_B = make_model_all() 201 | model_A = kaiming_normal_init_weight(model_A) 202 | model_B = kaiming_normal_init_weight(model_B) 203 | 204 | logging.info(optimizer_A) 205 | 206 | loss_func = make_loss_function(args.sup_loss) 207 | cps_loss_func = make_loss_function(args.cps_loss) 208 | 209 | 210 | 211 | if args.mixed_precision: 212 | amp_grad_scaler = GradScaler() 213 | 214 | cps_w = get_current_consistency_weight(0) 215 | best_eval = 0.0 216 | best_epoch = 0 217 | for epoch_num in range(args.max_epoch + 1): 218 | loss_list = [] 219 | loss_cps_list = [] 220 | loss_sup_list = [] 221 | 222 | model_A.train() 223 | model_B.train() 224 | for batch_l, batch_u in tqdm(zip(labeled_loader, unlabeled_loader)): 225 | optimizer_A.zero_grad() 226 | optimizer_B.zero_grad() 227 | 228 | image_l, label_l = fetch_data(batch_l) 229 | image_u = fetch_data(batch_u, labeled=False) 230 | image = torch.cat([image_l, image_u], dim=0) 231 | tmp_bs = image.shape[0] // 2 232 | 233 | if args.mixed_precision: 234 | with autocast(): 235 | output_A = model_A(image) 236 | output_B = model_B(image) 237 | del image 238 | 239 | # sup (ce + dice) 240 | output_A_l, output_A_u = output_A[:tmp_bs, ...], output_A[tmp_bs:, ...] 241 | output_B_l, output_B_u = output_B[:tmp_bs, ...], output_B[tmp_bs:, ...] 242 | loss_sup = loss_func(output_A_l, label_l) + loss_func(output_B_l, label_l) 243 | 244 | # cps (ce only) 245 | max_A = torch.argmax(output_A.detach(), dim=1, keepdim=True).long() 246 | max_B = torch.argmax(output_B.detach(), dim=1, keepdim=True).long() 247 | loss_cps = cps_loss_func(output_A, max_B) + cps_loss_func(output_B, max_A) 248 | # loss prop 249 | loss = loss_sup + cps_w * loss_cps 250 | 251 | 252 | # backward passes should not be under autocast. 253 | amp_grad_scaler.scale(loss).backward() 254 | amp_grad_scaler.step(optimizer_A) 255 | amp_grad_scaler.step(optimizer_B) 256 | amp_grad_scaler.update() 257 | 258 | else: 259 | raise NotImplementedError 260 | 261 | loss_list.append(loss.item()) 262 | loss_sup_list.append(loss_sup.item()) 263 | loss_cps_list.append(loss_cps.item()) 264 | 265 | writer.add_scalar('lr', get_lr(optimizer_A), epoch_num) 266 | writer.add_scalar('cps_w', cps_w, epoch_num) 267 | writer.add_scalar('loss/loss', np.mean(loss_list), epoch_num) 268 | writer.add_scalar('loss/sup', np.mean(loss_sup_list), epoch_num) 269 | writer.add_scalar('loss/cps', np.mean(loss_cps_list), epoch_num) 270 | logging.info(f'epoch {epoch_num} : loss : {np.mean(loss_list)}, cpsw:{cps_w} lr: {get_lr(optimizer_A)}') 271 | 272 | 273 | optimizer_A.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 274 | optimizer_B.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 275 | cps_w = get_current_consistency_weight(epoch_num) 276 | 277 | if epoch_num % 10 == 0 or epoch_num>50: 278 | 279 | # ''' ===== evaluation 280 | dice_list = [[] for _ in range(config.num_cls-1)] 281 | model_A.eval() 282 | model_B.eval() 283 | dice_func = SoftDiceLoss(smooth=1e-8) 284 | for batch in tqdm(eval_loader): 285 | with torch.no_grad(): 286 | image, gt = fetch_data(batch) 287 | # output = model_A(image) 288 | output = (model_A(image) + model_B(image)) / 2.0 289 | del image 290 | print(gt.max()) 291 | shp = output.shape 292 | gt = gt.long() 293 | y_onehot = torch.zeros(shp).cuda() 294 | y_onehot.scatter_(1, gt, 1) 295 | x_onehot = torch.zeros(shp).cuda() 296 | 297 | output = torch.argmax(output, dim=1, keepdim=True).long() 298 | x_onehot.scatter_(1, output, 1) 299 | 300 | dice = dice_func(x_onehot, y_onehot, is_training=False) 301 | dice = dice.data.cpu().numpy() 302 | for i, d in enumerate(dice): 303 | dice_list[i].append(d) 304 | 305 | dice_mean = [] 306 | for dice in dice_list: 307 | dice_mean.append(np.mean(dice)) 308 | logging.info(f'evaluation epoch {epoch_num}, dice: {np.mean(dice_mean)}, {dice_mean}') 309 | # ''' 310 | if np.mean(dice_mean) > best_eval: 311 | best_eval = np.mean(dice_mean) 312 | best_epoch = epoch_num 313 | save_path = os.path.join(snapshot_path, f'ckpts/best_model.pth') 314 | torch.save({ 315 | 'A': model_A.state_dict(), 316 | 'B': model_B.state_dict() 317 | }, save_path) 318 | logging.info(f'saving best model to {save_path}') 319 | logging.info(f'\t best eval dice is {best_eval} in epoch {best_epoch}') 320 | 321 | if epoch_num - best_epoch == config.early_stop_patience: 322 | logging.info(f'Early stop.') 323 | break 324 | 325 | writer.close() 326 | -------------------------------------------------------------------------------- /dmp/code/models/vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'batchnorm': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'groupnorm': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'instancenorm': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class DownsamplingConvBlock(nn.Module): 35 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 36 | super(DownsamplingConvBlock, self).__init__() 37 | 38 | ops = [] 39 | if normalization != 'none': 40 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 41 | if normalization == 'batchnorm': 42 | ops.append(nn.BatchNorm3d(n_filters_out)) 43 | elif normalization == 'groupnorm': 44 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 45 | elif normalization == 'instancenorm': 46 | ops.append(nn.InstanceNorm3d(n_filters_out)) 47 | else: 48 | assert False 49 | else: 50 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 51 | 52 | ops.append(nn.ReLU(inplace=True)) 53 | 54 | self.conv = nn.Sequential(*ops) 55 | 56 | def forward(self, x): 57 | x = self.conv(x) 58 | return x 59 | 60 | 61 | class UpsamplingDeconvBlock(nn.Module): 62 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 63 | super(UpsamplingDeconvBlock, self).__init__() 64 | 65 | ops = [] 66 | if normalization != 'none': 67 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 68 | if normalization == 'batchnorm': 69 | ops.append(nn.BatchNorm3d(n_filters_out)) 70 | elif normalization == 'groupnorm': 71 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 72 | elif normalization == 'instancenorm': 73 | ops.append(nn.InstanceNorm3d(n_filters_out)) 74 | else: 75 | assert False 76 | else: 77 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 78 | 79 | ops.append(nn.ReLU(inplace=True)) 80 | 81 | self.conv = nn.Sequential(*ops) 82 | 83 | def forward(self, x): 84 | x = self.conv(x) 85 | return x 86 | 87 | 88 | class VNet(nn.Module): 89 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False,return_layers = False): 90 | super(VNet, self).__init__() 91 | 92 | self.return_layers = return_layers 93 | self.has_dropout = has_dropout 94 | 95 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 96 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 97 | 98 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 99 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 100 | 101 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 102 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 103 | 104 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 105 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 106 | 107 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 108 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 109 | 110 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 111 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 112 | 113 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 114 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 115 | 116 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 117 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 118 | 119 | self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) 120 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 121 | 122 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 123 | 124 | def encoder(self, input): 125 | # print(input.size()) 126 | x1 = self.block_one(input) 127 | x1_dw = self.block_one_dw(x1) 128 | 129 | x2 = self.block_two(x1_dw) 130 | x2_dw = self.block_two_dw(x2) 131 | 132 | x3 = self.block_three(x2_dw) 133 | x3_dw = self.block_three_dw(x3) 134 | 135 | x4 = self.block_four(x3_dw) 136 | x4_dw = self.block_four_dw(x4) 137 | 138 | x5 = self.block_five(x4_dw) 139 | if self.has_dropout: 140 | x5 = self.dropout(x5) 141 | 142 | return x1, x2, x3, x4, x5 143 | 144 | def decoder(self, features): 145 | x1, x2, x3, x4, x5 = features 146 | # print(x1.size(), x2.size(), x3.size()) 147 | x5_up = self.block_five_up(x5) 148 | # print("1",x5.size(), x5_up.size(), x4.size()) 149 | x5_up = x5_up + x4 150 | 151 | x6 = self.block_six(x5_up) 152 | x6_up = self.block_six_up(x6) 153 | x6_up = x6_up + x3 154 | 155 | x7 = self.block_seven(x6_up) 156 | x7_up = self.block_seven_up(x7) 157 | x7_up = x7_up + x2 158 | 159 | x8 = self.block_eight(x7_up) 160 | x8_up = self.block_eight_up(x8) 161 | x8_up = x8_up + x1 162 | x9 = self.block_nine(x8_up) 163 | if self.has_dropout: 164 | x9 = self.dropout(x9) 165 | 166 | out = self.out_conv(x9) 167 | if self.return_layers: 168 | return out, [globals()[f'x{j}'] for j in range(1, 10)] 169 | return out 170 | 171 | def forward(self, image, turnoff_drop=False): 172 | if turnoff_drop: 173 | has_dropout = self.has_dropout 174 | self.has_dropout = False 175 | 176 | features = self.encoder(image) 177 | out = self.decoder(features) 178 | 179 | if turnoff_drop: 180 | self.has_dropout = has_dropout 181 | 182 | return out 183 | 184 | 185 | class VNet4SSNet(nn.Module): 186 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): 187 | super(VNet4SSNet, self).__init__() 188 | self.has_dropout = has_dropout 189 | 190 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 191 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 192 | 193 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 194 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 195 | 196 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 197 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 198 | 199 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 200 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 201 | 202 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 203 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 204 | 205 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 206 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 207 | 208 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 209 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 210 | 211 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 212 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 213 | 214 | self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) 215 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 216 | 217 | dim_in = n_filters 218 | feat_dim = n_filters * 2 219 | self.projection_head = nn.Sequential( 220 | nn.Linear(dim_in, feat_dim), 221 | nn.BatchNorm1d(feat_dim), 222 | nn.ReLU(inplace=True), 223 | nn.Linear(feat_dim, feat_dim) 224 | ) 225 | self.prediction_head = nn.Sequential( 226 | nn.Linear(feat_dim, feat_dim), 227 | nn.BatchNorm1d(feat_dim), 228 | nn.ReLU(inplace=True), 229 | nn.Linear(feat_dim, feat_dim) 230 | ) 231 | 232 | for class_c in range(n_classes): 233 | selector = nn.Sequential( 234 | nn.Linear(feat_dim, feat_dim), 235 | nn.BatchNorm1d(feat_dim), 236 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 237 | nn.Linear(feat_dim, 1) 238 | ) 239 | self.__setattr__('contrastive_class_selector_' + str(class_c), selector) 240 | 241 | for class_c in range(n_classes): 242 | selector = nn.Sequential( 243 | nn.Linear(feat_dim, feat_dim), 244 | nn.BatchNorm1d(feat_dim), 245 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 246 | nn.Linear(feat_dim, 1) 247 | ) 248 | self.__setattr__('contrastive_class_selector_memory' + str(class_c), selector) 249 | 250 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 251 | 252 | def encoder(self, input): 253 | # print(input.size()) 254 | x1 = self.block_one(input) 255 | x1_dw = self.block_one_dw(x1) 256 | 257 | x2 = self.block_two(x1_dw) 258 | x2_dw = self.block_two_dw(x2) 259 | 260 | x3 = self.block_three(x2_dw) 261 | x3_dw = self.block_three_dw(x3) 262 | 263 | x4 = self.block_four(x3_dw) 264 | x4_dw = self.block_four_dw(x4) 265 | 266 | x5 = self.block_five(x4_dw) 267 | if self.has_dropout: 268 | x5 = self.dropout(x5) 269 | 270 | return x1, x2, x3, x4, x5 271 | 272 | def decoder(self, features): 273 | x1, x2, x3, x4, x5 = features 274 | # print(x1.size(), x2.size(), x3.size()) 275 | x5_up = self.block_five_up(x5) 276 | # print("1",x5.size(), x5_up.size(), x4.size()) 277 | x5_up = x5_up + x4 278 | 279 | x6 = self.block_six(x5_up) 280 | x6_up = self.block_six_up(x6) 281 | x6_up = x6_up + x3 282 | 283 | x7 = self.block_seven(x6_up) 284 | x7_up = self.block_seven_up(x7) 285 | x7_up = x7_up + x2 286 | 287 | x8 = self.block_eight(x7_up) 288 | x8_up = self.block_eight_up(x8) 289 | x8_up = x8_up + x1 290 | x9 = self.block_nine(x8_up) 291 | if self.has_dropout: 292 | x9 = self.dropout(x9) 293 | 294 | out = self.out_conv(x9) 295 | return out, x9 296 | 297 | 298 | def forward_projection_head(self, features): 299 | return self.projection_head(features) 300 | 301 | def forward_prediction_head(self, features): 302 | return self.prediction_head(features) 303 | 304 | def forward(self, image, turnoff_drop=False): 305 | if turnoff_drop: 306 | has_dropout = self.has_dropout 307 | self.has_dropout = False 308 | 309 | features = self.encoder(image) 310 | out, embedding = self.decoder(features) 311 | 312 | if turnoff_drop: 313 | self.has_dropout = has_dropout 314 | 315 | if not self.training: 316 | return out 317 | 318 | return out, embedding -------------------------------------------------------------------------------- /dmp/code/models/vnet_dst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Any, Tuple 4 | import numpy as np 5 | from torch.autograd import Function 6 | 7 | 8 | 9 | class ConvBlock(nn.Module): 10 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 11 | super(ConvBlock, self).__init__() 12 | 13 | ops = [] 14 | for i in range(n_stages): 15 | if i==0: 16 | input_channel = n_filters_in 17 | else: 18 | input_channel = n_filters_out 19 | 20 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 21 | if normalization == 'batchnorm': 22 | ops.append(nn.BatchNorm3d(n_filters_out)) 23 | elif normalization == 'groupnorm': 24 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 25 | elif normalization == 'instancenorm': 26 | ops.append(nn.InstanceNorm3d(n_filters_out)) 27 | elif normalization != 'none': 28 | assert False 29 | ops.append(nn.ReLU(inplace=True)) 30 | 31 | self.conv = nn.Sequential(*ops) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class DownsamplingConvBlock(nn.Module): 39 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 40 | super(DownsamplingConvBlock, self).__init__() 41 | 42 | ops = [] 43 | if normalization != 'none': 44 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 45 | if normalization == 'batchnorm': 46 | ops.append(nn.BatchNorm3d(n_filters_out)) 47 | elif normalization == 'groupnorm': 48 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 49 | elif normalization == 'instancenorm': 50 | ops.append(nn.InstanceNorm3d(n_filters_out)) 51 | else: 52 | assert False 53 | else: 54 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 55 | 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | 60 | def forward(self, x): 61 | x = self.conv(x) 62 | return x 63 | 64 | 65 | class UpsamplingDeconvBlock(nn.Module): 66 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 67 | super(UpsamplingDeconvBlock, self).__init__() 68 | 69 | ops = [] 70 | if normalization != 'none': 71 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 72 | if normalization == 'batchnorm': 73 | ops.append(nn.BatchNorm3d(n_filters_out)) 74 | elif normalization == 'groupnorm': 75 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 76 | elif normalization == 'instancenorm': 77 | ops.append(nn.InstanceNorm3d(n_filters_out)) 78 | else: 79 | assert False 80 | else: 81 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 82 | 83 | ops.append(nn.ReLU(inplace=True)) 84 | 85 | self.conv = nn.Sequential(*ops) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | 93 | 94 | 95 | class GradientReverseFunction(Function): 96 | 97 | @staticmethod 98 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: 99 | ctx.coeff = coeff 100 | 101 | output = input * 1.0 102 | return output 103 | 104 | @staticmethod 105 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 106 | return grad_output.neg() * ctx.coeff, None 107 | 108 | 109 | class WarmStartGradientReverseLayer(nn.Module): 110 | """Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start 111 | The forward and backward behaviours are: 112 | .. math:: 113 | \mathcal{R}(x) = x, 114 | \dfrac{ d\mathcal{R}} {dx} = - \lambda I. 115 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: 116 | .. math:: 117 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo 118 | where :math:`i` is the iteration step. 119 | Args: 120 | alpha (float, optional): :math:`α`. Default: 1.0 121 | lo (float, optional): Initial value of :math:`\lambda`. Default: 0.0 122 | hi (float, optional): Final value of :math:`\lambda`. Default: 1.0 123 | max_iters (int, optional): :math:`N`. Default: 1000 124 | auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called. 125 | Otherwise use function `step` to increase :math:`i`. Default: False 126 | """ 127 | 128 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., 129 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): 130 | super(WarmStartGradientReverseLayer, self).__init__() 131 | self.alpha = alpha 132 | self.lo = lo 133 | self.hi = hi 134 | self.iter_num = 0 135 | self.max_iters = max_iters 136 | self.auto_step = auto_step 137 | 138 | def forward(self, input: torch.Tensor) -> torch.Tensor: 139 | """""" 140 | coeff = np.float( 141 | 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) 142 | - (self.hi - self.lo) + self.lo 143 | ) 144 | # print(input.shape) 145 | 146 | if self.auto_step: 147 | self.step() 148 | return GradientReverseFunction.apply(input, coeff) 149 | 150 | def step(self): 151 | """Increase iteration number :math:`i` by 1""" 152 | self.iter_num += 1 153 | 154 | class Decoder(nn.Module): 155 | def __init__(self, n_classes=2, n_filters=16, normalization=None, worst_case=False): 156 | super(Decoder, self).__init__() 157 | self.worst_case = worst_case 158 | 159 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 160 | 161 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 162 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 163 | 164 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 165 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 166 | 167 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 168 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 169 | 170 | self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) 171 | self.head = nn.Conv3d(n_filters, n_classes, 1, padding=0) 172 | self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=500, 173 | auto_step=False) 174 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 175 | 176 | 177 | 178 | 179 | def decode(self, features, has_dropout): 180 | x1, x2, x3, x4, x5 = features 181 | # print(x1.size(), x2.size(), x3.size()) 182 | x5_up = self.block_five_up(x5) 183 | # print("1",x5.size(), x5_up.size(), x4.size()) 184 | x5_up = x5_up + x4 185 | 186 | x6 = self.block_six(x5_up) 187 | x6_up = self.block_six_up(x6) 188 | x6_up = x6_up + x3 189 | 190 | x7 = self.block_seven(x6_up) 191 | x7_up = self.block_seven_up(x7) 192 | x7_up = x7_up + x2 193 | 194 | x8 = self.block_eight(x7_up) 195 | x8_up = self.block_eight_up(x8) 196 | x8_up = x8_up + x1 197 | x9 = self.block_nine(x8_up) 198 | if has_dropout: 199 | x9 = self.dropout(x9) 200 | out = self.head(x9) 201 | return out 202 | 203 | def decode_worst(self, features, has_dropout): 204 | x1, x2, x3, x4, x5 = features 205 | x1 = self.grl_layer(x1) 206 | x2 = self.grl_layer(x2) 207 | x3 = self.grl_layer(x3) 208 | x4 = self.grl_layer(x4) 209 | x5 = self.grl_layer(x5) 210 | # print(x1.size(), x2.size(), x3.size()) 211 | x5_up = self.block_five_up(x5) 212 | # print("1",x5.size(), x5_up.size(), x4.size()) 213 | x5_up = x5_up + x4 214 | 215 | x6 = self.block_six(x5_up) 216 | x6_up = self.block_six_up(x6) 217 | x6_up = x6_up + x3 218 | 219 | x7 = self.block_seven(x6_up) 220 | x7_up = self.block_seven_up(x7) 221 | x7_up = x7_up + x2 222 | 223 | x8 = self.block_eight(x7_up) 224 | x8_up = self.block_eight_up(x8) 225 | x8_up = x8_up + x1 226 | x9 = self.block_nine(x8_up) 227 | if has_dropout: 228 | x9 = self.dropout(x9) 229 | out = self.head(x9) 230 | return out 231 | 232 | def decode_worst_grl_last(self, features, has_dropout): 233 | x1, x2, x3, x4, x5 = features 234 | # print(x1.size(), x2.size(), x3.size()) 235 | x5_up = self.block_five_up(x5) 236 | # print("1",x5.size(), x5_up.size(), x4.size()) 237 | x5_up = x5_up + x4 238 | 239 | x6 = self.block_six(x5_up) 240 | x6_up = self.block_six_up(x6) 241 | x6_up = x6_up + x3 242 | 243 | x7 = self.block_seven(x6_up) 244 | x7_up = self.block_seven_up(x7) 245 | x7_up = x7_up + x2 246 | 247 | x8 = self.block_eight(x7_up) 248 | x8_up = self.block_eight_up(x8) 249 | x8_up = x8_up + x1 250 | 251 | x8_up = self.grl_layer(x8_up) 252 | x9 = self.block_nine(x8_up) 253 | if has_dropout: 254 | x9 = self.dropout(x9) 255 | 256 | out = self.head(x9) 257 | return out 258 | 259 | def forward(self, features, has_dropout): 260 | if self.worst_case: 261 | return self.decode_worst(features, has_dropout) 262 | else: 263 | return self.decode(features, has_dropout) 264 | 265 | 266 | class VNet_Decoupled(nn.Module): 267 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): 268 | super(VNet_Decoupled, self).__init__() 269 | self.has_dropout = has_dropout 270 | 271 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 272 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 273 | 274 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 275 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 276 | 277 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 278 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 279 | 280 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 281 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 282 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 283 | 284 | self.decoder = Decoder(n_classes=n_classes, n_filters=n_filters, normalization=normalization) 285 | self.decoder_pseudo = Decoder(n_classes=n_classes, n_filters=n_filters, normalization=normalization) 286 | self.decoder_worst = Decoder(n_classes=n_classes, n_filters=n_filters, normalization=normalization, worst_case=True) 287 | 288 | 289 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 290 | 291 | 292 | 293 | def encoder(self, input): 294 | # print(input.size()) 295 | x1 = self.block_one(input) 296 | x1_dw = self.block_one_dw(x1) 297 | 298 | x2 = self.block_two(x1_dw) 299 | x2_dw = self.block_two_dw(x2) 300 | 301 | x3 = self.block_three(x2_dw) 302 | x3_dw = self.block_three_dw(x3) 303 | 304 | x4 = self.block_four(x3_dw) 305 | x4_dw = self.block_four_dw(x4) 306 | 307 | x5 = self.block_five(x4_dw) 308 | if self.has_dropout: 309 | x5 = self.dropout(x5) 310 | 311 | return x1, x2, x3, x4, x5 312 | 313 | 314 | def forward(self, image, turnoff_drop=False): 315 | if turnoff_drop: 316 | has_dropout = self.has_dropout 317 | self.has_dropout = False 318 | 319 | features = self.encoder(image) 320 | 321 | 322 | out = self.decoder(features, self.has_dropout) 323 | 324 | if not self.training: 325 | return out 326 | 327 | out_pseudo = self.decoder_pseudo(features, self.has_dropout) 328 | out_worst = self.decoder_worst(features, self.has_dropout) 329 | 330 | if turnoff_drop: 331 | self.has_dropout = has_dropout 332 | 333 | 334 | return out, out_pseudo, out_worst 335 | 336 | def step(self): 337 | self.decoder_worst.grl_layer.step() 338 | -------------------------------------------------------------------------------- /dmp/code/train_depl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--task', type=str, default='synapse') 9 | parser.add_argument('--exp', type=str, default='cps') 10 | parser.add_argument('--seed', type=int, default=0) 11 | parser.add_argument('-sl', '--split_labeled', type=str, default='labeled_20p') 12 | parser.add_argument('-su', '--split_unlabeled', type=str, default='unlabeled_80p') 13 | parser.add_argument('-se', '--split_eval', type=str, default='eval') 14 | parser.add_argument('-m', '--mixed_precision', action='store_true', default=True) # <-- 15 | parser.add_argument('-ep', '--max_epoch', type=int, default=500) 16 | parser.add_argument('--cps_loss', type=str, default='wce') 17 | parser.add_argument('--sup_loss', type=str, default='w_ce+dice') 18 | parser.add_argument('--batch_size', type=int, default=2) 19 | parser.add_argument('--num_workers', type=int, default=2) 20 | parser.add_argument('--base_lr', type=float, default=0.001) 21 | parser.add_argument('-g', '--gpu', type=str, default='0') 22 | parser.add_argument('-w', '--cps_w', type=float, default=1) 23 | parser.add_argument('-r', '--cps_rampup', action='store_true', default=True) # <-- 24 | parser.add_argument('-cr', '--consistency_rampup', type=float, default=None) 25 | args = parser.parse_args() 26 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 27 | 28 | import numpy as np 29 | import torch 30 | import torch.optim as optim 31 | from torchvision import transforms 32 | from torch.utils.data import DataLoader 33 | from torch.utils.tensorboard import SummaryWriter 34 | from torch.cuda.amp import GradScaler, autocast 35 | 36 | from models.vnet import VNet 37 | from utils import maybe_mkdir, get_lr, fetch_data, seed_worker, poly_lr, xavier_normal_init_weight, kaiming_normal_init_weight 38 | from utils.loss import DC_and_CE_loss, RobustCrossEntropyLoss, SoftDiceLoss, WeightedCrossEntropyLoss 39 | from data.transforms import RandomCrop, CenterCrop, ToTensor, RandomFlip_LR, RandomFlip_UD 40 | from data.data_loaders import Synapse_AMOS 41 | import torch.nn.functional as F 42 | from utils.config import Config 43 | config = Config(args.task) 44 | 45 | def sigmoid_rampup(current, rampup_length): 46 | '''Exponential rampup from https://arxiv.org/abs/1610.02242''' 47 | if rampup_length == 0: 48 | return 1.0 49 | else: 50 | current = np.clip(current, 0.0, rampup_length) 51 | phase = 1.0 - current / rampup_length 52 | return float(np.exp(-5.0 * phase * phase)) 53 | 54 | 55 | def get_current_consistency_weight(epoch): 56 | if args.cps_rampup: 57 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 58 | if args.consistency_rampup is None: 59 | args.consistency_rampup = args.max_epoch 60 | return args.cps_w * sigmoid_rampup(epoch, args.consistency_rampup) 61 | else: 62 | return args.cps_w 63 | 64 | 65 | def make_loss_function(name, weight=None): 66 | if name == 'ce': 67 | return RobustCrossEntropyLoss() 68 | elif name == 'wce': 69 | return RobustCrossEntropyLoss(weight=weight) 70 | elif name == 'ce+dice': 71 | return DC_and_CE_loss() 72 | elif name == 'wce+dice': 73 | return DC_and_CE_loss(w_ce=weight) 74 | elif name == 'w_ce+dice': 75 | return DC_and_CE_loss(w_dc=weight, w_ce=weight) 76 | else: 77 | raise ValueError(name) 78 | 79 | 80 | def make_loader(split, dst_cls=Synapse_AMOS, repeat=None, is_training=True, unlabeled=False): 81 | if is_training: 82 | dst = dst_cls( 83 | task=args.task, 84 | split=split, 85 | repeat=repeat, 86 | unlabeled=unlabeled, 87 | num_cls=config.num_cls, 88 | transform=transforms.Compose([ 89 | RandomCrop(config.patch_size), 90 | RandomFlip_LR(), 91 | RandomFlip_UD(), 92 | ToTensor() 93 | ]) 94 | ) 95 | return DataLoader( 96 | dst, 97 | batch_size=args.batch_size, 98 | shuffle=True, 99 | num_workers=args.num_workers, 100 | pin_memory=True, 101 | worker_init_fn=seed_worker 102 | ) 103 | else: 104 | dst = dst_cls( 105 | task=args.task, 106 | split=split, 107 | is_val=True, 108 | num_cls=config.num_cls, 109 | transform=transforms.Compose([ 110 | CenterCrop(config.patch_size), 111 | ToTensor() 112 | ]) 113 | ) 114 | return DataLoader(dst, pin_memory=True) 115 | 116 | 117 | def make_model_all(): 118 | model = VNet( 119 | n_channels=config.num_channels, 120 | n_classes=config.num_cls, 121 | n_filters=config.n_filters, 122 | normalization='batchnorm', 123 | has_dropout=True 124 | ).cuda() 125 | optimizer = optim.SGD( 126 | model.parameters(), 127 | lr=args.base_lr, 128 | weight_decay=3e-5, 129 | momentum=0.9, 130 | nesterov=True 131 | ) 132 | return model, optimizer 133 | 134 | 135 | 136 | def causal_inference(current_logit, qhat, tau=0.5): 137 | # de-bias pseudo-labels 138 | # print(current_logit.shape) 139 | debiased_prob = F.softmax(current_logit - tau*torch.log(qhat), dim=1) 140 | return debiased_prob 141 | 142 | def initial_qhat(class_num=14): 143 | # initialize qhat of predictions (probability) 144 | qhat = torch.ones((1, class_num)+config.patch_size, dtype=torch.float) / class_num 145 | print("qhat size: ", qhat.size()) 146 | return qhat.cuda() 147 | 148 | def update_qhat(probs, qhat, momentum): 149 | mean_prob = probs.detach().mean(dim=0) 150 | qhat = momentum * qhat + (1 - momentum) * mean_prob 151 | return qhat 152 | 153 | 154 | 155 | if __name__ == '__main__': 156 | import random 157 | SEED=args.seed 158 | random.seed(SEED) 159 | np.random.seed(SEED) 160 | torch.manual_seed(SEED) 161 | torch.cuda.manual_seed(SEED) 162 | torch.cuda.manual_seed_all(SEED) 163 | fold = str(args.exp[-1]) 164 | if args.task == 'colon': 165 | #args.split_unlabeled = args.split_unlabeled+'_'+fold 166 | args.split_labeled = args.split_labeled+'_'+fold 167 | args.split_eval = args.split_eval+'_'+fold 168 | 169 | # make logger file 170 | snapshot_path = f'./logs/{args.exp}/' 171 | maybe_mkdir(snapshot_path) 172 | maybe_mkdir(os.path.join(snapshot_path, 'ckpts')) 173 | 174 | # make logger 175 | writer = SummaryWriter(os.path.join(snapshot_path, 'tensorboard')) 176 | logging.basicConfig( 177 | filename=os.path.join(snapshot_path, 'train.log'), 178 | level=logging.INFO, 179 | format='[%(asctime)s.%(msecs)03d] %(message)s', 180 | datefmt='%H:%M:%S' 181 | ) 182 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 183 | logging.info(str(args)) 184 | 185 | # make data loader 186 | unlabeled_loader = make_loader(args.split_unlabeled, unlabeled=True) 187 | labeled_loader = make_loader(args.split_labeled, repeat=len(unlabeled_loader.dataset)) 188 | eval_loader = make_loader(args.split_eval, is_training=False) 189 | 190 | 191 | logging.info(f'{len(labeled_loader)} itertations per epoch (labeled)') 192 | logging.info(f'{len(unlabeled_loader)} itertations per epoch (unlabeled)') 193 | 194 | # make model, optimizer, and lr scheduler 195 | model_A, optimizer_A = make_model_all() 196 | model_B, optimizer_B = make_model_all() 197 | model_A = kaiming_normal_init_weight(model_A) 198 | model_B = xavier_normal_init_weight(model_B) 199 | 200 | # make loss function 201 | # weight,_ = labeled_loader.dataset.weight() 202 | # print(weight) 203 | # print(sum(weight)) 204 | loss_func = make_loss_function(args.sup_loss) 205 | cps_loss_func = WeightedCrossEntropyLoss() 206 | 207 | if args.mixed_precision: 208 | amp_grad_scaler = GradScaler() 209 | 210 | cps_w = get_current_consistency_weight(0) 211 | best_eval = 0.0 212 | best_epoch = 0 213 | qhat_A = initial_qhat(class_num=config.num_cls) 214 | qhat_B = initial_qhat(class_num=config.num_cls) 215 | tau = 1.0 216 | for epoch_num in range(args.max_epoch + 1): 217 | loss_list = [] 218 | loss_cps_list = [] 219 | loss_sup_list = [] 220 | 221 | model_A.train() 222 | model_B.train() 223 | for batch_l, batch_u in tqdm(zip(labeled_loader, unlabeled_loader)): 224 | optimizer_A.zero_grad() 225 | optimizer_B.zero_grad() 226 | 227 | image_l, label_l = fetch_data(batch_l) 228 | image_u = fetch_data(batch_u, labeled=False) 229 | image = torch.cat([image_l, image_u], dim=0) 230 | tmp_bs = image.shape[0] // 2 231 | 232 | if args.mixed_precision: 233 | with autocast(): 234 | output_A = model_A(image) 235 | output_B = model_B(image) 236 | del image 237 | 238 | # sup (ce + dice) 239 | output_A_l, output_A_u = output_A[:tmp_bs, ...], output_A[tmp_bs:, ...] 240 | output_B_l, output_B_u = output_B[:tmp_bs, ...], output_B[tmp_bs:, ...] 241 | loss_sup = loss_func(output_A_l, label_l) + loss_func(output_B_l, label_l) 242 | 243 | # cps (ce only) 244 | 245 | pseudo_label_A = causal_inference(output_A.detach(), qhat_A, tau=tau) 246 | pseudo_label_B = causal_inference(output_B.detach(), qhat_B, tau=tau) 247 | 248 | max_probs_A, pseudo_targets_A = torch.max(pseudo_label_A, dim=1) 249 | mask_A = max_probs_A.ge(0.9).float() 250 | max_probs_B, pseudo_targets_B = torch.max(pseudo_label_B, dim=1) 251 | mask_B = max_probs_B.ge(0.9).float() 252 | 253 | # update qhat 254 | qhat_A = update_qhat(torch.softmax(output_A_u.detach(), dim=1), qhat_A, momentum=0.99) 255 | qhat_B = update_qhat(torch.softmax(output_B_u.detach(), dim=1), qhat_B, momentum=0.99) 256 | 257 | # adaptive marginal loss 258 | delta_logits_B = torch.log(qhat_B) 259 | output_A_u = output_A_u + tau * delta_logits_B 260 | 261 | delta_logits_A = torch.log(qhat_A) 262 | output_B_u = output_B_u + tau * delta_logits_A 263 | 264 | # pseudo_targets_A = torch.argmax(pseudo_label_A, dim=1, keepdim=True).long() 265 | # pseudo_targets_B = torch.argmax(pseudo_label_B, dim=1, keepdim=True).long() 266 | 267 | 268 | loss_cps = cps_loss_func(output_A, pseudo_targets_B, mask_B) + cps_loss_func(output_B, pseudo_targets_A, mask_A) 269 | # loss prop 270 | loss = loss_sup + cps_w * loss_cps 271 | 272 | 273 | # backward passes should not be under autocast. 274 | amp_grad_scaler.scale(loss).backward() 275 | amp_grad_scaler.step(optimizer_A) 276 | amp_grad_scaler.step(optimizer_B) 277 | amp_grad_scaler.update() 278 | 279 | else: 280 | raise NotImplementedError 281 | 282 | loss_list.append(loss.item()) 283 | loss_sup_list.append(loss_sup.item()) 284 | loss_cps_list.append(loss_cps.item()) 285 | 286 | writer.add_scalar('lr', get_lr(optimizer_A), epoch_num) 287 | writer.add_scalar('cps_w', cps_w, epoch_num) 288 | writer.add_scalar('loss/loss', np.mean(loss_list), epoch_num) 289 | writer.add_scalar('loss/sup', np.mean(loss_sup_list), epoch_num) 290 | writer.add_scalar('loss/cps', np.mean(loss_cps_list), epoch_num) 291 | logging.info(f'epoch {epoch_num} : loss : {np.mean(loss_list)}') 292 | 293 | optimizer_A.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 294 | optimizer_B.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 295 | cps_w = get_current_consistency_weight(epoch_num) 296 | 297 | if epoch_num % 10 == 0 or epoch_num>50: 298 | 299 | # ''' ===== evaluation 300 | dice_list = [[] for _ in range(config.num_cls-1)] 301 | model_A.eval() 302 | model_B.eval() 303 | dice_func = SoftDiceLoss(smooth=1e-8) 304 | for batch in tqdm(eval_loader): 305 | with torch.no_grad(): 306 | image, gt = fetch_data(batch) 307 | # output = model_A(image) 308 | output = (model_A(image) + model_B(image)) / 2.0 309 | del image 310 | 311 | shp = output.shape 312 | gt = gt.long() 313 | y_onehot = torch.zeros(shp).cuda() 314 | y_onehot.scatter_(1, gt, 1) 315 | 316 | x_onehot = torch.zeros(shp).cuda() 317 | output = torch.argmax(output, dim=1, keepdim=True).long() 318 | x_onehot.scatter_(1, output, 1) 319 | 320 | dice = dice_func(x_onehot, y_onehot, is_training=False) 321 | dice = dice.data.cpu().numpy() 322 | for i, d in enumerate(dice): 323 | dice_list[i].append(d) 324 | 325 | dice_mean = [] 326 | for dice in dice_list: 327 | dice_mean.append(np.mean(dice)) 328 | logging.info(f'evaluation epoch {epoch_num}, dice: {np.mean(dice_mean)}, {dice_mean}') 329 | # ''' 330 | if np.mean(dice_mean) > best_eval: 331 | best_eval = np.mean(dice_mean) 332 | best_epoch = epoch_num 333 | save_path = os.path.join(snapshot_path, f'ckpts/best_model.pth') 334 | torch.save({ 335 | 'A': model_A.state_dict(), 336 | 'B': model_B.state_dict() 337 | }, save_path) 338 | logging.info(f'saving best model to {save_path}') 339 | logging.info(f'\t best eval dice is {best_eval} in epoch {best_epoch}') 340 | 341 | if epoch_num - best_epoch == config.early_stop_patience: 342 | logging.info(f'Early stop.') 343 | break 344 | 345 | writer.close() 346 | -------------------------------------------------------------------------------- /dmp/code/train_crest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--task', type=str, default='synapse') 9 | parser.add_argument('--exp', type=str, default='cps') 10 | parser.add_argument('--seed', type=int, default=0) 11 | parser.add_argument('-sl', '--split_labeled', type=str, default='labeled_20p') 12 | parser.add_argument('-su', '--split_unlabeled', type=str, default='unlabeled_80p') 13 | parser.add_argument('-se', '--split_eval', type=str, default='eval') 14 | parser.add_argument('-m', '--mixed_precision', action='store_true', default=True) # <-- 15 | parser.add_argument('-ep', '--max_epoch', type=int, default=500) 16 | parser.add_argument('--cps_loss', type=str, default='wce') 17 | parser.add_argument('--sup_loss', type=str, default='w_ce+dice') 18 | parser.add_argument('--batch_size', type=int, default=2) 19 | parser.add_argument('--num_workers', type=int, default=2) 20 | parser.add_argument('--base_lr', type=float, default=0.001) 21 | parser.add_argument('-g', '--gpu', type=str, default='0') 22 | parser.add_argument('-w', '--cps_w', type=float, default=0.1) 23 | parser.add_argument('-r', '--cps_rampup', action='store_true', default=False) # <-- 24 | parser.add_argument('-cr', '--consistency_rampup', type=float, default=None) 25 | args = parser.parse_args() 26 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 27 | 28 | import numpy as np 29 | import torch 30 | from torch import nn 31 | import torch.optim as optim 32 | from torchvision import transforms 33 | from torch.utils.data import DataLoader 34 | from torch.utils.tensorboard import SummaryWriter 35 | from torch.cuda.amp import GradScaler, autocast 36 | 37 | from models.vnet import VNet 38 | from utils import EMA, maybe_mkdir, get_lr, fetch_data, seed_worker, poly_lr, kaiming_normal_init_weight, xavier_normal_init_weight, print_func 39 | from utils.loss import DC_and_CE_loss, RobustCrossEntropyLoss, SoftDiceLoss 40 | from data.transforms import RandomCrop, CenterCrop, ToTensor, RandomFlip_UD, RandomFlip_LR 41 | from data.data_loaders import Synapse_AMOS 42 | from utils.config import Config 43 | config = Config(args.task) 44 | 45 | 46 | def sigmoid_rampup(current, rampup_length): 47 | '''Exponential rampup from https://arxiv.org/abs/1610.02242''' 48 | if rampup_length == 0: 49 | return 1.0 50 | else: 51 | current = np.clip(current, 0.0, rampup_length) 52 | phase = 1.0 - current / rampup_length 53 | return float(np.exp(-5.0 * phase * phase)) 54 | 55 | 56 | def get_current_consistency_weight(epoch): 57 | if args.cps_rampup: 58 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 59 | if args.consistency_rampup is None: 60 | args.consistency_rampup = args.max_epoch 61 | return args.cps_w * sigmoid_rampup(epoch, args.consistency_rampup) 62 | else: 63 | return args.cps_w 64 | 65 | 66 | 67 | def make_loss_function(name, weight=None): 68 | if name == 'ce': 69 | return RobustCrossEntropyLoss() 70 | elif name == 'wce': 71 | return RobustCrossEntropyLoss(weight=weight) 72 | elif name == 'ce+dice': 73 | return DC_and_CE_loss() 74 | elif name == 'wce+dice': 75 | return DC_and_CE_loss(w_ce=weight) 76 | elif name == 'w_ce+dice': 77 | return DC_and_CE_loss(w_dc=weight, w_ce=weight) 78 | else: 79 | raise ValueError(name) 80 | 81 | 82 | def make_loader(split, dst_cls=Synapse_AMOS, repeat=None, is_training=True, unlabeled=False): 83 | if is_training: 84 | dst = dst_cls( 85 | task=args.task, 86 | split=split, 87 | repeat=repeat, 88 | unlabeled=unlabeled, 89 | num_cls=config.num_cls, 90 | transform=transforms.Compose([ 91 | RandomCrop(config.patch_size), 92 | RandomFlip_LR(), 93 | RandomFlip_UD(), 94 | ToTensor() 95 | ]) 96 | ) 97 | return DataLoader( 98 | dst, 99 | batch_size=args.batch_size, 100 | shuffle=True, 101 | num_workers=args.num_workers, 102 | pin_memory=True, 103 | worker_init_fn=seed_worker 104 | ) 105 | else: 106 | dst = dst_cls( 107 | task=args.task, 108 | split=split, 109 | is_val=True, 110 | num_cls=config.num_cls, 111 | transform=transforms.Compose([ 112 | CenterCrop(config.patch_size), 113 | ToTensor() 114 | ]) 115 | ) 116 | return DataLoader(dst, pin_memory=True) 117 | 118 | 119 | def make_model_all(): 120 | model = VNet( 121 | n_channels=config.num_channels, 122 | n_classes=config.num_cls, 123 | n_filters=config.n_filters, 124 | normalization='batchnorm', 125 | has_dropout=True 126 | ).cuda() 127 | optimizer = optim.SGD( 128 | model.parameters(), 129 | lr=args.base_lr, 130 | weight_decay=3e-5, 131 | momentum=0.9, 132 | nesterov=True 133 | ) 134 | return model, optimizer 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | class CReST(): 143 | def __init__(self, num_cls, do_bg=False, momentum=0.99): 144 | self.num_cls = num_cls 145 | self.do_bg = do_bg 146 | self.momentum = momentum 147 | 148 | def _cal_class_num(self, label_numpy): 149 | num_each_class = np.zeros(self.num_cls) 150 | for i in range(label_numpy.shape[0]): 151 | label = label_numpy[i].reshape(-1) 152 | tmp, _ = np.histogram(label, range(self.num_cls + 1)) 153 | num_each_class += tmp 154 | return num_each_class.astype(np.float32) 155 | 156 | def _weight_wbg(self, num_each_class): 157 | weight = np.zeros(self.num_cls) 158 | mu_all = np.power((num_each_class+1e-8) / (np.amax(num_each_class)+1e-8), 1/3) 159 | index = [i for i in range(config.num_cls)] 160 | mu_all_dict = dict(zip(index, list(mu_all))) 161 | mu_all_dict = dict(sorted(mu_all_dict.items(),key=lambda s:s[1], reverse=True)) 162 | for l in range(config.num_cls): 163 | idx = list(mu_all_dict.items())[l][0] # get the index of the majority class 164 | mu_l = list(mu_all_dict.items()) 165 | weight[idx] = mu_l[config.num_cls-1-l][1] # eq (1) 166 | return weight 167 | 168 | def init_crest_weight(self, labeled_dataset): 169 | if labeled_dataset.unlabeled: 170 | raise ValueError 171 | num_each_class = np.zeros(self.num_cls) 172 | for data_id in labeled_dataset.ids_list: 173 | _, _, label = labeled_dataset._get_data(data_id) 174 | label = label.reshape(-1) 175 | tmp, _ = np.histogram(label, range(self.num_cls + 1)) 176 | num_each_class += tmp 177 | weights = self._weight_wbg(num_each_class) 178 | self.weights = torch.FloatTensor(weights).cuda() 179 | return weights 180 | 181 | 182 | def cal_cur_weight(self, pseudo_label): 183 | pseudo_label = torch.argmax(pseudo_label.detach(), dim=1, keepdim=True).long() 184 | num_each_class = self._cal_class_num(pseudo_label.data.cpu().numpy()) 185 | 186 | weights = self._weight_wbg(num_each_class) 187 | weights = torch.FloatTensor(weights).cuda() 188 | return weights 189 | 190 | def get_weights(self, pseudo_label): 191 | cur_weights = self.cal_cur_weight(pseudo_label) 192 | self.weights = EMA(cur_weights, self.weights, momentum=self.momentum) 193 | return self.weights 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | if __name__ == '__main__': 203 | import random 204 | SEED=args.seed 205 | random.seed(SEED) 206 | np.random.seed(SEED) 207 | torch.manual_seed(SEED) 208 | torch.cuda.manual_seed(SEED) 209 | torch.cuda.manual_seed_all(SEED) 210 | fold = str(args.exp[-1]) 211 | if args.task == 'colon': 212 | #args.split_unlabeled = args.split_unlabeled+'_'+fold 213 | args.split_labeled = args.split_labeled+'_'+fold 214 | args.split_eval = args.split_eval+'_'+fold 215 | 216 | # make logger file 217 | snapshot_path = f'./logs/{args.exp}/' 218 | maybe_mkdir(snapshot_path) 219 | maybe_mkdir(os.path.join(snapshot_path, 'ckpts')) 220 | 221 | # make logger 222 | writer = SummaryWriter(os.path.join(snapshot_path, 'tensorboard')) 223 | logging.basicConfig( 224 | filename=os.path.join(snapshot_path, 'train.log'), 225 | level=logging.INFO, 226 | format='[%(asctime)s.%(msecs)03d] %(message)s', 227 | datefmt='%H:%M:%S' 228 | ) 229 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 230 | logging.info(str(args)) 231 | fold = str(args.exp[-1]) 232 | 233 | # make data loader 234 | unlabeled_loader = make_loader(args.split_unlabeled, unlabeled=True) 235 | labeled_loader = make_loader(args.split_labeled, repeat=len(unlabeled_loader.dataset)) 236 | eval_loader = make_loader(args.split_eval, is_training=False) 237 | 238 | 239 | logging.info(f'{len(labeled_loader)} itertations per epoch (labeled)') 240 | logging.info(f'{len(unlabeled_loader)} itertations per epoch (unlabeled)') 241 | 242 | # make model, optimizer, and lr scheduler 243 | model_A, optimizer_A = make_model_all() 244 | model_B, optimizer_B = make_model_all() 245 | model_A = kaiming_normal_init_weight(model_A) 246 | #model_B = xavier_normal_init_weight(model_B) 247 | model_B = kaiming_normal_init_weight(model_B) 248 | logging.info(optimizer_A) 249 | 250 | crest_A = CReST(config.num_cls, do_bg=True, momentum=0.99) 251 | crest_B = CReST(config.num_cls, do_bg=True, momentum=0.99) 252 | # crest_B = CReST() 253 | 254 | # make loss function 255 | weight_A = crest_A.init_crest_weight(labeled_loader.dataset) 256 | weight_B = crest_B.init_crest_weight(labeled_loader.dataset) 257 | print(weight_A) 258 | # print(torch.nn.functional.softmax(weight, dim=0)) 259 | # print(softmax_weight) 260 | loss_func_A = make_loss_function(args.sup_loss, weight=weight_A) 261 | loss_func_B = make_loss_function(args.sup_loss, weight=weight_B) 262 | cps_loss_func_A = make_loss_function(args.cps_loss, weight=weight_A) 263 | cps_loss_func_B = make_loss_function(args.cps_loss, weight=weight_B) 264 | 265 | # weight_A = weight_B = torch.FloatTensor(weight).cuda() 266 | 267 | 268 | if args.mixed_precision: 269 | amp_grad_scaler = GradScaler() 270 | 271 | cps_w = get_current_consistency_weight(0) 272 | best_eval = 0.0 273 | best_epoch = 0 274 | for epoch_num in range(args.max_epoch + 1): 275 | loss_list = [] 276 | loss_cps_list = [] 277 | loss_sup_list = [] 278 | 279 | model_A.train() 280 | model_B.train() 281 | for batch_l, batch_u in tqdm(zip(labeled_loader, unlabeled_loader)): 282 | optimizer_A.zero_grad() 283 | optimizer_B.zero_grad() 284 | 285 | image_l, label_l = fetch_data(batch_l) 286 | image_u = fetch_data(batch_u, labeled=False) 287 | image = torch.cat([image_l, image_u], dim=0) 288 | tmp_bs = image.shape[0] // 2 289 | 290 | if args.mixed_precision: 291 | with autocast(): 292 | output_A = model_A(image) 293 | output_B = model_B(image) 294 | del image 295 | 296 | # sup (ce + dice) 297 | output_A_l, output_A_u = output_A[:tmp_bs, ...], output_A[tmp_bs:, ...] 298 | output_B_l, output_B_u = output_B[:tmp_bs, ...], output_B[tmp_bs:, ...] 299 | 300 | 301 | # cps (ce only) 302 | max_A = torch.argmax(output_A.detach(), dim=1, keepdim=True).long() 303 | max_B = torch.argmax(output_B.detach(), dim=1, keepdim=True).long() 304 | 305 | 306 | # if epoch_num>0: 307 | 308 | weight_A = crest_A.get_weights(output_A_u) 309 | weight_B = crest_B.get_weights(output_B_u) 310 | # print(output_A) 311 | # print(new_weight_A) 312 | # print(softmax_weight) 313 | 314 | # weight_A = EMA(new_weight_A, weight_A, momentum=0.99) 315 | # weight_B = EMA(new_weight_B, weight_B, momentum=0.99) 316 | 317 | # weight_A = weight_A / torch.max(weight_A) 318 | # weight_B = weight_B / torch.max(weight_B) 319 | 320 | cps_loss_func_A.update_weight(weight_A) 321 | cps_loss_func_B.update_weight(weight_B) 322 | loss_func_A.update_weight(weight_A) 323 | loss_func_B.update_weight(weight_B) 324 | 325 | # softmax_weight = new_weight / torch.sum(new_weight) 326 | # new_weight = torch.pow(torch.amax(softmax_weight) / softmax_weight, 1/3) 327 | # new_weight = torch.cat([torch.tensor([1.0]).cuda(), new_weight]) 328 | # print("====",new_weight) 329 | loss_sup = loss_func_A(output_A_l, label_l) + loss_func_B(output_B_l, label_l) 330 | loss_cps = cps_loss_func_A(output_A, max_B) + cps_loss_func_B(output_B, max_A) 331 | # loss prop 332 | loss = loss_sup + cps_w * loss_cps 333 | 334 | 335 | 336 | # backward passes should not be under autocast. 337 | amp_grad_scaler.scale(loss).backward() 338 | amp_grad_scaler.step(optimizer_A) 339 | amp_grad_scaler.step(optimizer_B) 340 | amp_grad_scaler.update() 341 | # if epoch_num>0: 342 | 343 | else: 344 | raise NotImplementedError 345 | 346 | loss_list.append(loss.item()) 347 | loss_sup_list.append(loss_sup.item()) 348 | loss_cps_list.append(loss_cps.item()) 349 | 350 | writer.add_scalar('lr', get_lr(optimizer_A), epoch_num) 351 | writer.add_scalar('cps_w', cps_w, epoch_num) 352 | writer.add_scalar('loss/loss', np.mean(loss_list), epoch_num) 353 | writer.add_scalar('loss/sup', np.mean(loss_sup_list), epoch_num) 354 | writer.add_scalar('loss/cps', np.mean(loss_cps_list), epoch_num) 355 | writer.add_scalars('class_weights/A', dict(zip([str(i) for i in range(config.num_cls)] ,print_func(weight_A))), epoch_num) 356 | writer.add_scalars('class_weights/B', dict(zip([str(i) for i in range(config.num_cls)] ,print_func(weight_B))), epoch_num) 357 | logging.info(f'epoch {epoch_num} : loss : {np.mean(loss_list)}') 358 | # logging.info(f' cps_w: {cps_w}') 359 | # if epoch_num>0: 360 | logging.info(f" Class Weights: {print_func(weight_A)}, lr: {get_lr(optimizer_A)}") 361 | logging.info(f" Class Weights: {print_func(weight_B)}") 362 | # lr_scheduler_A.step() 363 | # lr_scheduler_B.step() 364 | optimizer_A.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 365 | optimizer_B.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 366 | # print(optimizer_A.param_groups[0]['lr']) 367 | cps_w = get_current_consistency_weight(epoch_num) 368 | 369 | if epoch_num % 10 == 0 or epoch_num >= 50: 370 | 371 | # ''' ===== evaluation 372 | dice_list = [[] for _ in range(config.num_cls-1)] 373 | model_A.eval() 374 | model_B.eval() 375 | dice_func = SoftDiceLoss(smooth=1e-8) 376 | for batch in tqdm(eval_loader): 377 | with torch.no_grad(): 378 | image, gt = fetch_data(batch) 379 | # output = model_A(image) 380 | output = (model_A(image) + model_B(image)) / 2.0 381 | del image 382 | 383 | shp = output.shape 384 | gt = gt.long() 385 | y_onehot = torch.zeros(shp).cuda() 386 | y_onehot.scatter_(1, gt, 1) 387 | 388 | x_onehot = torch.zeros(shp).cuda() 389 | output = torch.argmax(output, dim=1, keepdim=True).long() 390 | x_onehot.scatter_(1, output, 1) 391 | 392 | 393 | dice = dice_func(x_onehot, y_onehot, is_training=False) 394 | dice = dice.data.cpu().numpy() 395 | for i, d in enumerate(dice): 396 | dice_list[i].append(d) 397 | 398 | dice_mean = [] 399 | for dice in dice_list: 400 | dice_mean.append(np.mean(dice)) 401 | logging.info(f'evaluation epoch {epoch_num}, dice: {np.mean(dice_mean)}, {dice_mean}') 402 | # ''' 403 | if np.mean(dice_mean) > best_eval: 404 | best_eval = np.mean(dice_mean) 405 | best_epoch = epoch_num 406 | save_path = os.path.join(snapshot_path, f'ckpts/best_model.pth') 407 | torch.save({ 408 | 'A': model_A.state_dict(), 409 | 'B': model_B.state_dict() 410 | }, save_path) 411 | logging.info(f'saving best model to {save_path}') 412 | logging.info(f'\t best eval dice is {best_eval} in epoch {best_epoch}') 413 | 414 | if epoch_num - best_epoch == config.early_stop_patience: 415 | logging.info(f'Early stop.') 416 | break 417 | 418 | writer.close() 419 | -------------------------------------------------------------------------------- /dmp/code/train_dhc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--task', type=str, default='synapse') 9 | parser.add_argument('--exp', type=str) 10 | parser.add_argument('--seed', type=int, default=0) 11 | parser.add_argument('-sl', '--split_labeled', type=str, default='labeled_20p') 12 | parser.add_argument('-su', '--split_unlabeled', type=str, default='unlabeled_80p') 13 | parser.add_argument('-se', '--split_eval', type=str, default='eval') 14 | parser.add_argument('-m', '--mixed_precision', action='store_true', default=True) # <-- 15 | parser.add_argument('-ep', '--max_epoch', type=int, default=500) 16 | parser.add_argument('--cps_loss', type=str, default='wce') 17 | parser.add_argument('--sup_loss', type=str, default='w_ce+dice') 18 | parser.add_argument('--batch_size', type=int, default=2) 19 | parser.add_argument('--num_workers', type=int, default=2) 20 | parser.add_argument('--base_lr', type=float, default=0.001) 21 | parser.add_argument('-g', '--gpu', type=str, default='0') 22 | parser.add_argument('-w', '--cps_w', type=float, default=1) 23 | parser.add_argument('-r', '--cps_rampup', action='store_true', default=True) # <-- 24 | parser.add_argument('-cr', '--consistency_rampup', type=float, default=None) 25 | args = parser.parse_args() 26 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 27 | 28 | import numpy as np 29 | import torch 30 | import torch.optim as optim 31 | from torchvision import transforms 32 | from torch.utils.data import DataLoader 33 | from torch.utils.tensorboard import SummaryWriter 34 | from torch.cuda.amp import GradScaler, autocast 35 | 36 | from models.vnet import VNet 37 | from utils import EMA, maybe_mkdir, get_lr, fetch_data, seed_worker, poly_lr, print_func, kaiming_normal_init_weight,xavier_normal_init_weight 38 | from utils.loss import DC_and_CE_loss, RobustCrossEntropyLoss, SoftDiceLoss 39 | from data.transforms import RandomCrop, CenterCrop, ToTensor, RandomFlip_LR, RandomFlip_UD 40 | from data.data_loaders import Synapse_AMOS 41 | from utils.config import Config 42 | config = Config(args.task) 43 | 44 | 45 | 46 | def sigmoid_rampup(current, rampup_length): 47 | '''Exponential rampup from https://arxiv.org/abs/1610.02242''' 48 | if rampup_length == 0: 49 | return 1.0 50 | else: 51 | current = np.clip(current, 0.0, rampup_length) 52 | phase = 1.0 - current / rampup_length 53 | return float(np.exp(-5.0 * phase * phase)) 54 | 55 | 56 | def get_current_consistency_weight(epoch): 57 | if args.cps_rampup: 58 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 59 | if args.consistency_rampup is None: 60 | args.consistency_rampup = args.max_epoch 61 | return args.cps_w * sigmoid_rampup(epoch, args.consistency_rampup) 62 | else: 63 | return args.cps_w 64 | 65 | 66 | 67 | def make_loss_function(name, weight=None): 68 | if name == 'ce': 69 | return RobustCrossEntropyLoss() 70 | elif name == 'wce': 71 | return RobustCrossEntropyLoss(weight=weight) 72 | elif name == 'ce+dice': 73 | return DC_and_CE_loss() 74 | elif name == 'wce+dice': 75 | return DC_and_CE_loss(w_ce=weight) 76 | elif name == 'w_ce+dice': 77 | return DC_and_CE_loss(w_dc=weight, w_ce=weight) 78 | else: 79 | raise ValueError(name) 80 | 81 | 82 | def make_loader(split, dst_cls=Synapse_AMOS, repeat=None, is_training=True, unlabeled=False): 83 | if is_training: 84 | dst = dst_cls( 85 | task=args.task, 86 | split=split, 87 | repeat=repeat, 88 | unlabeled=unlabeled, 89 | num_cls=config.num_cls, 90 | transform=transforms.Compose([ 91 | RandomCrop(config.patch_size), 92 | RandomFlip_LR(), 93 | RandomFlip_UD(), 94 | ToTensor() 95 | ]) 96 | ) 97 | return DataLoader( 98 | dst, 99 | batch_size=args.batch_size, 100 | shuffle=True, 101 | num_workers=args.num_workers, 102 | pin_memory=True, 103 | worker_init_fn=seed_worker 104 | ) 105 | else: 106 | dst = dst_cls( 107 | task=args.task, 108 | split=split, 109 | is_val=True, 110 | num_cls=config.num_cls, 111 | transform=transforms.Compose([ 112 | CenterCrop(config.patch_size), 113 | ToTensor() 114 | ]) 115 | ) 116 | return DataLoader(dst, pin_memory=True) 117 | 118 | 119 | def make_model_all(): 120 | model = VNet( 121 | n_channels=config.num_channels, 122 | n_classes=config.num_cls, 123 | n_filters=config.n_filters, 124 | normalization='batchnorm', 125 | has_dropout=True 126 | ).cuda() 127 | optimizer = optim.SGD( 128 | model.parameters(), 129 | lr=args.base_lr, 130 | momentum=0.9, 131 | weight_decay=3e-5, 132 | nesterov=True 133 | ) 134 | 135 | return model, optimizer 136 | 137 | 138 | 139 | 140 | class DistDW: 141 | def __init__(self, num_cls, do_bg=False, momentum=0.95): 142 | self.num_cls = num_cls 143 | self.do_bg = do_bg 144 | self.momentum = momentum 145 | 146 | def _cal_weights(self, num_each_class): 147 | num_each_class = torch.FloatTensor(num_each_class).cuda() 148 | P = (num_each_class.max()+1e-8) / (num_each_class+1e-8) 149 | P_log = torch.log(P) 150 | weight = P_log / P_log.max() 151 | return weight 152 | 153 | def init_weights(self, labeled_dataset): 154 | if labeled_dataset.unlabeled: 155 | raise ValueError 156 | num_each_class = np.zeros(self.num_cls) 157 | for data_id in labeled_dataset.ids_list: 158 | _, _, label = labeled_dataset._get_data(data_id) 159 | label = label.reshape(-1) 160 | tmp, _ = np.histogram(label, range(self.num_cls + 1)) 161 | num_each_class += tmp 162 | weights = self._cal_weights(num_each_class) 163 | self.weights = weights * self.num_cls 164 | return self.weights.data.cpu().numpy() 165 | 166 | def get_ema_weights(self, pseudo_label): 167 | pseudo_label = torch.argmax(pseudo_label.detach(), dim=1, keepdim=True).long() 168 | label_numpy = pseudo_label.data.cpu().numpy() 169 | num_each_class = np.zeros(self.num_cls) 170 | for i in range(label_numpy.shape[0]): 171 | label = label_numpy[i].reshape(-1) 172 | tmp, _ = np.histogram(label, range(self.num_cls + 1)) 173 | num_each_class += tmp 174 | 175 | cur_weights = self._cal_weights(num_each_class) * self.num_cls 176 | self.weights = EMA(cur_weights, self.weights, momentum=self.momentum) 177 | return self.weights 178 | 179 | 180 | 181 | class DiffDW: 182 | def __init__(self, num_cls, accumulate_iters=20): 183 | self.last_dice = torch.zeros(num_cls).float().cuda() + 1e-8 184 | self.dice_func = SoftDiceLoss(smooth=1e-8, do_bg=True) 185 | self.cls_learn = torch.zeros(num_cls).float().cuda() 186 | self.cls_unlearn = torch.zeros(num_cls).float().cuda() 187 | self.num_cls = num_cls 188 | self.dice_weight = torch.ones(num_cls).float().cuda() 189 | self.accumulate_iters = accumulate_iters 190 | 191 | def init_weights(self): 192 | weights = np.ones(config.num_cls) * self.num_cls 193 | self.weights = torch.FloatTensor(weights).cuda() 194 | return weights 195 | 196 | def cal_weights(self, pred, label): 197 | x_onehot = torch.zeros(pred.shape).cuda() 198 | output = torch.argmax(pred, dim=1, keepdim=True).long() 199 | x_onehot.scatter_(1, output, 1) 200 | y_onehot = torch.zeros(pred.shape).cuda() 201 | y_onehot.scatter_(1, label, 1) 202 | cur_dice = self.dice_func(x_onehot, y_onehot, is_training=False) 203 | delta_dice = cur_dice - self.last_dice 204 | cur_cls_learn = torch.where(delta_dice>0, delta_dice, 0) * torch.log(cur_dice / self.last_dice) 205 | cur_cls_unlearn = torch.where(delta_dice<=0, delta_dice, 0) * torch.log(cur_dice / self.last_dice) 206 | self.last_dice = cur_dice 207 | self.cls_learn = EMA(cur_cls_learn, self.cls_learn, momentum=(self.accumulate_iters-1)/self.accumulate_iters) 208 | self.cls_unlearn = EMA(cur_cls_unlearn, self.cls_unlearn, momentum=(self.accumulate_iters-1)/self.accumulate_iters) 209 | cur_diff = (self.cls_unlearn + 1e-8) / (self.cls_learn + 1e-8) 210 | cur_diff = torch.pow(cur_diff, 1/5) 211 | self.dice_weight = EMA(1. - cur_dice, self.dice_weight, momentum=(self.accumulate_iters-1)/self.accumulate_iters) 212 | weights = cur_diff * self.dice_weight 213 | weights = weights / weights.max() 214 | return weights * self.num_cls 215 | 216 | 217 | 218 | 219 | 220 | if __name__ == '__main__': 221 | import random 222 | SEED=args.seed 223 | random.seed(SEED) 224 | np.random.seed(SEED) 225 | torch.manual_seed(SEED) 226 | torch.cuda.manual_seed(SEED) 227 | torch.cuda.manual_seed_all(SEED) 228 | fold = str(args.exp[-1]) 229 | if args.task == 'colon': 230 | #args.split_unlabeled = args.split_unlabeled+'_'+fold 231 | args.split_labeled = args.split_labeled+'_'+fold 232 | args.split_eval = args.split_eval+'_'+fold 233 | # make logger file 234 | snapshot_path = f'./logs/{args.exp}/' 235 | maybe_mkdir(snapshot_path) 236 | maybe_mkdir(os.path.join(snapshot_path, 'ckpts')) 237 | 238 | # make logger 239 | writer = SummaryWriter(os.path.join(snapshot_path, 'tensorboard')) 240 | logging.basicConfig( 241 | filename=os.path.join(snapshot_path, 'train.log'), 242 | level=logging.INFO, 243 | format='[%(asctime)s.%(msecs)03d] %(message)s', 244 | datefmt='%H:%M:%S' 245 | ) 246 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 247 | logging.info(str(args)) 248 | 249 | # make data loader 250 | unlabeled_loader = make_loader(args.split_unlabeled, unlabeled=True) 251 | labeled_loader = make_loader(args.split_labeled, repeat=len(unlabeled_loader.dataset)) 252 | eval_loader = make_loader(args.split_eval, is_training=False) 253 | 254 | 255 | 256 | logging.info(f'{len(labeled_loader)} itertations per epoch (labeled)') 257 | logging.info(f'{len(unlabeled_loader)} itertations per epoch (unlabeled)') 258 | 259 | # make model, optimizer, and lr scheduler 260 | model_A, optimizer_A = make_model_all() 261 | model_B, optimizer_B = make_model_all() 262 | model_A = kaiming_normal_init_weight(model_A) 263 | model_B = kaiming_normal_init_weight(model_B) 264 | 265 | 266 | # make loss function 267 | diffdw = DiffDW(config.num_cls, accumulate_iters=50) 268 | distdw = DistDW(config.num_cls, momentum=0.99) 269 | 270 | weight_A = diffdw.init_weights() 271 | weight_B = distdw.init_weights(labeled_loader.dataset) 272 | 273 | loss_func_A = make_loss_function(args.sup_loss, weight_A) 274 | loss_func_B = make_loss_function(args.sup_loss, weight_B) 275 | cps_loss_func_A = make_loss_function(args.cps_loss, weight_A) 276 | cps_loss_func_B = make_loss_function(args.cps_loss, weight_B) 277 | 278 | 279 | if args.mixed_precision: 280 | amp_grad_scaler = GradScaler() 281 | 282 | cps_w = get_current_consistency_weight(0) 283 | best_eval = 0.0 284 | best_epoch = 0 285 | for epoch_num in range(args.max_epoch + 1): 286 | loss_list = [] 287 | loss_cps_list = [] 288 | loss_sup_list = [] 289 | 290 | model_A.train() 291 | model_B.train() 292 | for batch_l, batch_u in tqdm(zip(labeled_loader, unlabeled_loader)): 293 | optimizer_A.zero_grad() 294 | optimizer_B.zero_grad() 295 | 296 | image_l, label_l = fetch_data(batch_l) 297 | image_u = fetch_data(batch_u, labeled=False) 298 | image = torch.cat([image_l, image_u], dim=0) 299 | tmp_bs = image.shape[0] // 2 300 | 301 | if args.mixed_precision: 302 | with autocast(): 303 | output_A = model_A(image) 304 | output_B = model_B(image) 305 | del image 306 | 307 | # sup (ce + dice) 308 | output_A_l, output_A_u = output_A[:tmp_bs, ...], output_A[tmp_bs:, ...] 309 | output_B_l, output_B_u = output_B[:tmp_bs, ...], output_B[tmp_bs:, ...] 310 | 311 | 312 | # cps (ce only) 313 | max_A = torch.argmax(output_A.detach(), dim=1, keepdim=True).long() 314 | max_B = torch.argmax(output_B.detach(), dim=1, keepdim=True).long() 315 | 316 | 317 | weight_A = diffdw.cal_weights(output_A_l.detach(), label_l.detach()) 318 | weight_B = distdw.get_ema_weights(output_B_u.detach()) 319 | 320 | 321 | 322 | loss_func_A.update_weight(weight_A) 323 | loss_func_B.update_weight(weight_B) 324 | cps_loss_func_A.update_weight(weight_A) 325 | cps_loss_func_B.update_weight(weight_B) 326 | 327 | 328 | loss_sup = loss_func_A(output_A_l, label_l) + loss_func_B(output_B_l, label_l) 329 | loss_cps = cps_loss_func_A(output_A, max_B) + cps_loss_func_B(output_B, max_A) 330 | loss = loss_sup + cps_w * loss_cps 331 | 332 | 333 | 334 | # backward passes should not be under autocast. 335 | amp_grad_scaler.scale(loss).backward() 336 | amp_grad_scaler.step(optimizer_A) 337 | amp_grad_scaler.step(optimizer_B) 338 | amp_grad_scaler.update() 339 | # if epoch_num>0: 340 | 341 | else: 342 | raise NotImplementedError 343 | 344 | loss_list.append(loss.item()) 345 | loss_sup_list.append(loss_sup.item()) 346 | loss_cps_list.append(loss_cps.item()) 347 | 348 | writer.add_scalar('lr', get_lr(optimizer_A), epoch_num) 349 | writer.add_scalar('cps_w', cps_w, epoch_num) 350 | writer.add_scalar('loss/loss', np.mean(loss_list), epoch_num) 351 | writer.add_scalar('loss/sup', np.mean(loss_sup_list), epoch_num) 352 | writer.add_scalar('loss/cps', np.mean(loss_cps_list), epoch_num) 353 | # print(dict(zip([i for i in range(config.num_cls)] ,print_func(weight_A)))) 354 | writer.add_scalars('class_weights/A', dict(zip([str(i) for i in range(config.num_cls)] ,print_func(weight_A))), epoch_num) 355 | writer.add_scalars('class_weights/B', dict(zip([str(i) for i in range(config.num_cls)] ,print_func(weight_B))), epoch_num) 356 | logging.info(f'epoch {epoch_num} : loss : {np.mean(loss_list)}') 357 | # logging.info(f' cps_w: {cps_w}') 358 | # if epoch_num>0: 359 | logging.info(f" Class Weights A: {print_func(weight_A)}, lr: {get_lr(optimizer_A)}") 360 | logging.info(f" Class Weights B: {print_func(weight_B)}") 361 | # logging.info(f" Class Weights u: {print_func(weight_u)}") 362 | # lr_scheduler_A.step() 363 | # lr_scheduler_B.step() 364 | optimizer_A.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 365 | optimizer_B.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 366 | # print(optimizer_A.param_groups[0]['lr']) 367 | cps_w = get_current_consistency_weight(epoch_num) 368 | if config.task == 'synapse': 369 | start = 60 370 | else: 371 | start = 3000 372 | if epoch_num % 10 == 0 or epoch_num >= 50: 373 | 374 | # ''' ===== evaluation 375 | dice_list = [[] for _ in range(config.num_cls-1)] 376 | model_A.eval() 377 | model_B.eval() 378 | dice_func = SoftDiceLoss(smooth=1e-8) 379 | for batch in tqdm(eval_loader): 380 | with torch.no_grad(): 381 | image, gt = fetch_data(batch) 382 | output = (model_A(image) + model_B(image))/2.0 383 | # output = model_B(image) 384 | del image 385 | 386 | shp = output.shape 387 | gt = gt.long() 388 | y_onehot = torch.zeros(shp).cuda() 389 | y_onehot.scatter_(1, gt, 1) 390 | 391 | x_onehot = torch.zeros(shp).cuda() 392 | output = torch.argmax(output, dim=1, keepdim=True).long() 393 | x_onehot.scatter_(1, output, 1) 394 | 395 | 396 | dice = dice_func(x_onehot, y_onehot, is_training=False) 397 | dice = dice.data.cpu().numpy() 398 | for i, d in enumerate(dice): 399 | dice_list[i].append(d) 400 | 401 | dice_mean = [] 402 | for dice in dice_list: 403 | dice_mean.append(np.mean(dice)) 404 | logging.info(f'evaluation epoch {epoch_num}, dice: {np.mean(dice_mean)}, {dice_mean}') 405 | # ''' 406 | if np.mean(dice_mean) > best_eval: 407 | best_eval = np.mean(dice_mean) 408 | best_epoch = epoch_num 409 | save_path = os.path.join(snapshot_path, f'ckpts/best_model.pth') 410 | torch.save({ 411 | 'A': model_A.state_dict(), 412 | 'B': model_B.state_dict() 413 | }, save_path) 414 | logging.info(f'saving best model to {save_path}') 415 | logging.info(f'\t best eval dice is {best_eval} in epoch {best_epoch}') 416 | if epoch_num - best_epoch == config.early_stop_patience: 417 | logging.info(f'Early stop.') 418 | break 419 | 420 | writer.close() 421 | -------------------------------------------------------------------------------- /dmp/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from tqdm import tqdm 4 | import numpy as np 5 | import random 6 | import SimpleITK as sitk 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from utils.config import Config 12 | 13 | 14 | 15 | 16 | def EMA(cur_weight, past_weight, momentum=0.9): 17 | new_weight = momentum * past_weight + (1 - momentum) * cur_weight 18 | return new_weight 19 | 20 | 21 | def kaiming_normal_init_weight(model): 22 | for m in model.modules(): 23 | if isinstance(m, nn.Conv3d): 24 | torch.nn.init.kaiming_normal_(m.weight) 25 | elif isinstance(m, nn.BatchNorm3d): 26 | m.weight.data.fill_(1) 27 | m.bias.data.zero_() 28 | return model 29 | 30 | 31 | def xavier_normal_init_weight(model): 32 | for m in model.modules(): 33 | if isinstance(m, nn.Conv3d): 34 | torch.nn.init.xavier_normal_(m.weight) 35 | elif isinstance(m, nn.BatchNorm3d): 36 | m.weight.data.fill_(1) 37 | m.bias.data.zero_() 38 | return model 39 | 40 | 41 | 42 | def print_func(item): 43 | # print(type(item)) 44 | if type(item) == torch.Tensor: 45 | return [round(x,4) for x in item.data.cpu().numpy().tolist()] 46 | elif type(item) == np.ndarray: 47 | return [round(x,4) for x in item.tolist()] 48 | else: 49 | raise TypeError 50 | 51 | 52 | def softmax(x): 53 | x_exp = np.exp(x) 54 | x_sum = np.sum(x_exp, axis=0, keepdims=True) 55 | s = x_exp / (x_sum) 56 | return s 57 | 58 | def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9): 59 | return initial_lr * (1 - epoch / max_epochs)**exponent 60 | 61 | def maybe_mkdir(path): 62 | if not os.path.exists(path): 63 | os.makedirs(path) 64 | 65 | 66 | def read_nifti(path): 67 | itk_img = sitk.ReadImage(path) 68 | itk_arr = sitk.GetArrayFromImage(itk_img) 69 | return itk_arr 70 | 71 | 72 | def read_list(split, task="synapse"): 73 | 74 | config = Config(task) 75 | 76 | ids_list = np.loadtxt( 77 | os.path.join(config.save_dir, 'splits', f'{split}.txt'), 78 | dtype=str 79 | ).tolist() 80 | return sorted(ids_list) 81 | 82 | 83 | def read_data(data_id, task, nifti=False, test=False, normalize=False,unlabeled = False): 84 | config = Config(task) 85 | im_path = os.path.join(config.save_dir, 'npy', f'{data_id}_image.npy') 86 | if not unlabeled: 87 | 88 | lb_path = os.path.join(config.save_dir, 'npy', f'{data_id}_label.npy') 89 | if not os.path.exists(im_path) or not os.path.exists(lb_path): 90 | print((f'data_id: {data_id}')) 91 | raise ValueError(data_id) 92 | 93 | image = np.load(im_path) 94 | label = np.load(lb_path) 95 | else: 96 | image = np.load(im_path) 97 | label = np.zeros_like(image) 98 | 99 | if normalize: 100 | if task == 'chd': 101 | min_val = np.percentile(image, 5) 102 | max_val = np.percentile(image, 95) 103 | image = image.clip(min=min_val, max=max_val) 104 | elif task == 'colon': 105 | image = image.clip(min=-250, max=275) 106 | else: 107 | image = image.clip(min=-75, max=275) 108 | image = (image - image.min()) / (image.max() - image.min()) 109 | image = image.astype(np.float32) 110 | 111 | return image, label 112 | 113 | 114 | def get_lr(optimizer): 115 | for param_group in optimizer.param_groups: 116 | return param_group['lr'] 117 | 118 | 119 | def seed_worker(worker_id): 120 | worker_seed = torch.initial_seed() % 2**32 121 | np.random.seed(worker_seed) 122 | random.seed(worker_seed) 123 | 124 | 125 | def fetch_data(batch, labeled=True): 126 | image = batch['image'].cuda() 127 | if labeled: 128 | label = batch['label'].cuda().unsqueeze(1) 129 | return image, label 130 | else: 131 | return image 132 | 133 | 134 | def test_all_case(net, ids_list, task, num_classes, patch_size, stride_xy, stride_z, test_save_path=None): 135 | for data_id in tqdm(ids_list): 136 | image, _ = read_data(data_id, task, test=True, normalize=True) 137 | pred, _ = test_single_case( 138 | net, 139 | image, 140 | stride_xy, 141 | stride_z, 142 | patch_size, 143 | num_classes=num_classes 144 | ) 145 | out = sitk.GetImageFromArray(pred.astype(np.float32)) 146 | sitk.WriteImage(out, f'{test_save_path}/{data_id}.nii.gz') 147 | 148 | 149 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes): 150 | image = image[np.newaxis] 151 | _, dd, ww, hh = image.shape 152 | # print(image.shape) 153 | # resize_shape=(patch_size[0]+patch_size[0]//4, 154 | # patch_size[1]+patch_size[1]//4, 155 | # patch_size[2]+patch_size[2]//4) 156 | # 157 | # image = torch.FloatTensor(image).unsqueeze(0) 158 | # image = F.interpolate(image, size=resize_shape,mode='trilinear', align_corners=False) 159 | # image = image.squeeze(0).numpy() 160 | 161 | image = image.transpose(0, 3, 2, 1) # <-- take care the shape 162 | # print(image.shape) 163 | patch_size = (patch_size[2], patch_size[1], patch_size[0]) 164 | _, ww, hh, dd = image.shape 165 | 166 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 167 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 168 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 169 | 170 | score_map = np.zeros((num_classes, ) + image.shape[1:4]).astype(np.float32) 171 | cnt = np.zeros(image.shape[1:4]).astype(np.float32) 172 | # print("score_map", score_map.shape) 173 | for x in range(sx): 174 | xs = min(stride_xy*x, ww-patch_size[0]) 175 | for y in range(sy): 176 | ys = min(stride_xy*y, hh-patch_size[1]) 177 | for z in range(sz): 178 | zs = min(stride_z*z, dd-patch_size[2]) 179 | test_patch = image[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 180 | # print("test", test_patch.shape) 181 | test_patch = np.expand_dims(test_patch, axis=0).astype(np.float32) 182 | test_patch = torch.from_numpy(test_patch).cuda() 183 | # print("===",test_patch.size()) 184 | # <-- [1, 1, Z, Y, X] => [1, 1, X, Y, Z] 185 | test_patch = test_patch.transpose(2, 4) 186 | y1 = net(test_patch) # <-- 187 | y = F.softmax(y1, dim=1) # <-- 188 | y = y.cpu().data.numpy() 189 | y = y[0, ...] 190 | y = y.transpose(0, 3, 2, 1) 191 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] += y 192 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] += 1 193 | # print("score_map", score_map.shape) 194 | # print("score_map", cnt.shape) 195 | 196 | score_map = score_map / np.expand_dims(cnt, axis=0) # [Z, Y, X] 197 | score_map = score_map.transpose(0, 3, 2, 1) # => [X, Y, Z] 198 | label_map = np.argmax(score_map, axis=0) 199 | return label_map, score_map 200 | 201 | 202 | 203 | def test_all_case_AB(net_A, net_B, ids_list, task, num_classes, patch_size, stride_xy, stride_z, test_save_path=None): 204 | for data_id in tqdm(ids_list): 205 | image, _ = read_data(data_id, task, test=True, normalize=True) 206 | pred, _ = test_single_case_AB( 207 | net_A, net_B, 208 | image, 209 | stride_xy, 210 | stride_z, 211 | patch_size, 212 | num_classes=num_classes 213 | ) 214 | out = sitk.GetImageFromArray(pred.astype(np.float32)) 215 | sitk.WriteImage(out, f'{test_save_path}/{data_id}.nii.gz') 216 | 217 | 218 | def test_single_case_AB(net_A, net_B, image, stride_xy, stride_z, patch_size, num_classes): 219 | image = image[np.newaxis] 220 | 221 | _, dd, ww, hh = image.shape 222 | #print(image.shape) 223 | # resize_shape=(patch_size[0]+patch_size[0]//4, 224 | # patch_size[1]+patch_size[1]//4, 225 | # patch_size[2]+patch_size[2]//4) 226 | 227 | # image = torch.FloatTensor(image).unsqueeze(0) 228 | # image = F.interpolate(image, size=resize_shape,mode='trilinear', align_corners=False) 229 | # image = image.squeeze(0).numpy() 230 | 231 | image = image.transpose(0, 3, 2, 1) # <-- take care the shape 232 | # print(image.shape) 233 | patch_size = (patch_size[2], patch_size[1], patch_size[0]) 234 | _, ww, hh, dd = image.shape 235 | 236 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 237 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 238 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 239 | 240 | score_map = np.zeros((num_classes, ) + image.shape[1:4]).astype(np.float32) 241 | cnt = np.zeros(image.shape[1:4]).astype(np.float32) 242 | # print("score_map", score_map.shape) 243 | for x in range(sx): 244 | xs = min(stride_xy*x, ww-patch_size[0]) 245 | for y in range(sy): 246 | ys = min(stride_xy*y, hh-patch_size[1]) 247 | for z in range(sz): 248 | zs = min(stride_z*z, dd-patch_size[2]) 249 | test_patch = image[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 250 | # print("test", test_patch.shape) 251 | test_patch = np.expand_dims(test_patch, axis=0).astype(np.float32) 252 | test_patch = torch.from_numpy(test_patch).cuda() 253 | # print("===",test_patch.size()) 254 | # <-- [1, 1, Z, Y, X] => [1, 1, X, Y, Z] 255 | test_patch = test_patch.transpose(2, 4) 256 | y1 = (net_A(test_patch) + net_B(test_patch)) / 2.0 # <-- 257 | y = F.softmax(y1, dim=1) # <-- 258 | y = y.cpu().data.numpy() 259 | y = y[0, ...] 260 | y = y.transpose(0, 3, 2, 1) 261 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] += y 262 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] += 1 263 | # print("score_map", score_map.shape) 264 | # print("score_map", cnt.shape) 265 | 266 | score_map = score_map / np.expand_dims(cnt, axis=0) # [Z, Y, X] 267 | score_map = score_map.transpose(0, 3, 2, 1) # => [X, Y, Z] 268 | label_map = np.argmax(score_map, axis=0) 269 | return label_map, score_map 270 | 271 | 272 | def eval_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1, use_softmax = True): 273 | w, h, d = image.shape 274 | 275 | # if the size of image is less than patch_size, then padding it 276 | add_pad = False 277 | if w < patch_size[0]: 278 | w_pad = patch_size[0] - w 279 | add_pad = True 280 | else: 281 | w_pad = 0 282 | if h < patch_size[1]: 283 | h_pad = patch_size[1] - h 284 | add_pad = True 285 | else: 286 | h_pad = 0 287 | if d < patch_size[2]: 288 | d_pad = patch_size[2] - d 289 | add_pad = True 290 | else: 291 | d_pad = 0 292 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 293 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 294 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 295 | if add_pad: 296 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 297 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 298 | ww, hh, dd = image.shape 299 | 300 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 301 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 302 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 303 | # print("{}, {}, {}".format(sx, sy, sz)) 304 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 305 | cnt = np.zeros(image.shape).astype(np.float32) 306 | 307 | for x in range(0, sx): 308 | xs = min(stride_xy * x, ww - patch_size[0]) 309 | for y in range(0, sy): 310 | ys = min(stride_xy * y, hh - patch_size[1]) 311 | for z in range(0, sz): 312 | zs = min(stride_z * z, dd - patch_size[2]) 313 | test_patch = image[xs:xs + patch_size[0], 314 | ys:ys + patch_size[1], zs:zs + patch_size[2]] 315 | test_patch = np.expand_dims(np.expand_dims( 316 | test_patch, axis=0), axis=0).astype(np.float32) 317 | test_patch = torch.from_numpy(test_patch).cuda() 318 | 319 | with torch.no_grad(): 320 | y1 = net(test_patch) 321 | # ensemble 322 | if use_softmax: 323 | y = torch.softmax(y1, dim=1) 324 | else: 325 | y = y1 326 | y = y.cpu().data.numpy() 327 | y = y[0, :, :, :, :] 328 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 329 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 330 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 331 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 332 | score_map = score_map / np.expand_dims(cnt, axis=0) 333 | label_map = np.argmax(score_map, axis=0) 334 | 335 | if add_pad: 336 | label_map = label_map[wl_pad:wl_pad + w, 337 | hl_pad:hl_pad + h, dl_pad:dl_pad + d] 338 | score_map = score_map[:, wl_pad:wl_pad + 339 | w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 340 | return label_map 341 | 342 | 343 | 344 | 345 | 346 | def eval_single_case_A(netA, image, stride_xy, stride_z, patch_size, num_classes=1, use_softmax = True): 347 | w, h, d = image.shape 348 | 349 | # if the size of image is less than patch_size, then padding it 350 | add_pad = False 351 | if w < patch_size[0]: 352 | w_pad = patch_size[0] - w 353 | add_pad = True 354 | else: 355 | w_pad = 0 356 | if h < patch_size[1]: 357 | h_pad = patch_size[1] - h 358 | add_pad = True 359 | else: 360 | h_pad = 0 361 | if d < patch_size[2]: 362 | d_pad = patch_size[2] - d 363 | add_pad = True 364 | else: 365 | d_pad = 0 366 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 367 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 368 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 369 | if add_pad: 370 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 371 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 372 | ww, hh, dd = image.shape 373 | 374 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 375 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 376 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 377 | # print("{}, {}, {}".format(sx, sy, sz)) 378 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 379 | cnt = np.zeros(image.shape).astype(np.float32) 380 | 381 | for x in range(0, sx): 382 | xs = min(stride_xy * x, ww - patch_size[0]) 383 | for y in range(0, sy): 384 | ys = min(stride_xy * y, hh - patch_size[1]) 385 | for z in range(0, sz): 386 | zs = min(stride_z * z, dd - patch_size[2]) 387 | test_patch = image[xs:xs + patch_size[0], 388 | ys:ys + patch_size[1], zs:zs + patch_size[2]] 389 | test_patch = np.expand_dims(np.expand_dims( 390 | test_patch, axis=0), axis=0).astype(np.float32) 391 | test_patch = torch.from_numpy(test_patch).cuda() 392 | 393 | with torch.no_grad(): 394 | y1 = netA(test_patch) 395 | # ensemble 396 | if use_softmax: 397 | y = torch.softmax(y1, dim=1) 398 | else: 399 | y = y1 400 | y = y.cpu().data.numpy() 401 | y = y[0, :, :, :, :] 402 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 403 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 404 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 405 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 406 | score_map = score_map / np.expand_dims(cnt, axis=0) 407 | label_map = np.argmax(score_map, axis=0) 408 | 409 | if add_pad: 410 | label_map = label_map[wl_pad:wl_pad + w, 411 | hl_pad:hl_pad + h, dl_pad:dl_pad + d] 412 | score_map = score_map[:, wl_pad:wl_pad + 413 | w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 414 | return label_map 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | def eval_single_case_AB(netA,netB, image, stride_xy, stride_z, patch_size, num_classes=1, use_softmax = True): 426 | w, h, d = image.shape 427 | 428 | # if the size of image is less than patch_size, then padding it 429 | add_pad = False 430 | if w < patch_size[0]: 431 | w_pad = patch_size[0] - w 432 | add_pad = True 433 | else: 434 | w_pad = 0 435 | if h < patch_size[1]: 436 | h_pad = patch_size[1] - h 437 | add_pad = True 438 | else: 439 | h_pad = 0 440 | if d < patch_size[2]: 441 | d_pad = patch_size[2] - d 442 | add_pad = True 443 | else: 444 | d_pad = 0 445 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 446 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 447 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 448 | if add_pad: 449 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 450 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 451 | ww, hh, dd = image.shape 452 | 453 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 454 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 455 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 456 | # print("{}, {}, {}".format(sx, sy, sz)) 457 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 458 | cnt = np.zeros(image.shape).astype(np.float32) 459 | 460 | for x in range(0, sx): 461 | xs = min(stride_xy * x, ww - patch_size[0]) 462 | for y in range(0, sy): 463 | ys = min(stride_xy * y, hh - patch_size[1]) 464 | for z in range(0, sz): 465 | zs = min(stride_z * z, dd - patch_size[2]) 466 | test_patch = image[xs:xs + patch_size[0], 467 | ys:ys + patch_size[1], zs:zs + patch_size[2]] 468 | test_patch = np.expand_dims(np.expand_dims( 469 | test_patch, axis=0), axis=0).astype(np.float32) 470 | test_patch = torch.from_numpy(test_patch).cuda() 471 | 472 | with torch.no_grad(): 473 | y1 = 0.5*(netA(test_patch)+netB(test_patch)) 474 | # ensemble 475 | if use_softmax: 476 | y = torch.softmax(y1, dim=1) 477 | else: 478 | y = y1 479 | y = y.cpu().data.numpy() 480 | y = y[0, :, :, :, :] 481 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 482 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 483 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 484 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 485 | score_map = score_map / np.expand_dims(cnt, axis=0) 486 | label_map = np.argmax(score_map, axis=0) 487 | 488 | if add_pad: 489 | label_map = label_map[wl_pad:wl_pad + w, 490 | hl_pad:hl_pad + h, dl_pad:dl_pad + d] 491 | score_map = score_map[:, wl_pad:wl_pad + 492 | w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 493 | return label_map 494 | 495 | 496 | def dice_coefficient(prediction, target, num_classes): 497 | dice = torch.zeros(num_classes) 498 | epsilon = 1e-6 # 避免除以零 499 | 500 | # 对每个类别计算 Dice 系数 501 | for i in range(num_classes): 502 | pred_i = (prediction == i) # 预测为类别 i 的体素 503 | target_i = (target == i) # 真实标签为类别 i 的体素 504 | inter = (pred_i & target_i).float().sum() # 交集 505 | union = pred_i.float().sum() + target_i.float().sum() # 并集 506 | 507 | dice[i] = (2 * inter + epsilon) / (union + epsilon) 508 | 509 | return dice 510 | 511 | 512 | 513 | -------------------------------------------------------------------------------- /dmp/code/train_cdifw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from tqdm import tqdm 5 | import argparse 6 | import torch.nn.functional as F 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--task', type=str, default='synapse') 9 | parser.add_argument('--exp', type=str) 10 | parser.add_argument('--seed', type=int, default=0) 11 | parser.add_argument('-sl', '--split_labeled', type=str, default='labeled_20p') 12 | parser.add_argument('-su', '--split_unlabeled', type=str, default='unlabeled_80p') 13 | parser.add_argument('-se', '--split_eval', type=str, default='eval') 14 | parser.add_argument('-m', '--mixed_precision', action='store_true', default=True) # <-- 15 | parser.add_argument('-ep', '--max_epoch', type=int, default=500) 16 | parser.add_argument('--cps_loss', type=str, default='wce') 17 | parser.add_argument('--sup_loss', type=str, default='w_ce+dice') 18 | parser.add_argument('--batch_size', type=int, default=2) 19 | parser.add_argument('--num_workers', type=int, default=2) 20 | parser.add_argument('--base_lr', type=float, default=0.001) 21 | parser.add_argument('-g', '--gpu', type=str, default='0') 22 | parser.add_argument('-w', '--cps_w', type=float, default=1) 23 | parser.add_argument('-r', '--cps_rampup', action='store_true', default=True) # <-- 24 | parser.add_argument('-cr', '--consistency_rampup', type=float, default=None) 25 | parser.add_argument('-sm', '--start_mix', type=int, default=40) 26 | args = parser.parse_args() 27 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 28 | 29 | import numpy as np 30 | import torch 31 | import torch.optim as optim 32 | from torchvision import transforms 33 | from torch.utils.data import DataLoader 34 | from torch.utils.tensorboard import SummaryWriter 35 | from torch.cuda.amp import GradScaler, autocast 36 | 37 | from models.vnet import VNet 38 | from utils import EMA, maybe_mkdir, get_lr, fetch_data, seed_worker, poly_lr, print_func, kaiming_normal_init_weight 39 | from utils.loss import DC_and_CE_loss, RobustCrossEntropyLoss, SoftDiceLoss 40 | from data.transforms import RandomCrop, CenterCrop, ToTensor, RandomFlip_LR, RandomFlip_UD 41 | from data.data_loaders import Synapse_AMOS 42 | from utils.config import Config 43 | import random 44 | config = Config(args.task) 45 | 46 | 47 | 48 | def create_ema_model(config,model): 49 | ema_model = VNet( 50 | n_channels=config.num_channels, 51 | n_classes=config.num_cls, 52 | n_filters=config.n_filters, 53 | normalization='batchnorm', 54 | has_dropout=True 55 | ).cuda() 56 | mp = list(model.parameters()) 57 | mcp = list(ema_model.parameters()) 58 | n = len(mp) 59 | for i in range(0, n): 60 | mcp[i].data[:] = mp[i].data[:].clone() 61 | return ema_model 62 | 63 | def update_ema_variables(ema_model, model, alpha_teacher, iteration): 64 | alpha_teacher = min(1 - 1 / (iteration + 1), alpha_teacher) 65 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 66 | ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:] 67 | return ema_model 68 | 69 | def mix(mask, data=None, target=None): 70 | # Mix 71 | if not (data is None): 72 | if mask.shape[0] == data.shape[0]: 73 | data = torch.cat([(mask[i] * data[i] + (1 - mask[i]) * data[(i + 1) % data.shape[0]]).unsqueeze(0) for i in range(data.shape[0])]) 74 | elif mask.shape[0] == data.shape[0] // 2: 75 | data = torch.cat((torch.cat([(mask[i] * data[2 * i] + (1 - mask[i]) * data[2 * i + 1]).unsqueeze(0) for i in range(data.shape[0] // 2)]), 76 | torch.cat([((1 - mask[i]) * data[2 * i] + mask[i] * data[2 * i + 1]).unsqueeze(0) for i in range(data.shape[0] // 2)]))) 77 | if not (target is None): 78 | target = torch.cat([(mask[i] * target[i] + (1 - mask[i]) * target[(i + 1) % target.shape[0]]).unsqueeze(0) for i in range(target.shape[0])]) 79 | return data, target 80 | 81 | 82 | def generate_class_mask(pred, classes): 83 | pred, classes = torch.broadcast_tensors(pred.unsqueeze(0), classes.unsqueeze(1).unsqueeze(2)) 84 | N = pred.eq(classes).sum(0) 85 | return N 86 | 87 | 88 | 89 | def sigmoid_rampup(current, rampup_length): 90 | '''Exponential rampup from https://arxiv.org/abs/1610.02242''' 91 | if rampup_length == 0: 92 | return 1.0 93 | else: 94 | current = np.clip(current, 0.0, rampup_length) 95 | phase = 1.0 - current / rampup_length 96 | return float(np.exp(-5.0 * phase * phase)) 97 | 98 | 99 | def get_current_consistency_weight(epoch): 100 | if args.cps_rampup: 101 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 102 | if args.consistency_rampup is None: 103 | args.consistency_rampup = args.max_epoch 104 | return args.cps_w * sigmoid_rampup(epoch, args.consistency_rampup) 105 | else: 106 | return args.cps_w 107 | 108 | 109 | 110 | def make_loss_function(name, weight=None): 111 | if name == 'ce': 112 | return RobustCrossEntropyLoss() 113 | elif name == 'wce': 114 | return RobustCrossEntropyLoss(weight=weight) 115 | elif name == 'ce+dice': 116 | return DC_and_CE_loss() 117 | elif name == 'wce+dice': 118 | return DC_and_CE_loss(w_ce=weight) 119 | elif name == 'w_ce+dice': 120 | return DC_and_CE_loss(w_dc=weight, w_ce=weight) 121 | else: 122 | raise ValueError(name) 123 | 124 | 125 | def make_loader(split, dst_cls=Synapse_AMOS, repeat=None, is_training=True, unlabeled=False): 126 | if is_training: 127 | dst = dst_cls( 128 | task=args.task, 129 | split=split, 130 | repeat=repeat, 131 | unlabeled=unlabeled, 132 | num_cls=config.num_cls, 133 | transform=transforms.Compose([ 134 | RandomCrop(config.patch_size), 135 | RandomFlip_LR(), 136 | RandomFlip_UD(), 137 | ToTensor() 138 | ]) 139 | ) 140 | return DataLoader( 141 | dst, 142 | batch_size=args.batch_size, 143 | shuffle=True, 144 | num_workers=args.num_workers, 145 | pin_memory=True, 146 | worker_init_fn=seed_worker 147 | ) 148 | else: 149 | dst = dst_cls( 150 | task=args.task, 151 | split=split, 152 | is_val=True, 153 | num_cls=config.num_cls, 154 | transform=transforms.Compose([ 155 | CenterCrop(config.patch_size), 156 | ToTensor() 157 | ]) 158 | ) 159 | return DataLoader(dst, pin_memory=True) 160 | 161 | 162 | def make_model_all(): 163 | model = VNet( 164 | n_channels=config.num_channels, 165 | n_classes=config.num_cls, 166 | n_filters=config.n_filters, 167 | normalization='batchnorm', 168 | has_dropout=True 169 | ).cuda() 170 | optimizer = optim.SGD( 171 | model.parameters(), 172 | lr=args.base_lr, 173 | momentum=0.9, 174 | weight_decay=3e-5, 175 | nesterov=True 176 | ) 177 | 178 | return model, optimizer 179 | 180 | 181 | 182 | 183 | class DistDW: 184 | def __init__(self, num_cls, do_bg=False, momentum=0.95): 185 | self.num_cls = num_cls 186 | self.do_bg = do_bg 187 | self.momentum = momentum 188 | 189 | def _cal_weights(self, num_each_class): 190 | num_each_class = torch.FloatTensor(num_each_class).cuda() 191 | P = (num_each_class.max()+1e-8) / (num_each_class+1e-8) 192 | P_log = torch.log(P) 193 | weight = P_log / P_log.max() 194 | return weight 195 | 196 | def init_weights(self, labeled_dataset): 197 | if labeled_dataset.unlabeled: 198 | raise ValueError 199 | num_each_class = np.zeros(self.num_cls) 200 | for data_id in labeled_dataset.ids_list: 201 | _, _, label = labeled_dataset._get_data(data_id) 202 | label = label.reshape(-1) 203 | tmp, _ = np.histogram(label, range(self.num_cls + 1)) 204 | num_each_class += tmp 205 | weights = self._cal_weights(num_each_class) 206 | self.weights = weights * self.num_cls 207 | return self.weights.data.cpu().numpy() 208 | 209 | def get_ema_weights(self, pseudo_label): 210 | pseudo_label = torch.argmax(pseudo_label.detach(), dim=1, keepdim=True).long() 211 | label_numpy = pseudo_label.data.cpu().numpy() 212 | num_each_class = np.zeros(self.num_cls) 213 | for i in range(label_numpy.shape[0]): 214 | label = label_numpy[i].reshape(-1) 215 | tmp, _ = np.histogram(label, range(self.num_cls + 1)) 216 | num_each_class += tmp 217 | 218 | cur_weights = self._cal_weights(num_each_class) * self.num_cls 219 | self.weights = EMA(cur_weights, self.weights, momentum=self.momentum) 220 | return self.weights 221 | 222 | 223 | 224 | class DiffDW: 225 | def __init__(self, num_cls, accumulate_iters=20): 226 | self.last_dice = torch.zeros(num_cls).float().cuda() + 1e-8 227 | self.dice_func = SoftDiceLoss(smooth=1e-8, do_bg=True) 228 | self.cls_learn = torch.zeros(num_cls).float().cuda() 229 | self.cls_unlearn = torch.zeros(num_cls).float().cuda() 230 | self.num_cls = num_cls 231 | self.dice_weight = torch.ones(num_cls).float().cuda() 232 | self.accumulate_iters = accumulate_iters 233 | 234 | def init_weights(self): 235 | weights = np.ones(config.num_cls) * self.num_cls 236 | self.weights = torch.FloatTensor(weights).cuda() 237 | return weights 238 | 239 | def cal_weights(self, pred, label): 240 | x_onehot = torch.zeros(pred.shape).cuda() 241 | output = torch.argmax(pred, dim=1, keepdim=True).long() 242 | x_onehot.scatter_(1, output, 1) 243 | y_onehot = torch.zeros(pred.shape).cuda() 244 | y_onehot.scatter_(1, label, 1) 245 | cur_dice = self.dice_func(x_onehot, y_onehot, is_training=False) 246 | delta_dice = cur_dice - self.last_dice 247 | cur_cls_learn = torch.where(delta_dice>0, delta_dice, 0) * torch.log(cur_dice / self.last_dice) 248 | cur_cls_unlearn = torch.where(delta_dice<=0, delta_dice, 0) * torch.log(cur_dice / self.last_dice) 249 | self.last_dice = cur_dice 250 | self.cls_learn = EMA(cur_cls_learn, self.cls_learn, momentum=(self.accumulate_iters-1)/self.accumulate_iters) 251 | self.cls_unlearn = EMA(cur_cls_unlearn, self.cls_unlearn, momentum=(self.accumulate_iters-1)/self.accumulate_iters) 252 | cur_diff = (self.cls_unlearn + 1e-8) / (self.cls_learn + 1e-8) 253 | cur_diff = torch.pow(cur_diff, 1/5) 254 | self.dice_weight = EMA(1. - cur_dice, self.dice_weight, momentum=(self.accumulate_iters-1)/self.accumulate_iters) 255 | weights = cur_diff * self.dice_weight 256 | weights = weights / weights.max() 257 | return weights * self.num_cls 258 | 259 | class Conf: 260 | def __init__(self, num_cls): 261 | self.conf_weight = torch.ones(num_cls).float().cuda() 262 | self.conf_cur = torch.zeros(num_cls).float().cuda() + 1e-8 263 | self.iter = 0 264 | def ema(self,new_weight): 265 | self.conf_cur = self.conf_cur/(self.iter+1e-8) 266 | self.conf_weight = EMA(self.conf_cur, self.conf_weight, momentum=0.99) 267 | self.iter = 1 268 | return self.conf_weight 269 | def update(self,this_conf): 270 | self.iter += 1 271 | this_conf = (1-this_conf+1e-2)/((1-this_conf).max()+1e-2) 272 | this_conf = torch.pow(this_conf,0.2) 273 | self.conf_cur += this_conf 274 | 275 | 276 | def get_conf(self): 277 | return self.conf_weight 278 | 279 | if __name__ == '__main__': 280 | import random 281 | SEED=args.seed 282 | random.seed(SEED) 283 | np.random.seed(SEED) 284 | torch.manual_seed(SEED) 285 | torch.cuda.manual_seed(SEED) 286 | torch.cuda.manual_seed_all(SEED) 287 | # make logger file 288 | snapshot_path = f'./logs/{args.exp}/' 289 | maybe_mkdir(snapshot_path) 290 | maybe_mkdir(os.path.join(snapshot_path, 'ckpts')) 291 | fold = str(args.exp[-1]) 292 | 293 | if args.task == 'colon': 294 | #args.split_unlabeled = args.split_unlabeled+'_'+fold 295 | args.split_labeled = args.split_labeled+'_'+fold 296 | args.split_eval = args.split_eval+'_'+fold 297 | # make logger 298 | writer = SummaryWriter(os.path.join(snapshot_path, 'tensorboard')) 299 | logging.basicConfig( 300 | filename=os.path.join(snapshot_path, 'train.log'), 301 | level=logging.INFO, 302 | format='[%(asctime)s.%(msecs)03d] %(message)s', 303 | datefmt='%H:%M:%S' 304 | ) 305 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 306 | logging.info(str(args)) 307 | 308 | # make data loader 309 | unlabeled_loader = make_loader(args.split_unlabeled, unlabeled=True) 310 | labeled_loader = make_loader(args.split_labeled, repeat=len(unlabeled_loader.dataset)) 311 | eval_loader = make_loader(args.split_eval, is_training=False) 312 | if args.task == 'colon': 313 | test_loader = make_loader(f'test_{fold}', is_training=False) 314 | else: 315 | test_loader = make_loader('test', is_training=False) 316 | 317 | 318 | 319 | logging.info(f'{len(labeled_loader)} itertations per epoch (labeled)') 320 | logging.info(f'{len(unlabeled_loader)} itertations per epoch (unlabeled)') 321 | 322 | # make model, optimizer, and lr scheduler 323 | model_A, optimizer_A = make_model_all() 324 | model_B, optimizer_B = make_model_all() 325 | model_A = kaiming_normal_init_weight(model_A) 326 | model_B = kaiming_normal_init_weight(model_B) 327 | 328 | ema_model_A = create_ema_model(config, model_A) 329 | ema_model_B = create_ema_model(config, model_B) 330 | 331 | ema_model_A.eval() 332 | ema_model_A = ema_model_A.cuda() 333 | 334 | 335 | ema_model_B.eval() 336 | ema_model_B = ema_model_B.cuda() 337 | 338 | 339 | # make loss function 340 | diffdw = DiffDW(config.num_cls, accumulate_iters=50) 341 | distdw = DistDW(config.num_cls, momentum=0.99) 342 | conf = Conf(config.num_cls) 343 | print(conf.get_conf()) 344 | 345 | weight_A = diffdw.init_weights() 346 | weight_B = distdw.init_weights(labeled_loader.dataset) 347 | 348 | loss_func_A = make_loss_function(args.sup_loss, weight_A) 349 | loss_func_B = make_loss_function(args.sup_loss, weight_B) 350 | cps_loss_func_A = make_loss_function(args.cps_loss, weight_A) 351 | cps_loss_func_B = make_loss_function(args.cps_loss, weight_B) 352 | 353 | loss_func_A_ema = make_loss_function(args.sup_loss, weight_A) 354 | loss_func_B_ema = make_loss_function(args.sup_loss, weight_B) 355 | cps_loss_func_A_ema = make_loss_function(args.cps_loss, weight_A) 356 | cps_loss_func_B_ema = make_loss_function(args.cps_loss, weight_B) 357 | 358 | if args.mixed_precision: 359 | amp_grad_scaler = GradScaler() 360 | 361 | cps_w = get_current_consistency_weight(0) 362 | best_eval = 0.0 363 | best_epoch = 0 364 | best_eval_t= 0.0 365 | 366 | average_per_channel_total = torch.ones(config.num_cls).cuda() # 假设通道数为config.num_cls 367 | batch_averages = torch.ones(config.num_cls).cuda() 368 | 369 | for epoch_num in range(args.max_epoch + 1): 370 | loss_list = [] 371 | loss_cps_list = [] 372 | loss_sup_list = [] 373 | 374 | model_A.train() 375 | model_B.train() 376 | for batch_l, batch_u in tqdm(zip(labeled_loader, unlabeled_loader)): 377 | optimizer_A.zero_grad() 378 | optimizer_B.zero_grad() 379 | 380 | image_l, label_l = fetch_data(batch_l) 381 | 382 | image_u = fetch_data(batch_u, labeled=False) 383 | 384 | 385 | image = torch.cat([image_l, image_u], dim=0) 386 | tmp_bs = image.shape[0] // 2 387 | 388 | label_l_backup = label_l.clone().squeeze(1).detach() 389 | conf_mask = torch.nn.functional.one_hot(label_l_backup, config.num_cls).permute(0,4,1,2,3).float() 390 | conf_mask[conf_mask!=0] = 1 391 | ones_per_channel = conf_mask.sum(dim=[0, 2, 3, 4]).detach().cuda() 392 | 393 | if args.mixed_precision: 394 | with autocast(): 395 | output_A = model_A(image) 396 | output_B = model_B(image) 397 | del image 398 | 399 | # sup (ce + dice) 400 | output_A_l, output_A_u = output_A[:tmp_bs, ...], output_A[tmp_bs:, ...] 401 | output_B_l, output_B_u = output_B[:tmp_bs, ...], output_B[tmp_bs:, ...] 402 | 403 | conf_map_A = F.softmax(output_A_l) * conf_mask #bcdwh 404 | sum_per_channel = conf_map_A.sum(dim=[0, 2, 3, 4]) 405 | conf_weight = (sum_per_channel.detach()+1) / (ones_per_channel.detach()+1) 406 | 407 | conf.update(conf_weight) 408 | conf_weight = conf.get_conf() 409 | 410 | # cps (ce only) 411 | max_A = torch.argmax(output_A.detach(), dim=1, keepdim=True).long() 412 | max_B = torch.argmax(output_B.detach(), dim=1, keepdim=True).long() 413 | 414 | 415 | weight_A = diffdw.cal_weights(output_A_l.detach(), label_l.detach()) 416 | weight_B = distdw.get_ema_weights(output_B_u.detach()) 417 | 418 | 419 | 420 | loss_func_A.update_weight(weight_A*conf_weight) 421 | loss_func_B.update_weight(weight_B) 422 | cps_loss_func_A.update_weight(weight_A*conf_weight) 423 | cps_loss_func_B.update_weight(weight_B) 424 | 425 | 426 | #print(f'label_l: {label_l.shape}, max : {label_l.max()}, max_A: {max_A.shape}, max_A: {max_A.max()}') 427 | loss_sup = loss_func_A(output_A_l, label_l) + loss_func_B(output_B_l, label_l) 428 | loss_cps = cps_loss_func_A(output_A, max_B) + cps_loss_func_B(output_B, max_A) 429 | loss = loss_sup + cps_w * loss_cps 430 | 431 | 432 | 433 | # backward passes should not be under autocast. 434 | amp_grad_scaler.scale(loss).backward() 435 | amp_grad_scaler.step(optimizer_A) 436 | amp_grad_scaler.step(optimizer_B) 437 | amp_grad_scaler.update() 438 | # if epoch_num>0: 439 | 440 | else: 441 | raise NotImplementedError 442 | 443 | loss_list.append(loss.item()) 444 | loss_sup_list.append(loss_sup.item()) 445 | loss_cps_list.append(loss_cps.item()) 446 | conf.ema(conf_weight) 447 | writer.add_scalar('lr', get_lr(optimizer_A), epoch_num) 448 | writer.add_scalar('cps_w', cps_w, epoch_num) 449 | writer.add_scalar('loss/loss', np.mean(loss_list), epoch_num) 450 | writer.add_scalar('loss/sup', np.mean(loss_sup_list), epoch_num) 451 | writer.add_scalar('loss/cps', np.mean(loss_cps_list), epoch_num) 452 | # print(dict(zip([i for i in range(config.num_cls)] ,print_func(weight_A)))) 453 | writer.add_scalars('class_weights/A', dict(zip([str(i) for i in range(config.num_cls)] ,print_func(weight_A))), epoch_num) 454 | writer.add_scalars('class_weights/B', dict(zip([str(i) for i in range(config.num_cls)] ,print_func(weight_B))), epoch_num) 455 | logging.info(f'epoch {epoch_num} : loss : {np.mean(loss_list)}') 456 | # logging.info(f' cps_w: {cps_w}') 457 | # if epoch_num>0: 458 | logging.info(f" Class Weights A: {print_func(weight_A)}, lr: {get_lr(optimizer_A)}") 459 | logging.info(f" Class Weights B: {print_func(weight_B)}") 460 | logging.info(f" Class Conf : {print_func(conf.get_conf())}") 461 | # logging.info(f" Class Weights u: {print_func(weight_u)}") 462 | # lr_scheduler_A.step() 463 | # lr_scheduler_B.step() 464 | optimizer_A.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 465 | optimizer_B.param_groups[0]['lr'] = poly_lr(epoch_num, args.max_epoch, args.base_lr, 0.9) 466 | # print(optimizer_A.param_groups[0]['lr']) 467 | cps_w = get_current_consistency_weight(epoch_num) 468 | 469 | if epoch_num % 10 == 0 or epoch_num >= args.start_mix: 470 | 471 | # ''' ===== evaluation 472 | dice_list = [[] for _ in range(config.num_cls-1)] 473 | model_A.eval() 474 | model_B.eval() 475 | dice_func = SoftDiceLoss(smooth=1e-8) 476 | for batch in tqdm(eval_loader): 477 | with torch.no_grad(): 478 | image, gt = fetch_data(batch) 479 | output = (model_A(image) + model_B(image))/2.0 480 | # output = model_B(image) 481 | del image 482 | 483 | shp = output.shape 484 | gt = gt.long() 485 | y_onehot = torch.zeros(shp).cuda() 486 | y_onehot.scatter_(1, gt, 1) 487 | 488 | x_onehot = torch.zeros(shp).cuda() 489 | output = torch.argmax(output, dim=1, keepdim=True).long() 490 | x_onehot.scatter_(1, output, 1) 491 | 492 | 493 | dice = dice_func(x_onehot, y_onehot, is_training=False) 494 | dice = dice.data.cpu().numpy() 495 | for i, d in enumerate(dice): 496 | dice_list[i].append(d) 497 | 498 | dice_mean = [] 499 | for dice in dice_list: 500 | dice_mean.append(np.mean(dice)) 501 | 502 | logging.info(f'evaluation epoch {epoch_num}, dice: {np.mean(dice_mean)}, {dice_mean}') 503 | if np.mean(dice_mean) > best_eval: 504 | best_eval = np.mean(dice_mean) 505 | best_epoch = epoch_num 506 | save_path = os.path.join(snapshot_path, f'ckpts/best_model.pth') 507 | 508 | logging.info(f'saving best model to {save_path}') 509 | logging.info(f'\t best eval dice is {best_eval} in epoch {best_epoch}') 510 | # ''' 511 | 512 | if config.task == 'colon': 513 | args.start_mix = 10 514 | 515 | if epoch_num >= args.start_mix: 516 | 517 | 518 | 519 | 520 | dice_list = [[] for _ in range(config.num_cls-1)] 521 | model_A.eval() 522 | model_B.eval() 523 | dice_func = SoftDiceLoss(smooth=1e-8) 524 | for batch in test_loader: 525 | with torch.no_grad(): 526 | image, gt = fetch_data(batch) 527 | output = (model_A(image) + model_B(image))/2.0 528 | # output = model_B(image) 529 | del image 530 | 531 | shp = output.shape 532 | gt = gt.long() 533 | y_onehot = torch.zeros(shp).cuda() 534 | y_onehot.scatter_(1, gt, 1) 535 | 536 | x_onehot = torch.zeros(shp).cuda() 537 | output = torch.argmax(output, dim=1, keepdim=True).long() 538 | x_onehot.scatter_(1, output, 1) 539 | 540 | 541 | dice = dice_func(x_onehot, y_onehot, is_training=False) 542 | dice = dice.data.cpu().numpy() 543 | for i, d in enumerate(dice): 544 | dice_list[i].append(d) 545 | 546 | dice_mean = [] 547 | for dice in dice_list: 548 | dice_mean.append(np.mean(dice)) 549 | print(dice_mean) 550 | 551 | 552 | if np.mean(dice_mean) > best_eval_t: 553 | best_eval_t = np.mean(dice_mean) 554 | #best_epoch = epoch_num 555 | save_path = os.path.join(snapshot_path, f'ckpts/best_model.pth') 556 | torch.save({ 557 | 'A': model_A.state_dict(), 558 | 'B': model_B.state_dict() 559 | }, save_path) 560 | print(f'\t best test dice is {best_eval_t} in epoch {best_epoch}') 561 | #config.early_stop_patience = 100 562 | if epoch_num - best_epoch == config.early_stop_patience: 563 | logging.info(f'Early stop.') 564 | break 565 | 566 | writer.close() 567 | --------------------------------------------------------------------------------