├── .gitignore ├── MSMTC ├── DigitalPose2D │ ├── PoseEnvLarge_multi.json │ ├── __init__.py │ ├── pose_env_base.py │ └── render.py └── DigitalPose2DBase │ ├── PoseEnvLarge_multi.json │ ├── __init__.py │ └── pose_env_base.py ├── README.md ├── environment.py ├── main.py ├── model.py ├── multiagent ├── __init__.py ├── core.py ├── environment.py ├── multi_discrete.py ├── policy.py ├── rendering.py ├── scenario.py └── scenarios │ ├── CN.py │ └── __init__.py ├── perception.py ├── player_util.py ├── render_test.py ├── requirements.txt ├── shared_optim.py ├── test.py ├── train.py ├── utils.py └── worker.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ 3 | logs/ 4 | trainedModel/ -------------------------------------------------------------------------------- /MSMTC/DigitalPose2D/PoseEnvLarge_multi.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name": "PoseEnv-v1", 3 | "max_steps": 100, 4 | "cam_number": 4, 5 | "visual_distance": 800, 6 | "rotation_scale": 5, 7 | "target_number": 5, 8 | "discrete_actions": [ 9 | [0], 10 | [1], 11 | [-1] 12 | ], 13 | "safe_start" :[ 14 | [ 0, 0] 15 | ], 16 | "reset_area" : [-1250, 1250, -1250, 1250], 17 | "obstacle_reset_area" : [-250, 250, -250, 250], 18 | "cam_area" : [ 19 | [400, 500, 400, 500], 20 | [400, 500, -500, -400], 21 | [-500, -400, -500, -400], 22 | [-500, -400, 400, 500], 23 | [-1350, -1250, -50, 50], 24 | [-1200, -1100, -1200, -1100], 25 | [-50, 50, -1350, -1250], 26 | [1250, 1350, -50, 50], 27 | [1100, 1200, 1100, 1200], 28 | [-50, 50, 1250, 1350] 29 | ], 30 | "continous_actions_player": { 31 | "high": [100, 30], 32 | "low": [50, -30] 33 | }, 34 | "obstacle_radius": [70, 110], 35 | "obstacle_numRange": [4,5] 36 | } -------------------------------------------------------------------------------- /MSMTC/DigitalPose2D/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from MSMTC.DigitalPose2D.pose_env_base import Pose_Env_Base 3 | 4 | 5 | class Gym: 6 | def make(self, env_id, args): 7 | reset_type = env_id.split('-v')[1] 8 | env = Pose_Env_Base(int(reset_type),args) 9 | return env 10 | 11 | 12 | gym = Gym() 13 | -------------------------------------------------------------------------------- /MSMTC/DigitalPose2D/pose_env_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import torch 5 | import random 6 | import numpy as np 7 | from gym import spaces 8 | import matplotlib.pyplot as plt 9 | from datetime import datetime 10 | 11 | from MSMTC.DigitalPose2D.render import render 12 | from model import A3C_Single 13 | from utils import goal_id_filter 14 | #from main import parser 15 | import random 16 | 17 | #args = parser.parse_args() 18 | 19 | 20 | class Pose_Env_Base: 21 | def __init__(self, reset_type, args, 22 | nav='Goal', # Random, Goal 23 | config_path="PoseEnvLarge_multi.json", 24 | setting_path=None, 25 | ): 26 | self.nav = nav 27 | self.reset_type = reset_type 28 | self.ENV_PATH = 'MSMTC/DigitalPose2D' 29 | if setting_path: 30 | self.SETTING_PATH = setting_path 31 | else: 32 | self.SETTING_PATH = os.path.join(self.ENV_PATH, config_path) 33 | with open(self.SETTING_PATH, encoding='utf-8') as f: 34 | setting = json.load(f) 35 | 36 | self.env_name = setting['env_name'] 37 | 38 | if args.num_agents == -1: 39 | self.n = setting['cam_number'] 40 | self.num_target = setting['target_number'] 41 | else: 42 | self.n = args.num_agents 43 | self.num_target = args.num_targets 44 | 45 | self.cam_area = np.array(setting['cam_area']) 46 | self.reset_area = setting['reset_area'] 47 | # for obstacle 48 | self.obstacle_pos_list = None 49 | if self.reset_type == 3: 50 | #self.obstacle_numRange = setting['obstacle_numRange'] 51 | self.obstacle_numRange = [self.n-1,self.n] 52 | self.num_obstacle = np.random.randint(self.obstacle_numRange[0], self.obstacle_numRange[1]) 53 | self.obstacle_radiusRange = setting['obstacle_radius'] 54 | self.obstacle_radius_list = [0 for i in range(self.num_obstacle)] 55 | self.obstacle_reset_area = setting['obstacle_reset_area'] 56 | 57 | # for executor 58 | self.discrete_actions = setting['discrete_actions'] 59 | self.continous_actions_player = setting['continous_actions_player'] 60 | 61 | self.max_steps = setting['max_steps'] 62 | 63 | self.visual_distance = setting['visual_distance'] 64 | self.safe_start = setting['safe_start'] 65 | self.start_area = self.get_start_area(self.safe_start[0], 1000) 66 | 67 | # define action space for coordinator 68 | self.action_space = [spaces.Discrete(2) for i in range(self.n * self.num_target)] 69 | self.rotation_scale = setting['rotation_scale'] 70 | 71 | # define observation space 72 | self.state_dim = 4 73 | if self.reset_type != 3: 74 | self.observation_space = np.zeros((self.n, self.num_target, self.state_dim), int) 75 | else: 76 | self.observation_space = np.zeros((self.n, self.num_target+self.num_obstacle, self.state_dim), int) 77 | 78 | # render 79 | self.render_save = args.render_save 80 | self.render = args.render 81 | 82 | # communication edges for render 83 | self.comm_edges = None 84 | 85 | self.cam = dict() 86 | for i in range(self.n): 87 | self.cam[i] = dict( 88 | location=[0, 0], 89 | rotation=[0], 90 | ) 91 | 92 | self.count_steps = 0 93 | self.goal_keep = 0 94 | self.KEEP = 10 95 | self.goals4cam = np.ones([self.n, self.num_target]) 96 | 97 | # construct target_agent 98 | self.target_type_prob = [0.3, 0.7] 99 | if 'Goal' in self.nav: 100 | self.random_agents = [GoalNavAgent(i, self.continous_actions_player, self.cam, self.visual_distance, self.reset_area) 101 | for i in range(self.num_target)] 102 | 103 | self.slave_rule = (args.load_executor_dir is None) 104 | 105 | if not self.slave_rule: 106 | self.device = torch.device('cpu') 107 | self.slave = A3C_Single(np.zeros((1, 1, 4)), [spaces.Discrete(3)], args) 108 | self.slave = self.slave.to(self.device) 109 | saved_state = torch.load( 110 | args.load_executor_dir, # 'trainedModel/best_executor.pth', 111 | map_location=lambda storage, loc: storage) 112 | self.slave.load_state_dict(saved_state['model'], strict=False) 113 | self.slave.eval() 114 | 115 | def set_location(self, cam_id, loc): 116 | self.cam[cam_id]['location'] = loc 117 | 118 | def get_location(self, cam_id): 119 | return self.cam[cam_id]['location'] 120 | 121 | def set_rotation(self, cam_id, rot): 122 | for i in range(len(rot)): 123 | if rot[i] > 180: 124 | rot[i] -= 360 125 | if rot[i] < -180: 126 | rot[i] += 360 127 | self.cam[cam_id]['rotation'] = rot 128 | 129 | def get_rotation(self, cam_id): 130 | return self.cam[cam_id]['rotation'] 131 | 132 | def get_cam_states(self): 133 | cam_states=None 134 | for i in range(self.n): 135 | rotation=self.get_rotation(i) 136 | location=self.get_location(i) 137 | state=np.array(location + rotation) 138 | if cam_states is None: 139 | cam_states=np.expand_dims(state,0) 140 | else: 141 | state=np.expand_dims(state,0) 142 | cam_states=np.concatenate((cam_states,state),0) 143 | return cam_states 144 | 145 | def get_hori_direction(self, current_pose, target_pose): 146 | y_delt = target_pose[1] - current_pose[1] 147 | x_delt = target_pose[0] - current_pose[0] 148 | angle_now = np.arctan2(y_delt, x_delt) / np.pi * 180 - current_pose[2] 149 | if angle_now > 180: 150 | angle_now -= 360 151 | if angle_now < -180: 152 | angle_now += 360 153 | return angle_now 154 | 155 | def get_distance(self, current_pose, target_pose): 156 | y_delt = target_pose[1] - current_pose[1] 157 | x_delt = target_pose[0] - current_pose[0] 158 | d = np.sqrt(y_delt * y_delt + x_delt * x_delt) 159 | return d 160 | 161 | def target_init_sample(self): 162 | # let target be initially located in the tracking area of one camera 163 | cam = np.random.randint(0, self.n) 164 | cam_loc = self.get_location(cam) 165 | cam_rot = self.get_rotation(cam) 166 | theta = (cam_rot[0] + np.random.randint(-45, 46)) * math.pi/180 167 | distance = np.random.randint(100, self.visual_distance - 100) 168 | x = cam_loc[0] + distance * math.cos(theta) 169 | y = cam_loc[1] + distance * math.sin(theta) 170 | return [float(x),float(y)] 171 | #return [float(np.random.randint(self.start_area[0], self.start_area[1])),float(np.random.randint(self.start_area[2], self.start_area[3]))] 172 | def obstacle_init_sample(self, loc_a, loc_b): 173 | d = self.get_distance(loc_a, loc_b) 174 | R = self.visual_distance 175 | #r1 = math.sqrt(R**2 - (d/2)**2) 176 | r2 = R - d/2 177 | #r = min(np.random.rand()*(r1+r2)/2, r2) 178 | r= np.random.rand() * r2 * 0.8 179 | c_x = (loc_a[0]+loc_b[0])/2 180 | c_y = (loc_a[1]+loc_b[1])/2 181 | theta = np.random.randint(0,360) * math.pi/180 182 | x = c_x + r * math.cos(theta) 183 | y = c_y + r * math.sin(theta) 184 | return [x,y] 185 | 186 | def get_mask(self): 187 | # agents only communicate with other agents who overlap with itself 188 | # this value should be updated in reset 189 | return self.mask 190 | 191 | def reset_mask(self): 192 | mask_all = [] 193 | for i in range(self.n): 194 | mask = [] 195 | loc_i = self.get_location(i) 196 | for j in range(self.n): 197 | if i == j: 198 | mask.append(0) 199 | else: 200 | loc_j = self.get_location(j) 201 | if self.get_distance(loc_i, loc_j) > 2 * self.visual_distance: 202 | mask.append(0) 203 | else: 204 | mask.append(1) 205 | mask_all.append(mask) 206 | mask_all = np.array(mask_all) 207 | return mask_all 208 | 209 | def reset(self): 210 | 211 | # reset camera 212 | camera_id_list = [i for i in range(self.n)] 213 | #random.shuffle(camera_id_list) 214 | 215 | for i in range(self.n): 216 | cam_loc = [np.random.randint(self.cam_area[i][0], self.cam_area[i][1]), 217 | np.random.randint(self.cam_area[i][2], self.cam_area[i][3]) 218 | ] 219 | self.set_location(camera_id_list[i], cam_loc) # shuffle 220 | 221 | # reset mask 222 | self.mask = self.reset_mask() 223 | 224 | for i in range(self.n): 225 | cam_rot = self.get_rotation(i) 226 | 227 | # start with non-focusing 228 | angle_h = np.random.randint(-180, 180) 229 | cam_rot[0] += angle_h * 1.0 230 | 231 | self.set_rotation(i, cam_rot) 232 | 233 | # reset the position and shape of obstacles 234 | if self.reset_type==3: 235 | self.num_obstacle = np.random.randint(self.obstacle_numRange[0], self.obstacle_numRange[1]) 236 | self.obstacle_pos_list = [] 237 | #choices = [(i, (i+1)%self.n) for i in range(self.n)] 238 | choices = [] 239 | for i in range(self.n-1): 240 | choices.append((i,i+1)) 241 | ''' 242 | for i in range(self.n): 243 | for j in range(i+1,self.n): 244 | choices.append((i,j)) 245 | ''' 246 | choices = random.sample(choices, self.num_obstacle) 247 | for k in range(self.num_obstacle): 248 | loc_a = self.get_location(choices[k][0]) 249 | loc_b = self.get_location(choices[k][1]) 250 | #dist = self.get_distance(loc_a, loc_b) 251 | for t in range(20): 252 | #tmp = [float(np.random.randint(self.obstacle_reset_area[0], self.obstacle_reset_area[1])), 253 | # float(np.random.randint(self.obstacle_reset_area[2], self.obstacle_reset_area[3]))] 254 | tmp_loc = self.obstacle_init_sample(loc_a, loc_b) 255 | tmp_R = np.random.randint(self.obstacle_radiusRange[0], self.obstacle_radiusRange[1]) 256 | flag = True 257 | for i in range(self.n): 258 | if self.get_distance(self.get_location(i), tmp_loc) < tmp_R: 259 | flag = False 260 | break 261 | if flag: 262 | break 263 | if t == 20: 264 | print("obstacle reset tried for 20 times") 265 | ''' 266 | flag = True 267 | while flag: 268 | tmp = [float(np.random.randint(self.obstacle_reset_area[0], self.obstacle_reset_area[1])), 269 | float(np.random.randint(self.obstacle_reset_area[2], self.obstacle_reset_area[3]))] 270 | flag = False 271 | for i in range(self.n): 272 | if self.get_distance(self.get_location(i), tmp) < self.obstacle_radius: 273 | flag = True # 274 | ''' 275 | self.obstacle_pos_list.append(tmp_loc) 276 | self.obstacle_radius_list[k] = tmp_R 277 | # reset targets 278 | self.target_pos_list = [] 279 | for _ in range(self.num_target): 280 | # tmp = [float(np.random.randint(self.start_area[0], self.start_area[1])), 281 | # float(np.random.randint(self.start_area[2], self.start_area[3]))] 282 | tmp = self.target_init_sample() 283 | if self.reset_type == 3: 284 | for i in range(20): 285 | if self.collision(tmp[0], tmp[1]): 286 | tmp = self.target_init_sample() 287 | #print("resample target initial state") 288 | else: 289 | break 290 | if i == 20: 291 | print("obstacle reset tried for 20 times") 292 | self.target_pos_list.append(tmp) 293 | self.target_pos_list = np.array(self.target_pos_list) 294 | # reset agent 295 | p = self.target_type_prob 296 | choices = [0, 1] # 0 for random walk, 1 for goal nav 297 | self.target_type = np.random.choice(choices, size=self.num_target, p=p) 298 | for i in range(len(self.random_agents)): 299 | if 'Goal' in self.nav: 300 | self.random_agents[i].reset(self.cam) 301 | 302 | self.count_steps = 0 303 | self.goal_keep = 0 304 | self.goals4cam = np.ones([self.n, self.num_target]) 305 | 306 | info = dict( 307 | Done=False, 308 | Reward=[0 for i in range(self.n)], 309 | Target_Pose=[], 310 | Cam_Pose=[], 311 | Steps=self.count_steps, 312 | ) 313 | 314 | gt_directions = [] 315 | gt_distance = [] 316 | cam_info = [] 317 | for i in range(self.n): 318 | # for target navigation 319 | cam_loc = self.get_location(i) 320 | cam_rot = self.get_rotation(i) 321 | cam_info.append([cam_loc, cam_rot]) 322 | gt_directions.append([]) 323 | gt_distance.append([]) 324 | for j in range(self.num_target): 325 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 326 | gt_directions[i].append([angle_h]) 327 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 328 | gt_distance[i].append(d) 329 | 330 | info['Cam_Pose'].append(cam_loc + cam_rot) 331 | 332 | info['Directions'] = np.array(gt_directions) 333 | info['Distance'] = np.array(gt_distance) 334 | info['Target_Pose'] = np.array(self.target_pos_list) # copy.deepcopy 335 | info['Reward'], info['Global_reward'], others = self.multi_reward(cam_info, self.goals4cam) 336 | if others: 337 | info['Camera_target_dict'] = self.Camera_target_dict = others['Camera_target_dict'] 338 | info['Target_camera_dict'] = self.Target_camera_dict = others['Target_camera_dict'] 339 | 340 | state, self.state_dim = self.preprocess_pose(info) 341 | return state 342 | 343 | def visible(self, cam_i, target_j, distance, angle): 344 | # whether target is visible from cam_i 345 | if self.reset_type == 2: 346 | return distance <= self.visual_distance #and 1-abs(angle)/45>0 347 | if self.reset_type == 3: 348 | if not distance <= self.visual_distance: #and 1-abs(angle)/45>0) 349 | return False 350 | # whether the obstacle block the view 351 | cam_loc = self.get_location(cam_i) 352 | cam_rot = self.get_rotation(cam_i) 353 | direction = list(cam_rot+angle) 354 | for i in range(self.num_obstacle): 355 | position_obstacle = self.obstacle_pos_list[i] 356 | angle_ = self.get_hori_direction(cam_loc+direction, position_obstacle) 357 | obstacle_distance = self.get_distance(cam_loc, position_obstacle) 358 | distance_vertical = math.sin(abs(angle_)/180*math.pi)*obstacle_distance 359 | if 0 self.reset_area[1] or \ 411 | loc[1] + delta_y < self.reset_area[2] or loc[1] + delta_y > self.reset_area[3]: #or \ 412 | #self.reset_type == 3 and self.collision(loc[0] + delta_x, loc[1] + delta_y): 413 | action = self.random_agents[i].act(loc) 414 | 415 | target_hpr_now = np.array(action[1:]) 416 | delta_x = target_hpr_now[0] * action[0] * delta_time 417 | delta_y = target_hpr_now[1] * action[0] * delta_time 418 | else: 419 | break 420 | if _ == 20: 421 | print("Target action Sample for 20 times") 422 | self.target_pos_list[i][0] += delta_x 423 | self.target_pos_list[i][1] += delta_y 424 | 425 | def step(self, actions, obstacle=False): 426 | # obstacle: whether the action contains obstacle or not 427 | info = dict( 428 | Done=False, 429 | Reward=[0 for i in range(self.n)], 430 | Target_Pose=[], 431 | Cam_Pose=[], 432 | Steps=self.count_steps 433 | ) 434 | 435 | if not obstacle: 436 | self.goals4cam = np.squeeze(actions) # num_agents * (num_targets) 437 | else: 438 | actions = np.squeeze(actions) # num_agents * (num_targets + num_obstacles) 439 | self.goals4cam = actions[:,:self.num_target] 440 | obstacle_goals = actions[:,self.num_target:] 441 | # target move 442 | self.target_move() 443 | 444 | # camera move 445 | cam_info = [] 446 | for i in range(self.n): 447 | cam_loc = self.get_location(i) 448 | cam_rot = self.get_rotation(i) 449 | cam_info.append([cam_loc, cam_rot]) 450 | 451 | if not obstacle: 452 | r, gr, others, cam_info = self.simulate(self.goals4cam, cam_info, keep=10) 453 | else: 454 | r, gr, others, cam_info = self.simulate(self.goals4cam, cam_info, keep=10, obstacle_goals=obstacle_goals) 455 | 456 | for i in range(self.n): 457 | cam_loc, cam_rot = cam_info[i] 458 | self.set_rotation(i, cam_rot) 459 | 460 | if others: 461 | info['Coverage_rate'] = others['Coverage_rate'] 462 | info['Camera_target_dict'] = self.Camera_target_dict = others['Camera_target_dict'] 463 | info['Target_camera_dict'] = self.Target_camera_dict = others['Target_camera_dict'] 464 | info['Camera_local_goal'] = others['Camera_local_goal'] 465 | info['cost'] = others['cost'] 466 | 467 | info['Reward'] = np.array(r) 468 | info['Global_reward'] = np.array(gr) 469 | 470 | gt_directions = [] 471 | gt_distance = [] 472 | for i in range(self.n): 473 | # for target navigation 474 | cam_loc = self.get_location(i) 475 | cam_rot = self.get_rotation(i) 476 | gt_directions.append([]) 477 | gt_distance.append([]) 478 | for j in range(self.num_target): 479 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 480 | gt_directions[i].append([angle_h]) 481 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 482 | gt_distance[i].append(d) 483 | 484 | info['Cam_Pose'].append(self.get_location(i) + self.get_rotation(i)) 485 | 486 | info['Target_Pose'] = np.array(self.target_pos_list) # copy.deepcopy 487 | info['Distance'] = np.array(gt_distance) 488 | info['Directions'] = np.array(gt_directions) 489 | 490 | self.count_steps += 1 491 | # set your done condition 492 | if self.count_steps >= self.max_steps: 493 | info['Done'] = True 494 | 495 | reward = info['Global_reward'] 496 | 497 | state, self.state_dim = self.preprocess_pose(info, GoalMap=self.goals4cam) 498 | return state, reward, info['Done'], info 499 | 500 | def get_baseline_action(self, cam_loc_rot, goals, i, obstacle_goals=None): 501 | camera_target_visible = [] 502 | for k, v in self.Camera_target_dict.items(): 503 | camera_target_visible += v 504 | 505 | goal_ids = goal_id_filter(goals) 506 | if len(goal_ids) != 0: 507 | if self.slave_rule: 508 | if obstacle_goals is None: 509 | target_position = (self.target_pos_list[goal_ids]).mean(axis=0) # avg pos: [x,y,z] 510 | else: 511 | obstacle_ids = goal_id_filter(obstacle_goals) 512 | selected_targets = np.array(self.target_pos_list[goal_ids]) 513 | selected_obstacles = (np.array(self.obstacle_pos_list))[obstacle_ids] 514 | all_goals = np.concatenate((selected_targets,selected_obstacles),axis=0) 515 | target_position = (all_goals).mean(axis=0) 516 | 517 | angle_h = self.get_hori_direction(cam_loc_rot, target_position) 518 | 519 | action_h = angle_h // self.rotation_scale 520 | action_h = np.clip(action_h, -1, 1) 521 | action_h *= self.rotation_scale 522 | action = [action_h] 523 | else: 524 | tmp = [] 525 | for j in range(len(self.target_pos_list[goal_ids])): 526 | tar_p = self.target_pos_list[goal_ids][j] 527 | angle_h = self.get_hori_direction(cam_loc_rot, tar_p) 528 | d = self.get_distance(cam_loc_rot, tar_p) 529 | tmp.append([i / 4, j / 5, angle_h / 180, d / 2000]) 530 | target = np.zeros((1, self.num_target, 4)) 531 | target[0, :len(tmp)] = tmp 532 | values, actions, entropies, log_probs = self.slave(torch.from_numpy(target).float().to(self.device), 533 | test=True) 534 | action = actions.item() 535 | action = np.array(self.discrete_actions[action]) * self.rotation_scale 536 | else: 537 | action = np.array( 538 | self.discrete_actions[np.random.choice(range(len(self.discrete_actions)))]) * self.rotation_scale 539 | 540 | return action 541 | 542 | def simulate(self, GoalMap, cam_info, keep=-1, obstacle_goals=None): 543 | cost = 0 544 | gre = np.array([0.0]) 545 | for _ in range(keep): 546 | # camera move 547 | visible = [] 548 | Cam_Pose = [] 549 | for i in range(self.n): 550 | cam_loc, cam_rot = cam_info[i] 551 | if obstacle_goals is None: 552 | action = self.get_baseline_action(cam_loc + cam_rot, GoalMap[i], i) 553 | else: 554 | # obstacle_goals: num_agents * num_obstacles 555 | action = self.get_baseline_action(cam_loc + cam_rot, GoalMap[i], i, obstacle_goals[i]) 556 | if action[0] != 0: 557 | cost += 1 558 | cam_rot[0] += action[0] 559 | cam_info[i] = cam_loc, cam_rot 560 | Cam_Pose.append(cam_loc + cam_rot) 561 | sub_visible = [] 562 | for j in range(self.num_target): 563 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 564 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 565 | sub_visible.append(self.visible(i, j, d, angle_h)) 566 | visible.append(sub_visible) 567 | 568 | # target move 569 | self.target_move() 570 | # 571 | r, gr, others = self.multi_reward(cam_info, GoalMap) 572 | gre += gr 573 | 574 | # render 575 | if self.render: 576 | render(Cam_Pose, np.array(self.target_pos_list), goal=self.goals4cam, obstacle_pos=np.array(self.obstacle_pos_list), comm_edges=self.comm_edges, obstacle_radius=self.obstacle_radius_list, save=self.render_save, visible=visible) 577 | 578 | cost = cost / keep 579 | others['cost'] = cost 580 | 581 | return r, gre / keep, others, cam_info 582 | 583 | def close(self): 584 | pass 585 | 586 | def seed(self, para): 587 | pass 588 | 589 | def get_start_area(self, safe_start, safe_range): 590 | start_area = [safe_start[0] - safe_range, safe_start[0] + safe_range, 591 | safe_start[1] - safe_range, safe_start[1] + safe_range] 592 | return start_area 593 | 594 | def angle_reward(self, angle_h, d): 595 | hori_reward = 1 - abs(angle_h) / 45.0 596 | visible = hori_reward > 0 and d <= self.visual_distance 597 | if visible: 598 | reward = np.clip(hori_reward, -1, 1) # * (self.visual_distance-d) 599 | else: 600 | reward = -1 601 | return reward, visible 602 | 603 | def simplified_multi_reward(self, cam_info): 604 | coverage_rate = [] 605 | min_angle = [180 for j in range(self.num_target)] 606 | for i in range(self.n): 607 | cam_loc, cam_rot = cam_info[i] 608 | for j in range(self.num_target): 609 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 610 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 611 | if d < self.visual_distance + 100: 612 | # 613 | min_angle[j] = min(min_angle[j], np.abs(angle_h)) 614 | reward, visible = self.angle_reward(angle_h, d) 615 | if visible: 616 | coverage_rate.append(j) 617 | min_angle_sum = sum(min_angle) 618 | coverage_rate = len(set(coverage_rate)) / self.num_target 619 | return coverage_rate, min_angle_sum 620 | 621 | def multi_reward(self, cam_info, goals4cam): 622 | camera_local_rewards = [] 623 | camera_local_goal = [] 624 | 625 | camera_target_dict = {} 626 | target_camera_dict = {} 627 | captured_targets = [] 628 | coverage_rate = [] 629 | for i in range(self.n): 630 | cam_loc, cam_rot = cam_info[i] 631 | camera_target_dict[i] = [] 632 | local_rewards = [] 633 | captured_num = 0 634 | goal_num = 0 635 | for j in range(self.num_target): 636 | if not target_camera_dict.get(j): 637 | target_camera_dict[j] = [] 638 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 639 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 640 | reward, visible = self.angle_reward(angle_h, d) 641 | #reward = self.angle_reward(angle_h, d) 642 | #visible = self.visible(i,j,d,angle_h) 643 | if visible: 644 | camera_target_dict[i].append(j) 645 | target_camera_dict[j].append(i) 646 | coverage_rate.append(j) 647 | if goals4cam is None or goals4cam[i][j] > 0: 648 | captured_targets.append(j) 649 | captured_num += 1 650 | 651 | if goals4cam is None and visible or goals4cam is not None and goals4cam[i][j] > 0: 652 | local_rewards.append(reward) 653 | goal_num += 1 654 | camera_local_goal.append(captured_num / goal_num if goal_num != 0 else -1) 655 | camera_local_rewards.append(np.mean(local_rewards) if len(local_rewards) > 0 else 0) 656 | camera_local = camera_local_rewards 657 | 658 | # real coverage rate 659 | coverage_rate = len(set(coverage_rate)) / self.num_target 660 | 661 | camera_global_reward = [coverage_rate] 662 | 663 | # if torch.is_tensor(goals4cam): 664 | # goals_sum = torch.sum(goals4cam) 665 | # else: 666 | # goals_sum = np.sum(goals4cam) 667 | # if goals_sum == 0: 668 | # camera_global_reward = [-0.1] 669 | if len(set(captured_targets)) == 0: 670 | camera_global_reward = [-0.1] 671 | 672 | return camera_local, camera_global_reward, {'Camera_target_dict': camera_target_dict, 673 | 'Target_camera_dict': target_camera_dict, 674 | 'Coverage_rate': coverage_rate, 675 | 'Captured_targetsN': len(set(captured_targets)), 676 | 'Camera_local_goal': camera_local_goal 677 | } 678 | 679 | def preprocess_pose(self, info, GoalMap=None): 680 | cam_pose_info = np.array(info['Cam_Pose']) 681 | target_pose_info = np.array(info['Target_Pose']) 682 | angles = info['Directions'] 683 | distances = info['Distance'] 684 | 685 | camera_num = len(cam_pose_info) 686 | target_num = len(target_pose_info) 687 | 688 | # normalize center 689 | center = np.mean(cam_pose_info[:, :2], axis=0) 690 | cam_pose_info[:, :2] -= center 691 | if target_pose_info is not None: 692 | target_pose_info[:, :2] -= center 693 | 694 | # scale 695 | norm_d = int(max(np.linalg.norm(cam_pose_info[:, :2], axis=1, ord=2))) + 1e-8 696 | cam_pose_info[:, :2] /= norm_d 697 | if target_pose_info is not None: 698 | target_pose_info[:, :2] /= norm_d 699 | 700 | state_dim = 4 701 | feature_dim = target_num * state_dim if self.reset_type!=3 else (target_num+self.num_obstacle)*state_dim 702 | state = np.zeros((camera_num, feature_dim)) 703 | for cam_i in range(camera_num): 704 | # target info 705 | target_info = [] 706 | for target_j in range(target_num): 707 | if self.reset_type == 1 or self.reset_type >= 2 and self.visible(cam_i, target_j, distances[cam_i, target_j], angles[cam_i, target_j]): 708 | angle_h = angles[cam_i, target_j] 709 | target_info += [cam_i / camera_num, target_j / target_num, angle_h / 180, distances[cam_i, target_j] / 2000] # 2000 is related with the area of cameras 710 | else: 711 | target_info += [0,0,0,0] 712 | if self.reset_type==3: 713 | for obstacle_i in range(self.num_obstacle): 714 | cam_loc = self.get_location(cam_i) 715 | cam_rot = self.get_rotation(cam_i) 716 | obstacle_angle = self.get_hori_direction(cam_loc + cam_rot, self.obstacle_pos_list[obstacle_i]) 717 | obstacle_distance = self.get_distance(cam_loc + cam_rot, self.obstacle_pos_list[obstacle_i]) 718 | #visible = 1-abs(obstacle_angle)/45>=0 and obstacle_distance-self.obstacle_radius<=self.visual_distance or \ 719 | # 0 self.max_len: 777 | self.goal = self.generate_goal(self.goal_area) 778 | for _ in range(20): 779 | if np.linalg.norm(self.goal[:2] - pose[:2]) == 0: 780 | self.goal = self.generate_goal(self.goal_area) 781 | print("resample target goal") 782 | else: 783 | break 784 | if _ == 20: 785 | print("Target Goal sample for 20 times") 786 | self.velocity = np.random.randint(self.velocity_low, self.velocity_high) 787 | 788 | self.step_counter = 0 789 | 790 | if np.linalg.norm(self.goal[:2] - pose[:2]) == 0: 791 | print("target already reached goal.{}".format(self.goal[:2] - pose[:2])) 792 | assert np.linalg.norm(self.goal[:2] - pose[:2]) != 0 793 | delt_unit = (self.goal[:2] - pose[:2]) / np.linalg.norm(self.goal[:2] - pose[:2]) 794 | velocity = self.velocity * (1 + 0.2 * np.random.random()) 795 | return [velocity, delt_unit[0], delt_unit[1]] 796 | 797 | def reset(self,cam): 798 | self.step_counter = 0 799 | self.keep_steps = 0 800 | self.goal_id = 0 801 | self.goal_area = self.cam2goal_area(cam) 802 | self.goal = self.generate_goal(self.goal_area) 803 | self.velocity = np.random.randint(self.velocity_low, self.velocity_high) 804 | self.pose_last = [[], []] 805 | 806 | def generate_goal(self, goal_area): 807 | if self.goal_list and len(self.goal_list) != 0: 808 | index = self.goal_id % len(self.goal_list) 809 | goal = np.array(self.goal_list[index]) 810 | else: 811 | ''' 812 | cam_num = len(goal_area) 813 | radius = self.cam_radius 814 | cam_id = np.random.randint(0,cam_num) 815 | theta = 2*np.random.rand()*math.pi 816 | x = goal_area[cam_id][0] + radius * math.cos(theta) 817 | y = goal_area[cam_id][1] + radius * math.sin(theta) 818 | ''' 819 | x = np.random.randint(goal_area[0], goal_area[1]) 820 | y = np.random.randint(goal_area[2], goal_area[3]) 821 | goal = np.array([x, y]) 822 | self.goal_id += 1 823 | return goal 824 | 825 | def check_reach(self, goal, now): 826 | error = np.array(now[:2]) - np.array(goal[:2]) 827 | distance = np.linalg.norm(error) 828 | return distance < 5 -------------------------------------------------------------------------------- /MSMTC/DigitalPose2D/render.py: -------------------------------------------------------------------------------- 1 | Cam_Pose = [[-742, 706, -62.842588558231455], [-843, 69, -26.590794532324466], [510, 703, -135.84636503921902], 2 | [466, -609, 153.13035548432399]] 3 | Target_Pose = [[407.90650859, -716.624028], 4 | [-64.83188835, -233.64760113], 5 | [-980.29575616, 201.18355808], 6 | [-493.24174167, 655.69319226], 7 | [-571.57383471, -673.35637078]] 8 | Target_camera_dict = {0: [], 1: [3], 2: [], 3: [], 4: [1]} 9 | Camera_target_dict = {0: [], 1: [4], 2: [], 3: [1]} 10 | Distance = [[1829.24686786, 1158.22893495, 558.23338079, 253.79410157, 1389.84498251], 11 | [1477.15002847, 834.94980715, 190.58493563, 683.03714475, 790.42086539], 12 | [1423.29036457, 1098.97244214, 1572.51428681, 1004.35647371, 1750.47388421], 13 | [122.3020243, 650.13223042, 1657.7601793, 1587.3227742, 1039.56779718]] 14 | reward = [0.4] 15 | goals4cam = [[1, 1, 1, 1, 1], 16 | [1, 1, 1, 1, 1], 17 | [1, 1, 1, 1, 1], 18 | [1, 1, 1, 1, 1]] 19 | 20 | import math 21 | import os 22 | import numpy as np 23 | import matplotlib 24 | import matplotlib.pyplot as plt 25 | from datetime import datetime 26 | import matplotlib.patches as mpatches 27 | from matplotlib.patches import Circle 28 | visual_distance = 100 29 | 30 | 31 | def render(camera_pos, target_pos, 32 | obstacle_pos=None, goal=None, comm_edges=None, obstacle_radius=None, save=False, visible=None): 33 | camera_pos = np.array(camera_pos) 34 | target_pos = np.array(target_pos) 35 | 36 | camera_pos[:, :2] /= 1000.0 37 | target_pos[:, :2] /= 1000.0 38 | 39 | length = 600 40 | area_length = 1 # for random cam loc 41 | target_pos[:, :2] = (target_pos[:, :2] + 1) / 2 42 | camera_pos[:, :2] = (camera_pos[:, :2] + 1) / 2 43 | 44 | img = np.zeros((length + 1, length + 1, 3)) + 255 45 | num_cam = len(camera_pos) 46 | camera_position = [camera_pos[i][:2] for i in range(num_cam)] 47 | camera_position = length * (1 - np.array(camera_position) / area_length) / 2 48 | abs_angles = [camera_pos[i][2] * -1 for i in range(num_cam)] 49 | 50 | num_target = len(target_pos) 51 | target_position = [target_pos[i][:2] for i in range(num_target)] 52 | target_position = length * (1 - np.array(target_position) / area_length) / 2 53 | 54 | if obstacle_pos.shape != () : 55 | num_obstacle = len(obstacle_pos) 56 | obstacle_pos = np.array(obstacle_pos) 57 | obstacle_pos[:, :2] /= 1000.0 58 | obstacle_pos[:, :2] = (obstacle_pos[:, :2] + 1) / 2 59 | obstacle_position = [obstacle_pos[i][:2] for i in range(num_obstacle)] 60 | obstacle_position = length * (1 - np.array(obstacle_position) / area_length) / 2 61 | 62 | fig = plt.figure(0) 63 | plt.cla() 64 | plt.imshow(img.astype(np.uint8)) 65 | 66 | # get camera's view space positions 67 | visua_len = 100 # length of arrow 68 | L = 120 # length of arrow 69 | ax = plt.gca() 70 | # obstacle 71 | if obstacle_pos.shape != () : 72 | for i in range(num_obstacle): 73 | disk_obs = plt.Circle((obstacle_position[i][0] + visua_len, obstacle_position[i][1] + visua_len), obstacle_radius[i] * L/800, color='grey', fill=True) 74 | ax.add_artist(disk_obs) 75 | plt.annotate(str(i + 1), xy=(obstacle_position[i][0] + visua_len, obstacle_position[i][1] + visua_len), 76 | xytext=(obstacle_position[i][0] + visua_len, obstacle_position[i][1] + visua_len), fontsize=10, 77 | color='black') 78 | 79 | for i in range(num_cam): 80 | # drawing the visible area of a camera 81 | # dash-circle 82 | r = L 83 | a, b = np.array(camera_position[i]) + visua_len 84 | theta = np.arange(0, 2 * np.pi, 0.01) 85 | x = a + r * np.cos(theta) 86 | y = b + r * np.sin(theta) 87 | plt.plot(x, y, linestyle=' ', 88 | linewidth=1, 89 | color='steelblue', 90 | dashes=(6, 5.), 91 | dash_capstyle='round', 92 | alpha=0.9) 93 | 94 | # fill circle 95 | disk1 = plt.Circle((a, b), r, color='steelblue', fill=True, alpha=0.05) 96 | ax.add_artist(disk1) 97 | # 98 | 99 | for i in range(num_cam): 100 | theta = abs_angles[i] # -90 101 | theta -= 90 102 | the1 = theta - 45 103 | the2 = theta + 45 104 | 105 | a = camera_position[i][0] + visua_len 106 | b = camera_position[i][1] + visua_len 107 | wedge = mpatches.Wedge((a, b), L, the1*-1, the2*-1+180, color='green', alpha=0.2) # drawing the current sector that the camera is monitoring 108 | # print(i, the1*-1, the2*-1) 109 | ax.add_artist(wedge) 110 | 111 | disk1 = plt.Circle((camera_position[i][0] + visua_len, camera_position[i][1] + visua_len), 4, color='navy', fill=True) # drawing the camera 112 | ax.add_artist(disk1) 113 | plt.annotate(str(i + 1), xy=(camera_position[i][0] + visua_len, camera_position[i][1] + visua_len), 114 | xytext=(camera_position[i][0] + visua_len, camera_position[i][1] + visua_len), fontsize=10, 115 | color='black') 116 | 117 | # draw the communication edges 118 | if comm_edges is not None: 119 | edges = (comm_edges.numpy()).reshape(num_cam, num_cam) 120 | # print the edge matrix and the sum of edges 121 | edge_cnt = np.sum(edges) 122 | plt.text(600,470, 'Total {} Comm Edges :'.format(edge_cnt), color="black") 123 | for i in range(num_cam): 124 | for j in range(num_cam): 125 | edge = edges[i][j] 126 | if edge: 127 | x,y = np.array(camera_position[i]) + visua_len 128 | x_target, y_target = np.array(camera_position[j]) + visua_len 129 | dx,dy = x_target-x, y_target-y 130 | ax.arrow(x, y, dx, dy, head_width=15, head_length=15, fc='y', ec='k') 131 | plt.text(600, 500 + i * 30, str(edges[i])) 132 | 133 | plt.text(5, 5, '{} sensors & {} targets'.format(num_cam, num_target), color="black") 134 | 135 | for i in range(num_target): 136 | c = 'firebrick' 137 | for j in range(num_cam): 138 | if visible[j][i]: 139 | c = 'yellow' 140 | 141 | plt.plot(target_position[i][0] + visua_len, target_position[i][1] + visua_len, color=c, 142 | marker="o") 143 | plt.annotate(str(i + 1), xy=(target_position[i][0] + visua_len, target_position[i][1] + visua_len), 144 | xytext=(target_position[i][0] + visua_len, target_position[i][1] + visua_len), fontsize=10, 145 | color='maroon') 146 | 147 | if goal is not None: 148 | plt.text(400, 470, 'Goals:') 149 | for i in range(len(goal)): 150 | tmp = np.zeros(len(goal[i])) 151 | tmp[goal[i] > 0.5] = 1 152 | plt.text(400, 500 + i * 30, str(tmp)) 153 | 154 | plt.axis('off') 155 | # plt.show() 156 | if save: 157 | file_path = '../demo/img' 158 | file_name = '{}.jpg'.format(datetime.now()) 159 | if not os.path.exists(file_path): 160 | os.makedirs(file_path) 161 | plt.savefig(os.path.join(file_path, file_name)) 162 | plt.pause(0.01) 163 | 164 | 165 | if __name__ == '__main__': 166 | render(Cam_Pose, Target_Pose, reward, np.array(goals4cam)) -------------------------------------------------------------------------------- /MSMTC/DigitalPose2DBase/PoseEnvLarge_multi.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name": "PoseEnv-v0", 3 | "max_steps": 100, 4 | "cam_number": 4, 5 | "visual_distance": 800, 6 | "rotation_scale": 5, 7 | "target_number": 5, 8 | "discrete_actions": [ 9 | [0], 10 | [1], 11 | [-1] 12 | ], 13 | "safe_start" :[ 14 | [ 0, 0] 15 | ], 16 | "reset_area" : [-1250, 1250, -1250, 1250], 17 | "cam_area" : [ 18 | [250, 750, 250, 750], 19 | [250, 750, -750, -250], 20 | [-1250, -750, -250, 250], 21 | [-750, -250, 250, 750], 22 | [750, 1250, -250, 250], 23 | [-750, -250, -750, -250] 24 | ], 25 | "continous_actions_player": { 26 | "high": [100, 30], 27 | "low": [50, -30] 28 | } 29 | 30 | } -------------------------------------------------------------------------------- /MSMTC/DigitalPose2DBase/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from MSMTC.DigitalPose2DBase.pose_env_base import Pose_Env_Base 3 | 4 | 5 | class Gym: 6 | def make(self, env_id, render_save): 7 | reset_type = env_id.split('-v')[1] 8 | env = Pose_Env_Base(int(reset_type),render_save=render_save) 9 | return env 10 | 11 | 12 | gym = Gym() 13 | -------------------------------------------------------------------------------- /MSMTC/DigitalPose2DBase/pose_env_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | from gym import spaces 6 | 7 | from MSMTC.DigitalPose2D.render import render 8 | from main import parser 9 | 10 | args = parser.parse_args() 11 | 12 | 13 | class Pose_Env_Base: 14 | def __init__(self, reset_type, 15 | nav='Goal', # Random, Goal 16 | config_path="PoseEnvLarge_multi.json", 17 | render_save=False, 18 | setting_path=None 19 | ): 20 | 21 | self.nav = nav 22 | self.reset_type = reset_type 23 | self.ENV_PATH = 'MSMTC/DigitalPose2DBase' 24 | 25 | if setting_path: 26 | self.SETTING_PATH = setting_path 27 | else: 28 | self.SETTING_PATH = os.path.join(self.ENV_PATH, config_path) 29 | with open(self.SETTING_PATH, encoding='utf-8') as f: 30 | setting = json.load(f) 31 | 32 | self.env_name = setting['env_name'] 33 | self.cam_id = setting['cam_id'] 34 | self.n = setting['cam_number'] 35 | self.discrete_actions = setting['discrete_actions'] 36 | self.cam_area = np.array(setting['cam_area']) 37 | 38 | self.num_target = setting['target_number'] 39 | self.continous_actions_player = setting['continous_actions_player'] 40 | self.reset_area = setting['reset_area'] 41 | 42 | self.max_steps = setting['max_steps'] 43 | self.visual_distance = setting['visual_distance'] 44 | self.safe_start = setting['safe_start'] 45 | self.start_area = self.get_start_area(self.safe_start[0], self.visual_distance // 2) 46 | 47 | # define action space 48 | self.action_space = [spaces.Discrete(len(self.discrete_actions)) for i in range(self.n)] 49 | self.rotation_scale = setting['rotation_scale'] 50 | 51 | # define observation space 52 | self.state_dim = 2 + 2 53 | self.observation_space = np.zeros((self.n, self.num_target, self.state_dim), int) 54 | 55 | self.render_save = render_save 56 | self.render = args.render 57 | 58 | self.cam = dict() 59 | for i in range(self.n): 60 | self.cam[i] = dict( 61 | location=[0, 0], 62 | rotation=[0], 63 | ) 64 | 65 | self.count_steps = 0 66 | self.goals4cam = np.ones([self.n, self.num_target]) 67 | 68 | # construct target_agent 69 | if 'Goal' in self.nav: 70 | self.random_agents = [GoalNavAgent(i, self.continous_actions_player, self.reset_area) 71 | for i in range(self.num_target)] 72 | 73 | def set_location(self, cam_id, loc): 74 | self.cam[cam_id]['location'] = loc 75 | 76 | def get_location(self, cam_id): 77 | return self.cam[cam_id]['location'] 78 | 79 | def set_rotation(self, cam_id, rot): 80 | for i in range(len(rot)): 81 | if rot[i] > 180: 82 | rot[i] -= 360 83 | if rot[i] < -180: 84 | rot[i] += 360 85 | self.cam[cam_id]['rotation'] = rot 86 | 87 | def get_rotation(self, cam_id): 88 | return self.cam[cam_id]['rotation'] 89 | 90 | def get_hori_direction(self, current_pose, target_pose): 91 | y_delt = target_pose[1] - current_pose[1] 92 | x_delt = target_pose[0] - current_pose[0] 93 | angle_now = np.arctan2(y_delt, x_delt) / np.pi * 180 - current_pose[2] 94 | if angle_now > 180: 95 | angle_now -= 360 96 | if angle_now < -180: 97 | angle_now += 360 98 | return angle_now 99 | 100 | def get_distance(self, current_pose, target_pose): 101 | y_delt = target_pose[1] - current_pose[1] 102 | x_delt = target_pose[0] - current_pose[0] 103 | d = np.sqrt(y_delt * y_delt + x_delt * x_delt) 104 | return d 105 | 106 | def reset(self): 107 | 108 | # reset targets 109 | self.target_pos_list = np.array([[ 110 | float(np.random.randint(self.start_area[0], self.start_area[1])), 111 | float(np.random.randint(self.start_area[2], self.start_area[3]))] for _ in range(self.num_target)]) 112 | # reset agent 113 | for i in range(len(self.random_agents)): 114 | if 'Goal' in self.nav: 115 | self.random_agents[i].reset() 116 | 117 | # reset camera 118 | camera_id_list = [i for i in range(self.n)] 119 | random.shuffle(camera_id_list) 120 | 121 | for i in range(self.n): 122 | cam_loc = [np.random.randint(self.cam_area[i][0], self.cam_area[i][1]), 123 | np.random.randint(self.cam_area[i][2], self.cam_area[i][3]) 124 | ] 125 | self.set_location(camera_id_list[i], cam_loc) # shuffle 126 | 127 | for cam_i in range(self.n): 128 | cam_loc = self.get_location(cam_i) 129 | cam_rot = self.get_rotation(cam_i) 130 | 131 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[cam_i]) 132 | cam_rot[0] += angle_h 133 | 134 | self.set_rotation(cam_i, cam_rot) 135 | 136 | self.count_steps = 0 137 | self.goals4cam = np.ones([self.n, self.num_target]) 138 | 139 | info = dict( 140 | Done=False, 141 | Reward=[0 for i in range(self.n)], 142 | Target_Pose=[], 143 | Cam_Pose=[], 144 | Steps=self.count_steps 145 | ) 146 | 147 | gt_directions = [] 148 | gt_distance = [] 149 | cam_info = [] 150 | for cam_i in range(self.n): 151 | # for target navigation 152 | cam_loc = self.get_location(cam_i) 153 | cam_rot = self.get_rotation(cam_i) 154 | cam_info.append([cam_loc, cam_rot]) 155 | gt_directions.append([]) 156 | gt_distance.append([]) 157 | for j in range(self.num_target): 158 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 159 | gt_directions[cam_i].append([angle_h]) 160 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 161 | gt_distance[cam_i].append(d) 162 | 163 | info['Cam_Pose'].append(cam_loc + cam_rot) 164 | 165 | info['Directions'] = np.array(gt_directions) 166 | info['Distance'] = np.array(gt_distance) 167 | info['Target_Pose'] = np.array(self.target_pos_list) # copy.deepcopy 168 | info['Reward'], info['Global_reward'], others = self.multi_reward(cam_info, self.goals4cam) 169 | if others: 170 | info['Camera_target_dict'] = self.Camera_target_dict = others['Camera_target_dict'] 171 | info['Target_camera_dict'] = self.Target_camera_dict = others['Target_camera_dict'] 172 | 173 | state, self.state_dim = self.preprocess_pose(info) 174 | return state 175 | 176 | def step(self, actions): 177 | 178 | info = dict( 179 | Done=False, 180 | Reward=[0 for i in range(self.n)], 181 | Target_Pose=[], 182 | Cam_Pose=[], 183 | Steps=self.count_steps 184 | ) 185 | 186 | actions = np.squeeze(actions) # [num_cam, action_dim] 187 | 188 | # actions for cameras 189 | actions2cam = [] 190 | for i in range(self.n): 191 | actions2cam.append(self.discrete_actions[actions[i]]) # delta_yaw, delta_pitch 192 | 193 | # target move 194 | step = 10 195 | if 'Random' in self.nav: 196 | for i in range(self.num_target): 197 | self.target_pos_list[i][:3] += [np.random.randint(-1 * step, step), 198 | np.random.randint(-1 * step, step)] 199 | elif 'Goal' in self.nav: 200 | delta_time = 0.3 201 | for i in range(self.num_target): # only one 202 | loc = list(self.target_pos_list[i]) 203 | action = self.random_agents[i].act(loc) 204 | 205 | target_hpr_now = np.array(action[1:]) 206 | delta_x = target_hpr_now[0] * action[0] * delta_time 207 | delta_y = target_hpr_now[1] * action[0] * delta_time 208 | while loc[0] + delta_x < self.reset_area[0] or loc[0] + delta_x > self.reset_area[1] or \ 209 | loc[1] + delta_y < self.reset_area[2] or loc[1] + delta_y > self.reset_area[3]: 210 | action = self.random_agents[i].act(loc) 211 | 212 | target_hpr_now = np.array(action[1:]) 213 | delta_x = target_hpr_now[0] * action[0] * delta_time 214 | delta_y = target_hpr_now[1] * action[0] * delta_time 215 | 216 | self.target_pos_list[i][0] += delta_x 217 | self.target_pos_list[i][1] += delta_y 218 | 219 | # camera move 220 | for cam_i in range(self.n): 221 | cam_rot = self.get_rotation(cam_i) 222 | cam_rot[0] += actions2cam[cam_i][0] * self.rotation_scale 223 | self.set_rotation(cam_i, cam_rot) 224 | 225 | cam_info = [] 226 | for cam_i in range(self.n): 227 | cam_loc = self.get_location(cam_i) 228 | cam_rot = self.get_rotation(cam_i) 229 | cam_info.append([cam_loc, cam_rot]) 230 | 231 | # r: every camera complete its goal; [camera_num] 232 | # gr: coverage rate; [1] 233 | r, gr, others = self.multi_reward(cam_info, self.goals4cam) 234 | # cost by rotation 235 | for cam_i in range(self.n): 236 | if actions[cam_i] != 0: 237 | r[cam_i] += -0.01 238 | 239 | if others: 240 | info['Coverage_rate'] = others['Coverage_rate'] 241 | info['Camera_target_dict'] = self.Camera_target_dict = others['Camera_target_dict'] 242 | info['Target_camera_dict'] = self.Target_camera_dict = others['Target_camera_dict'] 243 | info['Camera_local_goal'] = others['Camera_local_goal'] 244 | 245 | info['Reward'] = np.array(r) 246 | info['Global_reward'] = np.array(gr) 247 | 248 | gt_directions = [] 249 | gt_distance = [] 250 | for cam_i in range(self.n): 251 | # for target navigation 252 | cam_loc = self.get_location(cam_i) 253 | cam_rot = self.get_rotation(cam_i) 254 | gt_directions.append([]) 255 | gt_distance.append([]) 256 | for j in range(self.num_target): 257 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 258 | gt_directions[cam_i].append([angle_h]) 259 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 260 | gt_distance[cam_i].append(d) 261 | 262 | info['Cam_Pose'].append(self.get_location(cam_i) + self.get_rotation(cam_i)) 263 | 264 | info['Target_Pose'] = np.array(self.target_pos_list) # copy.deepcopy 265 | info['Distance'] = np.array(gt_distance) 266 | info['Directions'] = np.array(gt_directions) 267 | 268 | self.count_steps += 1 269 | 270 | # set your done condition 271 | if self.count_steps > self.max_steps: 272 | info['Done'] = True 273 | 274 | reward = info['Reward'] 275 | 276 | if self.render: 277 | render(info['Cam_Pose'], np.array(self.target_pos_list), goal=self.goals4cam, save=self.render_save) 278 | 279 | if self.count_steps % 10 == 0: 280 | self.reset_goalmap(info['Distance']) 281 | state, self.state_dim = self.preprocess_pose(info) 282 | return state, reward, info['Done'], info 283 | 284 | def reset_goalmap(self, distances): 285 | for cam_i in range(self.n): 286 | self.goals4cam[cam_i] = list(map(int, distances[cam_i] <= self.visual_distance)) 287 | 288 | def close(self): 289 | pass 290 | 291 | def seed(self, para): 292 | pass 293 | 294 | def get_start_area(self, safe_start, safe_range): 295 | start_area = [safe_start[0] - safe_range, safe_start[0] + safe_range, 296 | safe_start[1] - safe_range, safe_start[1] + safe_range] 297 | return start_area 298 | 299 | def angle_reward(self, angle_h, d): 300 | hori_reward = 1 - abs(angle_h) / 45.0 301 | visible = hori_reward > 0 and d <= self.visual_distance 302 | if visible: 303 | reward = np.clip(hori_reward, -1, 1) 304 | else: 305 | reward = -1 306 | return reward, visible 307 | 308 | def multi_reward(self, cam_info, goals4cam): 309 | # generate reward 310 | camera_local_rewards = [] 311 | camera_local_goal = [] 312 | 313 | camera_target_dict = {} 314 | target_camera_dict = {} 315 | captured_targets = [] 316 | camera_target_reward = [] 317 | coverage_rate = [] 318 | for cam_i in range(self.n): 319 | cam_loc, cam_rot = cam_info[cam_i] 320 | camera_target_dict[cam_i] = [] 321 | local_rewards = [] 322 | camera_target_reward.append([]) 323 | captured_num = 0 324 | goal_num = 0 325 | for j in range(self.num_target): 326 | if not target_camera_dict.get(j): 327 | target_camera_dict[j] = [] 328 | angle_h = self.get_hori_direction(cam_loc + cam_rot, self.target_pos_list[j]) 329 | d = self.get_distance(cam_loc + cam_rot, self.target_pos_list[j]) 330 | reward, visible = self.angle_reward(angle_h, d) 331 | if visible: 332 | camera_target_dict[cam_i].append(j) 333 | target_camera_dict[j].append(cam_i) 334 | coverage_rate.append(j) 335 | if goals4cam is None or goals4cam[cam_i][j] > 0 or self.reset_type == 1: 336 | captured_targets.append(j) 337 | captured_num += 1 338 | 339 | if goals4cam is None and visible or goals4cam is not None and goals4cam[cam_i][j] > 0: 340 | local_rewards.append(reward) 341 | goal_num += 1 342 | camera_target_reward[cam_i].append(reward) 343 | camera_local_goal.append(captured_num / goal_num if goal_num != 0 else -1) 344 | camera_local_rewards.append(np.mean(local_rewards) if len(local_rewards) > 0 else 0) 345 | camera_local = camera_local_rewards 346 | 347 | # real coverage rate 348 | coverage_rate = len(set(coverage_rate)) / self.num_target 349 | 350 | camera_global_reward = [coverage_rate] # 1)reward: [-1, 1], coverage 351 | if len(set(captured_targets)) == 0: 352 | camera_global_reward = [-0.1] 353 | 354 | return camera_local, camera_global_reward, {'Camera_target_dict': camera_target_dict, 355 | 'Target_camera_dict': target_camera_dict, 356 | 'Coverage_rate': coverage_rate, 357 | 'Captured_targetsN': len(set(captured_targets)), 358 | 'Camera_local_goal': camera_local_goal 359 | } 360 | 361 | def preprocess_pose(self, info): 362 | cam_pose_info = np.array(info['Cam_Pose']) 363 | target_pose_info = np.array(info['Target_Pose']) 364 | angles = info['Directions'] 365 | distances = info['Distance'] 366 | 367 | camera_num = len(cam_pose_info) 368 | target_num = len(target_pose_info) 369 | 370 | # normalize center 371 | center = np.mean(cam_pose_info[:, :2], axis=0) 372 | cam_pose_info[:, :2] -= center 373 | if target_pose_info is not None: 374 | target_pose_info[:, :2] -= center 375 | 376 | # scale 377 | norm_d = int(max(np.linalg.norm(cam_pose_info[:, :2], axis=1, ord=2))) + 1e-8 378 | cam_pose_info[:, :2] /= norm_d 379 | if target_pose_info is not None: 380 | target_pose_info[:, :2] /= norm_d 381 | 382 | state_dim = 4 383 | feature_dim = target_num * state_dim 384 | state = np.zeros((camera_num, feature_dim)) 385 | for cam_i in range(camera_num): 386 | target_isSelected_list = self.goals4cam[cam_i] 387 | # target info 388 | target_info = [] 389 | for target_j in range(target_num): 390 | if self.reset_type == 0 and target_isSelected_list[target_j] == 0: 391 | continue 392 | [angle_h] = angles[cam_i, target_j] 393 | target_angle = [cam_i / camera_num, target_j / target_num, angle_h / 180] 394 | line = target_angle + [distances[cam_i, target_j] / 2000] # 2000 is related with the area of cameras 395 | target_info += line 396 | target_info = target_info + [0] * (feature_dim - len(target_info)) 397 | state[cam_i] = target_info 398 | state = state.reshape((camera_num, target_num, state_dim)) 399 | return state, state_dim 400 | 401 | 402 | class GoalNavAgent(object): 403 | 404 | def __init__(self, id, action_space, goal_area, goal_list=None): 405 | self.id = id 406 | self.step_counter = 0 407 | self.keep_steps = 0 408 | self.goal_id = 0 409 | self.velocity_high = action_space['high'][0] 410 | self.velocity_low = action_space['low'][0] 411 | self.angle_high = action_space['high'][1] 412 | self.angle_low = action_space['low'][1] 413 | self.goal_area = goal_area 414 | self.goal_list = goal_list 415 | self.goal = self.generate_goal(self.goal_area) 416 | 417 | self.max_len = 100 418 | 419 | def act(self, pose): 420 | self.step_counter += 1 421 | if len(self.pose_last[0]) == 0: 422 | self.pose_last[0] = np.array(pose) 423 | self.pose_last[1] = np.array(pose) 424 | d_moved = 30 425 | else: 426 | d_moved = min(np.linalg.norm(np.array(self.pose_last[0]) - np.array(pose)), 427 | np.linalg.norm(np.array(self.pose_last[1]) - np.array(pose))) 428 | self.pose_last[0] = np.array(self.pose_last[1]) 429 | self.pose_last[1] = np.array(pose) 430 | if self.check_reach(self.goal, pose) or d_moved < 10 or self.step_counter > self.max_len: 431 | self.goal = self.generate_goal(self.goal_area) 432 | self.velocity = np.random.randint(self.velocity_low, self.velocity_high) 433 | 434 | self.step_counter = 0 435 | 436 | delt_unit = (self.goal[:2] - pose[:2]) / np.linalg.norm(self.goal[:2] - pose[:2]) 437 | velocity = self.velocity * (1 + 0.2 * np.random.random()) 438 | return [velocity, delt_unit[0], delt_unit[1]] 439 | 440 | def reset(self): 441 | self.step_counter = 0 442 | self.keep_steps = 0 443 | self.goal_id = 0 444 | self.goal = self.generate_goal(self.goal_area) 445 | self.velocity = np.random.randint(self.velocity_low, self.velocity_high) 446 | self.pose_last = [[], []] 447 | 448 | def generate_goal(self, goal_area): 449 | if self.goal_list and len(self.goal_list) != 0: 450 | index = self.goal_id % len(self.goal_list) 451 | goal = np.array(self.goal_list[index]) 452 | else: 453 | x = np.random.randint(goal_area[0], goal_area[1]) 454 | y = np.random.randint(goal_area[2], goal_area[3]) 455 | goal = np.array([x, y]) 456 | self.goal_id += 1 457 | return goal 458 | 459 | def check_reach(self, goal, now): 460 | error = np.array(now[:2]) - np.array(goal[:2]) 461 | distance = np.linalg.norm(error) 462 | return distance < 5 463 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ToM2C 2 | 3 | This repository is the offcial implementation of ToM2C, "[ToM2C: Target-oriented Multi-agent Communication and Cooperation with Theory of Mind (ICLR 2022)](https://arxiv.org/abs/2111.09189)" . 4 | 5 | ## Installation 6 | 7 | To install requirements: 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | All the environments have been included in the code, so there is no need to install Multi-sensor Multi-target Coverage(MSMTC) or MPE(Cooperative Navigation) additionally. 14 | 15 | ## Training 16 | 17 | To train ToM2C in `MSMTC`, run this command: 18 | 19 | ```bash 20 | python main.py --env MSMTC-v3 --model ToM2C --workers 6 --norm-reward 21 | ``` 22 | 23 | To train ToM2C in `CN`, run this command: 24 | 25 | ```bash 26 | python main.py --env CN --model ToM2C --num-agents 7 --num-targets 7 --workers 12 --env-steps 10 --A2C-steps 10 --norm-reward --gpu-id 0 27 | ``` 28 | 29 | Note that the command above will load the default environment described in the paper. If you want to change the number of agents and targets, please refer to the `num-agents` and `num-targets` arguments. 30 | 31 | After running the above command, you can run the following command respectively to do `Communication Reduction` mentioned in the paper: 32 | 33 | ```bash 34 | python main.py --env MSMTC-v3 --model ToM2C --workers 6 --norm-reward --train-comm --load-model-dir [trained_model_file_path] 35 | ``` 36 | 37 | The above command is for cpu training. If you want to train the model on GPU, try to add `--gpu-id [cuda_device_id]` in the command. Note that this implementation does NOT support multi-gpu training. 38 | 39 | ## Rendering 40 | 41 | After training, you can load the trained model and render its behavior by the following command. 42 | 43 | In `CN`: 44 | 45 | ```bash 46 | python render_test.py --env CN --model ToM2C --render --env-steps 10 --load-model-dir [trained_model_file_path] 47 | ``` 48 | 49 | In `MSMTC`: 50 | 51 | ```bash 52 | python render_test.py --env MSMTC-v3 --model ToM2C --render --env-steps 20 --load-model-dir [trained_model_file_path] 53 | ``` 54 | 55 | ## Citation 56 | 57 | If you found ToM2C useful, please consider citing: 58 | ``` 59 | @inproceedings{ 60 | wang2021tomc, 61 | title={ToM2C: Target-oriented Multi-agent Communication and Cooperation with Theory of Mind}, 62 | author={Yuanfei Wang and Fangwei Zhong and Jing Xu and Yizhou Wang}, 63 | booktitle={International Conference on Learning Representations}, 64 | year={2022}, 65 | url={https://openreview.net/forum?id=M3tw78MH1Bk} 66 | } 67 | ``` 68 | ## Contact 69 | 70 | If you have any suggestion or questions, please get in touch at [yuanfei_wang@pku.edu.cn](yuanfei_wang@pku.edu.cn) or [zfw@pku.edu.cn](zfw@pku.edu.cn). 71 | 72 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import time 4 | 5 | def create_env(env_id, args, rank=-1): 6 | if 'MSMTC' in env_id: 7 | import MSMTC.DigitalPose2D as poseEnv 8 | env = poseEnv.gym.make(env_id, args) 9 | # adjust env steps according to args 10 | env.max_steps = args.env_steps 11 | return env 12 | elif 'CN' in env_id: 13 | from multiagent.environment import MultiAgentEnv 14 | import multiagent.scenarios as scenarios 15 | scenario_name = args.env 16 | # load scenario from script 17 | scenario = scenarios.load(scenario_name + ".py").Scenario() 18 | # create world 19 | world = scenario.make_world(args.num_agents, args.num_targets) 20 | # create multiagent environment 21 | env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation) 22 | env_wrap = env_wrapper(env, args) 23 | return env_wrap 24 | else: 25 | raise NotImplementedError 26 | 27 | class env_wrapper: 28 | # wrap for CN low level execution 29 | def __init__(self,env,args): 30 | self.env = env 31 | self.n = self.env.n_agents 32 | self.num_target = len(self.env.world.landmarks) 33 | self.observation_space = np.zeros([self.n, self.num_target, 2]) 34 | self.action_space = np.zeros([self.n,self.num_target,1]) 35 | self.max_steps = args.env_steps 36 | self.render = args.render 37 | 38 | def rule_policy(self,obs): 39 | x_rel = obs[0] 40 | y_rel = obs[1] 41 | if max(abs(x_rel),abs(y_rel)) < 0.05: 42 | action = [0] 43 | elif abs(x_rel) > abs(y_rel): 44 | if x_rel > 0: 45 | action = [2] 46 | else: 47 | action = [1] 48 | else: 49 | if y_rel > 0: 50 | action = [4] 51 | else: 52 | action = [3] 53 | action = np.array(action) 54 | return action 55 | 56 | def step(self, goals_n): 57 | #print(goals_n) 58 | goals_n = np.squeeze(goals_n) 59 | keep = 10 60 | rew_ave = 0 61 | for step in range(keep): 62 | # get low level obs 63 | act_low_n = [] 64 | for i in range(self.n): 65 | goal = int(goals_n[i]) 66 | land_goal = self.env.world.landmarks[goal] 67 | agent = self.env.world.agents[i] 68 | entity_pos = [(land_goal.state.p_pos - agent.state.p_pos)] 69 | obs_low = np.concatenate(entity_pos) 70 | act_low_n.append(self.rule_policy(obs_low)) 71 | 72 | obs_n, rew, done_n, info_n = self.env.step(act_low_n) 73 | if self.render: 74 | self.env.render() 75 | time.sleep(0.1) 76 | rew_ave += rew[0] 77 | rew_all = np.array([rew_ave/keep]) 78 | return obs_n, rew_all, done_n, info_n 79 | 80 | def reset(self): 81 | obs_n = self.env.reset() 82 | return obs_n 83 | 84 | def seed(self, s): 85 | self.env.seed(s) 86 | 87 | def close(self): 88 | self.env.close() 89 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import time 4 | import torch 5 | import argparse 6 | from datetime import datetime 7 | import torch.multiprocessing as mp 8 | 9 | from test import test 10 | from train import train 11 | from worker import worker 12 | #from train_new import Policy_train 13 | from model import build_model 14 | from environment import create_env 15 | from shared_optim import SharedRMSprop, SharedAdam 16 | 17 | os.environ["OMP_NUM_THREADS"] = "1" 18 | 19 | parser = argparse.ArgumentParser(description='A3C') 20 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.0001)') 21 | parser.add_argument('--gamma', type=float, default=0.1, metavar='G', help='discount factor for rewards (default: 0.99)') 22 | parser.add_argument('--gamma-rate', type=float, default=0.002, metavar='G', help='the increase rate of gamma') 23 | parser.add_argument('--gamma-final', type=float, default=0.9, metavar='G', help='the increase rate of gamma') 24 | parser.add_argument('--tau', type=float, default=1.00, metavar='T', help='parameter for GAE (default: 1.00)') 25 | parser.add_argument('--entropy', type=float, default=0.005, metavar='T', help='parameter for entropy (default: 0.01)') 26 | parser.add_argument('--grad-entropy', type=float, default=1.0, metavar='T', help='parameter for entropy (default: 0.01)') 27 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 28 | parser.add_argument('--workers', type=int, default=1, metavar='W', help='how many training processes to use (default: 32)') 29 | parser.add_argument('--A2C-steps', type=int, default=20, metavar='NS', help='number of forward steps in A2C (default: 300)') 30 | parser.add_argument('--env-steps', type=int, default=20, metavar='NS', help='number of steps in one env episode') 31 | parser.add_argument('--start-eps', type=int, default=2000, metavar='NS', help='number of episodes before increasing gamma and env steps') 32 | parser.add_argument('--ToM-train-loops', type=int, default=1, metavar='NS', help='ToM training loops num') 33 | parser.add_argument('--policy-train-loops', type=int, default=1, metavar='NS', help='Policy training loops num') 34 | parser.add_argument('--test-eps', type=int, default=20, metavar='M', help='testing episodes') 35 | parser.add_argument('--ToM-frozen', type=int, default=5, metavar='M', help='episode length of freezing ToM in training') 36 | parser.add_argument('--env', default='MSMTC-v3', help='environment to train on') 37 | parser.add_argument('--optimizer', default='Adam', metavar='OPT', help='shares optimizer choice of Adam or RMSprop') 38 | parser.add_argument('--amsgrad', default=True, metavar='AM', help='Adam optimizer amsgrad parameter') 39 | parser.add_argument('--load-model-dir', default=None, metavar='LMD', help='folder to load trained models from') 40 | parser.add_argument('--load-executor-dir', default=None, metavar='LMD', help='folder to load trained low-level policy models from') 41 | parser.add_argument('--log-dir', default='logs/', metavar='LG', help='folder to save logs') 42 | parser.add_argument('--model', default='ToM2C', metavar='M', help='ToM2C') 43 | parser.add_argument('--gpu-id', type=int, default=-1, nargs='+', help='GPU to use [-1 CPU only] (default: -1)') 44 | parser.add_argument('--norm-reward', dest='norm_reward', action='store_true', default='True', help='normalize reward') 45 | parser.add_argument('--train-comm', dest='train_comm', action='store_true', help='train comm') 46 | parser.add_argument('--random-target', dest='random_target', action='store_true', default='True', help='random target in MSMTC') 47 | parser.add_argument('--mask-actions', dest='mask_actions', action='store_true', help='mask unavailable actions to boost training') 48 | parser.add_argument('--mask', dest='mask', action='store_true', help='mask ToM and communication to those out of range') 49 | parser.add_argument('--render', dest='render', action='store_true', help='render test') 50 | parser.add_argument('--fix', dest='fix', action='store_true', help='fix random seed') 51 | parser.add_argument('--shared-optimizer', dest='shared_optimizer', action='store_true', help='use an optimizer without shared statistics.') 52 | parser.add_argument('--train-mode', type=int, default=-1, metavar='TM', help='his') 53 | parser.add_argument('--lstm-out', type=int, default=32, metavar='LO', help='lstm output size') 54 | parser.add_argument('--sleep-time', type=int, default=0, metavar='LO', help='seconds') 55 | parser.add_argument('--max-step', type=int, default=3000000, metavar='LO', help='max learning steps') 56 | parser.add_argument('--render_save', dest='render_save', action='store_true', help='render save') 57 | 58 | parser.add_argument('--num-agents', type=int, default=-1) # if -1, then the env will load the default setting 59 | parser.add_argument('--num-targets', type=int, default=-1) # else, you can assign the number of agents and targets yourself 60 | 61 | # num_step: 20 62 | # max_step: 500000 63 | # env_max_step: 100 64 | # low-level step: 10 65 | # training mode: -1 for worker collecting trajectories, -10 for workers waiting for training process, -20 for training, -100 for all processes end 66 | 67 | def start(): 68 | args = parser.parse_args() 69 | args.shared_optimizer = True 70 | if args.gamma_rate == 0: 71 | args.gamma = 0.9 72 | args.env_steps *= 5 73 | if args.gpu_id == -1: 74 | torch.manual_seed(args.seed) 75 | args.gpu_id = [-1] 76 | device_share = torch.device('cpu') 77 | mp.set_start_method('spawn') 78 | else: 79 | torch.cuda.manual_seed(args.seed) 80 | mp.set_start_method('spawn', force=True) 81 | if len(args.gpu_id) > 1: 82 | raise AssertionError("Do not support multi-gpu training") 83 | #device_share = torch.device('cpu') 84 | else: 85 | device_share = torch.device('cuda:' + str(args.gpu_id[-1])) 86 | #device_share = torch.device('cuda:0') 87 | env = create_env(args.env, args) 88 | assert env.max_steps % args.A2C_steps == 0 89 | shared_model = build_model(env, args, device_share).to(device_share) 90 | shared_model.share_memory() 91 | shared_model.train() 92 | env.close() 93 | del env 94 | 95 | if args.load_model_dir is not None: 96 | saved_state = torch.load( 97 | args.load_model_dir, 98 | map_location=lambda storage, loc: storage) 99 | if args.load_model_dir[-3:] == 'pth': 100 | shared_model.load_state_dict(saved_state['model'], strict=False) 101 | else: 102 | shared_model.load_state_dict(saved_state) 103 | 104 | #params = shared_model.parameters() 105 | params = [] 106 | params_ToM = [] 107 | for name, param in shared_model.named_parameters(): 108 | if 'ToM' in name or 'other' in name: 109 | #print("ToM: ",name) 110 | params_ToM.append(param) 111 | else: 112 | #print("Not ToM: ",name) 113 | params.append(param) 114 | 115 | if args.shared_optimizer: 116 | print('share memory') 117 | if args.optimizer == 'RMSprop': 118 | optimizer_Policy = SharedRMSprop(params, lr=args.lr) 119 | if 'ToM' in args.model: 120 | optimizer_ToM = SharedRMSprop(params_ToM, lr=args.lr) 121 | else: 122 | optimizer_ToM = None 123 | if args.optimizer == 'Adam': 124 | optimizer_Policy = SharedAdam(params, lr=args.lr, amsgrad=args.amsgrad) 125 | if 'ToM' in args.model: 126 | print("ToM optimizer lr * 10") 127 | optimizer_ToM = SharedAdam(params_ToM, lr=args.lr*10, amsgrad=args.amsgrad) 128 | else: 129 | optimizer_ToM = None 130 | optimizer_Policy.share_memory() 131 | if optimizer_ToM is not None: 132 | optimizer_ToM.share_memory() 133 | else: 134 | optimizer_Policy = None 135 | optimizer_ToM = None 136 | 137 | current_time = datetime.now().strftime('%b%d_%H-%M') 138 | args.log_dir = os.path.join(args.log_dir, args.env, current_time) 139 | 140 | processes = [] 141 | manager = mp.Manager() 142 | train_modes = manager.list() 143 | n_iters = manager.list() 144 | curr_env_steps = manager.list() 145 | ToM_count = manager.list() 146 | ToM_history = manager.list() 147 | Policy_history = manager.list() 148 | step_history = manager.list() 149 | loss_history = manager.list() 150 | 151 | for rank in range(0, args.workers): 152 | p = mp.Process(target=worker, args=(rank, args, shared_model, train_modes, n_iters, curr_env_steps, ToM_count, ToM_history, Policy_history, step_history, loss_history)) 153 | 154 | train_modes.append(args.train_mode) 155 | n_iters.append(0) 156 | curr_env_steps.append(args.env_steps) 157 | ToM_count.append(0) 158 | ToM_history.append([]) 159 | Policy_history.append([]) 160 | step_history.append([]) 161 | loss_history.append([]) 162 | 163 | p.start() 164 | processes.append(p) 165 | time.sleep(args.sleep_time) 166 | 167 | p = mp.Process(target=test, args=(args, shared_model, optimizer_Policy, optimizer_ToM, train_modes, n_iters)) 168 | p.start() 169 | processes.append(p) 170 | time.sleep(args.sleep_time) 171 | 172 | if args.workers > 0: 173 | # not only test 174 | p = mp.Process(target=train, args=(args, shared_model, optimizer_Policy, optimizer_ToM, train_modes, n_iters, curr_env_steps, ToM_count, ToM_history, Policy_history, step_history, loss_history)) 175 | p.start() 176 | processes.append(p) 177 | time.sleep(args.sleep_time) 178 | 179 | for p in processes: 180 | time.sleep(args.sleep_time) 181 | p.join() 182 | 183 | 184 | if __name__=='__main__': 185 | start() -------------------------------------------------------------------------------- /multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | # Multiagent envs 4 | # ---------------------------------------- 5 | 6 | register( 7 | id='MultiagentSimple-v0', 8 | entry_point='multiagent.envs:SimpleEnv', 9 | # FIXME(cathywu) currently has to be exactly max_path_length parameters in 10 | # rllab run script 11 | max_episode_steps=100, 12 | ) 13 | 14 | register( 15 | id='MultiagentSimpleSpeakerListener-v0', 16 | entry_point='multiagent.envs:SimpleSpeakerListenerEnv', 17 | max_episode_steps=100, 18 | ) 19 | -------------------------------------------------------------------------------- /multiagent/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # physical/external base state of all entites 4 | class EntityState(object): 5 | def __init__(self): 6 | # physical position 7 | self.p_pos = None 8 | # physical velocity 9 | self.p_vel = None 10 | 11 | # state of agents (including communication and internal/mental state) 12 | class AgentState(EntityState): 13 | def __init__(self): 14 | super(AgentState, self).__init__() 15 | # communication utterance 16 | self.c = None 17 | 18 | # action of the agent 19 | class Action(object): 20 | def __init__(self): 21 | # physical action 22 | self.u = None 23 | # communication action 24 | self.c = None 25 | 26 | # properties and state of physical world entity 27 | class Entity(object): 28 | def __init__(self): 29 | # name 30 | self.name = '' 31 | # properties: 32 | self.size = 0.050 33 | # entity can move / be pushed 34 | self.movable = False 35 | # entity collides with others 36 | self.collide = True 37 | # material density (affects mass) 38 | self.density = 25.0 39 | # color 40 | self.color = None 41 | # max speed and accel 42 | self.max_speed = None 43 | self.accel = None 44 | # state 45 | self.state = EntityState() 46 | # mass 47 | self.initial_mass = 1.0 48 | 49 | @property 50 | def mass(self): 51 | return self.initial_mass 52 | 53 | # properties of landmark entities 54 | class Landmark(Entity): 55 | def __init__(self): 56 | super(Landmark, self).__init__() 57 | 58 | # properties of agent entities 59 | class Agent(Entity): 60 | def __init__(self): 61 | super(Agent, self).__init__() 62 | # agents are movable by default 63 | self.movable = True 64 | # cannot send communication signals 65 | self.silent = False 66 | # cannot observe the world 67 | self.blind = False 68 | # physical motor noise amount 69 | self.u_noise = None 70 | # communication noise amount 71 | self.c_noise = None 72 | # control range 73 | self.u_range = 1.0 74 | # state 75 | self.state = AgentState() 76 | # action 77 | self.action = Action() 78 | # script behavior to execute 79 | self.action_callback = None 80 | 81 | # multi-agent world 82 | class World(object): 83 | def __init__(self): 84 | # list of agents and entities (can change at execution-time!) 85 | self.agents = [] 86 | self.landmarks = [] 87 | self.preys = [] 88 | # communication channel dimensionality 89 | self.dim_c = 0 90 | # position dimensionality 91 | self.dim_p = 2 92 | # position range 93 | self.range_p = 1 94 | # color dimensionality 95 | self.dim_color = 3 96 | # simulation timestep 97 | self.dt = 0.1 98 | # physical damping 99 | self.damping = 0.25 100 | # contact response parameters 101 | self.contact_force = 1e+2 102 | self.contact_margin = 1e-3 103 | # observation info 104 | self.num_agents_obs = 0 105 | self.num_landmarks_obs = 0 106 | self.num_preys_obs = 0 107 | 108 | # return all entities in the world 109 | @property 110 | def entities(self): 111 | return self.agents + self.preys + self.landmarks 112 | 113 | # return all agents controllable by external policies 114 | @property 115 | def policy_agents(self): 116 | return [agent for agent in self.agents if agent.action_callback is None] 117 | 118 | # return all preys controllable by external policies 119 | @property 120 | def policy_preys(self): 121 | return [prey for prey in self.preys if prey.action_callback is None] 122 | 123 | # return all agents controlled by world scripts 124 | @property 125 | def scripted_agents(self): 126 | return [agent for agent in self.agents if agent.action_callback is not None] 127 | 128 | # return all preys controlled by world scripts 129 | @property 130 | def scripted_preys(self): 131 | return [prey for prey in self.preys if prey.action_callback is not None] 132 | 133 | # update state of the world 134 | def step(self): 135 | # set actions for scripted agents and preys 136 | for agent in self.scripted_agents: 137 | agent.action = agent.action_callback(agent, self) 138 | for prey in self.scripted_preys: 139 | prey.action = prey.action_callback(prey, self) 140 | # gather forces applied to entities 141 | p_force = [None] * len(self.entities) 142 | # apply agent physical controls 143 | p_force = self.apply_action_force(p_force) 144 | # apply environment forces 145 | p_force = self.apply_environment_force(p_force) 146 | # integrate physical state 147 | self.integrate_state(p_force) 148 | # update agent and prey state 149 | for agent in self.agents: 150 | self.update_agent_state(agent) 151 | for prey in self.preys: 152 | self.update_agent_state(prey) 153 | 154 | # gather agent action forces 155 | def apply_action_force(self, p_force): 156 | # set applied forces 157 | for i,agent in enumerate(self.agents): 158 | if agent.movable: 159 | noise = np.random.randn(*agent.action.u.shape) * agent.u_noise if agent.u_noise else 0.0 160 | p_force[i] = agent.action.u + noise 161 | for j,prey in enumerate(self.preys): 162 | if prey.movable: 163 | noise = np.random.randn(*prey.action.u.shape) * prey.u_noise if prey.u_noise else 0.0 164 | p_force[j+len(self.agents)] = prey.action.u + noise 165 | return p_force 166 | 167 | # gather physical forces acting on entities 168 | def apply_environment_force(self, p_force): 169 | # simple (but inefficient) collision response 170 | for a,entity_a in enumerate(self.entities): 171 | for b,entity_b in enumerate(self.entities): 172 | if(b <= a): continue 173 | [f_a, f_b] = self.get_collision_force(entity_a, entity_b) 174 | if(f_a is not None): 175 | if(p_force[a] is None): p_force[a] = 0.0 176 | p_force[a] = f_a + p_force[a] 177 | if(f_b is not None): 178 | if(p_force[b] is None): p_force[b] = 0.0 179 | p_force[b] = f_b + p_force[b] 180 | return p_force 181 | 182 | # integrate physical state 183 | def integrate_state(self, p_force): 184 | for i,entity in enumerate(self.entities): 185 | if not entity.movable: continue 186 | entity.state.p_vel = entity.state.p_vel * (1 - self.damping) 187 | if (p_force[i] is not None): 188 | entity.state.p_vel += (p_force[i] / entity.mass) * self.dt 189 | if entity.max_speed is not None: 190 | speed = np.sqrt(np.square(entity.state.p_vel[0]) + np.square(entity.state.p_vel[1])) 191 | if speed > entity.max_speed: 192 | entity.state.p_vel = entity.state.p_vel / np.sqrt(np.square(entity.state.p_vel[0]) + 193 | np.square(entity.state.p_vel[1])) * entity.max_speed 194 | entity.state.p_pos += entity.state.p_vel * self.dt 195 | 196 | def update_agent_state(self, agent): 197 | # set communication state (directly for now) 198 | if agent.silent: 199 | agent.state.c = np.zeros(self.dim_c) 200 | else: 201 | noise = np.random.randn(*agent.action.c.shape) * agent.c_noise if agent.c_noise else 0.0 202 | agent.state.c = agent.action.c + noise 203 | 204 | # get collision forces for any contact between two entities 205 | def get_collision_force(self, entity_a, entity_b): 206 | if (not entity_a.collide) or (not entity_b.collide): 207 | return [None, None] # not a collider 208 | if (entity_a is entity_b): 209 | return [None, None] # don't collide against itself 210 | # compute actual distance between entities 211 | delta_pos = entity_a.state.p_pos - entity_b.state.p_pos 212 | dist = np.sqrt(np.sum(np.square(delta_pos))) 213 | # minimum allowable distance 214 | dist_min = entity_a.size + entity_b.size 215 | # softmax penetration 216 | k = self.contact_margin 217 | penetration = np.logaddexp(0, -(dist - dist_min)/k)*k 218 | force = self.contact_force * delta_pos / dist * penetration 219 | force_a = +force if entity_a.movable else None 220 | force_b = -force if entity_b.movable else None 221 | return [force_a, force_b] -------------------------------------------------------------------------------- /multiagent/environment.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | from gym.envs.registration import EnvSpec 4 | import numpy as np 5 | from multiagent.multi_discrete import MultiDiscrete 6 | 7 | # environment for all agents in the multiagent world 8 | # currently code assumes that no agents will be created/destroyed at runtime! 9 | class MultiAgentEnv(gym.Env): 10 | metadata = { 11 | 'render.modes' : ['human', 'rgb_array'] 12 | } 13 | 14 | def __init__(self, world, reset_callback=None, reward_callback=None, 15 | observation_callback=None, info_callback=None, 16 | done_callback=None, shared_viewer=True): 17 | 18 | self.world = world 19 | self.agents = self.world.policy_agents 20 | self.preys = self.world.policy_preys 21 | # set required vectorized gym env property 22 | self.n_agents = len(world.policy_agents) 23 | self.n_landmarks = len(world.landmarks) 24 | self.n_landmarks_obs = world.num_landmarks_obs 25 | self.n_agents_obs = world.num_agents_obs 26 | self.n_preys_obs = world.num_preys_obs 27 | # scenario callbacks 28 | self.reset_callback = reset_callback 29 | self.reward_callback = reward_callback 30 | self.observation_callback = observation_callback 31 | self.info_callback = info_callback 32 | self.done_callback = done_callback 33 | # environment parameters 34 | self.discrete_action_space = True 35 | # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector 36 | self.discrete_action_input = True 37 | # if true, even the action is continuous, action will be performed discretely 38 | self.force_discrete_action = world.discrete_action if hasattr(world, 'discrete_action') else False 39 | # if true, every agent has the same reward 40 | self.shared_reward = world.collaborative if hasattr(world, 'collaborative') else False 41 | self.range_p = world.range_p 42 | self.dim_p = world.dim_p 43 | self.time = 0 44 | 45 | # configure spaces 46 | self.action_space = [] 47 | self.observation_space = [] 48 | for agent in self.agents: 49 | total_action_space = [] 50 | # physical action space 51 | if self.discrete_action_space: 52 | u_action_space = spaces.Discrete(world.dim_p * 2 + 1) 53 | else: 54 | u_action_space = spaces.Box(low=-agent.u_range, high=+agent.u_range, shape=(world.dim_p,), dtype=np.float32) 55 | if agent.movable: 56 | total_action_space.append(u_action_space) 57 | # communication action space 58 | if self.discrete_action_space: 59 | c_action_space = spaces.Discrete(world.dim_c) 60 | else: 61 | c_action_space = spaces.Box(low=0.0, high=1.0, shape=(world.dim_c,), dtype=np.float32) 62 | if not agent.silent: 63 | total_action_space.append(c_action_space) 64 | # total action space 65 | if len(total_action_space) > 1: 66 | # all action spaces are discrete, so simplify to MultiDiscrete action space 67 | if all([isinstance(act_space, spaces.Discrete) for act_space in total_action_space]): 68 | act_space = MultiDiscrete([[0, act_space.n - 1] for act_space in total_action_space]) 69 | else: 70 | act_space = spaces.Tuple(total_action_space) 71 | self.action_space.append(act_space) 72 | else: 73 | self.action_space.append(total_action_space[0]) 74 | # observation space 75 | obs_dim = len(observation_callback(agent, self.world)) 76 | self.observation_space.append(spaces.Box(low=-np.inf, high=+np.inf, shape=(obs_dim,), dtype=np.float32)) 77 | agent.action.c = np.zeros(self.world.dim_c) 78 | 79 | # rendering 80 | self.shared_viewer = shared_viewer 81 | if self.shared_viewer: 82 | self.viewers = [None] 83 | else: 84 | self.viewers = [None] * self.n 85 | self._reset_render() 86 | 87 | def bound(self, x): 88 | d = np.zeros(2) 89 | if abs(x[0])>abs(x[1]) and x[0]<0 and abs(x[0])>0.8*self.range_p: 90 | d[0] = 2 91 | if abs(x[0])>abs(x[1]) and x[0]>0 and abs(x[0])>0.8*self.range_p: 92 | d[0] = -2 93 | if abs(x[0])0.8*self.range_p: 94 | d[1] = 2 95 | if abs(x[0])0 and abs(x[1])>0.8*self.range_p: 96 | d[1] = -2 97 | return d 98 | 99 | def step(self, action_n): 100 | obs_n = [] 101 | reward_n = [] 102 | done_n = [] 103 | info_n = {'n': []} 104 | self.agents = self.world.policy_agents 105 | # set action for each agent 106 | for i, agent in enumerate(self.agents): 107 | self._set_action(action_n[i], agent, self.action_space[i]) 108 | # set action for each prey 109 | for j, prey in enumerate(self.preys): 110 | prey_action = np.zeros(self.action_space[0].n) 111 | min_dist = 10000 112 | direction = [] 113 | # move following the oppisite direction of closest agent 114 | for agent in self.agents: 115 | dist = np.sqrt(np.sum(np.square(prey.state.p_pos - agent.state.p_pos))) 116 | if dist < min_dist: 117 | min_dist = dist 118 | direction = (prey.state.p_pos - agent.state.p_pos)/dist 119 | direction_intensity = np.abs(direction) 120 | direction[np.argmax(direction_intensity)] = np.sign(direction[np.argmax(direction_intensity)])*1 121 | direction[np.argmin(direction_intensity)] = 0 122 | # not allow to cross the boundary 123 | in_bound = self.bound(prey.state.p_pos) 124 | prey_action[1] = direction[0] + in_bound[0] 125 | prey_action[3] = direction[1] + in_bound[1] 126 | # if captured, prey chooses to stay 127 | if min_dist <= (prey.size + agent.size): 128 | prey_action[0] = 1 129 | prey_action[1] = 0 130 | prey_action[3] = 0 131 | self.force_discrete_action = False 132 | self._set_action(prey_action, prey, self.action_space[0]) 133 | self.force_discrete_action = True 134 | # advance world state 135 | self.world.step() 136 | # record observation for each agent 137 | for agent in self.agents: 138 | obs_n.append(self._get_obs(agent)) 139 | r = self._get_reward(agent) 140 | reward_n.append(r) 141 | done_n.append(self._get_done(agent)) 142 | info_n['n'].append(self._get_info(agent)) 143 | # all agents get total reward in cooperative case 144 | reward = np.sum(reward_n) 145 | if self.shared_reward: 146 | reward_n = [reward] * self.n 147 | return obs_n, reward_n, done_n, info_n 148 | 149 | def reset(self): 150 | # reset world 151 | self.reset_callback(self.world) 152 | # reset renderer 153 | self._reset_render() 154 | # record observations for each agent 155 | obs_n = [] 156 | self.agents = self.world.policy_agents 157 | for agent in self.agents: 158 | obs_n.append(self._get_obs(agent)) 159 | return obs_n 160 | 161 | # get info used for benchmarking 162 | def _get_info(self, agent): 163 | if self.info_callback is None: 164 | return {} 165 | return self.info_callback(agent, self.world) 166 | 167 | # get observation for a particular agent 168 | def _get_obs(self, agent): 169 | if self.observation_callback is None: 170 | return np.zeros(0) 171 | return self.observation_callback(agent, self.world) 172 | 173 | # get dones for a particular agent 174 | # unused right now -- agents are allowed to go beyond the viewing screen 175 | def _get_done(self, agent): 176 | if self.done_callback is None: 177 | return False 178 | return self.done_callback(agent, self.world) 179 | 180 | # get reward for a particular agent 181 | def _get_reward(self, agent): 182 | if self.reward_callback is None: 183 | return 0.0 184 | return self.reward_callback(agent, self.world) 185 | 186 | # set env action for a particular agent 187 | def _set_action(self, action, agent, action_space, time=None): 188 | agent.action.u = np.zeros(self.world.dim_p) 189 | agent.action.c = np.zeros(self.world.dim_c) 190 | # process action 191 | if isinstance(action_space, MultiDiscrete): 192 | act = [] 193 | size = action_space.high - action_space.low + 1 194 | index = 0 195 | for s in size: 196 | act.append(action[index:(index+s)]) 197 | index += s 198 | action = act 199 | else: 200 | action = [action] 201 | 202 | if agent.movable: 203 | # physical action 204 | if self.discrete_action_input: 205 | agent.action.u = np.zeros(self.world.dim_p) 206 | # process discrete action 207 | if action[0] == 1: agent.action.u[0] = -1.0 208 | if action[0] == 2: agent.action.u[0] = +1.0 209 | if action[0] == 3: agent.action.u[1] = -1.0 210 | if action[0] == 4: agent.action.u[1] = +1.0 211 | else: 212 | if self.force_discrete_action: 213 | d = np.argmax(action[0]) 214 | action[0][:] = 0.0 215 | action[0][d] = 1.0 216 | if self.discrete_action_space: 217 | agent.action.u[0] += action[0][1] - action[0][2] 218 | agent.action.u[1] += action[0][3] - action[0][4] 219 | else: 220 | agent.action.u = action[0] 221 | sensitivity = 5.0 222 | if agent.accel is not None: 223 | sensitivity = agent.accel 224 | agent.action.u *= sensitivity 225 | action = action[1:] 226 | if not agent.silent: 227 | # communication action 228 | if self.discrete_action_input: 229 | agent.action.c = np.zeros(self.world.dim_c) 230 | agent.action.c[action[0]] = 1.0 231 | else: 232 | agent.action.c = action[0] 233 | action = action[1:] 234 | # make sure we used all elements of action 235 | assert len(action) == 0 236 | 237 | # reset rendering assets 238 | def _reset_render(self): 239 | self.render_geoms = None 240 | self.render_geoms_xform = None 241 | 242 | # render environment 243 | def render(self, mode='human'): 244 | if mode == 'human': 245 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 246 | message = '' 247 | for agent in self.world.agents: 248 | comm = [] 249 | for other in self.world.agents: 250 | if other is agent: continue 251 | if np.all(other.state.c == 0): 252 | word = '_' 253 | else: 254 | word = alphabet[np.argmax(other.state.c)] 255 | message += (other.name + ' to ' + agent.name + ': ' + word + ' ') 256 | 257 | 258 | for i in range(len(self.viewers)): 259 | # create viewers (if necessary) 260 | if self.viewers[i] is None: 261 | # import rendering only if we need it (and don't import for headless machines) 262 | #from gym.envs.classic_control import rendering 263 | from multiagent import rendering 264 | self.viewers[i] = rendering.Viewer(700,700) 265 | 266 | # create rendering geometry 267 | if self.render_geoms is None: 268 | # import rendering only if we need it (and don't import for headless machines) 269 | #from gym.envs.classic_control import rendering 270 | from multiagent import rendering 271 | self.render_geoms = [] 272 | self.render_geoms_xform = [] 273 | for entity in self.world.entities: 274 | geom = rendering.make_circle(entity.size) 275 | xform = rendering.Transform() 276 | if 'agent' in entity.name: 277 | geom.set_color(*entity.color, alpha=0.5) 278 | else: 279 | geom.set_color(*entity.color) 280 | geom.add_attr(xform) 281 | self.render_geoms.append(geom) 282 | self.render_geoms_xform.append(xform) 283 | 284 | # add geoms to viewer 285 | for viewer in self.viewers: 286 | viewer.geoms = [] 287 | for geom in self.render_geoms: 288 | viewer.add_geom(geom) 289 | 290 | results = [] 291 | for i in range(len(self.viewers)): 292 | from multiagent import rendering 293 | # update bounds to center around agent 294 | cam_range = 1 295 | if self.shared_viewer: 296 | pos = np.zeros(self.world.dim_p) 297 | else: 298 | pos = self.agents[i].state.p_pos 299 | self.viewers[i].set_bounds(pos[0]-cam_range,pos[0]+cam_range,pos[1]-cam_range,pos[1]+cam_range) 300 | # update geometry positions 301 | for e, entity in enumerate(self.world.entities): 302 | self.render_geoms_xform[e].set_translation(*entity.state.p_pos) 303 | # render to display or array 304 | results.append(self.viewers[i].render(return_rgb_array = mode=='rgb_array')) 305 | 306 | return results 307 | 308 | # create receptor field locations in local coordinate frame 309 | def _make_receptor_locations(self, agent): 310 | receptor_type = 'polar' 311 | range_min = 0.05 * 2.0 312 | range_max = 1.00 313 | dx = [] 314 | # circular receptive field 315 | if receptor_type == 'polar': 316 | for angle in np.linspace(-np.pi, +np.pi, 8, endpoint=False): 317 | for distance in np.linspace(range_min, range_max, 3): 318 | dx.append(distance * np.array([np.cos(angle), np.sin(angle)])) 319 | # add origin 320 | dx.append(np.array([0.0, 0.0])) 321 | # grid receptive field 322 | if receptor_type == 'grid': 323 | for x in np.linspace(-range_max, +range_max, 5): 324 | for y in np.linspace(-range_max, +range_max, 5): 325 | dx.append(np.array([x,y])) 326 | return dx 327 | 328 | 329 | # vectorized wrapper for a batch of multi-agent environments 330 | # assumes all environments have the same observation and action space 331 | class BatchMultiAgentEnv(gym.Env): 332 | metadata = { 333 | 'runtime.vectorized': True, 334 | 'render.modes' : ['human', 'rgb_array'] 335 | } 336 | 337 | def __init__(self, env_batch): 338 | self.env_batch = env_batch 339 | 340 | @property 341 | def n(self): 342 | return np.sum([env.n for env in self.env_batch]) 343 | 344 | @property 345 | def action_space(self): 346 | return self.env_batch[0].action_space 347 | 348 | @property 349 | def observation_space(self): 350 | return self.env_batch[0].observation_space 351 | 352 | def step(self, action_n, time): 353 | obs_n = [] 354 | reward_n = [] 355 | done_n = [] 356 | info_n = {'n': []} 357 | i = 0 358 | for env in self.env_batch: 359 | obs, reward, done, _ = env.step(action_n[i:(i+env.n)], time) 360 | i += env.n 361 | obs_n += obs 362 | # reward = [r / len(self.env_batch) for r in reward] 363 | reward_n += reward 364 | done_n += done 365 | return obs_n, reward_n, done_n, info_n 366 | 367 | def reset(self): 368 | obs_n = [] 369 | for env in self.env_batch: 370 | obs_n += env.reset() 371 | return obs_n 372 | 373 | # render environment 374 | def render(self, mode='human', close=True): 375 | results_n = [] 376 | for env in self.env_batch: 377 | results_n += env.render(mode, close) 378 | return results_n 379 | -------------------------------------------------------------------------------- /multiagent/multi_discrete.py: -------------------------------------------------------------------------------- 1 | # An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates) 2 | # (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py) 3 | 4 | import numpy as np 5 | 6 | import gym 7 | from gym.spaces import prng 8 | 9 | class MultiDiscrete(gym.Space): 10 | """ 11 | - The multi-discrete action space consists of a series of discrete action spaces with different parameters 12 | - It can be adapted to both a Discrete action space or a continuous (Box) action space 13 | - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space 14 | - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space 15 | where the discrete action space can take any integers from `min` to `max` (both inclusive) 16 | Note: A value of 0 always need to represent the NOOP action. 17 | e.g. Nintendo Game Controller 18 | - Can be conceptualized as 3 discrete action spaces: 19 | 1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4 20 | 2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 21 | 3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 22 | - Can be initialized as 23 | MultiDiscrete([ [0,4], [0,1], [0,1] ]) 24 | """ 25 | def __init__(self, array_of_param_array): 26 | self.low = np.array([x[0] for x in array_of_param_array]) 27 | self.high = np.array([x[1] for x in array_of_param_array]) 28 | self.num_discrete_space = self.low.shape[0] 29 | 30 | def sample(self): 31 | """ Returns a array with one sample from each discrete action space """ 32 | # For each row: round(random .* (max - min) + min, 0) 33 | random_array = prng.np_random.rand(self.num_discrete_space) 34 | return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)] 35 | def contains(self, x): 36 | return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all() 37 | 38 | @property 39 | def shape(self): 40 | return self.num_discrete_space 41 | def __repr__(self): 42 | return "MultiDiscrete" + str(self.num_discrete_space) 43 | def __eq__(self, other): 44 | return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high) -------------------------------------------------------------------------------- /multiagent/policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyglet.window import key 3 | 4 | # individual agent policy 5 | class Policy(object): 6 | def __init__(self): 7 | pass 8 | def action(self, obs): 9 | raise NotImplementedError() 10 | 11 | # interactive policy based on keyboard input 12 | # hard-coded to deal only with movement, not communication 13 | class InteractivePolicy(Policy): 14 | def __init__(self, env, agent_index): 15 | super(InteractivePolicy, self).__init__() 16 | self.env = env 17 | # hard-coded keyboard events 18 | self.move = [False for i in range(4)] 19 | self.comm = [False for i in range(env.world.dim_c)] 20 | # register keyboard events with this environment's window 21 | env.viewers[agent_index].window.on_key_press = self.key_press 22 | env.viewers[agent_index].window.on_key_release = self.key_release 23 | 24 | def action(self, obs): 25 | # ignore observation and just act based on keyboard events 26 | if self.env.discrete_action_input: 27 | u = 0 28 | if self.move[0]: u = 1 29 | if self.move[1]: u = 2 30 | if self.move[2]: u = 4 31 | if self.move[3]: u = 3 32 | else: 33 | u = np.zeros(5) # 5-d because of no-move action 34 | if self.move[0]: u[1] += 1.0 35 | if self.move[1]: u[2] += 1.0 36 | if self.move[3]: u[3] += 1.0 37 | if self.move[2]: u[4] += 1.0 38 | if True not in self.move: 39 | u[0] += 1.0 40 | return np.concatenate([u, np.zeros(self.env.world.dim_c)]) 41 | 42 | # keyboard event callbacks 43 | def key_press(self, k, mod): 44 | if k==key.LEFT: self.move[0] = True 45 | if k==key.RIGHT: self.move[1] = True 46 | if k==key.UP: self.move[2] = True 47 | if k==key.DOWN: self.move[3] = True 48 | def key_release(self, k, mod): 49 | if k==key.LEFT: self.move[0] = False 50 | if k==key.RIGHT: self.move[1] = False 51 | if k==key.UP: self.move[2] = False 52 | if k==key.DOWN: self.move[3] = False 53 | -------------------------------------------------------------------------------- /multiagent/rendering.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2D rendering framework 3 | """ 4 | from __future__ import division 5 | import os 6 | import six 7 | import sys 8 | 9 | if "Apple" in sys.version: 10 | if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ: 11 | os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib' 12 | # (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite 13 | 14 | from gym.utils import reraise 15 | from gym import error 16 | 17 | try: 18 | import pyglet 19 | except ImportError as e: 20 | reraise(suffix="HINT: you can install pyglet directly via 'pip install pyglet'. But if you really just want to install all Gym dependencies and not have to think about it, 'pip install -e .[all]' or 'pip install gym[all]' will do it.") 21 | 22 | try: 23 | from pyglet.gl import * 24 | except ImportError as e: 25 | reraise(prefix="Error occured while running `from pyglet.gl import *`",suffix="HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \"-screen 0 1400x900x24\" python '") 26 | 27 | import math 28 | import numpy as np 29 | 30 | RAD2DEG = 57.29577951308232 31 | 32 | def get_display(spec): 33 | """Convert a display specification (such as :0) into an actual Display 34 | object. 35 | 36 | Pyglet only supports multiple Displays on Linux. 37 | """ 38 | if spec is None: 39 | return None 40 | elif isinstance(spec, six.string_types): 41 | return pyglet.canvas.Display(spec) 42 | else: 43 | raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec)) 44 | 45 | class Viewer(object): 46 | def __init__(self, width, height, display=None): 47 | display = get_display(display) 48 | 49 | self.width = width 50 | self.height = height 51 | 52 | self.window = pyglet.window.Window(width=width, height=height, display=display) 53 | self.window.on_close = self.window_closed_by_user 54 | self.geoms = [] 55 | self.onetime_geoms = [] 56 | self.transform = Transform() 57 | 58 | glEnable(GL_BLEND) 59 | # glEnable(GL_MULTISAMPLE) 60 | glEnable(GL_LINE_SMOOTH) 61 | # glHint(GL_LINE_SMOOTH_HINT, GL_DONT_CARE) 62 | glHint(GL_LINE_SMOOTH_HINT, GL_NICEST) 63 | glLineWidth(2.0) 64 | glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 65 | 66 | def close(self): 67 | self.window.close() 68 | 69 | def window_closed_by_user(self): 70 | self.close() 71 | 72 | def set_bounds(self, left, right, bottom, top): 73 | assert right > left and top > bottom 74 | scalex = self.width/(right-left) 75 | scaley = self.height/(top-bottom) 76 | self.transform = Transform( 77 | translation=(-left*scalex, -bottom*scaley), 78 | scale=(scalex, scaley)) 79 | 80 | def add_geom(self, geom): 81 | self.geoms.append(geom) 82 | 83 | def add_onetime(self, geom): 84 | self.onetime_geoms.append(geom) 85 | 86 | def render(self, return_rgb_array=False): 87 | glClearColor(1,1,1,1) 88 | self.window.clear() 89 | self.window.switch_to() 90 | self.window.dispatch_events() 91 | self.transform.enable() 92 | for geom in self.geoms: 93 | geom.render() 94 | for geom in self.onetime_geoms: 95 | geom.render() 96 | self.transform.disable() 97 | arr = None 98 | if return_rgb_array: 99 | buffer = pyglet.image.get_buffer_manager().get_color_buffer() 100 | image_data = buffer.get_image_data() 101 | arr = np.fromstring(image_data.data, dtype=np.uint8, sep='') 102 | # In https://github.com/openai/gym-http-api/issues/2, we 103 | # discovered that someone using Xmonad on Arch was having 104 | # a window of size 598 x 398, though a 600 x 400 window 105 | # was requested. (Guess Xmonad was preserving a pixel for 106 | # the boundary.) So we use the buffer height/width rather 107 | # than the requested one. 108 | arr = arr.reshape(buffer.height, buffer.width, 4) 109 | arr = arr[::-1,:,0:3] 110 | self.window.flip() 111 | self.onetime_geoms = [] 112 | return arr 113 | 114 | # Convenience 115 | def draw_circle(self, radius=10, res=30, filled=True, **attrs): 116 | geom = make_circle(radius=radius, res=res, filled=filled) 117 | _add_attrs(geom, attrs) 118 | self.add_onetime(geom) 119 | return geom 120 | 121 | def draw_polygon(self, v, filled=True, **attrs): 122 | geom = make_polygon(v=v, filled=filled) 123 | _add_attrs(geom, attrs) 124 | self.add_onetime(geom) 125 | return geom 126 | 127 | def draw_polyline(self, v, **attrs): 128 | geom = make_polyline(v=v) 129 | _add_attrs(geom, attrs) 130 | self.add_onetime(geom) 131 | return geom 132 | 133 | def draw_line(self, start, end, **attrs): 134 | geom = Line(start, end) 135 | _add_attrs(geom, attrs) 136 | self.add_onetime(geom) 137 | return geom 138 | 139 | def get_array(self): 140 | self.window.flip() 141 | image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() 142 | self.window.flip() 143 | arr = np.fromstring(image_data.data, dtype=np.uint8, sep='') 144 | arr = arr.reshape(self.height, self.width, 4) 145 | return arr[::-1,:,0:3] 146 | 147 | def _add_attrs(geom, attrs): 148 | if "color" in attrs: 149 | geom.set_color(*attrs["color"]) 150 | if "linewidth" in attrs: 151 | geom.set_linewidth(attrs["linewidth"]) 152 | 153 | class Geom(object): 154 | def __init__(self): 155 | self._color=Color((0, 0, 0, 1.0)) 156 | self.attrs = [self._color] 157 | def render(self): 158 | for attr in reversed(self.attrs): 159 | attr.enable() 160 | self.render1() 161 | for attr in self.attrs: 162 | attr.disable() 163 | def render1(self): 164 | raise NotImplementedError 165 | def add_attr(self, attr): 166 | self.attrs.append(attr) 167 | def set_color(self, r, g, b, alpha=1): 168 | self._color.vec4 = (r, g, b, alpha) 169 | 170 | class Attr(object): 171 | def enable(self): 172 | raise NotImplementedError 173 | def disable(self): 174 | pass 175 | 176 | class Transform(Attr): 177 | def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1,1)): 178 | self.set_translation(*translation) 179 | self.set_rotation(rotation) 180 | self.set_scale(*scale) 181 | def enable(self): 182 | glPushMatrix() 183 | glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint 184 | glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0) 185 | glScalef(self.scale[0], self.scale[1], 1) 186 | def disable(self): 187 | glPopMatrix() 188 | def set_translation(self, newx, newy): 189 | self.translation = (float(newx), float(newy)) 190 | def set_rotation(self, new): 191 | self.rotation = float(new) 192 | def set_scale(self, newx, newy): 193 | self.scale = (float(newx), float(newy)) 194 | 195 | class Color(Attr): 196 | def __init__(self, vec4): 197 | self.vec4 = vec4 198 | def enable(self): 199 | glColor4f(*self.vec4) 200 | 201 | class LineStyle(Attr): 202 | def __init__(self, style): 203 | self.style = style 204 | def enable(self): 205 | glEnable(GL_LINE_STIPPLE) 206 | glLineStipple(1, self.style) 207 | def disable(self): 208 | glDisable(GL_LINE_STIPPLE) 209 | 210 | class LineWidth(Attr): 211 | def __init__(self, stroke): 212 | self.stroke = stroke 213 | def enable(self): 214 | glLineWidth(self.stroke) 215 | 216 | class Point(Geom): 217 | def __init__(self): 218 | Geom.__init__(self) 219 | def render1(self): 220 | glBegin(GL_POINTS) # draw point 221 | glVertex3f(0.0, 0.0, 0.0) 222 | glEnd() 223 | 224 | class FilledPolygon(Geom): 225 | def __init__(self, v): 226 | Geom.__init__(self) 227 | self.v = v 228 | def render1(self): 229 | if len(self.v) == 4 : glBegin(GL_QUADS) 230 | elif len(self.v) > 4 : glBegin(GL_POLYGON) 231 | else: glBegin(GL_TRIANGLES) 232 | for p in self.v: 233 | glVertex3f(p[0], p[1],0) # draw each vertex 234 | glEnd() 235 | 236 | color = (self._color.vec4[0] * 0.5, self._color.vec4[1] * 0.5, self._color.vec4[2] * 0.5, self._color.vec4[3] * 0.5) 237 | glColor4f(*color) 238 | glBegin(GL_LINE_LOOP) 239 | for p in self.v: 240 | glVertex3f(p[0], p[1],0) # draw each vertex 241 | glEnd() 242 | 243 | def make_circle(radius=10, res=30, filled=True): 244 | points = [] 245 | for i in range(res): 246 | ang = 2*math.pi*i / res 247 | points.append((math.cos(ang)*radius, math.sin(ang)*radius)) 248 | if filled: 249 | return FilledPolygon(points) 250 | else: 251 | return PolyLine(points, True) 252 | 253 | def make_polygon(v, filled=True): 254 | if filled: return FilledPolygon(v) 255 | else: return PolyLine(v, True) 256 | 257 | def make_polyline(v): 258 | return PolyLine(v, False) 259 | 260 | def make_capsule(length, width): 261 | l, r, t, b = 0, length, width/2, -width/2 262 | box = make_polygon([(l,b), (l,t), (r,t), (r,b)]) 263 | circ0 = make_circle(width/2) 264 | circ1 = make_circle(width/2) 265 | circ1.add_attr(Transform(translation=(length, 0))) 266 | geom = Compound([box, circ0, circ1]) 267 | return geom 268 | 269 | class Compound(Geom): 270 | def __init__(self, gs): 271 | Geom.__init__(self) 272 | self.gs = gs 273 | for g in self.gs: 274 | g.attrs = [a for a in g.attrs if not isinstance(a, Color)] 275 | def render1(self): 276 | for g in self.gs: 277 | g.render() 278 | 279 | class PolyLine(Geom): 280 | def __init__(self, v, close): 281 | Geom.__init__(self) 282 | self.v = v 283 | self.close = close 284 | self.linewidth = LineWidth(1) 285 | self.add_attr(self.linewidth) 286 | def render1(self): 287 | glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP) 288 | for p in self.v: 289 | glVertex3f(p[0], p[1],0) # draw each vertex 290 | glEnd() 291 | def set_linewidth(self, x): 292 | self.linewidth.stroke = x 293 | 294 | class Line(Geom): 295 | def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)): 296 | Geom.__init__(self) 297 | self.start = start 298 | self.end = end 299 | self.linewidth = LineWidth(1) 300 | self.add_attr(self.linewidth) 301 | 302 | def render1(self): 303 | glBegin(GL_LINES) 304 | glVertex2f(*self.start) 305 | glVertex2f(*self.end) 306 | glEnd() 307 | 308 | class Image(Geom): 309 | def __init__(self, fname, width, height): 310 | Geom.__init__(self) 311 | self.width = width 312 | self.height = height 313 | img = pyglet.image.load(fname) 314 | self.img = img 315 | self.flip = False 316 | def render1(self): 317 | self.img.blit(-self.width/2, -self.height/2, width=self.width, height=self.height) 318 | 319 | # ================================================================ 320 | 321 | class SimpleImageViewer(object): 322 | def __init__(self, display=None): 323 | self.window = None 324 | self.isopen = False 325 | self.display = display 326 | def imshow(self, arr): 327 | if self.window is None: 328 | height, width, channels = arr.shape 329 | self.window = pyglet.window.Window(width=width, height=height, display=self.display) 330 | self.width = width 331 | self.height = height 332 | self.isopen = True 333 | assert arr.shape == (self.height, self.width, 3), "You passed in an image with the wrong number shape" 334 | image = pyglet.image.ImageData(self.width, self.height, 'RGB', arr.tobytes(), pitch=self.width * -3) 335 | self.window.clear() 336 | self.window.switch_to() 337 | self.window.dispatch_events() 338 | image.blit(0,0) 339 | self.window.flip() 340 | def close(self): 341 | if self.isopen: 342 | self.window.close() 343 | self.isopen = False 344 | def __del__(self): 345 | self.close() -------------------------------------------------------------------------------- /multiagent/scenario.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # defines scenario upon which the world is built 4 | class BaseScenario(object): 5 | # create elements of the world 6 | def make_world(self): 7 | raise NotImplementedError() 8 | # create initial conditions of the world 9 | def reset_world(self, world): 10 | raise NotImplementedError() 11 | -------------------------------------------------------------------------------- /multiagent/scenarios/CN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiagent.core import World, Agent, Landmark 3 | from multiagent.scenario import BaseScenario 4 | import random 5 | 6 | 7 | class Scenario(BaseScenario): 8 | def make_world(self, num_agents, num_targets): 9 | world = World() 10 | # set any world properties first 11 | world.dim_c = 0 12 | if num_agents == -1: 13 | num_agents = 3 14 | num_landmarks = 3 15 | else: 16 | if num_targets == -1: 17 | raise AssertionError("Number of targets is not assigned") 18 | else: 19 | num_landmarks = num_targets 20 | world.collaborative = False 21 | world.discrete_action = True 22 | world.num_agents_obs = 2 23 | world.num_landmarks_obs = 2 24 | # add agents 25 | world.agents = [Agent() for i in range(num_agents)] 26 | for i, agent in enumerate(world.agents): 27 | agent.name = 'agent %d' % i 28 | agent.collide = True 29 | agent.silent = True 30 | agent.size = 0.05 31 | # add landmarks 32 | world.landmarks = [Landmark() for i in range(num_landmarks)] 33 | for i, landmark in enumerate(world.landmarks): 34 | landmark.name = 'landmark %d' % i 35 | landmark.collide = False 36 | landmark.movable = False 37 | # make initial conditions 38 | self.reset_world(world) 39 | return world 40 | 41 | def reset_world(self, world): 42 | # random properties for agents 43 | for i, agent in enumerate(world.agents): 44 | agent.color = np.array([0.35, 0.35, 0.85]) 45 | # random properties for landmarks 46 | for i, landmark in enumerate(world.landmarks): 47 | landmark.color = np.array([0.25, 0.25, 0.25]) 48 | # set random initial states 49 | for agent in world.agents: 50 | agent.state.p_pos = np.random.uniform(-world.range_p, +world.range_p, world.dim_p) 51 | agent.state.p_vel = np.zeros(world.dim_p) 52 | agent.state.c = np.zeros(world.dim_c) 53 | for i, landmark in enumerate(world.landmarks): 54 | landmark.state.p_pos = np.random.uniform(-world.range_p, +world.range_p, world.dim_p) 55 | if i != 0: 56 | for j in range(i): 57 | while True: 58 | if np.sqrt(np.sum(np.square(landmark.state.p_pos - world.landmarks[j].state.p_pos)))>0.22: 59 | break 60 | else: landmark.state.p_pos = np.random.uniform(-world.range_p, +world.range_p, world.dim_p) 61 | landmark.state.p_vel = np.zeros(world.dim_p) 62 | 63 | # # set agent goals 64 | # if goals is None: 65 | # goals = [i for i in range(len(world.agents))] 66 | # random.shuffle(goals) 67 | # world.goals = goals 68 | 69 | def benchmark_data(self, agent, world): 70 | rew = 0 71 | collisions = 0 72 | occupied_landmarks = 0 73 | min_dists = 0 74 | for l in world.landmarks: 75 | collision_dist = agent.size + l.size 76 | dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents] 77 | min_dists += min(dists) 78 | rew -= min(dists) 79 | if min(dists) < collision_dist: 80 | occupied_landmarks += 1 81 | if agent.collide: 82 | for a in world.agents: 83 | for b in world.agents: 84 | if a is b: continue 85 | if self.is_collision(a, b): 86 | collisions += 0.5 87 | return (rew, collisions, min_dists, occupied_landmarks) 88 | 89 | def is_collision(self, agent1, agent2): 90 | delta_pos = agent1.state.p_pos - agent2.state.p_pos 91 | dist = np.sqrt(np.sum(np.square(delta_pos))) 92 | collision_dist = agent1.size + agent2.size 93 | return True if dist < collision_dist else False 94 | 95 | def reward(self, agent, world): 96 | # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions 97 | rew = 0 98 | # local reward 99 | #dists = [np.sqrt(np.sum(np.square(agent.state.p_pos - l.state.p_pos))) for l in world.landmarks] 100 | #rew = rew - min(dists) 101 | # global reward 102 | for l in world.landmarks: 103 | dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents] 104 | rew -= min(dists) 105 | # collisions penalty 106 | if agent.collide: 107 | for a in world.agents: 108 | for b in world.agents: 109 | if a is b: continue 110 | if self.is_collision(a, b): 111 | rew -= 0.5 112 | return rew 113 | 114 | def observation(self, agent, world): 115 | entity_pos = [] 116 | dist_n = [] 117 | for entity in world.landmarks: # world.entities: 118 | entity_pos.append(entity.state.p_pos - agent.state.p_pos) 119 | dist_n.append(np.sqrt(np.sum(np.square(agent.state.p_pos - entity.state.p_pos)))) 120 | # dist_sort = dist_n.copy() 121 | # dist_sort.sort() 122 | # num_landmarks_obs = world.num_landmarks_obs 123 | # dist_thresh = dist_sort[num_landmarks_obs-1] 124 | target_pos = [] 125 | for i,pos in enumerate(entity_pos): 126 | if True:#dist_n[i] <= dist_thresh: 127 | target_pos.append(pos) 128 | else: 129 | target_pos.append(np.array([100,100])) 130 | other_pos = [] 131 | for other in world.agents: 132 | if other is agent: continue 133 | other_pos.append(other.state.p_pos - agent.state.p_pos) 134 | #print(target_pos) 135 | return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + other_pos + target_pos) 136 | 137 | 138 | -------------------------------------------------------------------------------- /multiagent/scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os.path as osp 3 | 4 | 5 | def load(name): 6 | pathname = osp.join(osp.dirname(__file__), name) 7 | return imp.load_source('', pathname) 8 | -------------------------------------------------------------------------------- /perception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | from torch.nn import Parameter 7 | from torch.autograd import Variable 8 | 9 | 10 | class NoisyLinear(nn.Linear): 11 | def __init__(self, in_features, out_features, sigma_init=0.017, bias=True): 12 | super(NoisyLinear, self).__init__(in_features, out_features, bias=True) 13 | self.sigma_init = sigma_init 14 | self.sigma_weight = Parameter(torch.Tensor(out_features, in_features)) 15 | self.sigma_bias = Parameter(torch.Tensor(out_features)) 16 | self.register_buffer('epsilon_weight', torch.zeros(out_features, in_features)) 17 | self.register_buffer('epsilon_bias', torch.zeros(out_features)) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | if hasattr(self, 'sigma_weight'): 22 | init.uniform(self.weight, -math.sqrt(3 / self.in_features), math.sqrt(3 / self.in_features)) 23 | init.uniform(self.bias, -math.sqrt(3 / self.in_features), math.sqrt(3 / self.in_features)) 24 | init.constant(self.sigma_weight, self.sigma_init) 25 | init.constant(self.sigma_bias, self.sigma_init) 26 | 27 | def forward(self, input): 28 | return F.linear(input, self.weight + self.sigma_weight * Variable(self.epsilon_weight), 29 | self.bias + self.sigma_bias * Variable(self.epsilon_bias)) 30 | 31 | def sample_noise(self): 32 | self.epsilon_weight = torch.randn(self.out_features, self.in_features) 33 | self.epsilon_bias = torch.randn(self.out_features) 34 | 35 | def remove_noise(self): 36 | self.epsilon_weight = torch.zeros(self.out_features, self.in_features) 37 | self.epsilon_bias = torch.zeros(self.out_features) 38 | 39 | 40 | class BiRNN(torch.nn.Module): 41 | def __init__(self, input_size, hidden_size, num_layers, device, head_name): 42 | super(BiRNN, self).__init__() 43 | self.hidden_size = hidden_size 44 | self.num_layers = num_layers 45 | if 'lstm' in head_name: 46 | self.lstm = True 47 | else: 48 | self.lstm = False 49 | if self.lstm: 50 | self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True).to(device) 51 | else: 52 | self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True).to(device) 53 | self.feature_dim = hidden_size * 2 54 | self.device = device 55 | 56 | def forward(self, x, state=None): 57 | # Set initial states 58 | 59 | h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(self.device) # 2 for bidirection 60 | c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(self.device) 61 | 62 | # Forward propagate LSTM 63 | if self.lstm: 64 | out, (_, hn) = self.rnn(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2) 65 | else: 66 | out, hn = self.rnn(x, h0) # out: tensor of shape (batch_size, seq_length, hidden_size*2) 67 | return out, hn 68 | 69 | class RNN(torch.nn.Module): 70 | def __init__(self, input_size, hidden_size, num_layers, device, head_name): 71 | super(RNN, self).__init__() 72 | self.hidden_size = hidden_size 73 | self.num_layers = num_layers 74 | if 'lstm' in head_name: 75 | self.lstm = True 76 | else: 77 | self.lstm = False 78 | if self.lstm: 79 | self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True).to(device) 80 | else: 81 | self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True).to(device) 82 | 83 | self.feature_dim = hidden_size 84 | # add layer normalization to stable training 85 | self.LayerNorm = nn.LayerNorm([hidden_size]) 86 | self.device = device 87 | 88 | def forward(self, x, h0, c0=None, state=None): 89 | # x: [batch_size, seq_length, input_size] h:[num_layers, batch_size, hidden_size] 90 | # Forward propagate LSTM 91 | h0 = self.LayerNorm(h0) 92 | if self.lstm: 93 | out, (_, hn) = self.rnn(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2) 94 | else: 95 | out, hn = self.rnn(x, h0) # out: tensor of shape (batch_size, seq_length, hidden_size*2) 96 | hn = self.LayerNorm(hn) 97 | 98 | return out, hn 99 | 100 | 101 | def xavier_init(layer): 102 | torch.nn.init.xavier_uniform_(layer.weight) 103 | torch.nn.init.constant_(layer.bias, 0) 104 | return layer 105 | 106 | 107 | class AttentionLayer(torch.nn.Module): 108 | def __init__(self, feature_dim, weight_dim, device=torch.device('cpu')): 109 | super(AttentionLayer, self).__init__() 110 | self.in_dim = feature_dim 111 | self.device = device 112 | 113 | self.Q = xavier_init(nn.Linear(self.in_dim, weight_dim)) 114 | self.K = xavier_init(nn.Linear(self.in_dim, weight_dim)) 115 | self.V = xavier_init(nn.Linear(self.in_dim, weight_dim)) 116 | 117 | self.feature_dim = weight_dim 118 | 119 | def forward(self, x): 120 | ''' 121 | inference 122 | :param x: [num_agent, num_target, feature_dim] 123 | :return z: [num_agent, num_target, weight_dim] 124 | ''' 125 | # z = softmax(Q,K)*V 126 | q = torch.tanh(self.Q(x)) # [batch_size, sequence_len, weight_dim] 127 | k = torch.tanh(self.K(x)) # [batch_size, sequence_len, weight_dim] 128 | v = torch.tanh(self.V(x)) # [batch_size, sequence_len, weight_dim] 129 | 130 | z = torch.bmm(F.softmax(torch.bmm(q, k.permute(0, 2, 1)), dim=2), v) # [batch_size, sequence_len, weight_dim] 131 | 132 | global_feature = z.sum(dim=1) 133 | return z, global_feature 134 | -------------------------------------------------------------------------------- /player_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import json 9 | from utils import ensure_shared_grads 10 | 11 | class Agent(object): 12 | def __init__(self, model, env, args, state, device): 13 | self.model = model 14 | self.env = env 15 | self.num_agents = env.n 16 | self.num_targets = env.num_target 17 | self.state_dim = env.observation_space.shape[2] 18 | self.model_name = args.model 19 | self.prior = torch.FloatTensor(np.array([0.7, 0.3])) # communication edge prior 20 | 21 | self.model_name = args.model 22 | self.eps_len = 0 23 | self.eps_num = 0 24 | self.args = args 25 | self.values = [] 26 | self.log_probs = [] 27 | self.rewards = [] 28 | self.entropies = [] 29 | self.rewards_eps = [] 30 | self.done = True 31 | self.info = None 32 | self.reward = 0 33 | self.device = device 34 | self.lstm_out = args.lstm_out 35 | self.reward_mean = None 36 | self.reward_std = 1 37 | self.num_steps = 0 38 | self.env_step = 0 39 | self.vk = 0 40 | self.state = state 41 | self.rank = 0 42 | # evaluation for ToM & Comm 43 | self.comm_ToM_loss = torch.zeros(1) 44 | self.no_comm_ToM_loss = torch.zeros(1) 45 | self.ToM_loss = torch.zeros(1) 46 | 47 | self.hself = torch.zeros(self.num_agents, self.lstm_out).to(device) 48 | self.hToM = torch.zeros(self.num_agents, self.num_agents, self.lstm_out).to(device) 49 | 50 | self.poses = None # cam_dim=3 ndarray 51 | self.ToM_history = [] 52 | self.Policy_history = [] 53 | 54 | def get_other_poses(self): 55 | # ToM2C requires the poses of each agent, so you need to declare how to get the poses for each env 56 | if "MSMTC" in self.args.env: 57 | cam_states = self.env.get_cam_states() 58 | cam_states = torch.from_numpy(np.array(cam_states)).float().to(self.device) 59 | 60 | # compute relative camera poses in self coordinate 61 | cam_dim = cam_states.size()[-1] # cam_dim=3 62 | cam_states_duplicate = cam_states.unsqueeze(0).expand(self.num_agents, self.num_agents, cam_dim) 63 | cam_states_relative = cam_states_duplicate - cam_states.unsqueeze(1).expand(self.num_agents, self.num_agents, cam_dim) 64 | cam_state_theta = ((cam_states_relative[:,:,-1]/180) * np.pi).reshape(self.num_agents, self.num_agents, 1) 65 | poses = torch.cat((cam_states_relative[:,:,:2], torch.cos(cam_state_theta), torch.sin(cam_state_theta)),-1) 66 | return poses 67 | elif "CN" in self.args.env: 68 | return torch.zeros(self.num_agents, self.num_agents, 1) 69 | 70 | def get_mask(self): 71 | if not self.args.mask: 72 | return torch.ones(self.num_agents, self.num_agents, 1) 73 | # ToM2C provides the option to mask the ToM inference and communication to agents out of ranges(include self) 74 | if "MSMTC" in self.args.env: 75 | mask = self.env.get_mask() 76 | mask = torch.from_numpy(mask).unsqueeze(-1).bool() 77 | mask = mask.to(self.device) 78 | return mask 79 | else: 80 | return torch.ones(self.num_agents, self.num_agents, 1) 81 | 82 | def get_available_actions(self): 83 | available_actions = self.env.get_available_actions() 84 | available_actions = torch.from_numpy(available_actions).to(self.device) 85 | return available_actions 86 | 87 | def action_train(self): 88 | if self.args.mask_actions: 89 | available_actions = self.get_available_actions() 90 | available_actions_data = available_actions.cpu().numpy() 91 | else: 92 | available_actions = None 93 | available_actions_data = 0 94 | 95 | self.poses = self.get_other_poses() 96 | self.mask = self.get_mask() 97 | value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, edge_logits, comm_edges, probs, real_cover, ToM_target_cover =\ 98 | self.model(self.state, self.hself, self.hToM, self.poses, self.mask, available_actions = available_actions) 99 | 100 | actions_env = actions.cpu().numpy() # only ndarrays can be processed by the environment 101 | state_multi, reward, self.done, self.info = self.env.step(actions_env)#,obstacle=True) 102 | reward_multi = reward.repeat(self.num_agents) # all agents share the same reward 103 | 104 | self.reward_org = reward_multi.copy() 105 | 106 | if self.args.norm_reward: 107 | reward_multi = self.reward_normalizer(reward_multi) 108 | 109 | # save state for training 110 | Policy_data = {"state":self.state.detach().cpu().numpy(), "poses": self.poses.detach().cpu().numpy(),"actions": actions_env, "reward": reward_multi,\ 111 | "mask":self.mask.detach().cpu().numpy(),"available_actions": available_actions_data} 112 | real_goals = torch.cat((1-actions,actions),-1) 113 | ToM_data = {"state":self.state.detach().cpu().numpy(), "poses":self.poses.detach().cpu().numpy(), "mask":self.mask.detach().cpu().numpy(),\ 114 | "real":real_goals.detach().cpu().numpy(), "available_actions": available_actions_data} 115 | self.Policy_history.append(Policy_data) 116 | self.ToM_history.append(ToM_data) 117 | 118 | if isinstance(self.done, list): self.done = np.sum(self.done) 119 | self.state = torch.from_numpy(np.array(state_multi)).float().to(self.device) 120 | 121 | self.reward = torch.tensor(reward_multi).float().to(self.device) 122 | self.eps_len += 1 123 | 124 | self.hself=hn_self 125 | self.hToM=hn_ToM 126 | 127 | self.env_step += 1 128 | if self.env_step >= self.env.max_steps: 129 | self.done = True 130 | 131 | def action_test(self): 132 | if self.args.mask_actions: 133 | available_actions = self.get_available_actions() 134 | else: 135 | available_actions = None 136 | 137 | with torch.no_grad(): 138 | self.poses = self.get_other_poses() 139 | self.mask = self.get_mask() 140 | value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, edge_logits, comm_edges, probs, real_cover, ToM_target_cover=\ 141 | self.model(self.state, self.hself, self.hToM, self.poses, self.mask, True, available_actions = available_actions) 142 | 143 | self.comm_cnt = torch.sum(comm_edges) 144 | self.comm_bit = self.comm_cnt * self.num_targets 145 | self.env.comm_edges = comm_edges 146 | 147 | ''' 148 | # compute ToM prediction accuracy 149 | ToM_goal = (ToM_goals[:,:,:,-1]>=0.1).unsqueeze(-1) # n * n-1 * m * 1 150 | random_ToM_goal = torch.randint(2,(self.num_agents,self.num_agents-1,self.num_targets,1)) 151 | real_goal = torch.from_numpy(actions) 152 | real_goal = real_goal.unsqueeze(0).repeat(self.num_agents,1,1,1) 153 | idx= (torch.ones(self.num_agents, self.num_agents) - torch.diag(torch.ones(self.num_agents))).bool() 154 | real_goal = real_goal[idx].reshape(self.num_agents, self.num_agents-1, self.num_targets, -1) 155 | ToM_cover = (ToM_target_cover >= 0.1) 156 | random_ToM_cover = torch.randint(2,(self.num_agents,self.num_agents-1,self.num_targets,1)) 157 | self.ToM_acc = (ToM_goal==real_goal)[real_cover].float() 158 | self.ToM_acc = torch.mean(self.ToM_acc) 159 | self.ToM_target_acc = torch.mean((real_cover==ToM_cover)[real_cover].float()) 160 | self.random_ToM_acc = torch.mean((random_ToM_goal==real_goal)[real_cover].float()) 161 | self.random_ToM_target_acc = torch.mean((real_cover==random_ToM_cover)[real_cover].float()) 162 | #print(torch.mean(ToM_goal.float())) 163 | ''' 164 | state_multi, self.reward, self.done, self.info = self.env.step(actions)#, obstacle=True) 165 | if isinstance(self.done, list): self.done = np.sum(self.done) 166 | self.state = torch.from_numpy(np.array(state_multi)).float().to(self.device) 167 | self.eps_len += 1 168 | 169 | self.hself=hn_self 170 | self.hToM=hn_ToM 171 | 172 | self.env_step += 1 173 | if self.env_step >= self.env.max_steps: 174 | self.done = True 175 | 176 | def reset(self): 177 | obs = self.env.reset() 178 | self.state = torch.from_numpy(np.array(obs)).float().to(self.device) 179 | 180 | self.eps_len = 0 181 | self.eps_num += 1 182 | self.reset_rnn_hidden() 183 | 184 | self.model.sample_noise() 185 | 186 | def clean_buffer(self, done): 187 | self.env_step = 0 188 | # outputs 189 | self.values = [] 190 | self.log_probs = [] 191 | self.entropies = [] 192 | # gt 193 | self.rewards = [] 194 | if done: 195 | # clean 196 | self.rewards_eps = [] 197 | 198 | return self 199 | 200 | def reward_normalizer(self, reward): 201 | reward = np.array(reward) 202 | self.num_steps += 1 203 | if self.num_steps == 1: 204 | self.reward_mean = reward 205 | self.vk = 0 206 | self.reward_std = 1 207 | else: 208 | delt = reward - self.reward_mean 209 | self.reward_mean = self.reward_mean + delt/self.num_steps 210 | self.vk = self.vk + delt * (reward-self.reward_mean) 211 | self.reward_std = np.sqrt(self.vk/(self.num_steps - 1)) 212 | reward = (reward - self.reward_mean) / (self.reward_std + 1e-8) 213 | return reward 214 | 215 | def reset_rnn_hidden(self): 216 | self.hself = torch.zeros(self.num_agents, self.lstm_out).to(self.device) 217 | self.hToM = torch.zeros(self.num_agents, self.num_agents, self.lstm_out).to(self.device) 218 | 219 | def update_rnn_hidden(self): 220 | self.hself = Variable(self.hself.data) 221 | self.hToM = Variable(self.hToM.data) 222 | 223 | -------------------------------------------------------------------------------- /render_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from setproctitle import setproctitle as ptitle 3 | 4 | import os 5 | import time 6 | import torch 7 | import logging 8 | import numpy as np 9 | import argparse 10 | from tensorboardX import SummaryWriter 11 | 12 | from model import build_model 13 | from utils import setup_logger 14 | from player_util import Agent 15 | from environment import create_env 16 | 17 | 18 | parser = argparse.ArgumentParser(description='render') 19 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 20 | parser.add_argument('--test-eps', type=int, default=5, metavar='M', help='testing episode length') 21 | parser.add_argument('--env', default='simple', metavar='Pose-v0', help='environment to train on (default: Pose-v0|Pose-v1)') 22 | parser.add_argument('--load-model-dir', default=None, help='folder to load trained high-level models from') 23 | parser.add_argument('--load-executor-dir', default=None, help='folder to load trained low-level models from') 24 | parser.add_argument('--env-steps', type=int, default=20, help='env steps') 25 | parser.add_argument('--model', default='single', metavar='M', help='multi-shapleyV|') 26 | parser.add_argument('--lstm-out', type=int, default=32, metavar='LO', help='lstm output size') 27 | parser.add_argument('--mask', dest='mask', action='store_true', help='mask ToM and communication to those out of range') 28 | parser.add_argument('--mask-actions', dest='mask_actions', action='store_true', help='mask unavailable actions to boost training') 29 | parser.add_argument('--gpu-id', type=int, default=-1, nargs='+', help='GPUs to use [-1 CPU only] (default: -1)') 30 | parser.add_argument('--render', dest='render', action='store_true', help='render test') 31 | parser.add_argument('--render_save', dest='render_save', action='store_true', help='render save') 32 | 33 | parser.add_argument('--num-agents', type=int, default=-1) # if -1, then the env will load the default setting 34 | parser.add_argument('--num-targets', type=int, default=-1) # else, you can assign the number of agents and targets yourself 35 | 36 | 37 | def render_test(args): 38 | gpu_id = args.gpu_id 39 | 40 | torch.manual_seed(args.seed) 41 | if gpu_id >= 0: 42 | torch.cuda.manual_seed(args.seed) 43 | device = torch.device('cuda:' + str(gpu_id)) 44 | else: 45 | device = torch.device('cpu') 46 | 47 | env = create_env(args.env, args) 48 | 49 | env.seed(args.seed) 50 | 51 | player = Agent(None, env, args, None, device) 52 | player.gpu_id = gpu_id 53 | player.model = build_model(player.env, args, device).to(device) 54 | player.model.eval() 55 | 56 | saved_state = torch.load(args.load_model_dir) 57 | player.model.load_state_dict(saved_state['model'],strict=False) 58 | 59 | ave_reward_list = [] 60 | comm_cnt_list = [] 61 | comm_bit_list = [] 62 | 63 | for i_episode in range(args.test_eps): 64 | player.reset() 65 | comm_cnt = 0 66 | comm_bit = 0 67 | reward_sum_ep = 0 68 | 69 | print(f"Episode:{i_episode}") 70 | for i_step in range(args.env_steps): 71 | player.action_test() 72 | comm_cnt += player.comm_cnt 73 | comm_bit += player.comm_bit 74 | reward_sum_ep += player.reward 75 | comm_cnt_list.append(comm_cnt/env.max_steps) 76 | comm_bit_list.append(comm_bit/env.max_steps) 77 | print('reward step',reward_sum_ep[0]/args.env_steps) 78 | print('comm_edge', comm_cnt.data/args.env_steps) 79 | print('comm_bandwidth', comm_bit.data/args.env_steps) 80 | # print(comm_bit_list) 81 | 82 | if __name__ == '__main__': 83 | args = parser.parse_args() 84 | render_test(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym == 0.10.5 2 | numpy == 1.15.1 3 | torch == 1.7.0 4 | tensorboard 5 | torchvision 6 | opencv-python 7 | setproctitle -------------------------------------------------------------------------------- /shared_optim.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | import torch 4 | import torch.optim as optim 5 | from collections import defaultdict 6 | 7 | 8 | class SharedRMSprop(optim.Optimizer): 9 | """Implements RMSprop algorithm with shared states. 10 | """ 11 | 12 | def __init__(self, 13 | params, 14 | lr=7e-4, 15 | alpha=0.99, 16 | eps=0.1, 17 | weight_decay=0, 18 | momentum=0, 19 | centered=False): 20 | defaults = defaultdict(lr=lr, alpha=alpha, eps=eps, 21 | weight_decay=weight_decay, momentum=momentum, centered=centered) 22 | super(SharedRMSprop, self).__init__(params, defaults) 23 | 24 | for group in self.param_groups: 25 | for p in group['params']: 26 | state = self.state[p] 27 | state['step'] = torch.zeros(1) 28 | state['grad_avg'] = p.data.new().resize_as_(p.data).zero_() 29 | state['square_avg'] = p.data.new().resize_as_(p.data).zero_() 30 | state['momentum_buffer'] = p.data.new( 31 | ).resize_as_(p.data).zero_() 32 | 33 | def share_memory(self): 34 | for group in self.param_groups: 35 | for p in group['params']: 36 | state = self.state[p] 37 | state['square_avg'].share_memory_() 38 | state['step'].share_memory_() 39 | state['grad_avg'].share_memory_() 40 | state['momentum_buffer'].share_memory_() 41 | 42 | def step(self, closure=None): 43 | """Performs a single optimization step. 44 | Arguments: 45 | closure (callable, optional): A closure that reevaluates the model 46 | and returns the loss. 47 | """ 48 | loss = None 49 | if closure is not None: 50 | loss = closure() 51 | 52 | for group in self.param_groups: 53 | for p in group['params']: 54 | if p.grad is None: 55 | continue 56 | grad = p.grad.data 57 | if grad.is_sparse: 58 | raise RuntimeError( 59 | 'RMSprop does not support sparse gradients') 60 | state = self.state[p] 61 | 62 | square_avg = state['square_avg'] 63 | alpha = group['alpha'] 64 | 65 | state['step'] += 1 66 | 67 | if group['weight_decay'] != 0: 68 | grad = grad.add(group['weight_decay'], p.data) 69 | 70 | square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) 71 | 72 | if group['centered']: 73 | grad_avg = state['grad_avg'] 74 | grad_avg.mul_(alpha).add_(1 - alpha, grad) 75 | avg = square_avg.addcmul( 76 | -1, grad_avg, grad_avg).sqrt().add_(group['eps']) 77 | else: 78 | avg = square_avg.sqrt().add_(group['eps']) 79 | 80 | if group['momentum'] > 0: 81 | buf = state['momentum_buffer'] 82 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 83 | p.data.add_(-group['lr'], buf) 84 | else: 85 | p.data.addcdiv_(-group['lr'], grad, avg) 86 | 87 | return loss 88 | 89 | 90 | class SharedAdam(optim.Optimizer): 91 | """Implements Adam algorithm with shared states. 92 | """ 93 | 94 | def __init__(self, 95 | params, 96 | lr=1e-3, 97 | betas=(0.9, 0.999), 98 | eps=1e-3, 99 | weight_decay=0, amsgrad=True): 100 | defaults = defaultdict(lr=lr, betas=betas, eps=eps, 101 | weight_decay=weight_decay, amsgrad=amsgrad) 102 | super(SharedAdam, self).__init__(params, defaults) 103 | 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | state = self.state[p] 107 | state['step'] = torch.zeros(1) 108 | state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() 109 | state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() 110 | state['max_exp_avg_sq'] = p.data.new( 111 | ).resize_as_(p.data).zero_() 112 | 113 | def share_memory(self): 114 | for group in self.param_groups: 115 | for p in group['params']: 116 | state = self.state[p] 117 | state['step'].share_memory_() 118 | state['exp_avg'].share_memory_() 119 | state['exp_avg_sq'].share_memory_() 120 | state['max_exp_avg_sq'].share_memory_() 121 | 122 | def step(self, closure=None): 123 | """Performs a single optimization step. 124 | Arguments: 125 | closure (callable, optional): A closure that reevaluates the model 126 | and returns the loss. 127 | """ 128 | loss = None 129 | if closure is not None: 130 | loss = closure() 131 | 132 | for group in self.param_groups: 133 | for p in group['params']: 134 | if p.grad is None: 135 | continue 136 | grad = p.grad.data 137 | if grad.is_sparse: 138 | raise RuntimeError( 139 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 140 | amsgrad = group['amsgrad'] 141 | 142 | state = self.state[p] 143 | 144 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 145 | if amsgrad: 146 | max_exp_avg_sq = state['max_exp_avg_sq'] 147 | beta1, beta2 = group['betas'] 148 | 149 | state['step'] += 1 150 | 151 | if group['weight_decay'] != 0: 152 | grad = grad.add(group['weight_decay'], p.data) 153 | 154 | # Decay the first and second moment running average coefficient 155 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 156 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 157 | 158 | if amsgrad: 159 | # Maintains the maximum of all 2nd moment running avg. till 160 | # now 161 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 162 | # Use the max. for normalizing running avg. of gradient 163 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 164 | else: 165 | denom = exp_avg_sq.sqrt().add_(group['eps']) 166 | 167 | # bias_correction1 = 1 - beta1**state['step'][0] 168 | # bias_correction2 = 1 - beta2**state['step'][0] 169 | bias_correction1 = 1 - beta1**state['step'].item() 170 | bias_correction2 = 1 - beta2**state['step'].item() 171 | step_size = group['lr'] * \ 172 | math.sqrt(bias_correction2) / bias_correction1 173 | 174 | p.data.addcdiv_(-step_size, exp_avg, denom) 175 | 176 | return loss 177 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from setproctitle import setproctitle as ptitle 3 | 4 | import os 5 | import time 6 | import torch 7 | import logging 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | 11 | from model import build_model 12 | from utils import setup_logger 13 | from player_util import Agent 14 | from environment import create_env 15 | 16 | 17 | def test(args, shared_model, optimizer, optimizer_ToM, train_modes, n_iters): 18 | ptitle('Test Agent') 19 | n_iter = 0 20 | writer = SummaryWriter(os.path.join(args.log_dir, 'Test')) 21 | gpu_id = args.gpu_id[-1] 22 | log = {} 23 | print(os.path.isdir(args.log_dir)) 24 | setup_logger('{}_log'.format(args.env), 25 | r'{0}/logger'.format(args.log_dir)) 26 | log['{}_log'.format(args.env)] = logging.getLogger( 27 | '{}_log'.format(args.env)) 28 | d_args = vars(args) 29 | for k in d_args.keys(): 30 | log['{}_log'.format(args.env)].info('{0}: {1}'.format(k, d_args[k])) 31 | 32 | torch.manual_seed(args.seed) 33 | if gpu_id >= 0: 34 | torch.cuda.manual_seed(args.seed) 35 | device = torch.device('cuda:' + str(gpu_id)) 36 | else: 37 | device = torch.device('cpu') 38 | 39 | env = create_env(args.env, args) 40 | #env.seed(args.seed) 41 | if "MSMTC" in args.env: 42 | # freeze env max steps to 100 43 | env.max_steps = 100 44 | 45 | start_time = time.time() 46 | count_eps = 0 47 | 48 | player = Agent(None, env, args, None, device) 49 | player.gpu_id = gpu_id 50 | player.model = build_model(player.env, args, device).to(device) 51 | player.model.eval() 52 | max_score = -100 53 | 54 | ave_reward_list = [] 55 | comm_cnt_list = [] 56 | comm_bit_list = [] 57 | tmp_list_1 = [] 58 | tmp_list_2 = [] 59 | while True: 60 | AG = 0 61 | reward_sum = np.zeros(player.num_agents) 62 | reward_sum_list = [] 63 | len_sum = 0 64 | 65 | for i_episode in range(args.test_eps): 66 | player.model.load_state_dict(shared_model.state_dict()) 67 | player.reset() 68 | reward_sum_ep = np.zeros(player.num_agents) 69 | rotation_sum_ep = 0 70 | 71 | fps_counter = 0 72 | t0 = time.time() 73 | count_eps += 1 74 | fps_all = [] 75 | 76 | comm_cnt = 0 77 | comm_bit = 0 78 | ToM_acc = 0 79 | ToM_target_acc = 0 80 | while True: 81 | player.action_test() 82 | fps_counter += 1 83 | reward_sum_ep += player.reward 84 | 85 | #ToM_acc += player.random_ToM_acc 86 | #ToM_target_acc += player.random_ToM_target_acc 87 | # comm_ToM_loss += player.comm_ToM_loss 88 | # no_comm_ToM_loss +=player.no_comm_ToM_loss 89 | # ToM_loss +=player.ToM_loss 90 | if 'comm' in args.model or 'ToM-v5' in args.model: 91 | comm_cnt += player.comm_cnt 92 | comm_bit += player.comm_bit 93 | if player.done: 94 | # print(ToM_acc/fps_counter) 95 | # print(ToM_target_acc/fps_counter) 96 | tmp_list_1.append(ToM_acc/fps_counter) 97 | tmp_list_2.append(ToM_target_acc/fps_counter) 98 | 99 | # if len(tmp_list_1) == 3: 100 | # print(np.mean(tmp_list_1),np.std(tmp_list_1)) 101 | # print(np.mean(tmp_list_2),np.std(tmp_list_2)) 102 | 103 | #print("steps:{}".format(fps_counter)) 104 | #print("comm:{}, no comm:{}, Total:{}".format(comm_ToM_loss.item()/fps_counter,no_comm_ToM_loss.item()/fps_counter,\ 105 | # ToM_loss.item()/fps_counter)) 106 | #print("reward:{}".format(reward_sum_ep[0])) 107 | #AG += reward_sum_ep[0]/rotation_sum_ep*player.num_agents 108 | reward_sum += reward_sum_ep 109 | reward_sum_list.append(reward_sum_ep[0]) 110 | len_sum += player.eps_len 111 | fps = fps_counter / (time.time()-t0) 112 | #n_iter = n_iters[0] if len(n_iters) > 0 else count_eps 113 | 114 | #for n in n_iters: 115 | # n_iter += n 116 | new_n_iter = sum(n_iters) 117 | if new_n_iter > n_iter: 118 | n_iter = new_n_iter 119 | # for i, r_i in enumerate(reward_sum_ep): 120 | # writer.add_scalar('test/reward'+str(i), r_i, n_iter) 121 | writer.add_scalar('test/reward', reward_sum_ep[0], n_iter) 122 | writer.add_scalar('test/fps', fps, n_iter) 123 | fps_all.append(fps) 124 | player.clean_buffer(player.done) 125 | 126 | #writer.add_scalar('test/eps_len', player.eps_len, n_iter) 127 | break 128 | ''' 129 | comm_cnt_list.append(comm_cnt/env.max_steps) 130 | comm_bit_list.append(comm_bit/env.max_steps) 131 | print("cnt: ",np.mean(comm_cnt_list),np.std(comm_cnt_list)) 132 | print("bit: ",np.mean(comm_bit_list),np.std(comm_bit_list)) 133 | comm_bit_list=[] 134 | comm_cnt_list=[] 135 | 136 | comm_cnt_avg = comm_cnt/(args.test_eps * 100) 137 | comm_bit_avg = comm_bit/(args.test_eps * 100) 138 | print("comm_cnt",comm_cnt_avg) 139 | print("comm_bandwidth",comm_bit_avg) 140 | comm_cnt_list.append(comm_cnt_avg) 141 | comm_bit_list.append(comm_bit_avg) 142 | if len(comm_cnt_list)==5: 143 | print(np.mean(comm_cnt_list),np.std(comm_cnt_list)) 144 | print(np.mean(comm_bit_list),np.std(comm_bit_list)) 145 | comm_bit_list=[] 146 | comm_cnt_list=[] 147 | ''' 148 | # player.max_length: 149 | ave_AG = AG/args.test_eps 150 | ave_reward_sum = reward_sum/args.test_eps 151 | len_mean = len_sum/args.test_eps 152 | reward_step = reward_sum / len_sum 153 | mean_reward = np.mean(reward_sum_list) 154 | std_reward = np.std(reward_sum_list) 155 | 156 | if args.workers == 0: 157 | # pure test, so compute reward mean and std 158 | ave_reward_list.append(mean_reward) 159 | if len(ave_reward_list) == 5: 160 | reward_mean = np.mean(ave_reward_list) 161 | reward_std = np.std(ave_reward_list) 162 | ave_reward_list = [] 163 | log['{}_log'.format(args.env)].info("mean reward {0}, std reward {1}".format(reward_mean, reward_std)) 164 | print("---------------") 165 | #n_iter = sum(n_iters) 166 | #writer.add_scalar('test/reward', ave_reward_sum[0], n_iter) 167 | 168 | log['{}_log'.format(args.env)].info( 169 | "Time {0}, ave eps reward {1}, ave eps length {2}, reward step {3}, FPS {4}, " 170 | "mean reward {5}, std reward {6}, AG {7}". 171 | format( 172 | time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), 173 | np.around(ave_reward_sum, decimals=2), np.around(len_mean, decimals=2), 174 | np.around(reward_step, decimals=2), np.around(np.mean(fps_all), decimals=2), 175 | mean_reward, std_reward, np.around(ave_AG, decimals=2) 176 | )) 177 | 178 | # save model 179 | if ave_reward_sum[0] > max_score: 180 | print('save best!') 181 | max_score = ave_reward_sum[0] 182 | model_dir = os.path.join(args.log_dir, 'best.pth') 183 | elif n_iter % 100000 == 0: 184 | model_dir = os.path.join(args.log_dir, ('new_'+str(n_iter)+'.pth').format(args.env)) 185 | #else: 186 | new_model_dir = os.path.join(args.log_dir, 'new.pth'.format(args.env)) 187 | state_to_save = {"model": player.model.state_dict(), 188 | "optimizer": optimizer.state_dict()} 189 | torch.save(state_to_save, model_dir) 190 | torch.save(state_to_save, new_model_dir) 191 | time.sleep(args.sleep_time) 192 | 193 | for rank in range(args.workers): 194 | if train_modes[rank] == -100: 195 | print("test process ended due to train process collapse") 196 | return 197 | 198 | if n_iter > args.max_step: 199 | env.close() 200 | for id in range(0, args.workers): 201 | train_modes[id] = -100 202 | break 203 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import torch.optim as optim 9 | from tensorboardX import SummaryWriter 10 | from setproctitle import setproctitle as ptitle 11 | 12 | import json 13 | from model import build_model 14 | from player_util import Agent 15 | from environment import create_env 16 | from shared_optim import SharedRMSprop, SharedAdam 17 | 18 | class HLoss(nn.Module): 19 | def __init__(self): 20 | super(HLoss, self).__init__() 21 | 22 | def forward(self, x, prior=None): 23 | if prior is None: 24 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 25 | b = -b.sum(1) 26 | b = b.mean() 27 | else: 28 | b = F.softmax(x, dim = -1) 29 | b = b * (F.log_softmax(x, dim = -1) - torch.log(prior).view(-1, x.size(-1))) 30 | b = -b.sum(-1) 31 | b = b.mean() 32 | return b 33 | 34 | def optimize_ToM(state, poses, masks, available_actions, args, params, optimizer_ToM, shared_model, device_share, env): 35 | num_agents = env.n 36 | num_targets = env.num_target 37 | max_steps = env.max_steps 38 | seg_num = int(max_steps/args.A2C_steps) 39 | if "MSMTC" in args.env: 40 | batch_size, num_agents, num_both, obs_dim = state.size() 41 | elif "CN" in args.env: 42 | batch_size, num_agents, obs_dim = state.size() 43 | count = int(batch_size/max_steps) 44 | print("batch_size = ",batch_size) 45 | # state, poses are only to device when being used 46 | if "MSMTC" in args.env: 47 | state = state.reshape(count, max_steps, num_agents, num_both, obs_dim)#.to(device_share) 48 | elif "CN" in args.env: 49 | state = state.reshape(count, max_steps, num_agents, obs_dim)#.to(device_share) 50 | 51 | batch_size, num_agents, num_agents, cam_dim = poses.size() 52 | poses = poses.reshape(count, max_steps, num_agents, num_agents, cam_dim)#.to(device_share) 53 | 54 | masks = masks.reshape(count, max_steps, num_agents, num_agents, 1) 55 | h_ToM = torch.zeros(count, num_agents, num_agents, args.lstm_out).to(device_share) 56 | hself = torch.zeros(count, num_agents, args.lstm_out ).to(device_share) 57 | hself_start = hself.clone().detach() # save the intial hidden state for every args.num_steps 58 | hToM_start = h_ToM.clone().detach() 59 | 60 | if args.mask_actions: 61 | available_actions = available_actions.reshape(count, max_steps, num_agents, num_targets, -1) 62 | 63 | ToM_loss_sum = torch.zeros(1).to(device_share) 64 | ToM_target_loss_sum = torch.zeros(1).to(device_share) 65 | ToM_target_acc_sum = torch.zeros(1).to(device_share) 66 | for seg in range(seg_num): 67 | for train_loop in range(args.ToM_train_loops): 68 | hself = hself_start.clone().detach() 69 | h_ToM = hToM_start.clone().detach() 70 | ToM_goals = None 71 | real_goals = None 72 | BCE_criterion = torch.nn.BCELoss(reduction='sum') 73 | ToM_target_loss = torch.zeros(1).to(device_share) 74 | ToM_target_acc = torch.zeros(1).to(device_share) 75 | for s_i in range(args.A2C_steps): 76 | step = seg * args.A2C_steps + s_i 77 | available_action = available_actions[:,step].to(device_share) if args.mask_actions else None 78 | 79 | value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, edge_logits, comm_edges, probs, real_cover, ToM_target_cover =\ 80 | shared_model(state[:,step].to(device_share), hself, h_ToM, poses[:,step].to(device_share), masks[:,step].to(device_share), available_actions = available_action) 81 | ToM_target_loss += BCE_criterion(ToM_target_cover.float(), real_cover.float()) 82 | ToM_target_cover_discrete = (ToM_target_cover > 0.6) 83 | ToM_target_acc += torch.sum((ToM_target_cover_discrete == real_cover)) 84 | 85 | hself = hn_self 86 | h_ToM = hn_ToM 87 | 88 | ToM_goal = ToM_goal.unsqueeze(1) 89 | if "MSMTC" in args.env: 90 | real_goal = torch.cat((1-actions,actions),-1).detach() 91 | real_goal_duplicate = real_goal.reshape(count, 1, num_agents, num_targets, -1).repeat(1, num_agents, 1, 1, 1) 92 | idx= (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool() 93 | real_goal_duplicate = real_goal_duplicate[:,idx].reshape(count, 1, num_agents, num_agents-1, num_targets, -1) 94 | elif "CN" in args.env: 95 | real_goal = actions.reshape(count * num_agents, 1) 96 | real_goal_duplicate = torch.zeros(count * num_agents, num_targets).to(device_share).scatter_(1, real_goal, 1) 97 | real_goal_duplicate = real_goal_duplicate.reshape(count, 1, num_agents, num_targets, -1).repeat(1, num_agents, 1, 1, 1) 98 | idx= (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool() 99 | real_goal_duplicate = real_goal_duplicate[:,idx].reshape(count, 1, num_agents, num_agents-1, num_targets) 100 | if ToM_goals is None: 101 | ToM_goals = ToM_goal 102 | real_goals = real_goal_duplicate 103 | else: 104 | ToM_goals = torch.cat((ToM_goals, ToM_goal),1) 105 | real_goals = torch.cat((real_goals, real_goal_duplicate), 1) 106 | ToM_loss = torch.zeros(1).to(device_share) 107 | KL_criterion = torch.nn.KLDivLoss(reduction='sum') 108 | real_prob = real_goals.float() 109 | ToM_prob = ToM_goals.float() 110 | ToM_loss += KL_criterion(ToM_prob.log(), real_prob) 111 | 112 | loss = ToM_loss + 0.5 * ToM_target_loss 113 | loss = loss/(count) 114 | shared_model.zero_grad() 115 | loss.backward() 116 | all_grads = [p.grad for p in params] 117 | flat_grads = torch.cat([g.view(-1) for g in all_grads]) 118 | if torch.isinf(flat_grads).any() or torch.isnan(flat_grads).any(): 119 | print("Detect inf/nan gradients, skip updating model") 120 | else: 121 | torch.nn.utils.clip_grad_norm_(params, 20) 122 | optimizer_ToM.step() 123 | 124 | # update hidden state start & loss sum 125 | hself_start = hself.clone().detach() 126 | hToM_start = h_ToM.clone().detach() 127 | ToM_loss_sum += ToM_loss 128 | ToM_target_loss_sum += ToM_target_loss 129 | ToM_target_acc_sum += ToM_target_acc 130 | 131 | print("ToM_loss =", ToM_loss_sum.sum().data) 132 | print("ToM Target loss=", ToM_target_loss_sum.sum().data) 133 | cnt_all = (num_agents * (num_agents-1) * num_targets * batch_size) 134 | ToM_loss_mean = ToM_loss_sum/cnt_all 135 | ToM_target_loss_mean = ToM_target_loss_sum/cnt_all 136 | ToM_target_acc_mean = ToM_target_acc_sum/cnt_all 137 | return ToM_loss_sum, ToM_loss_mean, ToM_target_loss_mean, ToM_target_acc_mean 138 | 139 | def optimize_Policy(state, poses, real_actions, reward, masks, available_actions, args, params, optimizer_Policy, shared_model, device_share, env): 140 | num_agents = env.n 141 | num_targets = env.num_target 142 | max_steps = env.max_steps 143 | assert max_steps % args.A2C_steps == 0 144 | seg_num = int(max_steps/args.A2C_steps) 145 | if "MSMTC" in args.env: 146 | batch_size, num_agents, num_both, obs_dim = state.size() 147 | elif "CN" in args.env: 148 | batch_size, num_agents, obs_dim = state.size() 149 | count = int(batch_size/max_steps) 150 | 151 | if count != args.workers: 152 | print(count) 153 | assert count == args.workers 154 | 155 | # state, cam_state, reward, real_actions are to device only when being used 156 | if "MSMTC" in args.env: 157 | state = state.reshape(count, max_steps, num_agents, num_both, obs_dim)#.to(device_share) 158 | real_actions = real_actions.reshape(count, max_steps, num_agents, num_targets, 1)#.to(device_share) 159 | elif "CN" in args.env: 160 | state = state.reshape(count, max_steps, num_agents, obs_dim)#.to(device_share) 161 | real_actions = real_actions.reshape(count, max_steps, num_agents, 1)#.to(device_share) 162 | 163 | batch_size, num_agents, num_agents, cam_dim = poses.size() 164 | poses = poses.reshape(count, max_steps, num_agents, num_agents, cam_dim)#.to(device_share) 165 | batch_size, num_agents, r_dim = reward.size() 166 | reward = reward.reshape(count, max_steps, num_agents, r_dim)#.to(device_share) 167 | 168 | masks = masks.reshape(count, max_steps, num_agents, num_agents, 1) 169 | h_ToM = torch.zeros(count, num_agents, num_agents, args.lstm_out).to(device_share) 170 | 171 | hself = torch.zeros(count, num_agents, args.lstm_out ).to(device_share) 172 | #hothers = torch.zeros(count, num_agents, num_agents-1, args.lstm_out).to(device_share) 173 | hself_start = hself.clone().detach() # save the intial hidden state for every args.num_steps 174 | hToM_start = h_ToM.clone().detach() 175 | if args.mask_actions: 176 | available_actions = available_actions.reshape(count, max_steps, num_agents, num_targets, -1) 177 | 178 | policy_loss_sum = torch.zeros(count, num_agents, num_targets, 1).to(device_share) 179 | value_loss_sum = torch.zeros(count, num_agents, 1).to(device_share) 180 | entropies_all = torch.zeros(1).to(device_share) 181 | Sparsity_loss_sum = torch.zeros(count, 1).to(device_share) 182 | 183 | for seg in range(seg_num): # loop for every args.A2C_steps 184 | for train_loop in range(args.policy_train_loops): 185 | hself = hself_start.clone().detach() 186 | h_ToM = hToM_start.clone().detach() 187 | values = [] 188 | entropies = [] 189 | log_probs = [] 190 | rewards = [] 191 | edge_logits = [] 192 | for s_i in range(args.A2C_steps): 193 | step = s_i + seg * args.A2C_steps 194 | available_action = available_actions[:,step].to(device_share) if args.mask_actions else None 195 | 196 | if "ToM2C" in args.model: 197 | value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, edge_logit, comm_edges, probs, real_cover, ToM_target_cover =\ 198 | shared_model(state[:,step].to(device_share), hself, h_ToM, poses[:,step].to(device_share), masks[:,step].to(device_share), available_actions= available_action) 199 | hself = hn_self 200 | hToM = hn_ToM 201 | 202 | values.append(value_multi) 203 | entropies.append(entropy) 204 | log_probs.append(torch.log(probs).gather(-1, real_actions[:,step].to(device_share))) 205 | rewards.append(reward[:,step].to(device_share)) 206 | 207 | edge_logits.append(edge_logit) 208 | 209 | R = torch.zeros(count, num_agents, 1).to(device_share) 210 | if seg < seg_num -1: 211 | # not the last segment of the episode 212 | next_step = (seg+1) * args.A2C_steps 213 | available_action = available_actions[:,next_step].to(device_share) if args.mask_actions else None 214 | value_multi, *others = shared_model(state[:,next_step].to(device_share), hself, h_ToM, poses[:,next_step].to(device_share), masks[:,next_step].to(device_share), available_actions= available_action) 215 | R = value_multi.clone().detach() 216 | 217 | R = R.to(device_share) 218 | values.append(R) 219 | 220 | policy_loss = torch.zeros(count, num_agents, num_targets, 1).to(device_share) 221 | value_loss = torch.zeros(count, num_agents, 1).to(device_share) 222 | entropies_sum = torch.zeros(1).to(device_share) 223 | w_entropies = float(args.entropy) 224 | 225 | Sparsity_loss = torch.zeros(count, 1).to(device_share) 226 | 227 | criterionH = HLoss() 228 | edge_prior = torch.FloatTensor(np.array([0.7, 0.3])).to(device_share) 229 | gae = torch.zeros(count, num_agents, 1).to(device_share) 230 | 231 | for i in reversed(range(args.A2C_steps)): 232 | R = args.gamma * R + rewards[i] 233 | advantage = R - values[i] 234 | value_loss = value_loss + 0.5 * advantage.pow(2) 235 | # Generalized Advantage Estimataion 236 | delta_t = rewards[i] + args.gamma * values[i + 1].data - values[i].data 237 | gae = gae * args.gamma * args.tau + delta_t 238 | #value_loss = value_loss + 0.5 * (gae + values[i].data -values[i]).pow(2) 239 | 240 | if "MSMTC" in args.env: 241 | gae_duplicate = gae.unsqueeze(2).repeat(1,1,num_targets,1) 242 | policy_loss = policy_loss - (w_entropies * entropies[i]) - (log_probs[i] * gae_duplicate) 243 | elif "CN" in args.env: 244 | gae_duplicate = gae 245 | if policy_loss.sum() == 0 : policy_loss = torch.zeros(1).to(device_share) 246 | policy_loss = policy_loss - (w_entropies * entropies[i].sum()) - (log_probs[i] * gae_duplicate).sum() 247 | 248 | entropies_sum += entropies[i].sum() 249 | 250 | edge_logit = edge_logits[i]#.reshape(count * num_agents * num_agents, -1) # k * 2 251 | Sparsity_loss += -criterionH(edge_logit, edge_prior) 252 | 253 | shared_model.zero_grad() 254 | loss = policy_loss.sum() + 0.5 * value_loss.sum() #+ 0.3 * Sparsity_loss.sum() 255 | loss = loss/(count * 4) 256 | loss.backward() 257 | 258 | torch.nn.utils.clip_grad_norm_(params, 5) 259 | optimizer_Policy.step() 260 | # update hself & hothers start for next segment 261 | hself_start = hself.clone().detach() 262 | hToM_start = h_ToM.clone().detach() 263 | # sum all the loss 264 | policy_loss_sum += policy_loss 265 | value_loss_sum += value_loss 266 | Sparsity_loss_sum += Sparsity_loss 267 | entropies_all += entropies_sum 268 | 269 | return policy_loss_sum, value_loss_sum, Sparsity_loss_sum, entropies_all 270 | 271 | def reduce_comm(policy_data, args, params_comm, optimizer, lr_scheduler, shared_model, ori_model, device_share, env): 272 | state, poses, real_actions, reward, comm_domains, available_actions = policy_data 273 | 274 | num_agents = env.n 275 | num_targets = env.num_target 276 | max_steps = env.max_steps 277 | assert max_steps % args.A2C_steps == 0 278 | 279 | if "MSMTC" in args.env: 280 | batch_size, num_agents, num_both, obs_dim = state.size() 281 | elif "CN" in args.env: 282 | batch_size, num_agents, obs_dim = state.size() 283 | count = int(batch_size/max_steps) 284 | 285 | # state, cam_state, reward, real_actions are to device only when being used 286 | if "MSMTC" in args.env: 287 | state = state.reshape(count, max_steps, num_agents, num_both, obs_dim)#.to(device_share) 288 | elif "CN" in args.env: 289 | state = state.reshape(count, max_steps, num_agents, obs_dim)#.to(device_share) 290 | 291 | batch_size, num_agents, num_agents, cam_dim = poses.size() 292 | poses = poses.reshape(count, max_steps, num_agents, num_agents, cam_dim)#.to(device_share) 293 | 294 | comm_domains = comm_domains.reshape(count, max_steps, num_agents, num_agents, 1) 295 | h_ToM = torch.zeros(count, num_agents, num_agents, args.lstm_out).to(device_share) 296 | hself = torch.zeros(count, num_agents, args.lstm_out ).to(device_share) 297 | if args.mask_actions: 298 | available_actions = available_actions.reshape(count, max_steps, num_agents, num_targets, -1) 299 | 300 | 301 | comm_loss_sum = torch.zeros(1).to(device_share) 302 | 303 | # sample_ids = [i for i in range(args.comm_train_loops * count)] 304 | # random.shuffle(sample_ids) 305 | # sample_ids = np.array(sample_ids) % count 306 | mini_batch_size = count 307 | 308 | #epoch_cnt = int(count * args.comm_train_loops / args.mini_batch_size) 309 | for epoch in range(1): 310 | #ids = sample_ids[epoch * mini_batch_size:(epoch+1)*mini_batch_size] 311 | 312 | h_ToM = torch.zeros(mini_batch_size, num_agents, num_agents, args.lstm_out).to(device_share) 313 | hself = torch.zeros(mini_batch_size, num_agents, args.lstm_out).to(device_share) 314 | 315 | comm_loss = torch.zeros(1).to(device_share) 316 | CE_criterion = nn.CrossEntropyLoss(reduction='mean') 317 | 318 | for step in range(max_steps): 319 | available_action = available_actions[:,step].to(device_share) if args.mask_actions else None 320 | 321 | hn_self, hn_ToM, edge_logit, curr_edges,_ , _= shared_model(state[:,step].to(device_share), hself, h_ToM,\ 322 | poses[:,step].to(device_share), comm_domains[:,step].to(device_share), available_actions= available_action, train_comm = args.train_comm) 323 | _, _, _, _, best_edges, edge_label= ori_model(state[:,step].to(device_share), hself.detach(), h_ToM.detach(),\ 324 | poses[:,step].to(device_share), comm_domains[:,step].to(device_share), available_actions= available_action, train_comm = args.train_comm) 325 | hself = hn_self 326 | hToM = hn_ToM 327 | 328 | # print(curr_edges) 329 | # print(best_edges) 330 | # idx = (best_edges == 1) 331 | # if curr_edges[idx].size()[0] > 0: 332 | # print(torch.sum(1-curr_edges[idx])/curr_edges[idx].size()[0]) 333 | # print(curr_edges.sum()/mini_batch_size) 334 | # print("------------") 335 | # print(edge_label) 336 | # print(edge_logit) 337 | # print(edge_logit.shape, edge_label.shape) 338 | edge_label = edge_label.detach() 339 | idx_0 = (edge_label == 0) 340 | idx_1 = (edge_label == 1) 341 | logit_0 = edge_logit[idx_0] 342 | logit_1 = edge_logit[idx_1] 343 | label_0 = edge_label[idx_0] 344 | label_1 = edge_label[idx_1] 345 | size_0 = label_0.size()[0] 346 | size_1 = label_1.size()[0] 347 | ''' 348 | if size_0 > size_1: 349 | random_ids = [i for i in range(size_1)] 350 | random.shuffle(random_ids) 351 | random_ids = random_ids[:size_1] 352 | logit_0 = logit_0[random_ids] 353 | label_0 = label_0[random_ids] 354 | ''' 355 | loss_0 = CE_criterion(logit_0, label_0.long()) if size_0 > 0 else 0 356 | loss_1 = CE_criterion(logit_1, label_1.long()) if size_1 > 0 else 0 357 | #print(CE_criterion(edge_logit,edge_label.long()), loss_0+loss_1) 358 | comm_loss += loss_0 + loss_1 359 | #print(logit_0[:5,0].reshape(-1).data) 360 | shared_model.zero_grad() 361 | comm_loss.backward()#retain_graph=True) 362 | torch.nn.utils.clip_grad_norm_(params_comm, 20) 363 | optimizer.step() 364 | lr_scheduler.step() 365 | comm_loss_sum += comm_loss 366 | print(curr_edges.sum()/mini_batch_size) 367 | 368 | # for param_group in optimizer.param_groups(): 369 | # for param in param_group['params']: 370 | # for name, model_param in shared_model.named_parameters(): 371 | # if model_param is param: 372 | # print(name) 373 | 374 | # for name,param in shared_model.named_parameters(): 375 | # if 'graph' in name: 376 | # if param.grad is None: 377 | # print(name) 378 | # else: 379 | # print(name, torch.norm(param.grad)) 380 | # # break 381 | 382 | return comm_loss_sum 383 | 384 | def load_data(args, history): 385 | history_list = [] 386 | for rank in range(args.workers): 387 | history_list += history[rank] 388 | 389 | item_cnt = len(history_list[0]) 390 | item_name = [item for item in history_list[0]] 391 | data_list = [[] for i in range(item_cnt)] 392 | 393 | for history in history_list: 394 | for i,item in enumerate(history): 395 | data_list[i].append(history[item]) 396 | 397 | for i in range(item_cnt): 398 | data_list[i] = torch.from_numpy(np.array(data_list[i])) 399 | if 'reward' in item_name[i]: 400 | data_list[i] = data_list[i].unsqueeze(-1) 401 | 402 | return data_list 403 | 404 | def train(args, shared_model, optimizer_Policy, optimizer_ToM, train_modes, n_iters, curr_env_steps, ToM_count, ToM_history, Policy_history, step_history, loss_history, env=None): 405 | rank = args.workers 406 | writer = SummaryWriter(os.path.join(args.log_dir, 'Train')) 407 | ptitle('Training') 408 | gpu_id = args.gpu_id[rank % len(args.gpu_id)] 409 | torch.manual_seed(args.seed + rank) 410 | env_name = args.env 411 | 412 | if gpu_id >= 0: 413 | torch.cuda.manual_seed(args.seed + rank) 414 | device = torch.device('cuda:' + str(gpu_id)) 415 | if len(args.gpu_id) > 1: 416 | raise AssertionError("Do not support multi-gpu training") 417 | #device_share = torch.device('cpu') 418 | else: 419 | device_share = torch.device('cuda:' + str(args.gpu_id[-1])) 420 | else: 421 | device_share = torch.device('cpu') 422 | #device_share = torch.device('cuda:0') 423 | if env == None: 424 | env = create_env(env_name, args) 425 | 426 | params = [] 427 | params_ToM = [] 428 | params_comm = [] 429 | for name,param in shared_model.named_parameters(): 430 | if 'ToM' in name or 'other' in name: 431 | params_ToM.append(param) 432 | else: 433 | params.append(param) 434 | if 'graph' in name: 435 | params_comm.append(param) 436 | 437 | if args.train_comm: # train communication in supervised way (communication reduction) 438 | optimizer_comm = SharedAdam(params_comm, lr=0.02, amsgrad=args.amsgrad) #lr=0.1 for MSMTC 439 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_comm, step_size=20, gamma=0.2) #lr=0.5 for MSMTC 440 | ori_model = build_model(env, args, device_share) 441 | ori_model = ori_model.to(device_share) 442 | ori_state = torch.load( 443 | args.load_model_dir, 444 | map_location=lambda storage, loc: storage 445 | ) 446 | ori_model.load_state_dict(ori_state['model']) 447 | ori_model.eval() 448 | 449 | train_step_cnt = 0 450 | while True: # wait for all workers to finish collecting trajectories 451 | t1 = time.time() 452 | while True: 453 | flag = True 454 | curr_time = time.time() 455 | if curr_time - t1 > 180: 456 | print("waiting too long for workers") 457 | print("train modes:", train_modes) 458 | return 459 | for rank in range(args.workers): 460 | if train_modes[rank] != -10: 461 | flag = False # some worker is still collecting trajectories 462 | break 463 | if flag: 464 | break 465 | 466 | t2 = time.time() 467 | 468 | print("training start after waiting for {} seconds".format(t2-t1)) 469 | 470 | if args.train_comm: 471 | data_list = load_data(args, Policy_history) 472 | comm_loss = reduce_comm(data_list, args, params_comm, optimizer_comm, lr_scheduler, shared_model, ori_model, device_share, env) 473 | writer.add_scalar('train/comm_loss', comm_loss.sum(), sum(n_iters)) 474 | print("comm_loss:", comm_loss.item()) 475 | if comm_loss.sum() < 1: 476 | break 477 | else: 478 | train_step_cnt += 1 479 | state, poses, real_actions, reward, masks, available_actions = load_data(args, Policy_history) 480 | 481 | policy_loss, value_loss, Sparsity_loss, entropies_sum =\ 482 | optimize_Policy(state, poses, real_actions, reward, masks, available_actions, args, params, optimizer_Policy, shared_model, device_share, env) 483 | # log training information 484 | n_steps = sum(n_iters) # global_steps_count 485 | writer.add_scalar('train/policy_loss_sum', policy_loss.sum(), n_steps) 486 | writer.add_scalar('train/value_loss_sum', value_loss.sum(), n_steps) 487 | writer.add_scalar('train/Sparsity_loss_sum', Sparsity_loss.sum(), n_steps) 488 | writer.add_scalar('train/entropies_sum', entropies_sum.sum(), n_steps) 489 | writer.add_scalar('train/gamma', args.gamma, n_steps) 490 | print("policy loss:{}".format(policy_loss.sum().data)) 491 | print("value loss:{}".format(value_loss.sum().data)) 492 | print("entropies:{}".format(entropies_sum.sum().data)) 493 | print("Policy training finished") 494 | print("---------------------") 495 | 496 | ToM_len = args.ToM_frozen * args.workers * env.max_steps 497 | if 'ToM2C' in args.model: 498 | if sum(ToM_count) >= ToM_len: 499 | print("ToM training started") 500 | state, poses, masks, real_goals, available_actions = load_data(args, ToM_history) 501 | print("ToM data loaded") 502 | ToM_loss_sum, ToM_loss_avg, ToM_target_loss, ToM_target_acc = optimize_ToM(state, poses, masks, available_actions, args, params_ToM, optimizer_ToM, shared_model, device_share, env) 503 | print("optimized based on ToM loss") 504 | 505 | writer.add_scalar('train/ToM_loss_sum', ToM_loss_sum.sum(), n_steps) 506 | writer.add_scalar('train/ToM_loss_avg', ToM_loss_avg.sum(), n_steps) 507 | writer.add_scalar('train/ToM_target_loss_avg', ToM_target_loss.sum(), n_steps) 508 | writer.add_scalar('train/ToM_target_acc_avg', ToM_target_acc.sum(), n_steps) 509 | 510 | for rank in range(args.workers): 511 | ToM_history[rank] = [] 512 | ToM_count[rank] = 0 513 | print("---------------------") 514 | 515 | if args.gamma_rate > 0: 516 | # add this one for schedule learning 517 | if n_steps >= args.start_eps * 20 * args.workers and args.gamma < args.gamma_final and train_step_cnt % (args.ToM_frozen) == 0: 518 | if args.gamma > 0.4: 519 | args.gamma = args.gamma * (1 + args.gamma_rate/2) 520 | else: 521 | args.gamma = args.gamma * (1 + args.gamma_rate) 522 | if "MSMTC" in args.env: 523 | new_env_step = int((args.gamma + 0.1)/0.2) * args.env_steps 524 | env.max_steps = new_env_step 525 | for rank in range(args.workers): 526 | curr_env_steps[rank] = new_env_step 527 | 528 | print("gamma:", args.gamma) 529 | assert args.gamma < 0.95 530 | for rank in range(args.workers): 531 | Policy_history[rank] = [] 532 | if train_modes[rank] == -100: 533 | return 534 | train_modes[rank] = -1 535 | 536 | if train_modes[0] == -100: 537 | env.close() 538 | break -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | import json 4 | import torch 5 | import logging 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | def setup_logger(logger_name, log_file, level=logging.INFO): 11 | l = logging.getLogger(logger_name) 12 | formatter = logging.Formatter('%(asctime)s : %(message)s') 13 | fileHandler = logging.FileHandler(log_file, mode='w') 14 | fileHandler.setFormatter(formatter) 15 | streamHandler = logging.StreamHandler() 16 | streamHandler.setFormatter(formatter) 17 | 18 | l.setLevel(level) 19 | l.addHandler(fileHandler) 20 | l.addHandler(streamHandler) 21 | 22 | 23 | def read_config(file_path): 24 | """Read JSON config.""" 25 | json_object = json.load(open(file_path, 'r')) 26 | return json_object 27 | 28 | 29 | def norm_col_init(weights, std=1.0): 30 | x = torch.randn(weights.size()) 31 | x *= std / torch.sqrt((x ** 2).sum(1, keepdim=True)) 32 | return x 33 | 34 | 35 | def ensure_shared_grads(model, shared_model, device, device_share): 36 | diff_device = device != device_share 37 | for param, shared_param in zip(model.parameters(), shared_model.parameters()): 38 | if param.grad is None: 39 | continue 40 | if shared_param.grad is not None and not diff_device: 41 | return 42 | elif not diff_device: 43 | shared_param._grad = param.grad 44 | else: 45 | shared_param._grad = param.grad.to(device_share) 46 | 47 | 48 | def ensure_shared_grads_param(params, shared_params, gpu=False): 49 | for param, shared_param in zip(params, shared_params): 50 | # print (shared_param) 51 | if shared_param.grad is not None and not gpu: 52 | return 53 | 54 | if not gpu: 55 | shared_param._grad = param.grad 56 | else: 57 | shared_param._grad = param.grad.clone().cpu() 58 | 59 | 60 | def weights_init(m): 61 | classname = m.__class__.__name__ 62 | if classname.find('Conv') != -1: 63 | weight_shape = list(m.weight.data.size()) 64 | fan_in = np.prod(weight_shape[1:4]) 65 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 66 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 67 | m.weight.data.uniform_(-w_bound, w_bound) 68 | m.bias.data.fill_(0) 69 | elif classname.find('Linear') != -1: 70 | weight_shape = list(m.weight.data.size()) 71 | fan_in = weight_shape[1] 72 | fan_out = weight_shape[0] 73 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 74 | m.weight.data.uniform_(-w_bound, w_bound) 75 | m.bias.data.fill_(0) 76 | 77 | 78 | def weights_init_mlp(m): 79 | classname = m.__class__.__name__ 80 | if classname.find('Linear') != -1: 81 | m.weight.data.normal_(0, 1) 82 | m.weight.data *= 1 / \ 83 | torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) 84 | if m.bias is not None: 85 | m.bias.data.fill_(0) 86 | 87 | 88 | def normal(x, mu, sigma, device): 89 | pi = np.array([math.pi]) 90 | pi = torch.from_numpy(pi).float() 91 | pi = Variable(pi).to(device) 92 | a = (-1 * (x - mu).pow(2) / (2 * sigma)).exp() 93 | b = 1 / (2 * sigma * pi.expand_as(sigma)).sqrt() 94 | return a * b 95 | 96 | 97 | def check_path(path): 98 | import os 99 | if not os.path.exists(path): 100 | os.mkdir(path) 101 | 102 | 103 | def goal_id_filter(goals): 104 | return np.where(goals > 0.5)[0] 105 | 106 | 107 | def norm(x, scale): 108 | assert len(x.shape) <= 2 109 | x = scale * (x - x.mean(0)) / (x.std(0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem 110 | return x 111 | 112 | 113 | class ToTensor(object): 114 | def __call__(self, sample): 115 | sample = sample.transpose(0, 3, 1, 2) 116 | return torch.from_numpy(sample.astype(np.float32)) 117 | -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | import torch 5 | import numpy as np 6 | import torch.optim as optim 7 | from tensorboardX import SummaryWriter 8 | from setproctitle import setproctitle as ptitle 9 | 10 | from model import build_model 11 | from player_util import Agent 12 | from environment import create_env 13 | 14 | 15 | def worker(rank, args, shared_model, train_modes, n_iters, curr_env_steps, ToM_count, ToM_history, Policy_history, step_history, loss_history, env=None): 16 | n_iter = 0 17 | writer = SummaryWriter(os.path.join(args.log_dir, 'Agent-{}'.format(rank))) 18 | ptitle('worker: {}'.format(rank)) 19 | gpu_id = args.gpu_id[rank % len(args.gpu_id)] 20 | torch.manual_seed(args.seed + rank) 21 | training_mode = args.train_mode 22 | env_name = args.env 23 | 24 | if gpu_id >= 0: 25 | torch.cuda.manual_seed(args.seed + rank) 26 | device = torch.device('cuda:' + str(gpu_id)) 27 | if len(args.gpu_id) > 1: 28 | raise AssertionError("Do not support multi-gpu training") 29 | #device_share = torch.device('cpu') 30 | else: 31 | device_share = torch.device('cuda:' + str(args.gpu_id[-1])) 32 | 33 | else: 34 | device = device_share = torch.device('cpu') 35 | 36 | #device = torch.device("cpu") # there's no need for worker to use 37 | 38 | if env == None: 39 | env = create_env(env_name, args, rank) 40 | 41 | if args.fix: 42 | env.seed(args.seed) 43 | else: 44 | env.seed(rank % (args.seed + 1)) 45 | 46 | player = Agent(None, env, args, None, device) 47 | player.rank = rank 48 | player.gpu_id = gpu_id 49 | 50 | # prepare model 51 | player.model = shared_model 52 | 53 | player.reset() 54 | reward_sum = torch.zeros(player.num_agents).to(device) 55 | reward_sum_org = np.zeros(player.num_agents) 56 | ave_reward = np.zeros(2) 57 | ave_reward_longterm = np.zeros(2) 58 | count_eps = 0 59 | #max_steps = env.max_steps 60 | while True: 61 | if "MSMTC" in args.env and args.random_target: 62 | p = 0.7 - (env.max_steps/20 -1) * 0.1 63 | 64 | env.target_type_prob = [p, 1-p] 65 | player.env.target_type_prob = [p, 1-p] 66 | 67 | # sys to the shared model 68 | player.model.load_state_dict(shared_model.state_dict()) 69 | if player.done: 70 | player.reset() 71 | reward_sum = torch.zeros(player.num_agents).to(device) 72 | reward_sum_org = np.zeros(player.num_agents) 73 | count_eps += 1 74 | 75 | 76 | player.update_rnn_hidden() 77 | t0 = time.time() 78 | 79 | for s_i in range(env.max_steps): 80 | player.action_train() 81 | if 'ToM' in args.model: 82 | ToM_count[rank] += 1 83 | reward_sum += player.reward 84 | reward_sum_org += player.reward_org 85 | if player.done: 86 | writer.add_scalar('train/reward', reward_sum[0], n_iter) 87 | writer.add_scalar('train/reward_org', reward_sum_org[0].sum(), n_iter) 88 | break 89 | fps = s_i / (time.time() - t0) 90 | 91 | writer.add_scalar('train/fps', fps, n_iter) 92 | 93 | n_iter += env.max_steps # s_i 94 | n_iters[rank] = n_iter 95 | 96 | # wait for training process 97 | Policy_history[rank] = player.Policy_history 98 | player.Policy_history = [] 99 | ''' 100 | # for evaluation, no need in real training 101 | player.optimize(None, None, shared_model, training_mode, device_share) 102 | step_history[rank] = player.step_history 103 | loss_history[rank] = player.loss_history 104 | ' 105 | player.step_history = [] 106 | player.loss_history = [] 107 | # evaluation end 108 | ''' 109 | if 'ToM' in args.model: 110 | ToM_history[rank] += player.ToM_history 111 | player.ToM_history = [] 112 | 113 | train_modes[rank] = -10 # have to put this line at last 114 | train_start_time = time.time() 115 | while train_modes[rank] != -1: 116 | current_time = time.time() 117 | if current_time - train_start_time > 180 : 118 | print("stuck in training") 119 | train_modes[rank] = -100 120 | return 121 | # update env steps during training 122 | env.max_steps = curr_env_steps[rank] 123 | player.env.max_steps = env.max_steps 124 | 125 | player.clean_buffer(player.done) 126 | 127 | if sum(n_iters) > args.max_step: 128 | train_modes[rank] = -100 129 | 130 | if train_modes[rank] == -100: 131 | env.close() 132 | break --------------------------------------------------------------------------------