├── .DS_Store ├── LICENSE ├── README.md ├── evaluation.py ├── model_train.py └── results └── checkpoints ├── checkpoint ├── dnn_ckpt_base_model.data-00000-of-00001 ├── dnn_ckpt_base_model.index ├── dnn_ckpt_model_final.data-00000-of-00001 └── dnn_ckpt_model_final.index /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-spatial/fair-ai-in-space/8ab1cac8f24e950ae2c500829c206830051bf826/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fairness-by-location 2 | 3 | ==================================================== 4 | ## Overview 5 | Code for paper: Fairness by "Where": A Statistically-Robust and Model-Agnostic Bi-Level Learning Framework. AAAI 2022. 6 | 7 | 8 | ==================================================== 9 | ## Explanation of files 10 | (Link for data: https://pitt-my.sharepoint.com/:f:/g/personal/erh108_pitt_edu/EmqOrtnsaCVFnD_PA1cFjt8BLX1zg6Ws0smAF0hr90JKjw?e=NCHtar) 11 | 12 | * X_train.npy: all training samples extracted from the satellite-based crop monitoring dataset. 13 | * y_train.npy: the corresponding labels for training samples. 14 | * train_id.pickle: training samples' indices for all partitions within each candidate partitioning. 15 | * X_test.npy: all testing samples (not overlapped with training samples). 16 | * y_test.npy: the corresponding labels for testing samples. 17 | * test_id.pickle: training samples' indices for all partitions within each candidate partitioning. 18 | * results: an example model. 19 | 20 | ## Explanation of the code: 21 | 22 | #### Procedures 23 | 24 | model_train.py: 25 | 1. Training a base model with training data with 300 epochs. 26 | 2. Applying stochastic and bi-level training strategies to the base model with 50 epochs. 27 | 28 | evaluation.py: 29 | 30 | 3. Comparing the overall performance and fairness between the base model and the final model. 31 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | device_name = tf.test.gpu_device_name() 4 | if device_name != '/device:GPU:0': 5 | raise SystemError('GPU device not found') 6 | print('Found GPU at: {}'.format(device_name)) 7 | print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) 8 | 9 | # --------------------------------------------------------------------- 10 | import numpy as np 11 | from tensorflow import keras 12 | from keras import backend as K 13 | import pickle 14 | 15 | # --------------------------------------------------------------------- 16 | # Global Variable 17 | model_dir = 'results' 18 | dir_ckpt = model_dir + '/' + 'checkpoints' 19 | 20 | MAX_ROW_PARTITION = 5 21 | MAX_COL_PARTITION = 5 22 | 23 | BATCH_SIZE = 256 * 256 24 | 25 | INPUT_SIZE = 10 # number of features 26 | NUM_LAYERS_DNN = 8 # num_layers is the number of non-input layers, one more LSTM 27 | NUM_CLASS = 23 28 | EPOCH_TRAIN = 300 29 | 30 | CKPT_FOLDER_PATH = dir_ckpt 31 | LEARNING_RATE = 0.001 32 | 33 | CLASSIFICATION_LOSS = 'categorical_crossentropy' 34 | REGRESSION_LOSS = 'mean_squared_error' 35 | 36 | # --------------------------------------------------------------------- 37 | # Read the X training dataset and the Y training dataset 38 | X = np.load('X_test.npy') 39 | y = np.load('y_test.npy') 40 | 41 | # read sample indices for all partitions within each candidate partitioning 42 | with open('test_id.pickle', 'rb') as handle: 43 | test_id = pickle.load(handle) 44 | 45 | 46 | # --------------------------------------------------------------------- 47 | # DNN model 48 | class DenseNet(tf.keras.Model): 49 | 50 | def __init__(self, layer_size=INPUT_SIZE, num_class=NUM_CLASS, ckpt_path=CKPT_FOLDER_PATH): 51 | super(DenseNet, self).__init__() 52 | 53 | initializer = tf.keras.initializers.TruncatedNormal(stddev=0.5) # mean=0.0, seed=None 54 | self.dense = [] 55 | for i in range(7): 56 | self.dense.append(tf.keras.layers.Dense(layer_size, activation=tf.nn.relu, kernel_initializer=initializer)) 57 | 58 | self.out = tf.keras.layers.Dense(num_class, kernel_initializer=initializer) 59 | 60 | self.ckpt_path = ckpt_path 61 | self.model_name = 'dnn' 62 | print('check ckpt path: ' + ckpt_path) 63 | 64 | def call(self, inputs): # , num_layers = NUM_LAYERS_DNN 65 | 66 | for i in range(7): 67 | if i == 0: 68 | layer = self.dense[i](inputs) 69 | else: 70 | layer = self.dense[i](layer) 71 | 72 | out_layer = tf.nn.softmax(self.out(layer)) 73 | 74 | return out_layer 75 | 76 | def save_branch(self, branch_id): 77 | # save the current branch 78 | # branch_id should include the current branch (not after added to X_branch_id) 79 | self.save_weights(self.ckpt_path + '/' + self.model_name + '_ckpt_' + branch_id) 80 | return 81 | 82 | def load_base_branch(self, branch_id): 83 | # load the base branch before further fine-tuning 84 | self.load_weights(self.ckpt_path + '/' + self.model_name + '_ckpt_' + branch_id) 85 | return 86 | 87 | 88 | def dice(y_true, y_pred): 89 | # y_pred is softmax output of shape (num_samples, num_classes) 90 | # y_true is one hot encoding of target (shape= (num_samples, num_classes)) 91 | 92 | intersect = K.sum(y_pred * y_true, axis=0) + K.epsilon() 93 | denominator = K.sum(y_pred, axis=0) + K.sum(y_true, axis=0) 94 | dice_scores = K.constant(2) * intersect / (denominator + K.epsilon()) 95 | return 1 - dice_scores 96 | 97 | 98 | def custom_loss(y_true, y_pred): 99 | loss = dice(y_true, y_pred) 100 | return loss 101 | 102 | 103 | optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE) 104 | 105 | 106 | def model_compile(model): 107 | # optimizer = keras.optimizers.Adam(learning_rate = LEARNING_RATE) 108 | global optimizer 109 | 110 | model.compile(optimizer=optimizer, loss=custom_loss, metrics=['accuracy']) 111 | 112 | 113 | # --------------------------------------------------------------------- 114 | def get_class_wise_accuracy(y_true, y_pred, prf=False): 115 | num_class = y_true.shape[1] 116 | stat = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 117 | 118 | true_pred_w_class = y_true * np.expand_dims(stat, 1) 119 | true = np.sum(true_pred_w_class, axis=0).reshape(-1) 120 | total = np.sum(y_true, axis=0).reshape(-1) 121 | 122 | if prf: 123 | pred_w_class = tf.math.argmax(y_pred, axis=1) 124 | pred_w_class = tf.one_hot(pred_w_class, depth=NUM_CLASS).numpy() 125 | pred_total = np.sum(pred_w_class, axis=0).reshape(-1) 126 | return true, total, pred_total 127 | else: 128 | return true, total 129 | 130 | 131 | def get_overall_accuracy(y_true, y_pred): 132 | stat = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 133 | true = np.sum(stat) 134 | total = stat.shape[0] 135 | 136 | return true, total 137 | 138 | 139 | def get_class_wise_list(y_true, y_pred): 140 | stat = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 141 | 142 | true_pred_w_class = y_true * np.expand_dims(stat, 1) 143 | pred_w_class = tf.math.argmax(y_pred, axis=1) 144 | pred_w_class = tf.one_hot(pred_w_class, depth=NUM_CLASS).numpy() 145 | 146 | return true_pred_w_class, pred_w_class 147 | 148 | 149 | def get_prf(true_class, total_class, pred_class): 150 | pre = true_class / pred_class 151 | rec = true_class / total_class 152 | 153 | pre_fix = np.nan_to_num(pre, nan=np.nanmean(pre)) 154 | rec_fix = np.nan_to_num(rec, nan=np.nanmean(rec)) 155 | f1 = 2 / (pre_fix ** (-1) + rec_fix ** (-1)) 156 | return pre, rec, f1 157 | 158 | 159 | def get_avg_f1(f1, total_class): 160 | avg_f1 = np.sum(f1 * total_class / np.sum(total_class)) 161 | 162 | return avg_f1 163 | 164 | 165 | def get_fairness_loss_all(y_test, y_pred, partition_list, all_partitioning_data_list): 166 | global MAX_ROW_PARTITION, MAX_COL_PARTITION, GLOBAL_MEAN 167 | 168 | fairness_loss_list = np.zeros((MAX_ROW_PARTITION * MAX_COL_PARTITION - 1), dtype='float') 169 | 170 | true_pred_w_class, pred_w_class = get_class_wise_list(y_test, y_pred) 171 | 172 | for (index1, index2) in partition_list: 173 | 174 | f1_list = np.zeros(index1 * index2, dtype='float') 175 | data_list = all_partitioning_data_list[(index1, index2)] 176 | 177 | for i in range(index1 * index2): 178 | true_class_part = np.sum(true_pred_w_class[data_list[i]], axis=0).reshape(-1) 179 | total_class_part = np.sum(y_test[data_list[i]], axis=0).reshape(-1) 180 | total_pred_part = np.sum(pred_w_class[data_list[i]], axis=0).reshape(-1) 181 | 182 | pre, rec, f1 = get_prf(true_class_part, total_class_part, total_pred_part) 183 | 184 | f1_list[i] = get_avg_f1(f1, total_class_part) 185 | 186 | fairness_loss_list[(index1 - 1) * MAX_COL_PARTITION + index2 - 1 - 1] = np.mean( 187 | np.abs(GLOBAL_MEAN - f1_list)) 188 | 189 | return fairness_loss_list 190 | 191 | 192 | def get_partition_data(index1, index2, X_data, y_data, all_partitioning_data_list): 193 | X_test = [] 194 | y_test = [] 195 | 196 | data_list = all_partitioning_data_list[(index1, index2)] 197 | for i in range(index1 * index2): 198 | X_test.append(X_data[data_list[i]]) 199 | y_test.append(y_data[data_list[i]]) 200 | 201 | return X_test, y_test 202 | 203 | 204 | ROW_LIST = list(range(1, MAX_ROW_PARTITION + 1)) 205 | COL_LIST = list(range(1, MAX_COL_PARTITION + 1)) 206 | 207 | PARTITIONINGS = [] 208 | 209 | for r in ROW_LIST: 210 | for c in COL_LIST: 211 | if r == 1 and c == 1: 212 | continue 213 | 214 | PARTITIONINGS.append((r, c)) 215 | 216 | print(PARTITIONINGS) 217 | 218 | # ----------------------------------------------------------------------------- 219 | base_model = DenseNet() 220 | model_compile(base_model) 221 | base_model.load_base_branch('base_model') 222 | 223 | y_pred = base_model.predict(X, batch_size=BATCH_SIZE) 224 | true_part, total_part = get_overall_accuracy(y, y_pred) 225 | 226 | true_class_part, total_class_part, total_pred_part = get_class_wise_accuracy(y, y_pred, prf=True) 227 | pre, rec, f1 = get_prf(true_class_part, total_class_part, total_pred_part) 228 | 229 | loss = dice(y.astype('float32'), y_pred) 230 | loss = tf.reduce_mean(loss).numpy() 231 | 232 | with np.printoptions(precision=4, suppress=True): 233 | print('te_accuracy = {:.4f}\nloss = {}\nf1 = {}'.format(true_part / total_part, loss, f1)) 234 | 235 | GLOBAL_MEAN = get_avg_f1(f1, total_class_part) 236 | fairness_loss_iter = get_fairness_loss_all(y, y_pred, PARTITIONINGS, test_id) 237 | base_fairness = np.sum(fairness_loss_iter) 238 | 239 | print('base', GLOBAL_MEAN, base_fairness) 240 | 241 | # ----------------------------------------------------------------------------- 242 | model = DenseNet() 243 | model_compile(model) 244 | model.load_base_branch('model_final') 245 | 246 | y_pred = model.predict(X, batch_size=BATCH_SIZE) 247 | true_part, total_part = get_overall_accuracy(y, y_pred) 248 | 249 | true_class_part, total_class_part, total_pred_part = get_class_wise_accuracy(y, y_pred, prf=True) 250 | pre, rec, f1 = get_prf(true_class_part, total_class_part, total_pred_part) 251 | 252 | loss = dice(y.astype('float32'), y_pred) 253 | loss = tf.reduce_mean(loss).numpy() 254 | 255 | with np.printoptions(precision=4, suppress=True): 256 | print('te_accuracy = {:.4f}\nloss = {}\nf1 = {}'.format(true_part / total_part, loss, f1)) 257 | 258 | GLOBAL_MEAN = get_avg_f1(f1, total_class_part) 259 | fairness_loss_iter = get_fairness_loss_all(y, y_pred, PARTITIONINGS, test_id) 260 | base_fairness = np.sum(fairness_loss_iter) 261 | 262 | print('SPAD', GLOBAL_MEAN, base_fairness) 263 | -------------------------------------------------------------------------------- /model_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | device_name = tf.test.gpu_device_name() 4 | if device_name != '/device:GPU:0': 5 | raise SystemError('GPU device not found') 6 | print('Found GPU at: {}'.format(device_name)) 7 | print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) 8 | 9 | # --------------------------------------------------------------------- 10 | import numpy as np 11 | from tensorflow import keras 12 | from keras.callbacks import LearningRateScheduler, LambdaCallback 13 | from keras import backend as K 14 | import pickle 15 | import random 16 | import time 17 | 18 | # --------------------------------------------------------------------- 19 | # Global Variable 20 | model_dir = 'results' 21 | dir_ckpt = model_dir + '/' + 'checkpoints' 22 | 23 | MAX_ROW_PARTITION = 5 24 | MAX_COL_PARTITION = 5 25 | 26 | BATCH_SIZE = 256 * 256 27 | 28 | INPUT_SIZE = 10 # number of features 29 | NUM_LAYERS_DNN = 8 # num_layers is the number of non-input layers, one more LSTM 30 | NUM_CLASS = 23 31 | EPOCH_TRAIN = 300 32 | 33 | CKPT_FOLDER_PATH = dir_ckpt 34 | LEARNING_RATE = 0.001 35 | 36 | CLASSIFICATION_LOSS = 'categorical_crossentropy' 37 | REGRESSION_LOSS = 'mean_squared_error' 38 | 39 | # --------------------------------------------------------------------- 40 | # Read the X training dataset and the Y training dataset 41 | X = np.load('X_train.npy') 42 | y = np.load('y_train.npy') 43 | 44 | # read sample indices for all partitions within each candidate partitioning 45 | with open('train_id.pickle', 'rb') as handle: 46 | train_id = pickle.load(handle) 47 | 48 | 49 | # --------------------------------------------------------------------- 50 | # DNN model 51 | class DenseNet(tf.keras.Model): 52 | 53 | def __init__(self, layer_size=INPUT_SIZE, num_class=NUM_CLASS, ckpt_path=CKPT_FOLDER_PATH): 54 | super(DenseNet, self).__init__() 55 | 56 | initializer = tf.keras.initializers.TruncatedNormal(stddev=0.5) # mean=0.0, seed=None 57 | self.dense = [] 58 | for i in range(7): 59 | self.dense.append(tf.keras.layers.Dense(layer_size, activation=tf.nn.relu, kernel_initializer=initializer)) 60 | 61 | self.out = tf.keras.layers.Dense(num_class, kernel_initializer=initializer) 62 | 63 | self.ckpt_path = ckpt_path 64 | self.model_name = 'dnn' 65 | print('check ckpt path: ' + ckpt_path) 66 | 67 | def call(self, inputs): # , num_layers = NUM_LAYERS_DNN 68 | 69 | for i in range(7): 70 | if i == 0: 71 | layer = self.dense[i](inputs) 72 | else: 73 | layer = self.dense[i](layer) 74 | 75 | out_layer = tf.nn.softmax(self.out(layer)) 76 | 77 | return out_layer 78 | 79 | def save_branch(self, branch_id): 80 | # save the current branch 81 | # branch_id should include the current branch (not after added to X_branch_id) 82 | self.save_weights(self.ckpt_path + '/' + self.model_name + '_ckpt_' + branch_id) 83 | return 84 | 85 | def load_base_branch(self, branch_id): 86 | # load the base branch before further fine-tuning 87 | self.load_weights(self.ckpt_path + '/' + self.model_name + '_ckpt_' + branch_id) 88 | return 89 | 90 | 91 | def dice(y_true, y_pred): 92 | # y_pred is softmax output of shape (num_samples, num_classes) 93 | # y_true is one hot encoding of target (shape= (num_samples, num_classes)) 94 | 95 | intersect = K.sum(y_pred * y_true, axis=0) + K.epsilon() 96 | denominator = K.sum(y_pred, axis=0) + K.sum(y_true, axis=0) 97 | dice_scores = K.constant(2) * intersect / (denominator + K.epsilon()) 98 | return 1 - dice_scores 99 | 100 | 101 | def custom_loss(y_true, y_pred): 102 | loss = dice(y_true, y_pred) 103 | return loss 104 | 105 | 106 | optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE) 107 | 108 | 109 | def model_compile(model): 110 | # optimizer = keras.optimizers.Adam(learning_rate = LEARNING_RATE) 111 | global optimizer 112 | 113 | model.compile(optimizer=optimizer, loss=custom_loss, metrics=['accuracy']) 114 | 115 | 116 | # --------------------------------------------------------------------- 117 | 118 | def model_train(model, X_in, y_in, init_epoch_number=0, epoch_train=EPOCH_TRAIN): # 119 | ''' 120 | Input model is complied! 121 | ''' 122 | 123 | # train the model for epoch_train number of epochs 124 | 125 | model.fit(X_in, y_in, batch_size=BATCH_SIZE, 126 | initial_epoch=init_epoch_number, epochs=init_epoch_number + epoch_train, shuffle=True) 127 | 128 | 129 | # --------------------------------------------------------------------- 130 | def get_class_wise_accuracy(y_true, y_pred, prf=False): 131 | num_class = y_true.shape[1] 132 | stat = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 133 | 134 | true_pred_w_class = y_true * np.expand_dims(stat, 1) 135 | true = np.sum(true_pred_w_class, axis=0).reshape(-1) 136 | total = np.sum(y_true, axis=0).reshape(-1) 137 | 138 | if prf: 139 | pred_w_class = tf.math.argmax(y_pred, axis=1) 140 | pred_w_class = tf.one_hot(pred_w_class, depth=NUM_CLASS).numpy() 141 | pred_total = np.sum(pred_w_class, axis=0).reshape(-1) 142 | return true, total, pred_total 143 | else: 144 | return true, total 145 | 146 | 147 | def get_overall_accuracy(y_true, y_pred): 148 | stat = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 149 | true = np.sum(stat) 150 | total = stat.shape[0] 151 | 152 | return true, total 153 | 154 | 155 | def get_class_wise_list(y_true, y_pred): 156 | stat = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 157 | 158 | true_pred_w_class = y_true * np.expand_dims(stat, 1) 159 | pred_w_class = tf.math.argmax(y_pred, axis=1) 160 | pred_w_class = tf.one_hot(pred_w_class, depth=NUM_CLASS).numpy() 161 | 162 | return true_pred_w_class, pred_w_class 163 | 164 | 165 | def get_prf(true_class, total_class, pred_class): 166 | pre = true_class / pred_class 167 | rec = true_class / total_class 168 | 169 | pre_fix = np.nan_to_num(pre, nan=np.nanmean(pre)) 170 | rec_fix = np.nan_to_num(rec, nan=np.nanmean(rec)) 171 | f1 = 2 / (pre_fix ** (-1) + rec_fix ** (-1)) 172 | return pre, rec, f1 173 | 174 | 175 | def get_avg_f1(f1, total_class): 176 | avg_f1 = np.sum(f1 * total_class / np.sum(total_class)) 177 | 178 | return avg_f1 179 | 180 | 181 | # --------------------------------------------------------------------- 182 | model = DenseNet() 183 | model_compile(model) 184 | model_train(model, X, y, init_epoch_number=0, epoch_train=EPOCH_TRAIN) 185 | model.save_branch('base_model') 186 | 187 | # --------------------------------------------------------------------- 188 | # training data 189 | 190 | y_pred = model.predict(X, batch_size=BATCH_SIZE) 191 | true_part, total_part = get_overall_accuracy(y, y_pred) 192 | 193 | true_class_part, total_class_part, total_pred_part = get_class_wise_accuracy(y, y_pred, prf=True) 194 | pre, rec, f1 = get_prf(true_class_part, total_class_part, total_pred_part) 195 | 196 | loss = dice(y.astype('float32'), y_pred) 197 | loss = tf.reduce_mean(loss).numpy() 198 | 199 | with np.printoptions(precision=4, suppress=True, threshold=NUM_CLASS): 200 | print('tr_accuracy = {:.4f}\nloss = {}\nf1 = {}'.format(true_part / total_part, loss, f1)) 201 | 202 | GLOBAL_MEAN = get_avg_f1(f1, total_class_part) 203 | print(GLOBAL_MEAN) 204 | 205 | 206 | # ------------------------------------------------------------------------------------- 207 | def get_partition_data(index1, index2, X_data, y_data, all_partitioning_data_list): 208 | X_train = [] 209 | y_train = [] 210 | 211 | data_list = all_partitioning_data_list[(index1, index2)] 212 | for i in range(index1 * index2): 213 | X_train.append(X_data[data_list[i]]) 214 | y_train.append(y_data[data_list[i]]) 215 | 216 | return X_train, y_train 217 | 218 | 219 | def get_weighted_f1_one(model, X_test, y_test, all_partitioning_data_list, index1, index2): 220 | w_f1_list = np.zeros((index1 * index2), dtype='float') 221 | 222 | y_pred = model.predict(X_test, batch_size=BATCH_SIZE) 223 | true_pred_w_class, pred_w_class = get_class_wise_list(y_test, y_pred) 224 | 225 | data_list = all_partitioning_data_list[(index1, index2)] 226 | 227 | for i in range(index1 * index2): 228 | true_class_part = np.sum(true_pred_w_class[data_list[i]], axis=0).reshape(-1) 229 | total_class_part = np.sum(y_test[data_list[i]], axis=0).reshape(-1) 230 | total_pred_part = np.sum(pred_w_class[data_list[i]], axis=0).reshape(-1) 231 | pre, rec, f1 = get_prf(true_class_part, total_class_part, total_pred_part) 232 | 233 | w_f1_list[i] = get_avg_f1(f1, total_class_part) 234 | 235 | return w_f1_list 236 | 237 | 238 | def set_lr_weight(index1, index2, w_f1): 239 | global lr_list, GLOBAL_MEAN 240 | 241 | lr_list = np.zeros((index1 * index2)) 242 | 243 | print(w_f1) 244 | 245 | lr_list = (GLOBAL_MEAN - w_f1) 246 | # percentage 247 | print(lr_list) 248 | lambda_value = 0.0005 249 | 250 | lr_list = tf.nn.relu(lr_list) 251 | lr_list = lr_list.numpy() 252 | 253 | if np.sum(lr_list) > 0: 254 | lr_list = lr_list / np.sum(lr_list) * lambda_value * (index1 * index2) 255 | 256 | # maximum to lambda_value 257 | lr_list = lr_list / (np.amax(lr_list) / lambda_value) 258 | 259 | # avoid zero 260 | lr_list += K.epsilon() 261 | 262 | return lr_list 263 | 264 | 265 | # ------------------------------------------------------------------------------ 266 | ROW_LIST = list(range(1, MAX_ROW_PARTITION + 1)) 267 | COL_LIST = list(range(1, MAX_COL_PARTITION + 1)) 268 | 269 | PARTITIONINGS = [] 270 | 271 | for r in ROW_LIST: 272 | for c in COL_LIST: 273 | if r == 1 and c == 1: 274 | continue 275 | 276 | PARTITIONINGS.append((r, c)) 277 | 278 | print(PARTITIONINGS) 279 | 280 | # --------------------------------------------------------------------------------- 281 | # Bi-level 282 | 283 | start_time = time.time() 284 | 285 | # RECURSIVELY ------ fairness training ------ 286 | loop_list = [4, 30] 287 | epoch_list = [5, 1] 288 | 289 | F1_MEAN = GLOBAL_MEAN 290 | 291 | for i in range(len(loop_list)): 292 | 293 | epochs = epoch_list[i] 294 | loops = loop_list[i] 295 | 296 | for l in range(loops): 297 | 298 | random.shuffle(PARTITIONINGS) 299 | 300 | for j in range(len(PARTITIONINGS)): 301 | (index1, index2) = PARTITIONINGS[j] 302 | 303 | X_train, y_train = get_partition_data(index1, index2, X, y, train_id) 304 | 305 | print(index1, index2) 306 | print('start training') 307 | 308 | w_f1 = get_weighted_f1_one(model, X, y, train_id, index1, index2) 309 | lr_list = set_lr_weight(index1, index2, w_f1) 310 | 311 | for e in range(epochs): 312 | 313 | for p in range(index1 * index2): 314 | 315 | LEARNING_RATE = lr_list[p] 316 | for var in optimizer.variables(): 317 | var.assign(tf.zeros_like(var)) 318 | 319 | model(tf.ones((1, INPUT_SIZE))) 320 | model_train(model, X_train[p], y_train[p], init_epoch_number=e, epoch_train=1) 321 | 322 | model.save_branch('model_final') 323 | 324 | print("Time: %f s" % (time.time() - start_time)) -------------------------------------------------------------------------------- /results/checkpoints/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "dnn_ckpt_model_final" 2 | all_model_checkpoint_paths: "dnn_ckpt_model_final" 3 | -------------------------------------------------------------------------------- /results/checkpoints/dnn_ckpt_base_model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-spatial/fair-ai-in-space/8ab1cac8f24e950ae2c500829c206830051bf826/results/checkpoints/dnn_ckpt_base_model.data-00000-of-00001 -------------------------------------------------------------------------------- /results/checkpoints/dnn_ckpt_base_model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-spatial/fair-ai-in-space/8ab1cac8f24e950ae2c500829c206830051bf826/results/checkpoints/dnn_ckpt_base_model.index -------------------------------------------------------------------------------- /results/checkpoints/dnn_ckpt_model_final.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-spatial/fair-ai-in-space/8ab1cac8f24e950ae2c500829c206830051bf826/results/checkpoints/dnn_ckpt_model_final.data-00000-of-00001 -------------------------------------------------------------------------------- /results/checkpoints/dnn_ckpt_model_final.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-spatial/fair-ai-in-space/8ab1cac8f24e950ae2c500829c206830051bf826/results/checkpoints/dnn_ckpt_model_final.index --------------------------------------------------------------------------------