├── model └── PLACEHOLDER ├── Extended Results.pdf ├── data └── PLACEHOLDER ├── code ├── data_range.py ├── reward_fun.py ├── grid_networks.py ├── networks.py ├── util.py ├── constants.py ├── grid_simulator │ └── grid_model.py ├── simulator │ └── disease_model.py ├── grid_train.py └── train.py └── readme.md /model/PLACEHOLDER: -------------------------------------------------------------------------------- 1 | PLACEHOLDER 2 | trained models will be stored in this folder -------------------------------------------------------------------------------- /Extended Results.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsinghua-fib-lab/Large-Scale-MARL-GATMF/HEAD/Extended Results.pdf -------------------------------------------------------------------------------- /data/PLACEHOLDER: -------------------------------------------------------------------------------- 1 | PLACEHOLDER 2 | please download the required data (~10GB) from https://drive.google.com/drive/folders/1-68jPOd6NXVyiC1PWbo-9wrqiktOi4GT?usp=sharing and substitute this folder 3 | -------------------------------------------------------------------------------- /code/data_range.py: -------------------------------------------------------------------------------- 1 | Gmat={'Atlanta_random':10000,'Atlanta':10000} 2 | Idata={'Atlanta_random':20,'Atlanta':1} 3 | Ddata={'Atlanta_random':5,'Atlanta':0.05} 4 | cbg_size={'Atlanta_random':10000,'Atlanta':10000} 5 | 6 | Sdata_diff={'Atlanta_random':0.01,'Atlanta':0.01} 7 | Idata_diff={'Atlanta_random':0.01,'Atlanta':0.01} 8 | Ddata_diff={'Atlanta_random':0.0005,'Atlanta':0.0005} -------------------------------------------------------------------------------- /code/reward_fun.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def reward_fun(s,s1,D_weight=0,option=0): 4 | reward_C=s1[:,:,1]-s[:,:,1] 5 | reward_D=s1[:,:,2]-s[:,:,2] 6 | 7 | if option==0: 8 | reward=-np.clip(np.sqrt(reward_C+D_weight*reward_D),0,0.4)[:,:,np.newaxis] 9 | 10 | elif option==1: 11 | reward=-1000*(reward_C+D_weight*reward_D)[:,:,np.newaxis] 12 | 13 | elif option==2: 14 | reward=-100*(reward_C+D_weight*reward_D)[:,:,np.newaxis] 15 | 16 | elif option==3: 17 | reward=-10*(reward_C+D_weight*reward_D)[:,:,np.newaxis] 18 | 19 | return reward -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # GAT-MF: Graph Attention Mean Field for Very Large Scale Multi-Agent Reinforcement Learning (KDD'23) 2 | 3 | ## To run the grid-world experiments, you need: 4 | - code/grid_simulator/: simultor for the grid-world experiment 5 | - code/grid_networks.py: neural networks definition for the grid-world experiment 6 | - code/grid_train.py: run this Python file to execute the training 7 | - model/: trained models will be stored to this folder 8 | 9 | ## To run the real-world experiments, you need: 10 | - data/: please download the required data (~10GB) from https://drive.google.com/drive/folders/1-68jPOd6NXVyiC1PWbo-9wrqiktOi4GT?usp=sharing and substitute this folder 11 | - code/simulator/: simultor for the real-world experiment 12 | - code/networks.py: neural networks definition for the real-world experiment 13 | - code/constants.py: intrinsic constants for the real-world experiment 14 | - code/data_range.py: data ranges for normalizations in the real-world experiment 15 | - code/reward_fun.py: reward function definition for the real-world experiment 16 | - code/util.py: util functions for the real-world experiment 17 | - code/train.py: run this Python file to execute the training 18 | - model/: trained models will be stored to this folder 19 | -------------------------------------------------------------------------------- /code/grid_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | Actor_in_features=10 6 | Actor_hidden_features1=32 7 | Actor_hidden_features2=16 8 | 9 | Critic_in_features=18 10 | Critic_hidden_features1=32 11 | Critic_hidden_features2=16 12 | 13 | Attention_in_features=5 14 | Attention_hidden_features=32 15 | 16 | class Actor(nn.Module): 17 | def __init__(self): 18 | super(Actor,self).__init__() 19 | 20 | self.lin1=nn.Linear(in_features=Actor_in_features,out_features=Actor_hidden_features1) 21 | self.lin2=nn.Linear(in_features=Actor_hidden_features1,out_features=Actor_hidden_features2) 22 | self.lin3=nn.Linear(in_features=Actor_hidden_features2,out_features=4) 23 | 24 | def forward(self,s): 25 | s=F.leaky_relu(self.lin1(s)) 26 | s=F.leaky_relu(self.lin2(s)) 27 | s=self.lin3(s) 28 | 29 | return s 30 | 31 | class Critic(nn.Module): 32 | def __init__(self): 33 | super(Critic,self).__init__() 34 | 35 | self.lin1=nn.Linear(in_features=Critic_in_features,out_features=Critic_hidden_features1) 36 | self.lin2=nn.Linear(in_features=Critic_hidden_features1,out_features=Critic_hidden_features2) 37 | self.lin3=nn.Linear(in_features=Critic_hidden_features2,out_features=1) 38 | 39 | def forward(self,s,a): 40 | x=torch.concat((s,a),dim=2) 41 | x=F.leaky_relu(self.lin1(x)) 42 | x=F.leaky_relu(self.lin2(x)) 43 | x=self.lin3(x) 44 | 45 | return x 46 | 47 | class Attention(nn.Module): 48 | def __init__(self): 49 | super(Attention,self).__init__() 50 | 51 | self.Qweight=nn.Parameter(torch.rand(Attention_in_features,Attention_hidden_features)*((4/Attention_in_features)**0.5)-(1/Attention_in_features)**0.5) 52 | self.Kweight=nn.Parameter(torch.rand(Attention_in_features,Attention_hidden_features)*((4/Attention_in_features)**0.5)-(1/Attention_in_features)**0.5) 53 | 54 | def forward(self,s,Gmat): 55 | q=torch.einsum('ijk,km->ijm',s,self.Qweight) 56 | k=torch.einsum('ijk,km->ijm',s,self.Kweight).permute(0, 2, 1) 57 | 58 | att=torch.square(torch.bmm(q,k))*Gmat 59 | att=att/(torch.sum(att,dim=2,keepdim=True)+0.001) 60 | 61 | return att -------------------------------------------------------------------------------- /code/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | Actor_in_features=12 7 | Actor_hidden_features1=32 8 | Actor_hidden_features2=16 9 | 10 | Critic_in_features=14 11 | Critic_hidden_features1=32 12 | Critic_hidden_features2=16 13 | 14 | Attention_in_features=6 15 | Attention_hidden_features=32 16 | 17 | class Actor(nn.Module): 18 | def __init__(self): 19 | super(Actor,self).__init__() 20 | 21 | self.lin1=nn.Linear(in_features=Actor_in_features,out_features=Actor_hidden_features1) 22 | self.lin2=nn.Linear(in_features=Actor_hidden_features1,out_features=Actor_hidden_features2) 23 | self.lin3=nn.Linear(in_features=Actor_hidden_features2,out_features=1) 24 | 25 | def forward(self,s): 26 | s=F.leaky_relu(self.lin1(s)) 27 | s=F.leaky_relu(self.lin2(s)) 28 | s=self.lin3(s) 29 | 30 | return s 31 | 32 | class Critic(nn.Module): 33 | def __init__(self): 34 | super(Critic,self).__init__() 35 | 36 | self.lin1=nn.Linear(in_features=Critic_in_features,out_features=Critic_hidden_features1) 37 | self.lin2=nn.Linear(in_features=Critic_hidden_features1,out_features=Critic_hidden_features2) 38 | self.lin3=nn.Linear(in_features=Critic_hidden_features2,out_features=1) 39 | 40 | def forward(self,s,a): 41 | x=torch.concat((s,a),dim=2) 42 | x=F.leaky_relu(self.lin1(x)) 43 | x=F.leaky_relu(self.lin2(x)) 44 | x=self.lin3(x) 45 | 46 | return x 47 | 48 | class Attention(nn.Module): 49 | def __init__(self): 50 | super(Attention,self).__init__() 51 | 52 | self.Qweight=nn.Parameter(torch.rand(Attention_in_features,Attention_hidden_features)*((4/Attention_in_features)**0.5)-(1/Attention_in_features)**0.5) 53 | self.Kweight=nn.Parameter(torch.rand(Attention_in_features,Attention_hidden_features)*((4/Attention_in_features)**0.5)-(1/Attention_in_features)**0.5) 54 | 55 | def forward(self,s,Gmat): 56 | q=torch.einsum('ijk,km->ijm',s,self.Qweight) 57 | k=torch.einsum('ijk,km->ijm',s,self.Kweight).permute(0, 2, 1) 58 | 59 | att=torch.square(torch.bmm(q,k))*Gmat 60 | att=att/(torch.sum(att,dim=2,keepdim=True)+0.001) 61 | 62 | return att -------------------------------------------------------------------------------- /code/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pandas as pd 4 | import numpy as np 5 | 6 | import constants 7 | 8 | def load_data(MSA_name): 9 | print('Loading data...') 10 | MSA_NAME_FULL = constants.MSA_NAME_FULL_DICT[MSA_name] 11 | epic_data_root = '../data' 12 | data=dict() 13 | 14 | f = open(os.path.join(epic_data_root, MSA_name, '%s_2020-03-01_to_2020-05-02_processed.pkl'%MSA_NAME_FULL), 'rb') 15 | poi_cbg_visits_list = pickle.load(f) 16 | f.close() 17 | data['poi_cbg_visits_list']=poi_cbg_visits_list 18 | 19 | d = pd.read_csv(os.path.join(epic_data_root,MSA_name, 'parameters_%s.csv' % MSA_name)) 20 | poi_areas = d['feet'].values 21 | poi_dwell_times = d['median'].values 22 | poi_dwell_time_correction_factors = (poi_dwell_times / (poi_dwell_times+60)) ** 2 23 | data['poi_areas']=poi_areas 24 | data['poi_times']=poi_dwell_times 25 | data['poi_dwell_time_correction_factors']=poi_dwell_time_correction_factors 26 | 27 | cbg_ids_msa = pd.read_csv(os.path.join(epic_data_root,MSA_name,'%s_cbg_ids.csv'%MSA_NAME_FULL)) 28 | cbg_ids_msa.rename(columns={"cbg_id":"census_block_group"}, inplace=True) 29 | 30 | filepath = os.path.join(epic_data_root,"safegraph_open_census_data/data/cbg_b01.csv") 31 | cbg_agesex = pd.read_csv(filepath) 32 | cbg_age_msa = pd.merge(cbg_ids_msa, cbg_agesex, on='census_block_group', how='left') 33 | del cbg_agesex 34 | 35 | for i in range(3,25+1): 36 | male_column = 'B01001e'+str(i) 37 | female_column = 'B01001e'+str(i+24) 38 | cbg_age_msa[constants.DETAILED_AGE_LIST[i-3]] = cbg_age_msa.apply(lambda x : x[male_column]+x[female_column],axis=1) 39 | 40 | cbg_age_msa.rename(columns={'B01001e1':'Sum'},inplace=True) 41 | columns_of_interest = ['census_block_group','Sum'] + constants.DETAILED_AGE_LIST 42 | cbg_age_msa = cbg_age_msa[columns_of_interest].copy() 43 | 44 | cbg_age_msa.fillna(0,inplace=True) 45 | cbg_age_msa['Sum'] = cbg_age_msa['Sum'].apply(lambda x : x if x!=0 else 1) 46 | 47 | cbg_sizes = cbg_age_msa['Sum'].values 48 | cbg_sizes = np.array(cbg_sizes,dtype='int32') 49 | data['cbg_sizes']=cbg_sizes 50 | data['cbg_ages']=cbg_age_msa.to_numpy()[:,2:]/cbg_sizes[:,np.newaxis] 51 | 52 | cbg_death_rates_original = np.loadtxt(os.path.join(epic_data_root, MSA_name, 'cbg_death_rates_original_'+MSA_name)) 53 | cbg_attack_rates_original = np.ones(cbg_death_rates_original.shape) 54 | 55 | attack_scale = 1 56 | cbg_attack_rates_scaled = cbg_attack_rates_original * attack_scale 57 | cbg_death_rates_scaled = cbg_death_rates_original * constants.death_scale_dict[MSA_name] 58 | data['cbg_attack_rates_scaled']=cbg_attack_rates_scaled 59 | data['cbg_death_rates_scaled']=cbg_death_rates_scaled 60 | 61 | return data -------------------------------------------------------------------------------- /code/constants.py: -------------------------------------------------------------------------------- 1 | NUM_AGE_GROUP_FOR_ATTACK_RATES = 9 2 | NUM_AGE_GROUP_FOR_DEATH_RATES = 17 3 | 4 | DETAILED_AGE_LIST =['Under 5 Years','5 To 9 Years','10 To 14 Years','15 To 17 Years','18 To 19 Years','20 Years','21 Years', 5 | '22 To 24 Years','25 To 29 Years','30 To 34 Years','35 To 39 Years','40 To 44 Years','45 To 49 Years', 6 | '50 To 54 Years','55 To 59 Years', '60 To 61 Years','62 To 64 Years','65 To 66 Years','67 To 69 Years', 7 | '70 To 74 Years','75 To 79 Years','80 To 84 Years','85 Years And Over'] 8 | 9 | AGE_GROUPS_FOR_ATTACK_RATES = { 10 | 0:['Under 5 Years','5 To 9 Years'], 11 | 1:['10 To 14 Years','15 To 17 Years','18 To 19 Years'], 12 | 2:['20 Years','21 Years', '22 To 24 Years','25 To 29 Years'], 13 | 3:['30 To 34 Years','35 To 39 Years'], 14 | 4:['40 To 44 Years','45 To 49 Years'], 15 | 5:['50 To 54 Years','55 To 59 Years'], 16 | 6:['60 To 61 Years','62 To 64 Years','65 To 66 Years','67 To 69 Years'], 17 | 7:['70 To 74 Years','75 To 79 Years'], 18 | 8:['80 To 84 Years','85 Years And Over'] 19 | } 20 | 21 | 22 | AGE_GROUPS_FOR_DEATH_RATES = { 23 | 0:['Under 5 Years'], 24 | 1:['5 To 9 Years'], 25 | 2:['10 To 14 Years'], 26 | 3:['15 To 17 Years','18 To 19 Years'], 27 | 4:['20 Years','21 Years', '22 To 24 Years'], 28 | 5:['25 To 29 Years'], 29 | 6:['30 To 34 Years'], 30 | 7:['35 To 39 Years'], 31 | 8:['40 To 44 Years'], 32 | 9:['45 To 49 Years'], 33 | 10:['50 To 54 Years'], 34 | 11:['55 To 59 Years'], 35 | 12:['60 To 61 Years','62 To 64 Years'], 36 | 13:['65 To 66 Years','67 To 69 Years'], 37 | 14:['70 To 74 Years'], 38 | 15:['75 To 79 Years'], 39 | 16:['80 To 84 Years','85 Years And Over'] 40 | } 41 | 42 | 43 | FIPS_CODES_FOR_50_STATES_PLUS_DC = 44 | "10": "Delaware", 45 | "11": "Washington, D.C.", 46 | "12": "Florida", 47 | "13": "Georgia", 48 | "15": "Hawaii", 49 | "16": "Idaho", 50 | "17": "Illinois", 51 | "18": "Indiana", 52 | "19": "Iowa", 53 | "20": "Kansas", 54 | "21": "Kentucky", 55 | "22": "Louisiana", 56 | "23": "Maine", 57 | "24": "Maryland", 58 | "25": "Massachusetts", 59 | "26": "Michigan", 60 | "27": "Minnesota", 61 | "28": "Mississippi", 62 | "29": "Missouri", 63 | "30": "Montana", 64 | "31": "Nebraska", 65 | "32": "Nevada", 66 | "33": "New Hampshire", 67 | "34": "New Jersey", 68 | "35": "New Mexico", 69 | "36": "New York", 70 | "37": "North Carolina", 71 | "38": "North Dakota", 72 | "39": "Ohio", 73 | "40": "Oklahoma", 74 | "41": "Oregon", 75 | "42": "Pennsylvania", 76 | "44": "Rhode Island", 77 | "45": "South Carolina", 78 | "46": "South Dakota", 79 | "47": "Tennessee", 80 | "48": "Texas", 81 | "49": "Utah", 82 | "50": "Vermont", 83 | "51": "Virginia", 84 | "53": "Washington", 85 | "54": "West Virginia", 86 | "55": "Wisconsin", 87 | "56": "Wyoming", 88 | "01": "Alabama", 89 | "02": "Alaska", 90 | "04": "Arizona", 91 | "05": "Arkansas", 92 | "06": "California", 93 | "08": "Colorado", 94 | "09": "Connecticut", 95 | } 96 | 97 | 98 | MSA_NAME_LIST = ['Atlanta','Chicago','Dallas','Houston', 'LosAngeles','Miami','Philadelphia','SanFrancisco','WashingtonDC'] 99 | MSA_NAME_FULL_DICT = { 100 | 'Atlanta':'Atlanta_Sandy_Springs_Roswell_GA', 101 | 'Chicago':'Chicago_Naperville_Elgin_IL_IN_WI', 102 | 'Dallas':'Dallas_Fort_Worth_Arlington_TX', 103 | 'Houston':'Houston_The_Woodlands_Sugar_Land_TX', 104 | 'LosAngeles':'Los_Angeles_Long_Beach_Anaheim_CA', 105 | 'Miami':'Miami_Fort_Lauderdale_West_Palm_Beach_FL', 106 | 'Philadelphia':'Philadelphia_Camden_Wilmington_PA_NJ_DE_MD', 107 | 'SanFrancisco':'San_Francisco_Oakland_Hayward_CA', 108 | 'WashingtonDC':'Washington_Arlington_Alexandria_DC_VA_MD_WV', 109 | 'Toy10':'Toy10_Washington_Arlington_Alexandria_DC_VA_MD_WV', 110 | 'Toy20':'Toy20_Washington_Arlington_Alexandria_DC_VA_MD_WV', 111 | 'Toy100':'Toy100_Washington_Arlington_Alexandria_DC_VA_MD_WV', 112 | 'Toy1000':'Toy1000_Washington_Arlington_Alexandria_DC_VA_MD_WV' 113 | } 114 | 115 | parameters_dict = {'Atlanta':[2e-4, 0.0037, 2388], 116 | 'Chicago': [1e-4,0.0063,2076], 117 | 'Dallas':[2e-4, 0.0063, 1452], 118 | 'Houston': [5e-4, 0.0037,1139], 119 | 'LosAngeles': [2e-4,0.0088,1452], 120 | 'Miami': [5e-4, 0.0012, 1764], 121 | 'Philadelphia': [0.001, 0.0037, 827], 122 | 'SanFrancisco': [5e-4, 0.0037, 1139], 123 | 'WashingtonDC': [5e-5, 0.0037, 2700], 124 | 'Toy10': [6e-3, 0.0037, 2700], 125 | 'Toy20': [2.5e-4, 0.0037, 2700], 126 | 'Toy100': [5e-5, 0.0037, 2700], 127 | 'Toy1000': [2e-5, 0.0037, 2700]} 128 | 129 | death_scale_dict = {'Atlanta':[1.20], 130 | 'Chicago':[1.30], 131 | 'Dallas':[1.03], 132 | 'Houston':[0.83], 133 | 'LosAngeles':[1.52], 134 | 'Miami':[0.78], 135 | 'Philadelphia':[2.08], 136 | 'SanFrancisco':[0.64], 137 | 'WashingtonDC':[1.40], 138 | 'Toy10':[1.40], 139 | 'Toy20':[1.40], 140 | 'Toy100':[1.40], 141 | 'Toy1000':[1.40] 142 | } -------------------------------------------------------------------------------- /code/grid_simulator/grid_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Model: 4 | def __init__(self): 5 | pass 6 | 7 | def init_exogenous_variables(self,num_grid,num_diamond,diamond_extent,num_move): 8 | self.num_grid=num_grid 9 | self.num_diamond=num_diamond 10 | self.diamond_extent=diamond_extent 11 | self.num_move=num_move 12 | self.field_length=3*self.num_grid 13 | 14 | x,y=np.meshgrid(range(-3*self.diamond_extent-1,3*(self.diamond_extent)+2),range(-3*self.diamond_extent-1,3*(self.diamond_extent)+2)) 15 | self.diamond_mat=np.exp(-2*((x/3)**2+(y/3)**2)/self.field_length) 16 | 17 | self.Gmat=np.zeros((self.num_grid*self.num_grid,self.num_grid*self.num_grid)) 18 | for x in range(self.num_grid): 19 | for y in range(self.num_grid): 20 | pos=self.num_grid*x+y 21 | self.Gmat[pos,self.num_grid*((x-1)%self.num_grid)+y%self.num_grid]=1 22 | self.Gmat[pos,self.num_grid*((x+1)%self.num_grid)+y%self.num_grid]=1 23 | self.Gmat[pos,self.num_grid*(x%self.num_grid)+(y-1)%self.num_grid]=1 24 | self.Gmat[pos,self.num_grid*(x%self.num_grid)+(y+1)%self.num_grid]=1 25 | 26 | def init_endogenous_variables(self): 27 | self.field=np.zeros((self.field_length,self.field_length)) 28 | self.miner=np.ones((self.num_grid,self.num_grid)) 29 | self.diamond_pos=list() 30 | 31 | for i in range(self.num_diamond): 32 | rand_x=np.random.randint(0,self.num_grid) 33 | rand_y=np.random.randint(0,self.num_grid) 34 | self.diamond_pos.append([rand_x,rand_y]) 35 | 36 | for x in range(-self.diamond_extent,self.diamond_extent+1): 37 | for y in range(-self.diamond_extent,self.diamond_extent+1): 38 | self.field[(3*(rand_x+x))%self.field_length,(3*(rand_y+y))%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x),3*(self.diamond_extent+y)] 39 | self.field[(3*(rand_x+x))%self.field_length,(3*(rand_y+y)+1)%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x),3*(self.diamond_extent+y)+1] 40 | self.field[(3*(rand_x+x))%self.field_length,(3*(rand_y+y)+2)%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x),3*(self.diamond_extent+y)+2] 41 | 42 | self.field[(3*(rand_x+x)+1)%self.field_length,(3*(rand_y+y))%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x)+1,3*(self.diamond_extent+y)] 43 | self.field[(3*(rand_x+x)+1)%self.field_length,(3*(rand_y+y)+1)%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x)+1,3*(self.diamond_extent+y)+1] 44 | self.field[(3*(rand_x+x)+1)%self.field_length,(3*(rand_y+y)+2)%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x)+1,3*(self.diamond_extent+y)+2] 45 | 46 | self.field[(3*(rand_x+x)+2)%self.field_length,(3*(rand_y+y))%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x)+2,3*(self.diamond_extent+y)] 47 | self.field[(3*(rand_x+x)+2)%self.field_length,(3*(rand_y+y)+1)%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x)+2,3*(self.diamond_extent+y)+1] 48 | self.field[(3*(rand_x+x)+2)%self.field_length,(3*(rand_y+y)+2)%self.field_length]+=self.diamond_mat[3*(self.diamond_extent+x)+2,3*(self.diamond_extent+y)+2] 49 | 50 | self.init_return=self.get_return() 51 | 52 | def move_miner(self,action_vector): 53 | action_vector=action_vector.reshape(self.num_grid,self.num_grid,-1) 54 | move_vector=self.num_move*action_vector 55 | self.miner=self.miner-self.num_move 56 | 57 | self.miner[:-1,:]+=move_vector[1:,:,0] 58 | self.miner[-1,:]+=move_vector[0,:,0] 59 | 60 | self.miner[1:,:]+=move_vector[:-1,:,1] 61 | self.miner[0,:]+=move_vector[-1,:,1] 62 | 63 | self.miner[:,:-1]+=move_vector[:,1:,2] 64 | self.miner[:,-1]+=move_vector[:,0,2] 65 | 66 | self.miner[:,1:]+=move_vector[:,:-1,3] 67 | self.miner[:,0]+=move_vector[:,-1,3] 68 | 69 | self.miner=np.clip(self.miner,0,None) 70 | 71 | def output_record(self): 72 | left_up=self.field[::3,::3].reshape(self.num_grid*self.num_grid,-1) 73 | right_up=self.field[::3,2::3].reshape(self.num_grid*self.num_grid,-1) 74 | center=self.field[1::3,1::3].reshape(self.num_grid*self.num_grid,-1) 75 | left_down=self.field[2::3,::3].reshape(self.num_grid*self.num_grid,-1) 76 | right_down=self.field[2::3,2::3].reshape(self.num_grid*self.num_grid,-1) 77 | 78 | return np.hstack((left_up,right_up,center,left_down,right_down)) 79 | 80 | def get_reward(self): 81 | reward=np.empty((self.num_grid,self.num_grid)) 82 | for x in range(self.num_grid): 83 | for y in range(self.num_grid): 84 | reward[x,y]=self.miner[x,y]*np.sum(self.field[3*x:3*x+3,3*y:3*y+3]) 85 | 86 | reward_sum=np.empty((self.num_grid,self.num_grid)) 87 | for x in range(self.num_grid): 88 | for y in range(self.num_grid): 89 | reward_sum[x,y]=reward[(x-1)%self.num_grid,y]+reward[(x+1)%self.num_grid,y]+reward[x,(y-1)%self.num_grid]+reward[x,(y+1)%self.num_grid] 90 | 91 | return reward_sum.reshape(-1,1) 92 | 93 | def get_return(self): 94 | reward=np.empty((self.num_grid,self.num_grid)) 95 | for x in range(self.num_grid): 96 | for y in range(self.num_grid): 97 | reward[x,y]=self.miner[x,y]*np.sum(self.field[3*x:3*x+3,3*y:3*y+3]) 98 | 99 | return np.sum(reward) 100 | 101 | def show_grid(self,path): 102 | import matplotlib.pyplot as plt 103 | plt.figure(figsize=(9,6)) 104 | plt.contour(self.field.T) 105 | cb=plt.colorbar() 106 | cb.set_label(label='Number of Diamonds',size=30) 107 | cb.ax.tick_params(labelsize=20) 108 | 109 | plt.imshow(self.miner.repeat(3,0).repeat(3,1),cmap='Reds',vmin=0.3,vmax=1.7) 110 | cb=plt.colorbar() 111 | cb.set_label(label='Number of Miners',size=30) 112 | cb.ax.tick_params(labelsize=20) 113 | 114 | for diamond in self.diamond_pos: 115 | plt.scatter(3*diamond[0]+1,3*diamond[1]+1,marker='d',facecolor='Blue') 116 | plt.hlines(np.linspace(2.5,self.field_length-3.5,self.num_grid-1),xmin=-0.5,xmax=self.field_length-0.5,color='red',linewidth=0.3) 117 | plt.vlines(np.linspace(2.5,self.field_length-3.5,self.num_grid-1),ymin=-0.5,ymax=self.field_length-0.5,color='red',linewidth=0.3) 118 | plt.xticks(range(1,self.field_length+1,3),range(0,self.num_grid),fontsize=20) 119 | plt.yticks(range(1,self.field_length+1,3),range(0,self.num_grid),fontsize=20) 120 | 121 | plt.title('Return=%.2f'%(self.get_return()-self.init_return),fontsize=30) 122 | 123 | plt.tight_layout() 124 | plt.savefig(path) 125 | plt.close() 126 | 127 | print(self.get_return()-self.init_return,np.sum(self.miner)) -------------------------------------------------------------------------------- /code/simulator/disease_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | class Model: 5 | def __init__(self, 6 | num_seeds=1, 7 | debug=False, 8 | clip_poisson_approximation=True, 9 | random_simulation=False): 10 | 11 | self.num_seeds = num_seeds 12 | self.debug = debug 13 | self.clip_poisson_approximation = clip_poisson_approximation 14 | self.random_simulation=random_simulation 15 | 16 | def init_exogenous_variables(self, 17 | poi_areas, 18 | cbg_sizes, 19 | p_sick_at_t0, 20 | poi_psi, 21 | home_beta, 22 | cbg_attack_rates_original, 23 | cbg_death_rates_original, 24 | poi_cbg_visits_list=None, 25 | poi_dwell_time_correction_factors=None, 26 | just_compute_r0=False, 27 | latency_period=96, # 4 days 28 | infectious_period=84, # 3.5 days 29 | confirmation_rate=.1, 30 | confirmation_lag=168, # 7 days 31 | death_lag=432, # 18 days 32 | no_print=False, 33 | ): 34 | self.M = len(poi_areas) 35 | self.N = len(cbg_sizes) 36 | self.MAX_T=len(poi_cbg_visits_list) 37 | 38 | self.PSI = poi_psi 39 | self.POI_AREAS = poi_areas 40 | self.DWELL_TIME_CORRECTION_FACTORS = poi_dwell_time_correction_factors 41 | self.POI_FACTORS = self.PSI / poi_areas 42 | if poi_dwell_time_correction_factors is not None: 43 | self.POI_FACTORS = poi_dwell_time_correction_factors * self.POI_FACTORS 44 | self.included_dwell_time_correction_factors = True 45 | else: 46 | self.included_dwell_time_correction_factors = False 47 | self.POI_CBG_VISITS_LIST = poi_cbg_visits_list 48 | self.clipping_monitor = { 49 | 'num_base_infection_rates_clipped':[], 50 | 'num_active_pois':[], 51 | 'num_poi_infection_rates_clipped':[], 52 | 'num_cbgs_active_at_pois':[], 53 | 'num_cbgs_with_clipped_poi_cases':[]} 54 | 55 | self.CBG_SIZES = cbg_sizes 56 | self.HOME_BETA = home_beta 57 | self.CBG_ATTACK_RATES_ORIGINAL = cbg_attack_rates_original 58 | self.CBG_DEATH_RATES_ORIGINAL = cbg_death_rates_original 59 | self.LATENCY_PERIOD = latency_period 60 | self.INFECTIOUS_PERIOD = infectious_period 61 | self.P_SICK_AT_T0 = p_sick_at_t0 62 | 63 | self.VACCINATION_VECTOR = np.zeros((self.num_seeds,self.N),dtype=np.float32) 64 | self.VACCINE_ACCEPTANCE = np.ones((self.num_seeds,self.N),dtype=np.float32) 65 | self.PROTECTION_RATE = 1.0 66 | 67 | self.just_compute_r0 = just_compute_r0 68 | self.confirmation_rate = confirmation_rate 69 | self.confirmation_lag = confirmation_lag 70 | self.death_lag = death_lag 71 | 72 | self.CBG_ATTACK_RATES_NEW = self.CBG_ATTACK_RATES_ORIGINAL * (1-self.PROTECTION_RATE*self.VACCINATION_VECTOR/self.CBG_SIZES) 73 | self.CBG_DEATH_RATES_NEW = self.CBG_DEATH_RATES_ORIGINAL 74 | self.CBG_ATTACK_RATES_NEW = np.clip(self.CBG_ATTACK_RATES_NEW, 0, None) 75 | self.CBG_DEATH_RATES_NEW = np.clip(self.CBG_DEATH_RATES_NEW, 0, None) 76 | self.CBG_DEATH_RATES_NEW = np.clip(self.CBG_DEATH_RATES_NEW, None, 1) 77 | 78 | assert((self.CBG_DEATH_RATES_NEW>=0).all()) 79 | assert((self.CBG_DEATH_RATES_NEW<=1).all()) 80 | 81 | def init_endogenous_variables(self): 82 | 83 | self.P0 = np.random.binomial(self.CBG_SIZES,self.P_SICK_AT_T0,size=(self.num_seeds, self.N)) 84 | 85 | self.cbg_latent = self.P0 86 | self.cbg_infected = np.zeros((self.num_seeds, self.N),dtype=np.float32) 87 | self.cbg_removed = np.zeros((self.num_seeds, self.N),dtype=np.float32) 88 | self.cases_to_confirm = np.zeros((self.num_seeds, self.N),dtype=np.float32) 89 | self.new_confirmed_cases = np.zeros((self.num_seeds, self.N),dtype=np.float32) 90 | self.deaths_to_happen = np.zeros((self.num_seeds, self.N),dtype=np.float32) 91 | self.new_deaths = np.zeros((self.num_seeds, self.N),dtype=np.float32) 92 | self.C2=np.zeros((self.num_seeds, self.N),dtype=np.float32) 93 | self.D2=np.zeros((self.num_seeds, self.N),dtype=np.float32) 94 | 95 | self.VACCINATION_VECTOR = np.zeros((self.num_seeds,self.N),dtype=np.float32) 96 | self.CBG_ATTACK_RATES_NEW = self.CBG_ATTACK_RATES_ORIGINAL * (1-self.PROTECTION_RATE*self.VACCINATION_VECTOR/self.CBG_SIZES) 97 | self.CBG_DEATH_RATES_NEW = self.CBG_DEATH_RATES_ORIGINAL 98 | self.CBG_ATTACK_RATES_NEW = np.clip(self.CBG_ATTACK_RATES_NEW, 0, None) 99 | self.CBG_DEATH_RATES_NEW = np.clip(self.CBG_DEATH_RATES_NEW, 0, None) 100 | self.CBG_DEATH_RATES_NEW = np.clip(self.CBG_DEATH_RATES_NEW, None, 1) 101 | 102 | self.L_1=[] 103 | self.I_1=[] 104 | self.R_1=[] 105 | self.C_1=[] 106 | self.D_1=[] 107 | self.T1=[] 108 | self.t = 0 109 | self.C=[0] 110 | self.D=[0] 111 | self.history_C2 = [] 112 | self.history_D2 = [] 113 | self.epidemic_over = False 114 | 115 | def reset_random_seed(self): 116 | np.random.seed(np.random.randint(10000,size=self.num_seeds)) 117 | 118 | def add_vaccine(self,vaccine_vector): 119 | self.VACCINATION_VECTOR+=vaccine_vector 120 | self.VACCINATION_VECTOR = np.clip(self.VACCINATION_VECTOR, None, (self.CBG_SIZES*self.VACCINE_ACCEPTANCE)) 121 | self.CBG_ATTACK_RATES_NEW = self.CBG_ATTACK_RATES_ORIGINAL * (1-self.PROTECTION_RATE*self.VACCINATION_VECTOR/self.CBG_SIZES) 122 | self.CBG_ATTACK_RATES_NEW = np.clip(self.CBG_ATTACK_RATES_NEW, 0, None) 123 | 124 | def get_new_infectious(self): 125 | if self.random_simulation: 126 | new_infectious = np.random.binomial(np.round(self.cbg_latent).astype(int), 1 / self.LATENCY_PERIOD) 127 | else: 128 | new_infectious = self.cbg_latent / self.LATENCY_PERIOD 129 | 130 | return new_infectious 131 | 132 | def get_new_removed(self): 133 | if self.random_simulation: 134 | new_removed = np.random.binomial(np.round(self.cbg_infected).astype(int), 1 / self.INFECTIOUS_PERIOD) 135 | else: 136 | new_removed=self.cbg_infected / self.INFECTIOUS_PERIOD 137 | 138 | return new_removed 139 | 140 | def format_floats(self, arr): 141 | return [int(round(x)) for x in arr] 142 | 143 | def simulate_disease_spread(self,length=24,verbosity=1,no_print=False): 144 | assert(self.t 0) and (self.t % verbosity == 0): 151 | L = np.sum(self.cbg_latent, axis=1) 152 | I = np.sum(self.cbg_infected, axis=1) 153 | R = np.sum(self.cbg_removed, axis=1) 154 | 155 | self.T1.append(self.t) 156 | self.L_1.append(L) 157 | self.I_1.append(I) 158 | self.R_1.append(R) 159 | self.C_1.append(self.C) 160 | self.D_1.append(self.D) 161 | 162 | self.history_C2.append(self.C2) 163 | self.history_D2.append(self.D2) 164 | 165 | if(no_print==False): 166 | print('t:',self.t,'L:',L,'I:',I,'R',R,'C',self.C,'D',self.D) 167 | 168 | self.update_states(self.t) 169 | C1 = np.sum(self.new_confirmed_cases,axis=1) 170 | self.C2=self.C2+self.new_confirmed_cases 171 | self.C[0]=self.C[0]+C1 172 | D1 = np.sum(self.new_deaths,axis=1) 173 | self.D2=self.D2+self.new_deaths 174 | self.D[0]=self.D[0]+D1 175 | if self.debug and verbosity > 0 and self.t % verbosity == 0: 176 | print('Num active POIs: %d. Num with infection rates clipped: %d' % (self.num_active_pois, self.num_poi_infection_rates_clipped)) 177 | print('Num CBGs active at POIs: %d. Num with clipped num cases from POIs: %d' % (self.num_cbgs_active_at_pois, self.num_cbgs_with_clipped_poi_cases)) 178 | if self.debug: 179 | print("Time for iteration %i: %2.3f seconds" % (self.t, time.time() - iter_t0)) 180 | 181 | self.t += 1 182 | 183 | def empty_record(self): 184 | self.L_1=[] 185 | self.I_1=[] 186 | self.R_1=[] 187 | self.C_1=[] 188 | self.D_1=[] 189 | self.T1=[] 190 | self.history_C2 = [] 191 | self.history_D2 = [] 192 | 193 | def output_record(self,full=False): 194 | cbg_all_affected = self.cbg_latent + self.cbg_infected + self.cbg_removed 195 | total_affected = np.sum(cbg_all_affected, axis=1) 196 | 197 | if full: 198 | print('Output records') 199 | o_T1=np.array(self.T1,dtype=np.float32) 200 | o_L_1=np.array(self.L_1,dtype=np.float32) 201 | o_I_1=np.array(self.I_1,dtype=np.float32) 202 | o_R_1=np.array(self.R_1,dtype=np.float32) 203 | o_C2=np.array(self.C2,dtype=np.float32) 204 | o_D2=np.array(self.D2,dtype=np.float32) 205 | o_history_C2=np.array(self.history_C2,dtype=np.float32) 206 | o_history_D2=np.array(self.history_D2,dtype=np.float32) 207 | o_total_affected=np.array(total_affected,dtype=np.float32) 208 | o_cbg_all_affected=np.array(cbg_all_affected,dtype=np.float32) 209 | 210 | o_history_C2=np.transpose(o_history_C2,(1,0,2)) 211 | o_history_D2=np.transpose(o_history_D2,(1,0,2)) 212 | 213 | return o_T1,o_L_1,o_I_1,o_R_1,o_C2,o_D2, o_total_affected, o_history_C2, o_history_D2, o_cbg_all_affected 214 | 215 | else: 216 | 217 | o_history_C2=np.array(self.history_C2,dtype=np.float32)[-1,:,:] 218 | o_history_D2=np.array(self.history_D2,dtype=np.float32)[-1,:,:] 219 | 220 | return o_history_C2,o_history_D2 221 | 222 | def update_states(self, t): 223 | self.get_new_cases(t) 224 | new_infectious = self.get_new_infectious() 225 | new_removed = self.get_new_removed() 226 | if not self.just_compute_r0: 227 | self.cbg_latent = self.cbg_latent + self.cbg_new_cases - new_infectious 228 | self.cbg_infected = self.cbg_infected + new_infectious - new_removed 229 | self.cbg_removed = self.cbg_removed + new_removed 230 | 231 | if self.random_simulation: 232 | self.new_confirmed_cases = np.random.binomial(np.round(self.cases_to_confirm).astype(int), 1/self.confirmation_lag) 233 | new_cases_to_confirm = np.random.binomial(np.round(new_infectious).astype(int), self.confirmation_rate) 234 | else: 235 | self.new_confirmed_cases=self.cases_to_confirm/self.confirmation_lag 236 | new_cases_to_confirm=new_infectious*self.confirmation_rate 237 | 238 | self.cases_to_confirm = self.cases_to_confirm + new_cases_to_confirm - self.new_confirmed_cases 239 | 240 | if self.random_simulation: 241 | self.new_deaths = np.random.binomial(np.round(self.deaths_to_happen).astype(int), 1/self.death_lag) 242 | new_deaths_to_happen = np.random.binomial(np.round(new_infectious).astype(int), self.CBG_DEATH_RATES_NEW) 243 | else: 244 | self.new_deaths=self.deaths_to_happen/self.death_lag 245 | new_deaths_to_happen=new_infectious*self.CBG_DEATH_RATES_NEW 246 | 247 | self.deaths_to_happen = self.deaths_to_happen + new_deaths_to_happen - self.new_deaths 248 | else: 249 | self.cbg_latent = self.cbg_latent - new_infectious 250 | self.cbg_infected = self.cbg_infected + new_infectious - new_removed 251 | self.cbg_removed = self.cbg_removed + new_removed + self.cbg_new_cases 252 | 253 | def get_new_cases(self, t): 254 | cbg_densities = self.cbg_infected / self.CBG_SIZES 255 | overall_densities = (np.sum(self.cbg_infected, axis=1) / np.sum(self.CBG_SIZES)).reshape(-1, 1) 256 | num_sus = np.clip(self.CBG_SIZES - self.cbg_latent - self.cbg_infected - self.cbg_removed, 0, None) 257 | sus_frac = num_sus / self.CBG_SIZES 258 | 259 | if self.PSI > 0: 260 | cbg_base_infection_rates = self.HOME_BETA * self.CBG_ATTACK_RATES_NEW * cbg_densities 261 | cbg_base_infection_rates=np.nan_to_num(cbg_base_infection_rates) 262 | else: 263 | cbg_base_infection_rates = np.tile(overall_densities, self.N) * self.HOME_BETA 264 | self.num_base_infection_rates_clipped = np.sum(cbg_base_infection_rates > 1) 265 | cbg_base_infection_rates = np.clip(cbg_base_infection_rates, None, 1.0) 266 | 267 | if self.POI_CBG_VISITS_LIST is not None: 268 | poi_cbg_visits = self.POI_CBG_VISITS_LIST[t] 269 | poi_visits = poi_cbg_visits @ np.ones(poi_cbg_visits.shape[1]) 270 | 271 | if not self.just_compute_r0: 272 | self.num_active_pois = np.sum(poi_visits > 0) 273 | col_sums = np.squeeze(np.array(poi_cbg_visits.sum(axis=0))) 274 | self.cbg_num_out = col_sums 275 | poi_infection_rates = self.POI_FACTORS * (poi_cbg_visits @ cbg_densities.T).T 276 | self.num_poi_infection_rates_clipped = np.sum(poi_infection_rates > 1) 277 | if self.clip_poisson_approximation: 278 | poi_infection_rates = np.clip(poi_infection_rates, None, 1.0) 279 | 280 | cbg_mean_new_cases_from_poi = self.CBG_ATTACK_RATES_NEW * sus_frac * (poi_infection_rates @ poi_cbg_visits) 281 | cbg_mean_new_cases_from_poi=np.nan_to_num(cbg_mean_new_cases_from_poi) 282 | 283 | if self.random_simulation: 284 | num_cases_from_poi = np.random.poisson(cbg_mean_new_cases_from_poi) 285 | else: 286 | num_cases_from_poi=cbg_mean_new_cases_from_poi 287 | self.num_cbgs_active_at_pois = np.sum(cbg_mean_new_cases_from_poi > 0) 288 | 289 | self.num_cbgs_with_clipped_poi_cases = np.sum(num_cases_from_poi > num_sus) 290 | self.cbg_new_cases_from_poi = np.clip(num_cases_from_poi, None, num_sus) 291 | num_sus_remaining = num_sus - self.cbg_new_cases_from_poi 292 | 293 | if self.random_simulation: 294 | self.cbg_new_cases_from_base = np.random.binomial(np.round(num_sus_remaining).astype(int),cbg_base_infection_rates) 295 | else: 296 | self.cbg_new_cases_from_base = num_sus_remaining*cbg_base_infection_rates 297 | 298 | self.cbg_new_cases = self.cbg_new_cases_from_poi + self.cbg_new_cases_from_base 299 | 300 | self.clipping_monitor['num_base_infection_rates_clipped'].append(self.num_base_infection_rates_clipped) 301 | self.clipping_monitor['num_active_pois'].append(self.num_active_pois) 302 | self.clipping_monitor['num_poi_infection_rates_clipped'].append(self.num_poi_infection_rates_clipped) 303 | self.clipping_monitor['num_cbgs_active_at_pois'].append(self.num_cbgs_active_at_pois) 304 | self.clipping_monitor['num_cbgs_with_clipped_poi_cases'].append(self.num_cbgs_with_clipped_poi_cases) 305 | assert (self.cbg_new_cases <= num_sus).all() -------------------------------------------------------------------------------- /code/grid_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.chdir(os.path.split(os.path.realpath(__file__))[0]) 3 | 4 | # external packages 5 | import numpy as np 6 | import time 7 | import torch 8 | import copy 9 | import json 10 | from tqdm import tqdm 11 | 12 | # self writing files 13 | import grid_networks as networks 14 | from grid_simulator import grid_model 15 | 16 | class MARL(object): 17 | def __init__(self, 18 | num_grid=10, 19 | num_diamond=2, 20 | diamond_extent=10, 21 | num_move=0.1, 22 | max_steps=10, 23 | max_episode=100000, 24 | update_batch=100, 25 | batch_size=256, 26 | buffer_capacity=3000000, 27 | update_interval=100, 28 | save_interval=1000, 29 | lr=0.0001, 30 | lr_decay=False, 31 | grad_clip=False, 32 | max_grad_norm=10, 33 | soft_replace_rate=0.01, 34 | gamma=0, 35 | explore_noise=0.1, 36 | explore_noise_decay=True, 37 | explore_noise_decay_rate=0.2): 38 | super().__init__() 39 | 40 | print('Initializing...') 41 | 42 | # generate config 43 | config_data=locals() 44 | del config_data['self'] 45 | del config_data['__class__'] 46 | time_data=time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime()) 47 | config_data['time']=time_data 48 | 49 | # environment 50 | self.device='cuda' if torch.cuda.is_available() else 'cpu' 51 | torch.manual_seed(9) 52 | if self.device=='cuda': 53 | torch.cuda.manual_seed(9) 54 | np.random.seed(9) 55 | 56 | self.num_grid=num_grid 57 | self.num_diamond=num_diamond 58 | self.diamond_extent=diamond_extent 59 | self.num_move=num_move 60 | 61 | # simulator 62 | self.simulator=grid_model.Model() 63 | self.simulator.init_exogenous_variables(num_grid=self.num_grid, 64 | num_diamond=self.num_diamond, 65 | diamond_extent=self.diamond_extent, 66 | num_move=self.num_move 67 | ) 68 | 69 | # adjency matrix 70 | self.Gmat=self.simulator.Gmat 71 | self.Gmat=torch.FloatTensor(self.Gmat).to(self.device) 72 | 73 | # learning parameters 74 | self.max_steps=max_steps 75 | self.max_episode=max_episode 76 | self.update_batch=update_batch 77 | self.batch_size=batch_size 78 | self.update_interval=update_interval 79 | self.save_interval=save_interval 80 | self.lr=lr 81 | self.lr_decay=lr_decay 82 | self.grad_clip=grad_clip 83 | self.max_grad_norm=max_grad_norm 84 | self.soft_replace_rate=soft_replace_rate 85 | self.gamma=gamma 86 | self.explore_noise=explore_noise 87 | self.explore_noise_decay=explore_noise_decay 88 | self.explore_noise_decay_rate=explore_noise_decay_rate 89 | 90 | # networks and optimizers 91 | self.actor=networks.Actor().to(self.device) 92 | # self.actor.load_state_dict(torch.load(os.path.join('..','model','30grid_2022-09-14_22-35-29','actor_92400.pth'))) 93 | self.actor_target=copy.deepcopy(self.actor).eval() 94 | self.actor_optimizer=torch.optim.Adam(self.actor.parameters(),lr=self.lr) 95 | 96 | self.critic=networks.Critic().to(self.device) 97 | self.critic_target=copy.deepcopy(self.critic).eval() 98 | self.critic_optimizer=torch.optim.Adam(self.critic.parameters(),lr=self.lr) 99 | 100 | self.actor_attention=networks.Attention().to(self.device) 101 | # self.actor_attention.load_state_dict(torch.load(os.path.join('..','model','30grid_2022-09-14_22-35-29','actor_attention_92400.pth'))) 102 | self.actor_attention_target=copy.deepcopy(self.actor_attention).eval() 103 | self.actor_attention_optimizer=torch.optim.Adam(self.actor_attention.parameters(),lr=self.lr) 104 | 105 | self.critic_attention=networks.Attention().to(self.device) 106 | self.critic_attention_target=copy.deepcopy(self.critic_attention).eval() 107 | self.critic_attention_optimizer=torch.optim.Adam(self.critic_attention.parameters(),lr=self.lr) 108 | 109 | if self.lr_decay: 110 | self.actor_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.actor_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 111 | self.critic_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.critic_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 112 | self.actor_attention_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.actor_attention_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 113 | self.critic_attention_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.critic_attention_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 114 | 115 | # buffer 116 | self.buffer_capacity=buffer_capacity 117 | self.buffer_pointer=0 118 | self.buffer_size=0 119 | self.buffer_s=np.empty((self.buffer_capacity,self.num_grid*self.num_grid,5),dtype=np.float32) # S,C,D,Sdiff,Cdiff,Ddiff 120 | self.buffer_a=np.empty((self.buffer_capacity,self.num_grid*self.num_grid,4),dtype=np.float32) 121 | self.buffer_s1=np.empty((self.buffer_capacity,self.num_grid*self.num_grid,5),dtype=np.float32) # S,C,D,Sdiff,Cdiff,Ddiff 122 | self.buffer_r=np.empty((self.buffer_capacity,self.num_grid*self.num_grid,1),dtype=np.float32) 123 | self.buffer_end=np.ones((self.buffer_capacity,self.num_grid*self.num_grid,1),dtype=np.float32) 124 | 125 | # training trackors 126 | self.episode_return_trackor=list() 127 | self.critic_loss_trackor=list() 128 | self.actor_loss_trackor=list() 129 | 130 | # making output directory 131 | self.output_dir=os.path.join('..','model',f'{self.num_grid}grid_{time_data}') 132 | os.mkdir(self.output_dir) 133 | with open(os.path.join(self.output_dir,'config.json'),'w') as f: 134 | json.dump(config_data,f) 135 | 136 | print(f'Training platform with {self.num_grid*self.num_grid} grids initialized') 137 | 138 | def get_action(self,action): 139 | action_vector=torch.softmax(action,dim=-1) 140 | 141 | return action_vector 142 | 143 | def get_entropy(self,action): 144 | weight=torch.softmax(action,dim=1) 145 | action_entropy=-torch.sum(weight*torch.log2(weight)) 146 | 147 | return action_entropy 148 | 149 | def update(self): 150 | actor_loss_sum=0 151 | critic_loss_sum=0 152 | 153 | for batch_count in range(self.update_batch): 154 | batch_index=np.random.randint(self.buffer_size,size=self.batch_size) 155 | s_batch=torch.FloatTensor(self.buffer_s[batch_index]).to(self.device) 156 | a_batch=torch.FloatTensor(self.buffer_a[batch_index]).to(self.device) 157 | s1_batch=torch.FloatTensor(self.buffer_s1[batch_index]).to(self.device) 158 | r_batch=torch.FloatTensor(self.buffer_r[batch_index]).to(self.device) 159 | end_batch=torch.FloatTensor(self.buffer_end[batch_index]).to(self.device) 160 | 161 | with torch.no_grad(): 162 | update_Actor_attention1=self.actor_attention_target(s1_batch,self.Gmat) 163 | update_Actor_state1_bar=torch.bmm(update_Actor_attention1,s1_batch) 164 | update_Actor_state1_all=torch.concat([s1_batch,update_Actor_state1_bar],dim=-1) 165 | update_action1=self.actor_target(update_Actor_state1_all) 166 | 167 | update_Critic_attention1=self.critic_attention_target(s1_batch,self.Gmat) 168 | update_Critic_state1_bar=torch.bmm(update_Critic_attention1,s1_batch) 169 | update_Critic_state1_all=torch.concat([s1_batch,update_Critic_state1_bar],dim=-1) 170 | update_action1_bar=torch.bmm(update_Critic_attention1,update_action1) 171 | update_action1_all=torch.concat([update_action1,update_action1_bar],dim=-1) 172 | 173 | y=r_batch+self.gamma*self.critic_target(update_Critic_state1_all,update_action1_all)*end_batch 174 | 175 | update_Critic_attention=self.critic_attention(s_batch,self.Gmat) 176 | update_Critic_state_bar=torch.bmm(update_Critic_attention,s_batch) 177 | update_Critic_state_all=torch.concat([s_batch,update_Critic_state_bar],dim=-1) 178 | update_action_bar=torch.bmm(update_Critic_attention,a_batch) 179 | update_action_all=torch.concat([a_batch,update_action_bar],dim=-1) 180 | 181 | critic_loss=torch.sum(torch.square(y-self.critic(update_Critic_state_all,update_action_all)))/self.batch_size 182 | critic_loss_sum+=critic_loss.cpu().item() 183 | 184 | self.critic_optimizer.zero_grad() 185 | self.critic_attention_optimizer.zero_grad() 186 | critic_loss.backward() 187 | if self.grad_clip: 188 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) 189 | torch.nn.utils.clip_grad_norm_(self.critic_attention.parameters(), self.max_grad_norm) 190 | self.critic_optimizer.step() 191 | self.critic_attention_optimizer.step() 192 | 193 | update_Actor_attention=self.actor_attention(s_batch,self.Gmat) 194 | update_Actor_state_bar=torch.bmm(update_Actor_attention,s_batch) 195 | update_Actor_state_all=torch.concat([s_batch,update_Actor_state_bar],dim=-1) 196 | update_action=self.actor(update_Actor_state_all) 197 | 198 | with torch.no_grad(): 199 | update_Critic_attention_new=self.critic_attention(s_batch,self.Gmat) 200 | update_Critic_state_bar_new=torch.bmm(update_Critic_attention_new,s_batch) 201 | update_Critic_state_all_new=torch.concat([s_batch,update_Critic_state_bar_new],dim=-1) 202 | 203 | update_action_bar_new=torch.bmm(update_Critic_attention_new,update_action) 204 | update_action_all_new=torch.concat([update_action,update_action_bar_new],dim=-1) 205 | 206 | actor_loss=-torch.sum(self.critic(update_Critic_state_all_new,update_action_all_new))/self.batch_size 207 | actor_loss_sum+=actor_loss.cpu().item() 208 | 209 | self.actor_optimizer.zero_grad() 210 | self.actor_attention_optimizer.zero_grad() 211 | actor_loss.backward() 212 | if self.grad_clip: 213 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) 214 | torch.nn.utils.clip_grad_norm_(self.actor_attention.parameters(), self.max_grad_norm) 215 | self.actor_optimizer.step() 216 | self.actor_attention_optimizer.step() 217 | 218 | for x in self.actor.state_dict().keys(): 219 | eval('self.actor_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 220 | eval('self.actor_target.'+x+'.data.add_(self.soft_replace_rate*self.actor.'+x+'.data)') 221 | for x in self.actor_attention.state_dict().keys(): 222 | eval('self.actor_attention_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 223 | eval('self.actor_attention_target.'+x+'.data.add_(self.soft_replace_rate*self.actor_attention.'+x+'.data)') 224 | for x in self.critic.state_dict().keys(): 225 | eval('self.critic_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 226 | eval('self.critic_target.'+x+'.data.add_(self.soft_replace_rate*self.critic.'+x+'.data)') 227 | for x in self.critic_attention.state_dict().keys(): 228 | eval('self.critic_attention_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 229 | eval('self.critic_attention_target.'+x+'.data.add_(self.soft_replace_rate*self.critic_attention.'+x+'.data)') 230 | 231 | if self.lr_decay: 232 | self.actor_optimizer_scheduler.step() 233 | self.actor_attention_optimizer_scheduler.step() 234 | self.critic_optimizer_scheduler.step() 235 | self.critic_attention_optimizer_scheduler.step() 236 | 237 | critic_loss_mean=critic_loss_sum/self.update_batch 238 | self.critic_loss_trackor.append(critic_loss_mean) 239 | actor_loss_mean=actor_loss_sum/self.update_batch 240 | self.actor_loss_trackor.append(actor_loss_mean) 241 | 242 | tqdm.write(f'Update: Critic Loss {critic_loss_mean} | Actor Loss {actor_loss_mean}') 243 | 244 | def train(self): 245 | for episode in tqdm(range(self.max_episode)): 246 | self.simulator.init_endogenous_variables() 247 | 248 | if self.explore_noise_decay: 249 | self.explore_noise=self.explore_noise/((episode+1)**self.explore_noise_decay_rate) 250 | 251 | current_state=self.simulator.output_record() 252 | for step in range(self.max_steps): 253 | self.buffer_s[self.buffer_pointer]=current_state 254 | 255 | with torch.no_grad(): 256 | state=torch.FloatTensor(current_state).to(self.device).unsqueeze(0) 257 | 258 | Actor_attention=self.actor_attention(state,self.Gmat) 259 | Actor_state_bar=torch.bmm(Actor_attention,state) 260 | Actor_state_all=torch.concat([state,Actor_state_bar],dim=-1) 261 | action=self.actor(Actor_state_all) 262 | 263 | action=action+torch.randn_like(action)*torch.mean(torch.abs(action))*self.explore_noise 264 | # action=action+torch.randn_like(action)*self.explore_noise 265 | action_vector=self.get_action(action).squeeze(0).cpu().numpy() 266 | 267 | reward_old=self.simulator.get_reward() 268 | self.simulator.move_miner(action_vector) 269 | reward_new=self.simulator.get_reward() 270 | reward=reward_new-reward_old 271 | 272 | current_state=self.simulator.output_record() 273 | 274 | self.buffer_s1[self.buffer_pointer]=current_state 275 | self.buffer_a[self.buffer_pointer]=action.cpu().numpy() 276 | self.buffer_r[self.buffer_pointer]=reward 277 | if step==self.max_steps-1: 278 | self.buffer_end[self.buffer_pointer]=0 279 | 280 | self.buffer_size=max(self.buffer_size,self.buffer_pointer+1) 281 | self.buffer_pointer=(self.buffer_pointer+1)%self.buffer_capacity 282 | 283 | episode_return=self.simulator.get_return() 284 | self.episode_return_trackor.append(episode_return) 285 | 286 | tqdm.write(f'Episode {episode}: Return {episode_return}') 287 | 288 | if (episode+1)%self.update_interval==0: 289 | self.update() 290 | 291 | if (episode+1)%self.save_interval==0: 292 | self.save_models(episode+1) 293 | 294 | def save_models(self,episode): 295 | torch.save(self.actor.state_dict(),os.path.join(self.output_dir,f'actor_{episode}.pth')) 296 | torch.save(self.actor_target.state_dict(),os.path.join(self.output_dir,f'actor_target_{episode}.pth')) 297 | 298 | torch.save(self.actor_attention.state_dict(),os.path.join(self.output_dir,f'actor_attention_{episode}.pth')) 299 | torch.save(self.actor_attention_target.state_dict(),os.path.join(self.output_dir,f'actor_attention_target_{episode}.pth')) 300 | 301 | torch.save(self.critic.state_dict(),os.path.join(self.output_dir,f'critic_{episode}.pth')) 302 | torch.save(self.critic_target.state_dict(),os.path.join(self.output_dir,f'critic_target_{episode}.pth')) 303 | 304 | torch.save(self.critic_attention.state_dict(),os.path.join(self.output_dir,f'critic_attention_{episode}.pth')) 305 | torch.save(self.critic_attention_target.state_dict(),os.path.join(self.output_dir,f'critic_attention_target_{episode}.pth')) 306 | 307 | with open(os.path.join(self.output_dir,'episode_return.json'),'w') as f: 308 | json.dump(str(self.episode_return_trackor),f) 309 | with open(os.path.join(self.output_dir,'critic_loss.json'),'w') as f: 310 | json.dump(str(self.critic_loss_trackor),f) 311 | with open(os.path.join(self.output_dir,'actor_loss.json'),'w') as f: 312 | json.dump(str(self.actor_loss_trackor),f) 313 | 314 | 315 | if __name__ == '__main__': 316 | train_platform=MARL() 317 | # train_platform.test_simulation() 318 | # train_platform.test_network() 319 | train_platform.train() -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | # environment setting 2 | import os 3 | os.chdir(os.path.split(os.path.realpath(__file__))[0]) 4 | 5 | # external packages 6 | import numpy as np 7 | import time 8 | import torch 9 | import copy 10 | import json 11 | from tqdm import tqdm 12 | 13 | # self writing files 14 | import networks as networks 15 | from reward_fun import reward_fun 16 | import constants 17 | import data_range 18 | from simulator import disease_model 19 | from util import * 20 | 21 | class MARL(object): 22 | def __init__(self, 23 | MSA_name='Atlanta', 24 | vaccine_day=0.01, 25 | step_length=24, 26 | num_seed=64, 27 | max_episode=110, 28 | update_batch=300, 29 | batch_size=24, 30 | buffer_capacity=64, 31 | save_interval=1, 32 | lr=0.0001, 33 | lr_decay=False, 34 | grad_clip=False, 35 | max_grad_norm=10, 36 | soft_replace_rate=0.01, 37 | gamma=0.6, 38 | D_weight=0, 39 | entropy_weight=0, 40 | reward_option=0, 41 | explore_noise=0.002, 42 | explore_noise_decay=True, 43 | explore_noise_decay_rate=0.2, 44 | manual_seed=0, 45 | random_simulation=False): 46 | super().__init__() 47 | 48 | print('Initializing...') 49 | 50 | # generate config 51 | config_data=locals() 52 | del config_data['self'] 53 | del config_data['__class__'] 54 | time_data=time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime()) 55 | config_data['time']=time_data 56 | 57 | # environment 58 | self.device='cuda' if torch.cuda.is_available() else 'cpu' 59 | self.manual_seed=manual_seed 60 | torch.manual_seed(self.manual_seed) 61 | if self.device=='cuda': 62 | torch.cuda.manual_seed(self.manual_seed) 63 | np.random.seed(self.manual_seed) 64 | self.random_simulation=random_simulation 65 | 66 | # loading cbg data (for simulation) 67 | self.MSA_name=MSA_name 68 | self.data_range_key=(self.MSA_name+'_random') if self.random_simulation else self.MSA_name 69 | self.data=load_data(self.MSA_name) 70 | self.poi_areas=self.data['poi_areas'] 71 | self.cbg_sizes=self.data['cbg_sizes'] 72 | self.poi_cbg_visits_list=self.data['poi_cbg_visits_list']# time_length*poi*cbg 73 | self.time_length=len(self.poi_cbg_visits_list) 74 | self.day_length=int(self.time_length/24) 75 | self.step_length=step_length 76 | 77 | self.num_cbg=len(self.cbg_sizes) 78 | self.sum_population=np.sum(self.cbg_sizes) 79 | assert len(self.poi_areas)==self.poi_cbg_visits_list[0].shape[0] 80 | self.num_poi=len(self.poi_areas) 81 | 82 | # simulator 83 | self.num_seed=num_seed 84 | self.simulator=disease_model.Model(num_seeds=self.num_seed,random_simulation=self.random_simulation) 85 | self.simulator.init_exogenous_variables(poi_areas=self.poi_areas, 86 | poi_dwell_time_correction_factors=self.data['poi_dwell_time_correction_factors'], 87 | cbg_sizes=self.cbg_sizes, 88 | poi_cbg_visits_list=self.poi_cbg_visits_list, 89 | cbg_attack_rates_original = self.data['cbg_attack_rates_scaled'], 90 | cbg_death_rates_original = self.data['cbg_death_rates_scaled'], 91 | p_sick_at_t0=constants.parameters_dict[self.MSA_name][0], 92 | home_beta=constants.parameters_dict[self.MSA_name][1], 93 | poi_psi=constants.parameters_dict[self.MSA_name][2], 94 | just_compute_r0=False, 95 | latency_period=96, # 4 days 96 | infectious_period=84, # 3.5 days 97 | confirmation_rate=.1, 98 | confirmation_lag=168, # 7 days 99 | death_lag=432) 100 | 101 | # adjency matrix 102 | self.Gmat=np.load(os.path.join('..','data',self.MSA_name,f'{self.MSA_name}_Gmat.npy')) 103 | self.Gmat=np.clip(self.Gmat,0,data_range.Gmat[self.data_range_key])/data_range.Gmat[self.data_range_key] 104 | self.Gmat=torch.FloatTensor(self.Gmat).to(self.device) 105 | 106 | # dynamic features 107 | self.cbg_state_raw=np.zeros((self.num_seed,self.num_cbg,3),dtype=np.float32)#S,C,D 108 | self.cbg_state=np.zeros((self.num_seed,self.num_cbg,3),dtype=np.float32)#S,C,D 109 | self.cbg_state_diff=np.zeros((self.num_seed,self.num_cbg,3),dtype=np.float32) 110 | 111 | # vaccine number 112 | self.vaccine_day=int(vaccine_day*self.sum_population) 113 | 114 | # learning parameters 115 | self.max_episode=max_episode 116 | self.update_batch=update_batch 117 | self.batch_size=batch_size 118 | self.save_interval=save_interval 119 | self.lr=lr 120 | self.lr_decay=lr_decay 121 | self.grad_clip=grad_clip 122 | self.max_grad_norm=max_grad_norm 123 | self.soft_replace_rate=soft_replace_rate 124 | self.gamma=gamma 125 | self.D_weight=D_weight 126 | self.entropy_weight=entropy_weight 127 | self.reward_option=reward_option 128 | self.explore_noise=explore_noise 129 | self.explore_noise_decay=explore_noise_decay 130 | self.explore_noise_decay_rate=explore_noise_decay_rate 131 | 132 | # networks and optimizers 133 | self.actor=networks.Actor().to(self.device) 134 | self.actor_target=copy.deepcopy(self.actor).eval() 135 | self.actor_optimizer=torch.optim.Adam(self.actor.parameters(),lr=self.lr) 136 | 137 | self.critic=networks.Critic().to(self.device) 138 | self.critic_target=copy.deepcopy(self.critic).eval() 139 | self.critic_optimizer=torch.optim.Adam(self.critic.parameters(),lr=self.lr) 140 | 141 | self.actor_attention=networks.Attention().to(self.device) 142 | self.actor_attention_target=copy.deepcopy(self.actor_attention).eval() 143 | self.actor_attention_optimizer=torch.optim.Adam(self.actor_attention.parameters(),lr=self.lr) 144 | 145 | self.critic_attention=networks.Attention().to(self.device) 146 | self.critic_attention_target=copy.deepcopy(self.critic_attention).eval() 147 | self.critic_attention_optimizer=torch.optim.Adam(self.critic_attention.parameters(),lr=self.lr) 148 | 149 | if self.lr_decay: 150 | self.actor_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.actor_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 151 | self.critic_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.critic_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 152 | self.actor_attention_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.actor_attention_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 153 | self.critic_attention_optimizer_scheduler = torch.optim.lr_scheduler.LambdaLR(self.critic_attention_optimizer, lr_lambda=lambda epoch: 0.99**epoch) 154 | 155 | # buffer 156 | self.buffer_capacity=self.num_seed*(self.day_length-1)*buffer_capacity 157 | self.buffer_pointer=0 158 | self.buffer_size=0 159 | self.buffer_s=np.empty((self.buffer_capacity,self.num_cbg,6),dtype=np.float32) # S,C,D,Sdiff,Cdiff,Ddiff 160 | self.buffer_a=np.empty((self.buffer_capacity,self.num_cbg,1),dtype=np.float32) 161 | self.buffer_s1=np.empty((self.buffer_capacity,self.num_cbg,6),dtype=np.float32) # S,C,D,Sdiff,Cdiff,Ddiff 162 | self.buffer_r=np.empty((self.buffer_capacity,self.num_cbg,1),dtype=np.float32) 163 | self.buffer_end=np.ones((self.buffer_capacity,self.num_cbg,1),dtype=np.float32) 164 | 165 | # training trackors 166 | self.episode_deaths_trackor=list() 167 | self.episode_cases_trackor=list() 168 | self.critic_loss_trackor=list() 169 | self.actor_loss_trackor=list() 170 | self.action_entropy_trackor=list() 171 | 172 | # making output directory 173 | self.output_dir=os.path.join('..','model',f'{self.MSA_name}_{self.num_seed}seeds_{time_data}') 174 | os.mkdir(self.output_dir) 175 | with open(os.path.join(self.output_dir,'config.json'),'w') as f: 176 | json.dump(config_data,f) 177 | 178 | print(f'Training platform on {self.MSA_name} initialized') 179 | print(f'Number of CBGs={self.num_cbg}') 180 | print(f'Number of POIs={self.num_poi}') 181 | print(f'Total population={self.sum_population}') 182 | print(f'Time length={self.time_length}') 183 | print(f'Train with {self.num_seed} random seeds') 184 | 185 | def test_simulation(self): 186 | for num in range(1): 187 | self.simulator.reset_random_seed() 188 | self.simulator.init_endogenous_variables() 189 | # mat=500*np.ones((self.num_seed,self.num_cbg)) 190 | # mat[:30,:]-=300 191 | # self.simulator.add_vaccine(mat) 192 | for i in tqdm(range(63)): 193 | # if i==20: 194 | # mat=500*np.ones((self.num_seed,self.num_cbg)) 195 | # self.simulator.add_vaccine(mat) 196 | self.simulator.simulate_disease_spread(no_print=True) 197 | # current_C,current_D=self.simulator.output_record() 198 | 199 | T1,L_1,I_1,R_1,C2,D2,total_affected, history_C2, history_D2, total_affected_each_cbg=self.simulator.output_record(full=True) 200 | print(np.mean(np.sum(history_C2[:,-1,:],axis=-1))) 201 | print(np.mean(np.sum(history_D2[:,-1,:],axis=-1))) 202 | 203 | print(history_C2[:5,-1,:10]) 204 | 205 | gt_result_root=os.path.join('..','model','simulator_test') 206 | if not os.path.exists(gt_result_root): 207 | os.mkdir(gt_result_root) 208 | savepath = os.path.join(gt_result_root, f'cases_cbg_no_vaccination_{self.MSA_name}_{self.num_seed}seeds_step_raw{num}.npy') 209 | np.save(savepath, history_C2) 210 | savepath = os.path.join(gt_result_root, f'deaths_cbg_no_vaccination_{self.MSA_name}_{self.num_seed}seeds_step_raw{num}.npy') 211 | np.save(savepath, history_D2) 212 | 213 | def test_network(self): 214 | index=np.array(range(self.batch_size)) 215 | state=torch.FloatTensor(self.buffer_s[index,:,:]).to(self.device) 216 | print(state.shape) 217 | 218 | Actor_attention=self.actor_attention(state,self.Gmat) 219 | print(Actor_attention.shape) 220 | Actor_state_bar=torch.bmm(Actor_attention,state) 221 | print(Actor_state_bar.shape) 222 | Actor_state_all=torch.concat([state,Actor_state_bar],dim=-1) 223 | print(Actor_state_all.shape) 224 | action=self.actor(Actor_state_all) 225 | print(action.shape) 226 | 227 | Critic_attention=self.critic_attention(state,self.Gmat) 228 | print(Critic_attention.shape) 229 | Critic_state_bar=torch.bmm(Critic_attention,state) 230 | print(Critic_state_bar.shape) 231 | Critic_state_all=torch.concat([state,Critic_state_bar],dim=-1) 232 | print(Critic_state_all.shape) 233 | action_bar=torch.bmm(Critic_attention,action) 234 | print(action_bar.shape) 235 | action_all=torch.concat([action,action_bar],dim=-1) 236 | print(action_all.shape) 237 | estimate_reward=self.critic(Critic_state_all,action_all) 238 | print(estimate_reward.shape) 239 | 240 | def update_cbg_state(self,current_C,current_D): 241 | self.cbg_state_raw[:,:,1]=current_C 242 | self.cbg_state_raw[:,:,2]=current_D 243 | self.cbg_state_raw[:,:,0]=self.cbg_sizes-current_C-current_D 244 | 245 | self.cbg_state[:,:,1]=np.clip(self.cbg_state_raw[:,:,1],0,data_range.Idata[self.data_range_key])/data_range.Idata[self.data_range_key] 246 | self.cbg_state[:,:,2]=np.clip(self.cbg_state_raw[:,:,2],0,data_range.Ddata[self.data_range_key])/data_range.Ddata[self.data_range_key] 247 | self.cbg_state[:,:,0]=np.clip(self.cbg_state_raw[:,:,0],0,data_range.cbg_size[self.data_range_key])/data_range.cbg_size[self.data_range_key] 248 | 249 | def norm_cbg_state_diff(self): 250 | self.cbg_state_diff[:,:,1]=np.clip(self.cbg_state_diff[:,:,1],-data_range.Idata_diff[self.data_range_key],data_range.Idata_diff[self.data_range_key])/data_range.Idata_diff[self.data_range_key] 251 | self.cbg_state_diff[:,:,2]=np.clip(self.cbg_state_diff[:,:,2],-data_range.Ddata_diff[self.data_range_key],data_range.Ddata_diff[self.data_range_key])/data_range.Ddata_diff[self.data_range_key] 252 | self.cbg_state_diff[:,:,0]=np.clip(self.cbg_state_diff[:,:,0],-data_range.Sdata_diff[self.data_range_key],data_range.Sdata_diff[self.data_range_key])/data_range.Sdata_diff[self.data_range_key] 253 | 254 | def get_vaccine(self,action): 255 | weight=torch.softmax(action,dim=1) 256 | vaccine_mat=self.vaccine_day*weight 257 | 258 | return vaccine_mat 259 | 260 | def get_entropy(self,action): 261 | weight=torch.softmax(action,dim=1) 262 | action_entropy=-torch.sum(weight*torch.log2(weight)) 263 | 264 | return action_entropy 265 | 266 | def update(self): 267 | actor_loss_sum=0 268 | action_entropy_sum=0 269 | critic_loss_sum=0 270 | 271 | for batch_count in range(self.update_batch): 272 | batch_index=np.random.randint(self.buffer_size,size=self.batch_size) 273 | s_batch=torch.FloatTensor(self.buffer_s[batch_index]).to(self.device) 274 | a_batch=torch.FloatTensor(self.buffer_a[batch_index]).to(self.device) 275 | s1_batch=torch.FloatTensor(self.buffer_s1[batch_index]).to(self.device) 276 | r_batch=torch.FloatTensor(self.buffer_r[batch_index]).to(self.device) 277 | end_batch=torch.FloatTensor(self.buffer_end[batch_index]).to(self.device) 278 | 279 | with torch.no_grad(): 280 | update_Actor_attention1=self.actor_attention_target(s1_batch,self.Gmat) 281 | update_Actor_state1_bar=torch.bmm(update_Actor_attention1,s1_batch) 282 | update_Actor_state1_all=torch.concat([s1_batch,update_Actor_state1_bar],dim=-1) 283 | update_action1=self.actor_target(update_Actor_state1_all) 284 | # update_action1=self.get_vaccine(update_action1)/self.vaccine_day 285 | # update_action1=self.get_vaccine(update_action1)/(2*self.vaccine_day/self.num_cbg) 286 | # update_action1=self.get_vaccine(update_action1) 287 | 288 | update_Critic_attention1=self.critic_attention_target(s1_batch,self.Gmat) 289 | update_Critic_state1_bar=torch.bmm(update_Critic_attention1,s1_batch) 290 | update_Critic_state1_all=torch.concat([s1_batch,update_Critic_state1_bar],dim=-1) 291 | update_action1_bar=torch.bmm(update_Critic_attention1,update_action1) 292 | update_action1_all=torch.concat([update_action1,update_action1_bar],dim=-1) 293 | 294 | y=r_batch+self.gamma*self.critic_target(update_Critic_state1_all,update_action1_all)*end_batch 295 | 296 | update_Critic_attention=self.critic_attention(s_batch,self.Gmat) 297 | update_Critic_state_bar=torch.bmm(update_Critic_attention,s_batch) 298 | update_Critic_state_all=torch.concat([s_batch,update_Critic_state_bar],dim=-1) 299 | update_action_bar=torch.bmm(update_Critic_attention,a_batch) 300 | update_action_all=torch.concat([a_batch,update_action_bar],dim=-1) 301 | 302 | critic_loss=torch.sum(torch.square(y-self.critic(update_Critic_state_all,update_action_all)))/self.batch_size 303 | critic_loss_sum+=critic_loss.cpu().item() 304 | 305 | self.critic_optimizer.zero_grad() 306 | self.critic_attention_optimizer.zero_grad() 307 | critic_loss.backward() 308 | if self.grad_clip: 309 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) 310 | torch.nn.utils.clip_grad_norm_(self.critic_attention.parameters(), self.max_grad_norm) 311 | self.critic_optimizer.step() 312 | self.critic_attention_optimizer.step() 313 | 314 | update_Actor_attention=self.actor_attention(s_batch,self.Gmat) 315 | update_Actor_state_bar=torch.bmm(update_Actor_attention,s_batch) 316 | update_Actor_state_all=torch.concat([s_batch,update_Actor_state_bar],dim=-1) 317 | update_action=self.actor(update_Actor_state_all) 318 | action_entropy=self.get_entropy(update_action) 319 | action_entropy_sum+=action_entropy.cpu().item() 320 | # update_action=self.get_vaccine(update_action)/self.vaccine_day 321 | # update_action=self.get_vaccine(update_action)/(2*self.vaccine_day/self.num_cbg) 322 | # update_action=self.get_vaccine(update_action) 323 | 324 | with torch.no_grad(): 325 | update_Critic_attention_new=self.critic_attention(s_batch,self.Gmat) 326 | update_Critic_state_bar_new=torch.bmm(update_Critic_attention_new,s_batch) 327 | update_Critic_state_all_new=torch.concat([s_batch,update_Critic_state_bar_new],dim=-1) 328 | 329 | update_action_bar_new=torch.bmm(update_Critic_attention_new,update_action) 330 | update_action_all_new=torch.concat([update_action,update_action_bar_new],dim=-1) 331 | 332 | # print(action_entropy) 333 | # print(-torch.sum(self.critic(update_Critic_state_all_new,update_action_all_new))) 334 | 335 | actor_loss=-torch.sum(self.critic(update_Critic_state_all_new,update_action_all_new))/self.batch_size-self.entropy_weight*action_entropy/self.batch_size 336 | actor_loss_sum+=actor_loss.cpu().item() 337 | 338 | self.actor_optimizer.zero_grad() 339 | self.actor_attention_optimizer.zero_grad() 340 | actor_loss.backward() 341 | if self.grad_clip: 342 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) 343 | torch.nn.utils.clip_grad_norm_(self.actor_attention.parameters(), self.max_grad_norm) 344 | self.actor_optimizer.step() 345 | self.actor_attention_optimizer.step() 346 | 347 | for x in self.actor.state_dict().keys(): 348 | eval('self.actor_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 349 | eval('self.actor_target.'+x+'.data.add_(self.soft_replace_rate*self.actor.'+x+'.data)') 350 | for x in self.actor_attention.state_dict().keys(): 351 | eval('self.actor_attention_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 352 | eval('self.actor_attention_target.'+x+'.data.add_(self.soft_replace_rate*self.actor_attention.'+x+'.data)') 353 | for x in self.critic.state_dict().keys(): 354 | eval('self.critic_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 355 | eval('self.critic_target.'+x+'.data.add_(self.soft_replace_rate*self.critic.'+x+'.data)') 356 | for x in self.critic_attention.state_dict().keys(): 357 | eval('self.critic_attention_target.'+x+'.data.mul_(1-self.soft_replace_rate)') 358 | eval('self.critic_attention_target.'+x+'.data.add_(self.soft_replace_rate*self.critic_attention.'+x+'.data)') 359 | 360 | if self.lr_decay: 361 | self.actor_optimizer_scheduler.step() 362 | self.actor_attention_optimizer_scheduler.step() 363 | self.critic_optimizer_scheduler.step() 364 | self.critic_attention_optimizer_scheduler.step() 365 | 366 | critic_loss_mean=critic_loss_sum/self.update_batch 367 | self.critic_loss_trackor.append(critic_loss_mean) 368 | actor_loss_mean=actor_loss_sum/self.update_batch 369 | self.actor_loss_trackor.append(actor_loss_mean) 370 | action_entropy_mean=action_entropy_sum/self.update_batch 371 | self.action_entropy_trackor.append(action_entropy_mean) 372 | 373 | tqdm.write(f'Update: Critic Loss {critic_loss_mean} | Actor Loss {actor_loss_mean} | Action Entropy {action_entropy_mean}') 374 | 375 | def train(self): 376 | for episode in tqdm(range(self.max_episode)): 377 | self.simulator.init_endogenous_variables() 378 | self.update_cbg_state(0,0) 379 | self.cbg_state_init=copy.deepcopy(self.cbg_state_raw) 380 | 381 | self.simulator.simulate_disease_spread(length=self.step_length,verbosity=1,no_print=True) 382 | current_C,current_D=self.simulator.output_record() 383 | self.simulator.empty_record() 384 | 385 | self.update_cbg_state(current_C,current_D) 386 | self.cbg_state_diff=self.cbg_state_raw-self.cbg_state_init 387 | self.norm_cbg_state_diff() 388 | 389 | if self.explore_noise_decay: 390 | self.explore_noise=self.explore_noise/((episode+1)**self.explore_noise_decay_rate) 391 | 392 | for day in range(1,self.day_length): 393 | current_state=np.concatenate((self.cbg_state,self.cbg_state_diff),axis=-1) 394 | self.buffer_s[self.buffer_pointer:self.buffer_pointer+self.num_seed]=current_state 395 | 396 | with torch.no_grad(): 397 | state=torch.FloatTensor(current_state).to(self.device) 398 | 399 | Actor_attention=self.actor_attention(state,self.Gmat) 400 | Actor_state_bar=torch.bmm(Actor_attention,state) 401 | Actor_state_all=torch.concat([state,Actor_state_bar],dim=-1) 402 | action=self.actor(Actor_state_all) 403 | 404 | action=action+torch.randn_like(action)*self.explore_noise 405 | vaccine_mat=self.get_vaccine(action).cpu().numpy() 406 | 407 | self.simulator.add_vaccine(vaccine_mat.squeeze(-1)) 408 | self.simulator.simulate_disease_spread(length=self.step_length,verbosity=1,no_print=True) 409 | current_C,current_D=self.simulator.output_record() 410 | self.simulator.empty_record() 411 | 412 | cbg_state_raw_old=copy.deepcopy(self.cbg_state_raw) 413 | self.update_cbg_state(current_C,current_D) 414 | self.cbg_state_diff=self.cbg_state_raw-cbg_state_raw_old 415 | self.norm_cbg_state_diff() 416 | reward=reward_fun(cbg_state_raw_old,self.cbg_state_raw,D_weight=self.D_weight,option=self.reward_option) 417 | 418 | self.buffer_s1[self.buffer_pointer:self.buffer_pointer+self.num_seed]=np.concatenate((self.cbg_state,self.cbg_state_diff),axis=-1) 419 | # self.buffer_a[self.buffer_pointer:self.buffer_pointer+self.num_seed]=vaccine_mat/self.vaccine_day 420 | # self.buffer_a[self.buffer_pointer:self.buffer_pointer+self.num_seed]=vaccine_mat/(2*self.vaccine_day/self.num_cbg) 421 | # self.buffer_a[self.buffer_pointer:self.buffer_pointer+self.num_seed]=vaccine_mat 422 | self.buffer_a[self.buffer_pointer:self.buffer_pointer+self.num_seed]=action.cpu().numpy() 423 | self.buffer_r[self.buffer_pointer:self.buffer_pointer+self.num_seed]=reward 424 | if day==self.day_length-1: 425 | self.buffer_end[self.buffer_pointer:self.buffer_pointer+self.num_seed]=0 426 | 427 | self.buffer_size=max(self.buffer_size,self.buffer_pointer+self.num_seed) 428 | self.buffer_pointer=(self.buffer_pointer+self.num_seed)%self.buffer_capacity 429 | 430 | episode_C=np.mean(np.sum(self.cbg_state_raw[:,:,1],axis=1),axis=0) 431 | self.episode_cases_trackor.append(episode_C) 432 | episode_D=np.mean(np.sum(self.cbg_state_raw[:,:,2],axis=1),axis=0) 433 | self.episode_deaths_trackor.append(episode_D) 434 | 435 | tqdm.write(f'Episode {episode}: Cases {episode_C} | Deaths {episode_D}') 436 | 437 | self.update() 438 | 439 | if (episode+1)%self.save_interval==0: 440 | self.save_models(episode+1) 441 | 442 | def save_models(self,episode): 443 | torch.save(self.actor.state_dict(),os.path.join(self.output_dir,f'actor_{episode}.pth')) 444 | torch.save(self.actor_target.state_dict(),os.path.join(self.output_dir,f'actor_target_{episode}.pth')) 445 | 446 | torch.save(self.actor_attention.state_dict(),os.path.join(self.output_dir,f'actor_attention_{episode}.pth')) 447 | torch.save(self.actor_attention_target.state_dict(),os.path.join(self.output_dir,f'actor_attention_target_{episode}.pth')) 448 | 449 | torch.save(self.critic.state_dict(),os.path.join(self.output_dir,f'critic_{episode}.pth')) 450 | torch.save(self.critic_target.state_dict(),os.path.join(self.output_dir,f'critic_target_{episode}.pth')) 451 | 452 | torch.save(self.critic_attention.state_dict(),os.path.join(self.output_dir,f'critic_attention_{episode}.pth')) 453 | torch.save(self.critic_attention_target.state_dict(),os.path.join(self.output_dir,f'critic_attention_target_{episode}.pth')) 454 | 455 | with open(os.path.join(self.output_dir,'episode_cases.json'),'w') as f: 456 | json.dump(str(self.episode_cases_trackor),f) 457 | with open(os.path.join(self.output_dir,'episode_deaths.json'),'w') as f: 458 | json.dump(str(self.episode_deaths_trackor),f) 459 | with open(os.path.join(self.output_dir,'critic_loss.json'),'w') as f: 460 | json.dump(str(self.critic_loss_trackor),f) 461 | with open(os.path.join(self.output_dir,'actor_loss.json'),'w') as f: 462 | json.dump(str(self.actor_loss_trackor),f) 463 | with open(os.path.join(self.output_dir,'action_entropy.json'),'w') as f: 464 | json.dump(str(self.action_entropy_trackor),f) 465 | 466 | 467 | if __name__ == '__main__': 468 | train_platform=MARL() 469 | # train_platform.test_simulation() 470 | # train_platform.test_network() 471 | train_platform.train() --------------------------------------------------------------------------------