├── Data splitting.py ├── ID3QNE-deepQnet.py ├── ID3QNE-evaluate.py ├── ID3QNE-main.py ├── README.md └── experiment ├── action distribution ├── 3D_dongzuo.py └── fig.3.png ├── bearman error └── 贝尔曼误差画图 │ ├── record_bellman_error_D3QN.npy │ ├── record_bellman_error_DQN.npy │ ├── record_bellman_error_WD3QN.npy │ └── 画图.py ├── expected return └── return_wuICU.py ├── sensitivity analysis ├── __pycache__ │ └── fangfa.cpython-38.pyc ├── 数据查看.py ├── 画图_loss.py └── 画图_return.py └── survival rate ├── __pycache__ └── deepQnet.cpython-38.pyc ├── deepQnet.py ├── main_shengcunlv-37.py ├── shengcunlv.py └── 图 ├── fig.5(a).png └── fig.5(b).png /Data splitting.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import numpy as np 4 | 5 | if __name__ == '__main__': 6 | # 计算时间 7 | start = time.perf_counter() 8 | # 载入总数据 9 | with open('患者总数据.pkl', 'rb') as file: 10 | MIMICtable = pickle.load(file) 11 | 12 | #####################模型参数设置############################## 13 | ncv = 10 # nr of crossvalidation runs (each is 80% training / 20% test)交叉验证运行的Nr(每次为80%训练/ 20%测试) 14 | icustayidlist = MIMICtable['icustayid'] 15 | icuuniqueids = np.unique(icustayidlist) 16 | N = icuuniqueids.size 17 | grp = np.floor(ncv * np.random.rand(N, 18 | 1) + 1) 19 | crossval1 = 1 # 利用等于1和不等于1 20 | crossval2 = 2 # 利用等于1和不等于1 21 | trainidx = icuuniqueids[np.where(grp > crossval2)[0]] 22 | validationidx = icuuniqueids[np.where(grp == crossval1)[0]] 23 | testidx = icuuniqueids[np.where(grp == crossval2)[0]] 24 | train = np.isin(icustayidlist, trainidx) 25 | validation = np.isin(icustayidlist, validationidx) 26 | test = np.isin(icustayidlist, testidx) 27 | # 保存到文件 28 | np.save('数据集/train.npy', train) 29 | np.save('数据集/validation.npy', validation) 30 | np.save('数据集/test.npy', test) 31 | 32 | -------------------------------------------------------------------------------- /ID3QNE-deepQnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim 5 | import torch.nn.functional as F 6 | import copy 7 | 8 | # gamma = 0.9 9 | device = 'cpu' 10 | 11 | 12 | class DistributionalDQN(nn.Module): 13 | def __init__(self, state_dim, n_actions, ): 14 | super(DistributionalDQN, self).__init__() 15 | 16 | self.conv = nn.Sequential( 17 | nn.Linear(state_dim, 128), 18 | nn.ReLU(), 19 | nn.Linear(128, 128), 20 | nn.ReLU(), 21 | ) 22 | self.fc_val = nn.Sequential( 23 | nn.Linear(128, 256), 24 | nn.ReLU(), 25 | nn.Linear(256, 1) 26 | ) 27 | self.fc_adv = nn.Sequential( 28 | nn.Linear(128, 256), 29 | nn.ReLU(), 30 | nn.Linear(256, n_actions) 31 | ) 32 | 33 | def forward(self, state): 34 | conv_out = self.conv(state) 35 | val = self.fc_val(conv_out) 36 | adv = self.fc_adv(conv_out) 37 | return val + adv - adv.mean(dim=1, keepdim=True) 38 | 39 | 40 | class Dist_DQN(object): 41 | def __init__(self, 42 | state_dim=37, 43 | num_actions=25, 44 | device='cpu', 45 | gamma=0.999, 46 | tau=0.1, ): 47 | self.device = device 48 | self.Q = DistributionalDQN(state_dim, num_actions, ).to(device) 49 | self.Q_target = copy.deepcopy(self.Q) 50 | self.tau = tau 51 | self.gamma = gamma 52 | self.num_actions = num_actions 53 | self.optimizer = torch.optim.Adam(self.Q.parameters(), lr=0.0001) 54 | 55 | def train(self, batchs, epoch): 56 | (state, next_state, action, next_action, reward, done, bloc_num, SOFAS) = batchs 57 | batch_s = 128 58 | uids = np.unique(bloc_num) 59 | num_batch = uids.shape[0] // batch_s # 分批次 60 | record_loss = [] 61 | sum_q_loss = 0 62 | Batch = 0 63 | for batch_idx in range(num_batch + 1): # MAX_Iteration 64 | # ===============两种方法:1.随机抽取;2.顺序抽取=========================== 65 | # -----1.随机抽取----------------- 66 | # batch_uids = np.random.choice(uids, batch_s) #replace=False 67 | # -----2.顺序抽取----------------- 68 | batch_uids = uids[batch_idx * batch_s: (batch_idx + 1) * batch_s] 69 | batch_user = np.isin(bloc_num, batch_uids) 70 | state_user = state[batch_user, :] 71 | next_state_user = next_state[batch_user, :] 72 | action_user = action[batch_user] 73 | next_action_user = next_action[batch_user] 74 | reward_user = reward[batch_user] 75 | done_user = done[batch_user] 76 | SOFAS_user = SOFAS[batch_user] 77 | batch = (state_user, next_state_user, action_user, next_action_user, reward_user, done_user, SOFAS_user) 78 | loss = self.compute_loss(batch) 79 | sum_q_loss += loss.item() 80 | self.optimizer.zero_grad() # 梯度清零 81 | loss.backward() # 反向传播 82 | self.optimizer.step() # 更新 83 | if Batch % 25 == 0: 84 | print('Epoch :', epoch, 'Batch :', Batch, 'Average Loss :', sum_q_loss / (Batch + 1)) 85 | record_loss1 = sum_q_loss / (Batch + 1) 86 | record_loss.append(record_loss1) 87 | if Batch % 100 == 0: 88 | self.polyak_target_update() 89 | Batch += 1 90 | return record_loss 91 | 92 | def polyak_target_update(self): 93 | for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()): 94 | target_param.data.copy_(param.data) 95 | 96 | def compute_loss(self, batch): 97 | state, next_state, action, next_action, reward, done, SOFA = batch 98 | gamma = 0.99 99 | end_multiplier = 1 - done 100 | batch_size = state.shape[0] 101 | range_batch = torch.arange(batch_size).long().to(device) 102 | log_Q_dist_prediction = self.Q(state) 103 | log_Q_dist_prediction1 = log_Q_dist_prediction[range_batch, action] 104 | q_eval4nex = self.Q(next_state) 105 | max_eval_next = torch.argmax(q_eval4nex, dim=1) 106 | with torch.no_grad(): 107 | Q_dist_target = self.Q_target(next_state) 108 | Q_target = Q_dist_target.clone().detach() 109 | Q_dist_eval = Q_dist_target[range_batch, max_eval_next] 110 | max_target_next = torch.argmax(Q_dist_target, dim=1) 111 | Q_dist_tar = Q_dist_target[range_batch, max_target_next] 112 | Q_target_pro = F.softmax(Q_target) 113 | pro1 = Q_target_pro[range_batch, max_eval_next] 114 | pro2 = Q_target_pro[range_batch, max_target_next] 115 | Q_dist_star = (pro1 / (pro1 + pro2)) * Q_dist_eval + (pro2 / (pro1 + pro2)) * Q_dist_tar 116 | log_Q_experience = Q_dist_target[range_batch, next_action.squeeze(1)] 117 | Q_experi = torch.where(SOFA < 4, log_Q_experience, Q_dist_star) 118 | targetQ1 = reward + (gamma * Q_experi * end_multiplier) 119 | return nn.SmoothL1Loss()(targetQ1, log_Q_dist_prediction1) 120 | 121 | def get_action(self, state): 122 | with torch.no_grad(): 123 | batch_size = state.shape[0] 124 | Q_dist = self.Q(state) 125 | a_star = torch.argmax(Q_dist, dim=1) 126 | return a_star 127 | -------------------------------------------------------------------------------- /ID3QNE-evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import warnings 4 | import torch.nn.functional as F 5 | 6 | warnings.filterwarnings("ignore") 7 | device = 'cpu' 8 | 9 | 10 | def do_eval(model, batchs, batch_size=128): 11 | (state, next_state, action, next_action, reward, done) = batchs 12 | Q_value = model.Q(state) 13 | agent_actions = torch.argmax(Q_value, dim=1) 14 | phy_actions = action 15 | Q_value_pro1 = F.softmax(Q_value) 16 | Q_value_pro_ind = torch.argmax(Q_value_pro1, dim=1) 17 | Q_value_pro_ind1 = range(len(Q_value_pro_ind)) 18 | Q_value_pro = Q_value_pro1[Q_value_pro_ind1, Q_value_pro_ind] 19 | return Q_value, agent_actions, phy_actions, Q_value_pro 20 | 21 | 22 | def do_test(model, Xtest, actionbloctest, bloctest, Y90, SOFA, reward_value, beat): 23 | bloc_max = max(bloctest) # 最大才20个阶段 24 | r = np.array([reward_value, -reward_value]).reshape(1, -1) 25 | r2 = r * (2 * (1 - Y90.reshape(-1, 1)) - 1) 26 | R3 = r2[:, 0] 27 | R4 = (R3 + reward_value) / (2 * reward_value) 28 | RNNstate = Xtest 29 | print('#### 生成测试集轨迹 ####') 30 | statesize = int(RNNstate.shape[1]) 31 | states = np.zeros((np.floor(RNNstate.shape[0]).astype('int64'), statesize)) 32 | actions = np.zeros((np.floor(RNNstate.shape[0]).astype('int64'), 1), dtype=int) 33 | next_actions = np.zeros((np.floor(RNNstate.shape[0]).astype('int64'), 1), dtype=int) 34 | rewards = np.zeros((np.floor(RNNstate.shape[0]).astype('int64'), 1)) 35 | next_states = np.zeros((np.floor(RNNstate.shape[0]).astype('int64'), statesize)) 36 | done_flags = np.zeros((np.floor(RNNstate.shape[0]).astype('int64'), 1)) 37 | bloc_num = np.zeros((np.floor(RNNstate.shape[0]).astype('int64'), 1)) 38 | blocnum1 = 1 39 | c = 0 40 | 41 | bloc_num_reward = 0 42 | for i in range(RNNstate.shape[0] - 1): # 每一行循环 43 | states[c] = RNNstate[i, :] 44 | actions[c] = actionbloctest[i] 45 | bloc_num[c] = blocnum1 46 | if (bloctest[i + 1] == 1): # end of trace for this patient 47 | next_states1 = np.zeros(statesize) 48 | next_actions1 = -1 49 | done_flags1 = 1 50 | blocnum1 = blocnum1 + 1 51 | bloc_num_reward += 1 52 | reward1 = -beat[0] * (SOFA[i]) + R3[i] 53 | bloc_num_reward = 0 54 | else: 55 | next_states1 = RNNstate[i + 1, :] 56 | next_actions1 = actionbloctest[i + 1] 57 | done_flags1 = 0 58 | blocnum1 = blocnum1 59 | reward1 = - beat[1] * (SOFA[i + 1] - SOFA[i]) 60 | bloc_num_reward += 1 61 | next_states[c] = next_states1 62 | next_actions[c] = next_actions1 63 | rewards[c] = reward1 64 | done_flags[c] = done_flags1 65 | c = c + 1 # 从0开始 66 | states[c] = RNNstate[c, :] 67 | actions[c] = actionbloctest[c] 68 | bloc_num[c] = blocnum1 69 | 70 | next_states1 = np.zeros(statesize) 71 | next_actions1 = -1 72 | done_flags1 = 1 73 | blocnum1 = blocnum1 + 1 74 | bloc_num_reward += 1 75 | reward1 = -beat[0] * (SOFA[c]) + R3[c] 76 | 77 | bloc_num_reward = 0 78 | next_states[c] = next_states1 79 | next_actions[c] = next_actions1 80 | rewards[c] = reward1 81 | done_flags[c] = done_flags1 82 | c = c + 1 # 从0开始 83 | bloc_num = bloc_num[:c, :] 84 | states = states[: c, :] 85 | next_states = next_states[: c, :] 86 | actions = actions[: c, :] 87 | next_actions = next_actions[: c, :] 88 | rewards = rewards[: c, :] 89 | done_flags = done_flags[: c, :] 90 | 91 | bloc_num = np.squeeze(bloc_num) 92 | actions = np.squeeze(actions) 93 | rewards = np.squeeze(rewards) 94 | done_flags = np.squeeze(done_flags) 95 | 96 | # numpy形式转化为tensor形式 97 | batch_size = states.shape[0] 98 | state = torch.FloatTensor(states).to(device) 99 | next_state = torch.FloatTensor(next_states).to(device) 100 | action = torch.LongTensor(actions).to(device) 101 | next_action = torch.LongTensor(next_actions).to(device) 102 | reward = torch.FloatTensor(rewards).to(device) 103 | done = torch.FloatTensor(done_flags).to(device) 104 | batchs = (state, next_state, action, next_action, reward, done, bloc_num) 105 | 106 | rec_phys_q = [] 107 | rec_agent_q = [] 108 | rec_agent_q_pro = [] 109 | rec_phys_a = [] 110 | rec_agent_a = [] 111 | rec_sur = [] 112 | rec_reward_user = [] 113 | batch_s = 128 114 | uids = np.unique(bloc_num) 115 | num_batch = uids.shape[0] // batch_s # 分批次 116 | for batch_idx in range(num_batch + 1): 117 | batch_uids = uids[batch_idx * batch_s: (batch_idx + 1) * batch_s] 118 | batch_user = np.isin(bloc_num, batch_uids) 119 | state_user = state[batch_user, :] 120 | next_state_user = next_state[batch_user, :] 121 | action_user = action[batch_user] 122 | next_action_user = next_action[batch_user] 123 | reward_user = reward[batch_user] 124 | done_user = done[batch_user] 125 | sur_Y90 = Y90[batch_user] 126 | 127 | batch = (state_user, next_state_user, action_user, next_action_user, reward_user, done_user) 128 | q_output, agent_actions, phys_actions, Q_value_pro = do_eval(model, batch) 129 | 130 | q_output_len = range(len(q_output)) 131 | agent_q = q_output[q_output_len, agent_actions] 132 | phys_q = q_output[q_output_len, phys_actions] 133 | 134 | rec_agent_q.extend(agent_q.detach().numpy()) 135 | rec_agent_q_pro.extend(Q_value_pro.detach().numpy()) 136 | 137 | rec_phys_q.extend(phys_q.detach().numpy()) 138 | rec_agent_a.extend(agent_actions.detach().numpy()) 139 | rec_phys_a.extend(phys_actions.detach().numpy()) 140 | rec_sur.extend(sur_Y90) 141 | rec_reward_user.extend(reward_user.detach().numpy()) 142 | 143 | np.save('Q值/shencunlv.npy', rec_sur) 144 | np.save('Q值/agent_bQ.npy', rec_agent_q) 145 | np.save('Q值/phys_bQ.npy', rec_phys_q) 146 | np.save('Q值/reward.npy', rec_reward_user) 147 | 148 | np.save('Q值/agent_actionsb.npy', rec_agent_a) 149 | np.save('Q值/phys_actionsb.npy', rec_phys_a) 150 | 151 | np.save('Q值/rec_agent_q_pro.npy', rec_agent_q_pro) 152 | -------------------------------------------------------------------------------- /ID3QNE-main.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A value-based deep reinforcement learning model with human expertise in optimal treatment of sepsis 2 | # WD3QNE-algorithm 3 | Architecture of WD3QNE algorithm. The real-time treatment process of the WD3QNE agent. Patients' vital signs and drug level are monitored and recorded in real time. Then construct continuous state space and discrete action space. The agent makes decisions based on states, actions, and clinician’s experience. Finally, patients are treated dynamically. 4 | 5 | ![Personalized Cell Therapies to Combat COVID-19(9)](https://user-images.githubusercontent.com/33052678/163176041-0e827223-3392-4aff-96e3-53243547a6fd.png) 6 | ![ID3QNE algorithm structure](https://user-images.githubusercontent.com/33052678/163176391-2eb188d4-a716-4819-be9d-5456ef65719a.PNG) 7 | 8 | ## Acknowledgement 9 | If you use code or concepts available in this repository, we would be grateful if you would: 10 | cite the paper: 11 | 12 | GB/T 7714: Wu X D, Li R C, He Z, et al. A value-based deep reinforcement learning model with human expertise in optimal treatment of sepsis[J]. npj Digital Medicine, 2023, 6(1): 15. 13 | 14 | MLA: Wu, XiaoDan, et al. "A value-based deep reinforcement learning model with human expertise in optimal treatment of sepsis." npj Digital Medicine 6.1 (2023): 15. 15 | 16 | APA: Wu, X., Li, R., He, Z., Yu, T., & Cheng, C. (2023). A value-based deep reinforcement learning model with human expertise in optimal treatment of sepsis. npj Digital Medicine, 6(1), 15. 17 | 18 | @article{wu2023value, 19 | title={A value-based deep reinforcement learning model with human expertise in optimal treatment of sepsis}, 20 | author={Wu, XiaoDan and Li, RuiChang and He, Zhen and Yu, TianZhi and Cheng, ChangQing}, 21 | journal={npj Digital Medicine}, 22 | volume={6}, 23 | number={1}, 24 | pages={15}, 25 | year={2023}, 26 | publisher={Nature Publishing Group UK London} 27 | } 28 | 29 | 30 | The open-source MIMIC-III data used in this present study can be retrieved from https://physionet.org/content/mimiciii/1.4/. 31 | The open-source eICU-RI data: http://eicu-crd.mit.edu/ 32 | 33 | -------------------------------------------------------------------------------- /experiment/action distribution/3D_dongzuo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib import cm 4 | 5 | if __name__ == '__main__': 6 | 7 | fig = plt.figure(figsize=(20, 4), facecolor='w') 8 | 9 | ax = fig.add_subplot(131, projection='3d') 10 | 11 | D3Q_actions = np.load('D3Q-37/agent_actionsb.npy', allow_pickle=True) 12 | phys_actions = np.load('phys_actionsb.npy', allow_pickle=True) 13 | 14 | inv_action_map = {} 15 | count = 0 16 | for i in range(5): 17 | for j in range(5): 18 | inv_action_map[count] = [i, j] 19 | count += 1 20 | 21 | phys_actions_tuple = [None for i in range(len(phys_actions))] 22 | 23 | for i in range(len(phys_actions)): 24 | phys_actions_tuple[i] = inv_action_map[phys_actions[i]] 25 | 26 | phys_actions_tuple = np.array(phys_actions_tuple) 27 | 28 | phys_actions_iv = phys_actions_tuple[:, 0] 29 | phys_actions_vaso = phys_actions_tuple[:, 1] 30 | hist, x_edges, y_edges = np.histogram2d(phys_actions_iv, phys_actions_vaso, bins=5) 31 | 32 | x_edges = np.arange(-0.5, 5) 33 | y_edges = np.arange(-0.5, 5) 34 | 35 | Z = hist.reshape(25) 36 | z_sun = sum(Z) 37 | height = np.zeros_like(Z) 38 | xx, yy = np.meshgrid(x_edges, y_edges) # 网格化坐标 39 | X, Y = xx.ravel(), yy.ravel() 40 | 41 | width = depth = 0.6 # 柱子的长和宽 42 | 43 | for i in range(0, 5): 44 | for j in range(0, 5): 45 | z = hist[i][j] # 该柱的高 46 | norm = plt.Normalize(0, 0.25) 47 | norm_values = norm(z / z_sun) 48 | map_vir = cm.get_cmap(name='spring') 49 | colors = map_vir(norm_values) 50 | sc1 = ax.bar3d(j, i, height, width, depth, z, color=colors) 51 | 52 | 53 | font1 = {'family': 'Arial', 'weight': 'normal'} 54 | 55 | ax.set_xlabel('VP dose', fontsize=14, fontname="Arial") 56 | ax.set_ylabel('IV fluid dose', fontsize=14, fontname="Arial") 57 | ax.set_zlabel('Action counts', fontsize=14, fontname="Arial") 58 | ax.set_title("Physician policy", fontsize=15, fontname="Arial") 59 | 60 | sm = cm.ScalarMappable(cmap=map_vir, norm=norm) # norm设置最大最小值 61 | sm.set_array([]) 62 | cb = plt.colorbar(sm) 63 | cb.ax.tick_params(labelsize=10) # 设置色标刻度字体大小。 64 | cb.ax.set_yticklabels([0, 1500, 3000, 4500, 6000, 7500], fontname="Arial") 65 | 66 | ax.set_ylim(5, -0.5) 67 | ax.view_init(32, 43) 68 | 69 | # =========================第二个图=================================================================== 70 | ax2 = fig.add_subplot(133, projection='3d') 71 | 72 | IDDE_actions = np.load('IDDE/Q值/agent_actionsb.npy', allow_pickle=True) 73 | 74 | IDDE_actions_tuple = [None for i in range(len(IDDE_actions))] 75 | for i in range(len(IDDE_actions)): 76 | IDDE_actions_tuple[i] = inv_action_map[IDDE_actions[i]] 77 | 78 | IDDE_actions_tuple = np.array(IDDE_actions_tuple) 79 | IDDE_actions_iv = IDDE_actions_tuple[:, 0] 80 | IDDE_actions_vaso = IDDE_actions_tuple[:, 1] 81 | IDDE_hist, x_edges, y_edges = np.histogram2d(IDDE_actions_iv, IDDE_actions_vaso, bins=5) 82 | 83 | x_edges = np.arange(-0.5, 5) 84 | y_edges = np.arange(-0.5, 5) 85 | 86 | Z = hist.reshape(25) 87 | z_sun = sum(Z) 88 | height = np.zeros_like(Z) 89 | xx, yy = np.meshgrid(x_edges, y_edges) # 网格化坐标 90 | X, Y = xx.ravel(), yy.ravel() 91 | 92 | for i in range(0, 5): 93 | for j in range(0, 5): 94 | z = IDDE_hist[i][j] # 该柱的高 95 | norm = plt.Normalize(0, 0.25) 96 | norm_values = norm(z / z_sun) 97 | map_vir = cm.get_cmap(name='cool') 98 | colors = map_vir(norm_values) 99 | sc2 = ax2.bar3d(j, i, height, width, depth, z, color=colors) 100 | 101 | ax2.set_xlabel('VP dose', fontsize=14, fontname="Arial") 102 | ax2.set_ylabel('IV fluid dose', fontsize=14, fontname="Arial") 103 | ax2.set_zlabel('Action counts', fontsize=14, fontname="Arial") 104 | ax2.set_title("ID3QNE policy", fontsize=15, fontname="Arial") 105 | 106 | sm2 = cm.ScalarMappable(cmap=map_vir, norm=norm) # norm设置最大最小值 107 | sm2.set_array([]) 108 | cb2 = plt.colorbar(sm2) 109 | cb2.ax.tick_params(labelsize=10) # 设置色标刻度字体大小。 110 | cb2.ax.set_yticklabels([0, 1500, 3000, 4500, 6000, 7500], fontname="Arial") 111 | 112 | ax2.set_ylim(5, -0.5) 113 | ax2.view_init(32, 43) 114 | 115 | # ================第三张图====================================================================== 116 | 117 | ax3 = fig.add_subplot(132, projection='3d') 118 | 119 | D3Q_actions_tuple = [None for i in range(len(D3Q_actions))] 120 | for i in range(len(D3Q_actions)): 121 | D3Q_actions_tuple[i] = inv_action_map[D3Q_actions[i]] 122 | 123 | D3Q_actions_tuple = np.array(D3Q_actions_tuple) 124 | D3Q_actions_iv = D3Q_actions_tuple[:, 0] 125 | D3Q_actions_vaso = D3Q_actions_tuple[:, 1] 126 | D3Q_hist, x_edges, y_edges = np.histogram2d(D3Q_actions_iv, D3Q_actions_vaso, bins=5) 127 | 128 | x_edges = np.arange(-0.5, 5) 129 | y_edges = np.arange(-0.5, 5) 130 | 131 | Z = hist.reshape(25) 132 | z_sun = sum(Z) 133 | height = np.zeros_like(Z) 134 | xx, yy = np.meshgrid(x_edges, y_edges) # 网格化坐标 135 | X, Y = xx.ravel(), yy.ravel() 136 | 137 | for i in range(0, 5): 138 | for j in range(0, 5): 139 | z = D3Q_hist[i][j] # 该柱的高 140 | norm = plt.Normalize(0, 0.25) 141 | norm_values = norm(z / z_sun) 142 | map_vir = cm.get_cmap(name='winter') 143 | colors = map_vir(norm_values) 144 | sc2 = ax3.bar3d(j, i, height, width, depth, z, color=colors) 145 | 146 | ax3.set_xlabel('VP dose', fontsize=14, fontname="Arial") 147 | ax3.set_ylabel('IV fluid dose', fontsize=14, fontname="Arial") 148 | ax3.set_zlabel('Action counts', fontsize=14, fontname="Arial") 149 | ax3.set_title("D3QN-37 policy", fontsize=15, fontname="Arial") 150 | 151 | sm3 = cm.ScalarMappable(cmap=map_vir, norm=norm) # norm设置最大最小值 152 | sm3.set_array([]) 153 | cb3 = plt.colorbar(sm3) 154 | cb3.ax.tick_params(labelsize=10) # 设置色标刻度字体大小。 155 | cb3.ax.set_yticklabels([0, 1500, 3000, 4500, 6000, 7500], fontname="Arial") 156 | 157 | ax3.set_ylim(5, -0.5) 158 | ax3.view_init(32, 43) 159 | 160 | plt.show() 161 | -------------------------------------------------------------------------------- /experiment/action distribution/fig.3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/action distribution/fig.3.png -------------------------------------------------------------------------------- /experiment/bearman error/贝尔曼误差画图/record_bellman_error_D3QN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/bearman error/贝尔曼误差画图/record_bellman_error_D3QN.npy -------------------------------------------------------------------------------- /experiment/bearman error/贝尔曼误差画图/record_bellman_error_DQN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/bearman error/贝尔曼误差画图/record_bellman_error_DQN.npy -------------------------------------------------------------------------------- /experiment/bearman error/贝尔曼误差画图/record_bellman_error_WD3QN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/bearman error/贝尔曼误差画图/record_bellman_error_WD3QN.npy -------------------------------------------------------------------------------- /experiment/bearman error/贝尔曼误差画图/画图.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | bellman_error_DQN=np.load("record_bellman_error_DQN.npy",allow_pickle=True) 5 | bellman_error_D3QN=np.load("record_bellman_error_D3QN.npy",allow_pickle=True) 6 | bellman_error_WD3QN=np.load("record_bellman_error_WD3QN.npy",allow_pickle=True) 7 | 8 | x_length_list = list(range(len(bellman_error_DQN))) 9 | plt.figure() 10 | 11 | plt.plot(x_length_list, bellman_error_DQN,'--',label= 'Dueling DQN') 12 | plt.plot(x_length_list, bellman_error_D3QN,'-.',label= 'D3QN') 13 | plt.plot(x_length_list, bellman_error_WD3QN,label= 'WD3QN',color='red') 14 | 15 | font1 = {'family': 'Arial', 'weight': 'normal','size':12} 16 | 17 | plt.xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], fontname="Arial", ) 18 | 19 | plt.legend(prop=font1) 20 | plt.xlabel("Epochs", font1) 21 | plt.ylabel("Bellman error", font1) 22 | # 23 | plt.show() -------------------------------------------------------------------------------- /experiment/expected return/return_wuICU.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | if __name__ == '__main__': 5 | DQN_45 = np.load('ICU/DQN-45.npy', allow_pickle=True) 6 | DQN_37 = np.load('ICU/DQN-37.npy', allow_pickle=True) 7 | DDQN_45 = np.load('ICU/DDQN-45.npy', allow_pickle=True) 8 | DDQN_37 = np.load('ICU/DDQN-37.npy', allow_pickle=True) 9 | D3QN_45 = np.load('ICU/D3QN-45.npy', allow_pickle=True) 10 | D3QN_37 = np.load('ICU/D3QN-37.npy', allow_pickle=True) 11 | ID3QN = np.load('ICU/ID3QN.npy', allow_pickle=True) 12 | ID3QNE = np.load('ICU/ID3QNE.npy', allow_pickle=True) 13 | x_length_list = list(range(len(DQN_45))) 14 | ax = plt.figure() 15 | plt.plot(x_length_list, DQN_45, '-.', label='DQN_45', markersize=3) 16 | plt.plot(x_length_list, DQN_37, 'v', label='DQN_37', markersize=1) 17 | plt.plot(x_length_list, DDQN_45, '.', label='DDQN_45', markersize=3) 18 | plt.plot(x_length_list, DDQN_37, '*', label='DDQN_37', markersize=3) 19 | plt.plot(x_length_list, D3QN_45, '<', label='D3QN_45', markersize=3) 20 | plt.plot(x_length_list, D3QN_37, '+', label='D3QN_37', markersize=3) 21 | plt.plot(x_length_list, ID3QN, '--', color='y', label='ID3QN') 22 | plt.plot(x_length_list, ID3QNE, '-', color='r', label='ID3QNE') 23 | font1 = {'family': 'Arial', 'weight': 'normal'} 24 | plt.legend(prop=font1) 25 | plt.xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], fontname="Arial") 26 | plt.yticks([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24], fontname="Arial") 27 | plt.xlabel("Epochs", font1) 28 | plt.ylabel("Expected return", font1) 29 | plt.show() 30 | 31 | -------------------------------------------------------------------------------- /experiment/sensitivity analysis/__pycache__/fangfa.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/sensitivity analysis/__pycache__/fangfa.cpython-38.pyc -------------------------------------------------------------------------------- /experiment/sensitivity analysis/数据查看.py: -------------------------------------------------------------------------------- 1 | #有用 做的一个数据集分割 timestep 2 | 3 | import time 4 | import pickle #打开pkl包 5 | import numpy as np 6 | 7 | from scipy.stats import zscore, rankdata 8 | def my_zscore(x): 9 | return zscore(x,ddof=1),np.mean(x,axis=0),np.std(x,axis=0,ddof=1) 10 | with open('1小时.pkl', 'rb') as file: 11 | MIMICtable = pickle.load(file) # [278751 rows x 59 columns]的数据 12 | reformat5 = MIMICtable.values.copy() 13 | 14 | 15 | 16 | 17 | # -----------------------筛选后的特征=37个-------------------------------- 18 | colnorm = ['SOFA', 'age', 'Weight_kg', 'GCS', 'HR', 'SysBP', 'MeanBP', 'DiaBP', 'RR', 'Temp_C', 19 | 'Sodium', 'Chloride', 'Glucose', 'Calcium', 'Hb', 'WBC_count', 'Platelets_count', 20 | 'PTT', 'PT', 'Arterial_pH', 'paO2', 'paCO2', 'HCO3', 'Arterial_lactate', 'Shock_Index', 21 | 'PaO2_FiO2', 'cumulated_balance', 'CO2_mEqL', 'Ionised_Ca'] 22 | ##8个指标 23 | collog = ['SpO2', 'BUN', 'Creatinine', 'SGOT', 'Total_bili', 'INR', 'input_total', 'output_total'] 24 | 25 | colnorm = np.where(np.isin(MIMICtable.columns, colnorm))[0] 26 | collog = np.where(np.isin(MIMICtable.columns, collog))[0] 27 | scaleMIMIC = np.concatenate([zscore(reformat5[:, colnorm], ddof=1), 28 | zscore(np.log(0.1 + reformat5[:, collog]), ddof=1)], axis=1) 29 | MIMICtablecopy=MIMICtable 30 | stayID=reformat5[:,1] 31 | blocID=reformat5[:,1] 32 | bloc=1 33 | cout=0 34 | timestep=8 35 | for i in range(0,len(blocID)-timestep,timestep): 36 | MIMICtablecopy.iloc[cout,1:]=MIMICtable.iloc[i,1:] 37 | 38 | MIMICtablecopy.iloc[cout, 0] = bloc 39 | cout = cout + 1 40 | bloc=bloc+1 41 | if stayID[i]!=stayID[i+timestep]: 42 | bloc = 1 43 | MIMICtablecopydown=MIMICtablecopy.iloc[0:cout,:] 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /experiment/sensitivity analysis/画图_loss.py: -------------------------------------------------------------------------------- 1 | #灵敏度分析中的 loss 图 2 | import pickle #打开pkl包 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | loss_8=np.load('第一次数据保存全/8小时验证集/loss.npy') 8 | loss_6=np.load('第一次数据保存全/6小时验证集/loss.npy') 9 | loss_4=np.load('第一次数据保存全/4小时验证集/loss.npy') 10 | loss_2=np.load('第一次数据保存全/2小时验证集/loss.npy') 11 | loss_1=np.load('第一次数据保存全/1小时验证集/loss.npy') 12 | 13 | plt.figure() 14 | # plt.title('Training') 15 | plt.plot(loss_length,loss_1,'-o',label= '1 hour',markersize=3,color='burlywood') 16 | plt.plot(loss_length,loss_2,'-v',label= '2 hours',markersize=3,color='plum') 17 | plt.plot(loss_length,loss_4,'-',label= '4 hours',markersize=3,color='red') 18 | plt.plot(loss_length,loss_6,'--',label= '6 hours',markersize=3,color='cornflowerblue') 19 | plt.plot(loss_length,loss_8,'-.',label= '8 hours',markersize=3,color='lime') 20 | 21 | 22 | font1 = {'family': 'Arial', 'weight': 'normal','size':12} 23 | 24 | plt.xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], fontname="Arial", ) 25 | plt.yticks([6, 8, 10, 12, 14, 16, 18, 20, 22], fontname="Arial") 26 | 27 | plt.legend(prop=font1) 28 | plt.xlabel("Epochs", font1) 29 | plt.ylabel("Loss value", font1) 30 | plt.show() 31 | 32 | print() 33 | -------------------------------------------------------------------------------- /experiment/sensitivity analysis/画图_return.py: -------------------------------------------------------------------------------- 1 | #灵敏度分析中的 期望回报图 2 | 3 | import pickle #打开pkl包 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | loss_8=np.load('第一次数据保存全/8小时验证集/mean_agent_q.npy',allow_pickle=True) 9 | loss_6=np.load('第一次数据保存全/6小时验证集/mean_agent_q.npy',allow_pickle=True) 10 | loss_4=np.load('第一次数据保存全/4小时验证集/mean_agent_q.npy',allow_pickle=True) 11 | loss_2=np.load('第一次数据保存全/2小时验证集/mean_agent_q.npy',allow_pickle=True) 12 | loss_1=np.load('第一次数据保存全/1小时验证集/mean_agent_q.npy',allow_pickle=True) 13 | 14 | loss_length = list(range(len(loss_8))) 15 | plt.figure() 16 | 17 | 18 | plt.plot(loss_length ,loss_8,'-o',label= '8 hour',markersize=3,color='burlywood') 19 | plt.plot(loss_length ,loss_6,'-v',label= '6 hours',markersize=3,color='plum') 20 | plt.plot(loss_length ,loss_4,'-',label= '4 hours',markersize=3,color='red') 21 | plt.plot(loss_length ,loss_2,'--',label= '2 hours',markersize=3,color='cornflowerblue') 22 | plt.plot(loss_length ,loss_1,'-.',label= '1 hours',markersize=3,color='lime') 23 | 24 | 25 | font1 = {'family': 'Arial', 'weight': 'normal','size':12} 26 | 27 | plt.xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], fontname="Arial", ) 28 | plt.yticks([0, 5, 10, 15, 20, 25], fontname="Arial") 29 | 30 | plt.legend(prop=font1) 31 | plt.xlabel("Epochs", font1) 32 | plt.ylabel("Expected return", font1) 33 | 34 | plt.show() 35 | -------------------------------------------------------------------------------- /experiment/survival rate/__pycache__/deepQnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/survival rate/__pycache__/deepQnet.cpython-38.pyc -------------------------------------------------------------------------------- /experiment/survival rate/deepQnet.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle #打开pkl包 3 | import numpy as np 4 | from ismember import ismember 5 | import pandas as pd 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim 9 | import torch.nn.functional as F 10 | import copy 11 | import os 12 | gamma = 0.9 13 | device='cpu' 14 | 15 | 16 | 17 | class DistributionalDQN(nn.Module): 18 | def __init__(self, state_dim, n_actions, N_ATOMS): 19 | super(DistributionalDQN, self).__init__() 20 | 21 | self.input_layer = nn.Linear(state_dim, 128) 22 | self.hiddens = nn.ModuleList([nn.Sequential( 23 | nn.Linear(128, 128), 24 | nn.ReLU()) for _ in range(7)]) 25 | 26 | self.out = nn.Linear(128, n_actions) 27 | 28 | def forward(self, state): 29 | batch_size = state.size()[0] 30 | out = self.input_layer(state) 31 | for layer in self.hiddens: 32 | out = layer(out) 33 | 34 | out = self.out(out) 35 | return out 36 | 37 | 38 | class dist_DQN(object): 39 | def __init__(self, 40 | state_dim=37, 41 | num_actions=25, 42 | v_max=20, 43 | v_min=-20, 44 | device='cpu', 45 | gamma=0.9, 46 | tau=0.005, 47 | n_atoms=51 48 | ): 49 | self.device=device 50 | 51 | self.Q = DistributionalDQN(state_dim, num_actions, n_atoms).to(self.device) 52 | self.Q_target = copy.deepcopy(self.Q) 53 | self.optimizer = torch.optim.Adam(self.Q.parameters(), lr=0.000005) #0.00001 54 | self.tau = tau 55 | self.gamma = gamma 56 | self.v_min = v_min 57 | self.v_max = v_max 58 | 59 | self.num_actions = num_actions 60 | self.atoms = n_atoms 61 | 62 | def train(self,batchs,epoch): 63 | 64 | (state, next_state, action, next_action, reward, done,bloc_num)=batchs 65 | states_num = state.shape[0] 66 | batch_s = 128 67 | uids = np.unique(bloc_num) 68 | num_batch = uids.shape[0] // batch_s # 分批次 69 | record_loss_num = 0 70 | record_loss = [] 71 | 72 | sum_q_loss = 0 73 | Batch = 0 74 | for batch_idx in range(num_batch): 75 | batch_uids = uids[batch_idx * batch_s: (batch_idx + 1) * batch_s] 76 | batch_user = np.isin(bloc_num, batch_uids) 77 | state_user = state[batch_user, :] 78 | next_state_user = next_state[batch_user, :] 79 | action_user = action[batch_user] 80 | next_action_user = next_action[batch_user] 81 | reward_user = reward[batch_user] 82 | done_user = done[batch_user] 83 | batch = (state_user, next_state_user, action_user, next_action_user,reward_user, done_user) 84 | loss = self.compute_loss(batch) 85 | sum_q_loss += loss.item() 86 | if Batch % 25 == 0: 87 | print('Epoch :', epoch, 'Batch :', Batch, 'Average Loss :', sum_q_loss / (Batch + 1)) 88 | record_loss1 = sum_q_loss / (Batch + 1) 89 | record_loss.append(record_loss1) 90 | 91 | self.optimizer.zero_grad() #梯度清零 92 | loss.backward() #反向传播 93 | self.optimizer.step() #更新 94 | 95 | #更新Q_target网络参数 96 | if num_batch%50==0: 97 | self.polyak_target_update() 98 | 99 | Batch += 1 100 | 101 | return record_loss 102 | 103 | 104 | def polyak_target_update(self): #更新网络 105 | for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()): 106 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 107 | 108 | def compute_loss(self, batch): 109 | state, next_state, action, next_action, reward, done = batch 110 | batch_size = state.shape[0] 111 | range_batch = torch.arange(batch_size).long().to(device) 112 | 113 | #利用神经网络输出动作 114 | log_Q_dist_prediction = self.Q(state) 115 | log_Q_dist_prediction1 = log_Q_dist_prediction[range_batch, action] # 1810 一个一维的数 116 | 117 | with torch.no_grad(): 118 | Q_dist_target= self.Q_target(next_state) 119 | 120 | #求最大值 121 | a_star = torch.argmax(Q_dist_target, dim=1) 122 | 123 | log_Q_experience = Q_dist_target[range_batch, next_action.squeeze(1)] 124 | 125 | #最大的Q值 126 | Q_dist_star = Q_dist_target[range_batch, a_star] 127 | 128 | # 更新 targetQ Q值============================= 129 | end_multiplier = 1 - done 130 | eplion=0.2 131 | 132 | targetQ = reward + (gamma * end_multiplier* (Q_dist_star+eplion*(log_Q_experience-Q_dist_star))) 133 | 134 | td_error = torch.square(targetQ - log_Q_dist_prediction1) 135 | old_loss = torch.mean(td_error) 136 | 137 | return old_loss 138 | 139 | def get_action(self, state): 140 | with torch.no_grad(): 141 | batch_size = state.shape[0] 142 | Q_dist= self.Q(state) 143 | a_star = torch.argmax(Q_dist, dim=1) 144 | return a_star 145 | -------------------------------------------------------------------------------- /experiment/survival rate/main_shengcunlv-37.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from deepQnet import dist_DQN 8 | 9 | from torchsummary import summary 10 | 11 | device = 'cpu' 12 | from scipy.stats import zscore, rankdata 13 | 14 | def my_zscore(x): 15 | return zscore(x, ddof=1), np.mean(x, axis=0), np.std(x, axis=0, ddof=1) 16 | 17 | def sliding_mean(data_array, window=1): 18 | new_list = [] 19 | for i in range(len(data_array)): 20 | indices = range(max(i - window + 1, 0), 21 | min(i + window + 1, len(data_array))) 22 | avg = 0 23 | for j in indices: 24 | avg += data_array[j] 25 | avg /= float(len(indices)) 26 | new_list.append(avg) 27 | return np.array(new_list) 28 | 29 | 30 | if __name__ == '__main__': 31 | from scipy.stats import sem 32 | 33 | agent_bQ = np.load('生存率数据/Q值-ICU的-37/agent_bQ.npy') 34 | phys_bQ = np.load('生存率数据/Q值-ICU的-37/phys_bQ.npy') 35 | shencunlv1 = np.load('生存率数据/Q值-ICU的-37/shencunlv.npy') 36 | shencunlv = 1 - shencunlv1 37 | agent_b = agent_bQ[:] 38 | phys_b = phys_bQ[:] 39 | 40 | max_agent_b = max(agent_b) 41 | min_agent_b = min(agent_b) 42 | 43 | max_phys_b = max(phys_b) 44 | min_phys_b = min(phys_b) 45 | 46 | bin_medians = [] 47 | mort = [] 48 | mort_std = [] 49 | 50 | i = -7 51 | while i <= 25: 52 | shengcun_bool = (agent_bQ[:] > i - 0.5) & (agent_bQ[:] < i + 0.5) 53 | count = shencunlv[shengcun_bool == True] 54 | try: 55 | res = sum(count) / float(len(count)) 56 | if len(count) >= 2: 57 | bin_medians.append(i) 58 | mort.append(res) 59 | mort_std.append(sem(count)) 60 | except ZeroDivisionError: 61 | pass 62 | i += 1 63 | i = -7 64 | phys_bin_medians = [] 65 | phys_mort = [] 66 | phys_mort_std = [] 67 | 68 | while i <= 8: 69 | shengcun_bool = (phys_bQ[:] > i - 0.5) & (phys_bQ[:] < i + 0.5) 70 | count = shencunlv[shengcun_bool == True] 71 | 72 | try: 73 | res = sum(count) / float(len(count)) 74 | if len(count) >= 2: 75 | phys_bin_medians.append(i) 76 | phys_mort.append(res) 77 | phys_mort_std.append(sem(count)) 78 | except ZeroDivisionError: 79 | pass 80 | i += 1 81 | 82 | plt.plot(bin_medians, sliding_mean(mort), color='g') 83 | plt.fill_between(bin_medians, sliding_mean(mort) - 1 * mort_std, 84 | sliding_mean(mort) + 1 * mort_std, color='palegreen') 85 | 86 | x_r = [i / 1.0 for i in range(-7, 27, 3)] 87 | y_r = [i / 10.0 for i in range(0, 11, 1)] 88 | plt.xticks(x_r, fontname="Arial") 89 | plt.yticks(y_r, fontname="Arial") 90 | 91 | plt.title('The reward with 37 observation features', fontname="Arial") 92 | plt.xlabel("Expected Return", fontname="Arial") 93 | plt.ylabel("Survival Rate", fontname="Arial") 94 | 95 | plt.show() 96 | -------------------------------------------------------------------------------- /experiment/survival rate/shengcunlv.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /experiment/survival rate/图/fig.5(a).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/survival rate/图/fig.5(a).png -------------------------------------------------------------------------------- /experiment/survival rate/图/fig.5(b).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaryLi666/ID3QNE-algorithm/1fcb321c175c85ea49be25d7edf553a1161de2bd/experiment/survival rate/图/fig.5(b).png --------------------------------------------------------------------------------