├── model ├── __init__.py ├── random_policy.py ├── inference_utils.py ├── gumbel.py ├── inference.py ├── contrastive.py └── contrastive_cmi.py ├── utils ├── __init__.py ├── penv.py ├── utils.py ├── multiprocessing_env.py ├── plot.py ├── sum_tree.py └── replay_buffer.py ├── .gitignore ├── .gitmodules ├── README.md ├── collect_minigrid_data.py ├── configs ├── minigrid_full.json ├── igibson_discrete.json ├── igibson_continuous.json └── minigrid_sparse.json ├── collect_igibson_data.py ├── test.py ├── causal_inference.py └── train.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | original_scripts 2 | *.idea* 3 | *.pyc -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "iGibson-CausalMoMa"] 2 | path = iGibson-CausalMoMa 3 | url = https://github.com/JiahengHu/iGibson-CausalMoMa.git 4 | [submodule "sb3-CausalMoMa"] 5 | path = sb3-CausalMoMa 6 | url = https://github.com/JiahengHu/sb3-CausalMoMa.git 7 | [submodule "Minigrid-CausalMoMa"] 8 | path = Minigrid-CausalMoMa 9 | url = https://github.com/JiahengHu/Minigrid-CausalMoMa.git 10 | -------------------------------------------------------------------------------- /model/random_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | 5 | class RandomPolicy(nn.Module): 6 | def __init__(self, params): 7 | super(RandomPolicy, self).__init__() 8 | self.continuous_action = params.continuous_action 9 | if self.continuous_action: 10 | action_low, action_high = params.action_spec 11 | self.action_mean = (action_low + action_high) / 2 12 | self.action_scale = (action_high - action_low) / 2 13 | else: 14 | self.action_dim = params.action_dim 15 | 16 | def act_randomly(self): 17 | if self.continuous_action: 18 | return self.action_mean + self.action_scale * np.random.uniform(-1, 1, self.action_scale.shape) 19 | else: 20 | return np.random.randint(self.action_dim) 21 | 22 | def act(self, obs): 23 | return self.act_randomly() 24 | 25 | def save(self, path): 26 | pass 27 | -------------------------------------------------------------------------------- /utils/penv.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process, Pipe 2 | import gym 3 | 4 | def worker(conn, env): 5 | while True: 6 | cmd, data = conn.recv() 7 | if cmd == "step": 8 | obs, reward, done, info = env.step(data) 9 | if done: 10 | obs = env.reset() 11 | conn.send((obs, reward, done, info)) 12 | elif cmd == "reset": 13 | obs = env.reset() 14 | conn.send(obs) 15 | else: 16 | raise NotImplementedError 17 | 18 | class ParallelEnv(gym.Env): 19 | """A concurrent execution of environments in multiple processes.""" 20 | 21 | def __init__(self, envs): 22 | assert len(envs) >= 1, "No environment given." 23 | 24 | self.envs = envs 25 | self.observation_space = self.envs[0].observation_space 26 | self.action_space = self.envs[0].action_space 27 | 28 | self.locals = [] 29 | for env in self.envs[1:]: 30 | local, remote = Pipe() 31 | self.locals.append(local) 32 | p = Process(target=worker, args=(remote, env)) 33 | p.daemon = True 34 | p.start() 35 | remote.close() 36 | 37 | def reset(self): 38 | for local in self.locals: 39 | local.send(("reset", None)) 40 | results = [self.envs[0].reset()] + [local.recv() for local in self.locals] 41 | return results 42 | 43 | def step(self, actions): 44 | for local, action in zip(self.locals, actions[1:]): 45 | local.send(("step", action)) 46 | obs, reward, done, info = self.envs[0].step(actions[0]) 47 | if done: 48 | obs = self.envs[0].reset() 49 | results = zip(*[(obs, reward, done, info)] + [local.recv() for local in self.locals]) 50 | return results 51 | 52 | def render(self): 53 | raise NotImplementedError -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Causal Policy Gradient for Whole-Body Mobile Manipulation 2 | 3 | Jiaheng Hu, Peter Stone, Roberto Martin-Martin 4 | 5 | RSS2023 6 | 7 | ## Setup 8 | 9 | 1. Clone this repo and its submodules: 10 | ```bash 11 | git clone https://github.com/JiahengHu/CausalMoMa.git --recursive 12 | ``` 13 | 14 | 2. Install the cloned `iGibson-CausalMoMa`, `Minigrid-CausalMoMa` and `sb3-CausalMoMa`, following the respective `README.md` instructions. 15 | 16 | 3. Download the required [iGibson data](https://stanfordvl.github.io/iGibson/installation.html#downloading-the-assets-and-datasets-of-scenes-and-objects). 17 | Download [HSR mesh data](https://drive.google.com/file/d/1Vz-Shra-Y3ZiHJdCjnQg8hZBqFPw6byG/view?usp=sharing) and extract 18 | it into `iGibson-CausalMoMa/igibson/data/assets/models/hsr` 19 | 20 | ## Causal Inference 21 | 1. Download pre-collected [Causal inference data](https://drive.google.com/drive/folders/1j0sSoHC_Hx6Wcel4mDvBevXboYOg1dKs?usp=sharing) and put them into `data/`. 22 | Alternatively, collect new data by running: 23 | ``` 24 | # iGibson data 25 | python collect_igibson_data.py 26 | 27 | # Minigrid data 28 | python collect_igibson_data.py 29 | ``` 30 | 31 | 32 | 2. Run causal discovery with one of the config file provided: 33 | ``` 34 | python causal_inference.py --config PATH_TO_CONFIG 35 | 36 | # e.g., for minigrid 37 | python causal_inference.py --config configs/minigrid_full.json 38 | ``` 39 | 40 | Results will be stored inside `causal/`. 41 | 42 | ## Policy Learning 43 | 44 | The inferred causal matrix is already put inside `train.py` 45 | 46 | ``` 47 | # HSR with Causal MoMa 48 | python train.py -sc --robot hsr 49 | 50 | # HSR with Vanilla PPO 51 | python train.py -fc --robot hsr 52 | 53 | # Fetch with Causal MoMa 54 | python train.py -sc --robot fetch 55 | 56 | # Fetch with Vanilla PPO 57 | python train.py -fc --robot fetch 58 | ``` 59 | 60 | Results will be stored inside `log_dir/`. 61 | 62 | ## Citing 63 | ``` 64 | @inproceedings{hu2023causal, 65 | title={Causal Policy Gradient for Whole-Body Mobile Manipulation}, 66 | author={Hu, Jiaheng and Stone, Peter and Mart{\'\i}n-Mart{\'\i}n, Roberto}, 67 | booktitle={arXiv preprint arXiv:2305.04866}, 68 | year={2023} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /collect_minigrid_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy 3 | 4 | from utils.utils import set_seed_everywhere 5 | from utils.penv import ParallelEnv 6 | import pickle 7 | import numpy as np 8 | import torch 9 | import gym 10 | import gym_minigrid 11 | 12 | def make_env(env_key, seed=None): 13 | env = gym.make(env_key) 14 | env.reset(seed=seed) 15 | return env 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | # Parse arguments 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--seed", type=int, default=0, 22 | help="random seed (default: 0)") 23 | parser.add_argument("--episodes", type=int, default=1000000, 24 | help="number of episodes to visualize") 25 | parser.add_argument("--env", type=str, default="MiniGrid-SwampEnv-8x8-N3-v0", 26 | help="Envrionment name") 27 | parser.add_argument("--num", type=int, default=100000, 28 | help="Number of datapoint to collect") 29 | args = parser.parse_args() 30 | 31 | # Set seed for all randomness sources 32 | set_seed_everywhere(args.seed) 33 | 34 | # Set device 35 | print(f"Device: {device}\n") 36 | procs = 8 37 | 38 | # Load environment 39 | envs = [] 40 | for i in range(procs): 41 | envs.append(make_env(args.env, args.seed + 10000 * i)) 42 | env = ParallelEnv(envs) 43 | print("Environment loaded\n") 44 | 45 | # Load agent 46 | obs_list = [] 47 | rewards_list = [] 48 | actions_list = [] 49 | 50 | obs = env.reset() 51 | 52 | save_fn = "minigrid_causal_data_" + args.env 53 | num_of_data = args.num 54 | for i in range(num_of_data): 55 | # Do one agent-environment interaction 56 | action = [env.action_space.sample() for _ in range(procs)] 57 | 58 | # For sparse reward scenario 59 | if "Sparse1d" in args.env: 60 | action_threshold = 0.7 61 | probs = np.random.uniform(size=procs) 62 | for j in range(procs): 63 | if probs[j] < action_threshold: 64 | action[j][0] = 2 # manually move towards target 65 | 66 | nxt_obs, reward, done, info = env.step(action) 67 | rewards_list += info # info contains the decomposed reward 68 | obs_list += obs 69 | actions_list += action 70 | obs = nxt_obs 71 | 72 | if (i+1) % 500 == 0: 73 | print(f"saving iteration {i+1}...") 74 | with open(save_fn, "wb") as fp: # Pickling 75 | pickle.dump([obs_list, actions_list, rewards_list], fp) 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /configs/minigrid_full.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": "causal_inference_cmi_minigrid_full", 3 | "seed": 0, 4 | "cuda_id": 0, 5 | "domain": "minigrid", 6 | "igibson_reward_type": "", 7 | "mini_env_name": "full", 8 | "rb_path": "minigrid_causal_data", 9 | "train_mask": false, 10 | "env_params": { 11 | "env_name": "Causal", 12 | "num_envs": 1 13 | }, 14 | "training_params": { 15 | "inference_algo": "cmi", 16 | "object_level_obs": false, 17 | "num_observation_steps": 1, 18 | "load_inference": "", 19 | "total_step": 75000, 20 | "init_step": 0, 21 | "collect_transitions": false, 22 | "num_inference_opt_steps": 1, 23 | "num_policy_opt_steps": 0, 24 | "eval_freq": 25, 25 | "saving_freq": 20000, 26 | "plot_freq": 1000, 27 | "replay_buffer_params": { 28 | "capacity": 2000000, 29 | "max_sample_time": 64, 30 | "saving_freq": 0, 31 | "prioritized_buffer": false, 32 | "parallel_sample": true, 33 | "prioritized_alpha": 10 34 | } 35 | }, 36 | "inference_params": { 37 | "num_pred_steps": 1, 38 | "batch_size": 64, 39 | "lr": 1e-4, 40 | "train_prop": 0.9, 41 | "residual": false, 42 | "log_std_min": -15, 43 | "log_std_max": 30, 44 | "grad_clip_norm": 20, 45 | "detach_encoder": true, 46 | "cmi_params": { 47 | "feature_fc_dims": [128, 128], 48 | "generative_fc_dims": [128, 128], 49 | "causal_pred_reward_weight": 0.0, 50 | "pred_diff_reward_weight": 1.0, 51 | "causal_opt_freq": 10, 52 | "eval_tau": 0.999, 53 | "CMI_threshold": 0.02 54 | } 55 | }, 56 | "contrastive_params": { 57 | "num_pred_steps": 1, 58 | "batch_size": 32, 59 | "lr": 3e-4, 60 | "loss_type": "contrastive", 61 | "gradient_through_all_samples": false, 62 | "l2_reg_coef": 1e-3, 63 | "num_negative_samples": 512, 64 | "grad_clip_norm": 20, 65 | "num_pred_samples": 16384, 66 | "num_pred_iters": 3, 67 | "pred_sigma_init": 0.33, 68 | "pred_sigma_shrink": 0.5, 69 | "modular_params": { 70 | "fc_dims": [256, 256, 256] 71 | }, 72 | "cmi_params": { 73 | "learn_bo": false, 74 | "dot_product_energy": true, 75 | "aggregation": "max", 76 | "train_all_masks": false, 77 | "feature_fc_dims": [256, 128], 78 | "enery_fc_dims": [128], 79 | "mask_opt_freq": 1, 80 | "full_opt_freq": 25, 81 | "causal_opt_freq": 25, 82 | "eval_num_negative_samples": 8192, 83 | "eval_tau": 0.995, 84 | "CMI_threshold": 0.01 85 | } 86 | } 87 | } -------------------------------------------------------------------------------- /configs/igibson_discrete.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": "causal_inference_cmi_igibson_discrete", 3 | "seed": 0, 4 | "cuda_id": 0, 5 | "domain": "igibson", 6 | "igibson_reward_type": "discrete", 7 | "mini_env_name": "", 8 | "rb_path": "igibson_causal_data", 9 | "train_mask": false, 10 | "env_params": { 11 | "env_name": "Causal", 12 | "num_envs": 1 13 | }, 14 | "training_params": { 15 | "inference_algo": "cmi", 16 | "object_level_obs": false, 17 | "num_observation_steps": 1, 18 | "load_inference": "", 19 | "total_step": 75000, 20 | "init_step": 0, 21 | "collect_transitions": false, 22 | "num_inference_opt_steps": 1, 23 | "num_policy_opt_steps": 0, 24 | "eval_freq": 25, 25 | "saving_freq": 20000, 26 | "plot_freq": 1000, 27 | "replay_buffer_params": { 28 | "capacity": 2000000, 29 | "max_sample_time": 64, 30 | "saving_freq": 0, 31 | "prioritized_buffer": false, 32 | "parallel_sample": true, 33 | "prioritized_alpha": 10 34 | } 35 | }, 36 | "inference_params": { 37 | "num_pred_steps": 1, 38 | "batch_size": 64, 39 | "lr": 1e-4, 40 | "train_prop": 0.9, 41 | "residual": false, 42 | "log_std_min": -15, 43 | "log_std_max": 30, 44 | "grad_clip_norm": 20, 45 | "detach_encoder": true, 46 | "cmi_params": { 47 | "feature_fc_dims": [128, 128], 48 | "generative_fc_dims": [128, 128], 49 | "causal_pred_reward_weight": 0.0, 50 | "pred_diff_reward_weight": 1.0, 51 | "causal_opt_freq": 10, 52 | "eval_tau": 0.999, 53 | "CMI_threshold": 0.003 54 | } 55 | }, 56 | "contrastive_params": { 57 | "num_pred_steps": 1, 58 | "batch_size": 32, 59 | "lr": 3e-4, 60 | "loss_type": "contrastive", 61 | "gradient_through_all_samples": false, 62 | "l2_reg_coef": 1e-3, 63 | "num_negative_samples": 512, 64 | "grad_clip_norm": 20, 65 | "num_pred_samples": 16384, 66 | "num_pred_iters": 3, 67 | "pred_sigma_init": 0.33, 68 | "pred_sigma_shrink": 0.5, 69 | "modular_params": { 70 | "fc_dims": [256, 256, 256] 71 | }, 72 | "cmi_params": { 73 | "learn_bo": false, 74 | "dot_product_energy": true, 75 | "aggregation": "max", 76 | "train_all_masks": false, 77 | "feature_fc_dims": [256, 128], 78 | "enery_fc_dims": [128], 79 | "mask_opt_freq": 1, 80 | "full_opt_freq": 25, 81 | "causal_opt_freq": 25, 82 | "eval_num_negative_samples": 8192, 83 | "eval_tau": 0.995, 84 | "CMI_threshold": 0.01 85 | } 86 | } 87 | } -------------------------------------------------------------------------------- /configs/igibson_continuous.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": "causal_inference_cmi_igibson_continuous", 3 | "seed": 0, 4 | "cuda_id": 0, 5 | "domain": "igibson", 6 | "igibson_reward_type": "continuous", 7 | "mini_env_name": "", 8 | "rb_path": "igibson_causal_data", 9 | "train_mask": false, 10 | "env_params": { 11 | "env_name": "Causal", 12 | "num_envs": 1 13 | }, 14 | "training_params": { 15 | "inference_algo": "cmi", 16 | "object_level_obs": false, 17 | "num_observation_steps": 1, 18 | "load_inference": "", 19 | "total_step": 40000, 20 | "init_step": 0, 21 | "collect_transitions": false, 22 | "num_inference_opt_steps": 1, 23 | "num_policy_opt_steps": 0, 24 | "eval_freq": 25, 25 | "saving_freq": 20000, 26 | "plot_freq": 1000, 27 | "replay_buffer_params": { 28 | "capacity": 2000000, 29 | "max_sample_time": 64, 30 | "saving_freq": 0, 31 | "prioritized_buffer": false, 32 | "parallel_sample": true, 33 | "prioritized_alpha": 10 34 | } 35 | }, 36 | "inference_params": { 37 | "num_pred_steps": 1, 38 | "batch_size": 64, 39 | "lr": 1e-4, 40 | "train_prop": 0.9, 41 | "residual": false, 42 | "log_std_min": -15, 43 | "log_std_max": 30, 44 | "grad_clip_norm": 20, 45 | "detach_encoder": true, 46 | "cmi_params": { 47 | "feature_fc_dims": [128, 128], 48 | "generative_fc_dims": [128, 128], 49 | "causal_pred_reward_weight": 0.0, 50 | "pred_diff_reward_weight": 1.0, 51 | "causal_opt_freq": 10, 52 | "eval_tau": 0.999, 53 | "CMI_threshold": 0.12 54 | } 55 | }, 56 | "contrastive_params": { 57 | "num_pred_steps": 1, 58 | "batch_size": 32, 59 | "lr": 3e-4, 60 | "loss_type": "contrastive", 61 | "gradient_through_all_samples": false, 62 | "l2_reg_coef": 1e-3, 63 | "num_negative_samples": 512, 64 | "grad_clip_norm": 20, 65 | "num_pred_samples": 16384, 66 | "num_pred_iters": 3, 67 | "pred_sigma_init": 0.33, 68 | "pred_sigma_shrink": 0.5, 69 | "modular_params": { 70 | "fc_dims": [256, 256, 256] 71 | }, 72 | "cmi_params": { 73 | "learn_bo": false, 74 | "dot_product_energy": true, 75 | "aggregation": "max", 76 | "train_all_masks": false, 77 | "feature_fc_dims": [256, 128], 78 | "enery_fc_dims": [128], 79 | "mask_opt_freq": 1, 80 | "full_opt_freq": 25, 81 | "causal_opt_freq": 25, 82 | "eval_num_negative_samples": 8192, 83 | "eval_tau": 0.995, 84 | "CMI_threshold": 0.01 85 | } 86 | } 87 | } -------------------------------------------------------------------------------- /configs/minigrid_sparse.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": "causal_inference_cmi_minigrid_sparse", 3 | "seed": 0, 4 | "cuda_id": 0, 5 | "domain": "minigrid", 6 | "igibson_reward_type": "", 7 | "mini_env_name": "sparse", 8 | "rb_path": "minigrid_causal_data_Sparse1d", 9 | "train_mask": false, 10 | "env_params": { 11 | "env_name": "Causal", 12 | "num_envs": 1 13 | }, 14 | "training_params": { 15 | "inference_algo": "cmi", 16 | "object_level_obs": false, 17 | "num_observation_steps": 1, 18 | "load_inference": "", 19 | "total_step": 5000, 20 | "init_step": 0, 21 | "collect_transitions": false, 22 | "num_inference_opt_steps": 1, 23 | "num_policy_opt_steps": 0, 24 | "eval_freq": 25, 25 | "saving_freq": 20000, 26 | "plot_freq": 1000, 27 | "replay_buffer_params": { 28 | "capacity": 2000000, 29 | "max_sample_time": 64, 30 | "saving_freq": 0, 31 | "prioritized_buffer": false, 32 | "parallel_sample": true, 33 | "prioritized_alpha": 10 34 | } 35 | }, 36 | "inference_params": { 37 | "num_pred_steps": 1, 38 | "batch_size": 64, 39 | "lr": 1e-4, 40 | "train_prop": 0.9, 41 | "residual": false, 42 | "log_std_min": -15, 43 | "log_std_max": 30, 44 | "grad_clip_norm": 20, 45 | "detach_encoder": true, 46 | "cmi_params": { 47 | "feature_fc_dims": [128, 128], 48 | "generative_fc_dims": [128, 128], 49 | "causal_pred_reward_weight": 0.0, 50 | "pred_diff_reward_weight": 1.0, 51 | "causal_opt_freq": 10, 52 | "eval_tau": 0.999, 53 | "CMI_threshold": 0.02 54 | } 55 | }, 56 | "contrastive_params": { 57 | "num_pred_steps": 1, 58 | "batch_size": 32, 59 | "lr": 3e-4, 60 | "loss_type": "contrastive", 61 | "gradient_through_all_samples": false, 62 | "l2_reg_coef": 1e-3, 63 | "num_negative_samples": 512, 64 | "grad_clip_norm": 20, 65 | "num_pred_samples": 16384, 66 | "num_pred_iters": 3, 67 | "pred_sigma_init": 0.33, 68 | "pred_sigma_shrink": 0.5, 69 | "modular_params": { 70 | "fc_dims": [256, 256, 256] 71 | }, 72 | "cmi_params": { 73 | "learn_bo": false, 74 | "dot_product_energy": true, 75 | "aggregation": "max", 76 | "train_all_masks": false, 77 | "feature_fc_dims": [256, 128], 78 | "enery_fc_dims": [128], 79 | "mask_opt_freq": 1, 80 | "full_opt_freq": 25, 81 | "causal_opt_freq": 25, 82 | "eval_num_negative_samples": 8192, 83 | "eval_tau": 0.995, 84 | "CMI_threshold": 0.01 85 | } 86 | } 87 | } -------------------------------------------------------------------------------- /collect_igibson_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collect training data for causal inference 3 | """ 4 | 5 | import logging 6 | import os 7 | 8 | import igibson 9 | from igibson.envs.igibson_env import iGibsonEnv 10 | import yaml 11 | from igibson.render.profiler import Profiler 12 | import numpy as np 13 | import pickle 14 | 15 | if __name__ == "__main__": 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | config_filename = os.path.join(igibson.configs_path, "fetch_reaching.yaml") 19 | config_data = yaml.load(open(config_filename, "r"), Loader=yaml.FullLoader) 20 | 21 | config_data["output"] = ['scan', 'task_obs'] 22 | 23 | # Set task to factored version 24 | config_data["task"] = "factored_reaching_random" 25 | 26 | # Create a new environment for evaluation 27 | eval_env = iGibsonEnv( 28 | config_file=config_data, 29 | mode="headless", 30 | action_timestep=1 / 10.0, 31 | physics_timestep=1 / 120.0, 32 | ) 33 | 34 | print(eval_env.action_space) 35 | max_iterations = 50000 36 | data_list = [] 37 | for j in range(max_iterations): 38 | print("Resetting environment") 39 | prev_state = eval_env.reset() 40 | for i in range(100): 41 | with Profiler("Environment action step"): 42 | action = eval_env.action_space.sample() 43 | step_reward = None 44 | action_duration = 10 45 | collision_name_list = ["base_collision", "arm_collision", "self_collision", "collision_occur"] 46 | 47 | merged_info = {} 48 | for collision in collision_name_list: 49 | merged_info[collision] = False 50 | 51 | for _ in range(action_duration): 52 | state, reward, done, info = eval_env.step(action) 53 | if step_reward is None: 54 | step_reward = reward 55 | else: 56 | step_reward += reward 57 | 58 | # Process info: has collision occur in the past n timesteps? 59 | for collision_name in collision_name_list: 60 | merged_info[collision_name] = merged_info[collision_name] or info[collision_name] 61 | 62 | if done: 63 | break 64 | 65 | # after action finished, store prev_state 66 | data = [prev_state, action, merged_info, step_reward] 67 | prev_state = state 68 | data_list.append(data) 69 | 70 | # Reset after collision 71 | if merged_info["collision_occur"]: 72 | done = True 73 | 74 | if done: 75 | print("Episode finished after {} timesteps".format(i + 1)) 76 | break 77 | 78 | if (j+1) % 500 == 0: 79 | print("\nsaving...\n") 80 | with open("causal_data", "wb") as fp: # Pickling 81 | pickle.dump(data_list, fp) 82 | eval_env.close() 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import torch 5 | import shutil 6 | import random 7 | 8 | from utils.multiprocessing_env import SubprocVecEnv 9 | import numpy as np 10 | 11 | import sys 12 | 13 | 14 | class Logger(object): 15 | def __init__(self, fp): 16 | self.terminal = sys.stdout 17 | self.log = open(fp, "a") 18 | 19 | def write(self, message): 20 | self.terminal.write(message) 21 | self.log.write(message) 22 | 23 | def flush(self): 24 | # this flush method is needed for python 3 compatibility. 25 | # this handles the flush command by doing nothing. 26 | # you might want to specify some extra behavior here. 27 | pass 28 | 29 | 30 | class AttrDict(dict): 31 | def __init__(self, *args, **kwargs): 32 | super(AttrDict, self).__init__(*args, **kwargs) 33 | self.__dict__ = self 34 | 35 | 36 | class TrainingParams(AttrDict): 37 | def __init__(self, training_params_fname="params.json", train=True): 38 | config = json.load(open(training_params_fname)) 39 | for k, v in config.items(): 40 | self.__dict__[k] = v 41 | self.__dict__ = self._clean_dict(self.__dict__) 42 | 43 | repo_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 44 | training_params = self.training_params 45 | if getattr(training_params, "load_inference", None) is not None: 46 | training_params.load_inference = \ 47 | os.path.join(repo_path, "interesting_models", training_params.load_inference) 48 | 49 | if train: 50 | sub_dirname = "dynamics" 51 | info = self.info.replace(" ", "_") 52 | if config["train_mask"]: 53 | info += "_" + "train_mask" 54 | experiment_dirname = info + "_" + time.strftime("%Y_%m_%d_%H_%M_%S") 55 | self.rslts_dir = os.path.join(repo_path, "causal", "rslts", sub_dirname, experiment_dirname) 56 | os.makedirs(self.rslts_dir) 57 | shutil.copyfile(training_params_fname, os.path.join(self.rslts_dir, "params.json")) 58 | 59 | self.replay_buffer_dir = None 60 | if training_params.replay_buffer_params.saving_freq: 61 | self.replay_buffer_dir = os.path.join(repo_path, "replay_buffer", experiment_dirname) 62 | os.makedirs(self.replay_buffer_dir) 63 | 64 | super(TrainingParams, self).__init__(self.__dict__) 65 | 66 | def _clean_dict(self, _dict): 67 | for k, v in _dict.items(): 68 | if v == "": # encode empty string as None 69 | v = None 70 | if isinstance(v, dict): 71 | v = self._clean_dict(v) 72 | _dict[k] = v 73 | return AttrDict(_dict) 74 | 75 | 76 | def soft_update_params(net, target_net, tau): 77 | for param, target_param in zip(net.parameters(), target_net.parameters()): 78 | target_param.data.copy_( 79 | tau * param.data + (1 - tau) * target_param.data 80 | ) 81 | 82 | 83 | def set_seed_everywhere(seed): 84 | torch.manual_seed(seed) 85 | if torch.cuda.is_available(): 86 | torch.cuda.manual_seed_all(seed) 87 | np.random.seed(seed) 88 | random.seed(seed) 89 | 90 | 91 | def to_numpy(tensor): 92 | return tensor.cpu().detach().numpy() 93 | 94 | 95 | def to_device(dictionary, device): 96 | """ 97 | place dict of tensors + dict to device recursively 98 | """ 99 | new_dictionary = {} 100 | for key, val in dictionary.items(): 101 | if isinstance(val, dict): 102 | new_dictionary[key] = to_device(val, device) 103 | elif isinstance(val, torch.Tensor): 104 | new_dictionary[key] = val.to(device) 105 | else: 106 | raise ValueError("Unknown value type {} for key {}".format(type(val), key)) 107 | return new_dictionary 108 | 109 | 110 | def get_start_step_from_model_loading(params): 111 | """ 112 | if inference is loaded, return its training step; 113 | else, return 0 114 | """ 115 | load_inference = params.training_params.load_inference 116 | if load_inference is not None and os.path.exists(load_inference): 117 | model_name = load_inference.split(os.sep)[-1] 118 | start_step = int(model_name.split("_")[-1]) 119 | print("start_step:", start_step) 120 | else: 121 | start_step = 0 122 | return start_step 123 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Callable 4 | import numpy as np 5 | 6 | import igibson 7 | from igibson.envs.igibson_env import iGibsonEnv 8 | import yaml 9 | from igibson.render.profiler import Profiler 10 | try: 11 | import gym 12 | import torch as th 13 | import torch.nn as nn 14 | from stable_baselines3 import PPO 15 | from stable_baselines3 import FPPO 16 | from stable_baselines3.common.evaluation import evaluate_policy 17 | from stable_baselines3.common.preprocessing import maybe_transpose 18 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 19 | from stable_baselines3.common.utils import set_random_seed 20 | from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor 21 | from stable_baselines3.common.save_util import load_from_zip_file 22 | 23 | except ModuleNotFoundError: 24 | print("stable-baselines3 is not installed. You would need to do: pip install stable-baselines3") 25 | exit(1) 26 | 27 | from train import CustomCombinedExtractor, get_causal_matrix 28 | 29 | """ 30 | This is to test (and visualize) the trained policy 31 | """ 32 | 33 | if __name__ == "__main__": 34 | logging.basicConfig(level=logging.INFO) 35 | np.set_printoptions(precision=2) 36 | robot = "hsr" # "fetch" # 37 | if robot == "hsr": 38 | config_fn = "hsr_reaching.yaml" 39 | model_path = "weight_dir/ckpt_hsr.zip" 40 | elif robot == "fetch": 41 | config_fn = "fetch_reaching.yaml" 42 | model_path = "weight_dir/ckpt_fetch.zip" 43 | config_filename = os.path.join(igibson.configs_path, config_fn) 44 | config_data = yaml.load(open(config_filename, "r"), Loader=yaml.FullLoader) 45 | 46 | # Improving visuals in the example (optional) 47 | config_data["enable_shadow"] = True 48 | config_data["enable_pbr"] = True 49 | 50 | scnene_ids = ["Rs_int", "Beechwood_0_int", "Merom_0_int", 51 | "Wainscott_0_int", "Ihlen_0_int", "Benevolence_1_int", "Pomaria_1_int", "Ihlen_1_int", ] 52 | 53 | # No living: Benevolence_0_int, Beechwood_1_int, Benevolence_2_int, Pomaria_0_int 54 | # 3: really hard & large env 4: uninteresting env, 5: hard env, 6: empty, 7: with sofa but easy 55 | config_data["scene_id"] = scnene_ids[1] 56 | 57 | factored = True 58 | obstacles = True 59 | multi_step = True 60 | 61 | if factored: 62 | if multi_step: 63 | config_data["task"] = "factored_multistep_reaching_random" 64 | else: 65 | config_data["task"] = "factored_reaching_random" 66 | else: 67 | if multi_step: 68 | config_data["task"] = "factored_multistep_reaching_random" 69 | 70 | config_data["rd_target"] = True 71 | config_data["vis_ee_target"] = False 72 | config_data["simple_orientation"] = True 73 | config_data["enum_orientation"] = True 74 | config_data["position_reward"] = True 75 | config_data["proportional_local_reward"] = True 76 | 77 | if not obstacles: 78 | config_data["load_room_types"] = "kitchen" 79 | 80 | # Create a new environment for evaluation 81 | eval_env = iGibsonEnv( 82 | config_file=config_data, 83 | mode="gui_interactive", # "headless", 84 | action_timestep=1 / 10.0, 85 | physics_timestep=1 / 120.0, 86 | print_reward=True, 87 | # use_pb_gui=True, 88 | ) 89 | 90 | # Alternatively we can make the causal argument optional and store them in data -- this is probably a better way 91 | device = th.device("cpu") 92 | 93 | 94 | data, params, pytorch_variables = load_from_zip_file( 95 | model_path, 96 | device=device, 97 | ) 98 | 99 | reward_channels_dim = 8 100 | causal_matrix = get_causal_matrix(reward_channels_dim, eval_env, robot=robot, fc_causal=True) 101 | policy_kwargs = dict( 102 | features_extractor_class=CustomCombinedExtractor, 103 | ) 104 | 105 | if factored: 106 | model = FPPO("MultiInputPolicy", eval_env, reward_channels_dim, causal_matrix, 107 | policy_kwargs=policy_kwargs) 108 | 109 | else: 110 | model = PPO("MultiInputPolicy", eval_env, 111 | policy_kwargs=policy_kwargs) 112 | 113 | model.set_parameters(params, exact_match=True, device=device) 114 | 115 | from datetime import datetime 116 | set_random_seed(int(datetime.now().timestamp())) 117 | 118 | # Evaluate the trained model loaded from file 119 | mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=50, deterministic=False, 120 | reward_channels_dim=reward_channels_dim, report_factored_reward=True) 121 | 122 | print(f"After Loading: Mean reward: {mean_reward} +/- {std_reward}") 123 | print(f"After Loading: Mean reward: {mean_reward.sum()}") 124 | 125 | -------------------------------------------------------------------------------- /model/inference_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque, OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from model.gumbel import gumbel_sigmoid 9 | 10 | 11 | def reset_layer(w, b): 12 | fan_in = w.shape[0] 13 | bound = 1 / np.sqrt(fan_in) 14 | nn.init.uniform_(w, -bound, bound) 15 | nn.init.uniform_(b, -bound, bound) 16 | 17 | 18 | def reparameterize(mu, log_std): 19 | std = torch.exp(log_std) 20 | eps = torch.randn_like(std) 21 | return eps * std + mu 22 | 23 | 24 | def forward_network(input, weights, biases, activation=F.relu): 25 | """ 26 | given an input and a multi-layer networks (i.e., a list of weights and a list of biases), 27 | apply the network to each input, and return output 28 | the same activation function is applied to all layers except for the last layer 29 | """ 30 | x = input 31 | for i, (w, b) in enumerate(zip(weights, biases)): 32 | # x (p_bs, bs, in_dim), bs: data batch size which must be 1D 33 | # w (p_bs, in_dim, out_dim), p_bs: parameter batch size 34 | # b (p_bs, 1, out_dim) 35 | x = torch.bmm(x, w) + b 36 | if i < len(weights) - 1 and activation: 37 | x = activation(x) 38 | return x 39 | 40 | 41 | def forward_network_batch(inputs, weights, biases, activation=F.relu): 42 | """ 43 | given a list of inputs and a list of ONE-LAYER networks (i.e., a list of weights and a list of biases), 44 | apply each network to each input, and return a list 45 | """ 46 | x = [] 47 | for x_i, w, b in zip(inputs, weights, biases): 48 | # x_i (p_bs, bs, in_dim), bs: data batch size which must be 1D 49 | # w (p_bs, in_dim, out_dim), p_bs: parameter batch size 50 | # b (p_bs, 1, out_dim) 51 | x_i = torch.bmm(x_i, w) + b 52 | if activation: 53 | x_i = activation(x_i) 54 | x.append(x_i) 55 | return x 56 | 57 | 58 | def forward_gated_network(input, weights, biases, gate_weights, gate_biases, deterministic=False, activation=F.relu): 59 | """ 60 | given an input and a multi-layer networks (i.e., a list of weights and a list of biases), 61 | apply the network to each input, and return output 62 | the same activation function is applied to all layers except for the last layer 63 | """ 64 | gate = None 65 | if len(gate_weights): 66 | gate_log_alpha = input 67 | for i, (w, b) in enumerate(zip(gate_weights, gate_biases)): 68 | # gate_log_alpha (bs, p_bs, in_dim), bs: data batch size 69 | # w (p_bs, out_dim, in_dim), p_bs: parameter batch size 70 | # b (p_bs, out_dim) 71 | gate_log_alpha = gate_log_alpha.unsqueeze(dim=-2) # (bs, p_bs, 1, in_dim) 72 | gate_log_alpha = (gate_log_alpha * w).sum(dim=-1) + b # (bs, p_bs, out_dim) 73 | if i < len(gate_weights) - 1 and activation: 74 | gate_log_alpha = activation(gate_log_alpha) 75 | 76 | if deterministic: 77 | gate = (gate_log_alpha > 0).float() 78 | else: 79 | gate = gumbel_sigmoid(gate_log_alpha, device=gate_log_alpha.device, hard=True) 80 | 81 | x = input 82 | for i, (w, b) in enumerate(zip(weights, biases)): 83 | # x (bs, p_bs, in_dim), bs: data batch size 84 | # w (p_bs, out_dim, in_dim), p_bs: parameter batch size 85 | # b (p_bs, out_dim) 86 | x = x.unsqueeze(dim=-2) # (bs, p_bs, 1, in_dim) 87 | x = (x * w).sum(dim=-1) + b # (bs, p_bs, out_dim) 88 | if i < len(weights) - 1 and activation: 89 | x = activation(x) 90 | if i == len(weights) - 2 and gate is not None: 91 | x = x * gate 92 | return x 93 | 94 | 95 | def get_controllable(mask): 96 | feature_dim = mask.shape[0] 97 | M = mask[:, :feature_dim] 98 | I = mask[:, feature_dim:] 99 | 100 | # feature that are directly affected by actions 101 | action_children = [] 102 | for i in range(feature_dim): 103 | if I[i].any(): 104 | action_children.append(i) 105 | 106 | # decedents of those features 107 | controllable = [] 108 | queue = deque(action_children) 109 | while len(queue): 110 | feature_idx = queue.popleft() 111 | controllable.append(feature_idx) 112 | for i in range(feature_dim): 113 | if M[i, feature_idx] and (i not in controllable) and (i not in queue): 114 | queue.append(i) 115 | return controllable 116 | 117 | 118 | def get_state_abstraction(mask): 119 | feature_dim = mask.shape[0] 120 | M = mask[:, :feature_dim] 121 | 122 | controllable = get_controllable(mask) 123 | # ancestors of controllable features 124 | action_relevant = [] 125 | queue = deque(controllable) 126 | while len(queue): 127 | feature_idx = queue.popleft() 128 | if feature_idx not in controllable: 129 | action_relevant.append(feature_idx) 130 | for i in range(feature_dim): 131 | if (i not in controllable + action_relevant) and (i not in queue): 132 | if M[feature_idx, i]: 133 | queue.append(i) 134 | 135 | abstraction_idx = list(set(controllable + action_relevant)) 136 | abstraction_idx.sort() 137 | 138 | abstraction_graph = OrderedDict() 139 | for idx in abstraction_idx: 140 | abstraction_graph[idx] = [i for i, e in enumerate(mask[idx]) if e] 141 | 142 | return abstraction_graph 143 | -------------------------------------------------------------------------------- /model/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | EPS = 1e-6 8 | 9 | 10 | def sample_logistic(shape, device): 11 | u = torch.rand(shape, dtype=torch.float32, device=device) 12 | u = torch.clip(u, EPS, 1 - EPS) 13 | return torch.log(u) - torch.log(1 - u) 14 | 15 | 16 | def gumbel_sigmoid(log_alpha, device, bs=None, tau=1, hard=False): 17 | if bs is None: 18 | shape = log_alpha.shape 19 | else: 20 | shape = log_alpha.shape + bs 21 | 22 | logistic_noise = sample_logistic(shape, device) 23 | y_soft = torch.sigmoid((log_alpha + logistic_noise) / tau) 24 | 25 | if hard: 26 | y_hard = (y_soft > 0.5).float() 27 | # This weird line does two things: 28 | # 1) at forward, we get a hard sample. 29 | # 2) at backward, we differentiate the gumbel sigmoid 30 | y = y_hard.detach() - y_soft.detach() + y_soft 31 | else: 32 | y = y_soft 33 | 34 | return y 35 | 36 | 37 | class GumbelMatrix(torch.nn.Module): 38 | """ 39 | Random matrix M used for the mask. Can sample a matrix and backpropagate using the 40 | Gumbel straigth-through estimator. 41 | """ 42 | def __init__(self, shape, init_value, device): 43 | super(GumbelMatrix, self).__init__() 44 | self.device = device 45 | self.shape = shape 46 | self.log_alpha = torch.nn.Parameter(torch.zeros(shape)) 47 | self.reset_parameters(init_value) 48 | 49 | def forward(self, bs, tau=1, drawhard=True): 50 | if self.training: 51 | sample = gumbel_sigmoid(self.log_alpha, self.device, bs, tau=tau, hard=drawhard) 52 | else: 53 | sample = (self.log_alpha > 0).float() 54 | return sample 55 | 56 | def get_prob(self): 57 | """Returns probability of getting one""" 58 | return torch.sigmoid(self.log_alpha) 59 | 60 | def reset_parameters(self, init_value): 61 | log_alpha_init = -np.log(1 / init_value - 1) 62 | torch.nn.init.constant_(self.log_alpha, log_alpha_init) 63 | 64 | 65 | class ConditionalGumbelMatrix(torch.nn.Module): 66 | """ 67 | Random matrix M used for the mask that's conditioned on state and action. 68 | Can sample a matrix and backpropagate using the Gumbel straigth-through estimator. 69 | """ 70 | def __init__(self, feature_dim, action_dim, final_dim, fc_dims, device): 71 | super(ConditionalGumbelMatrix, self).__init__() 72 | self.feature_dim = feature_dim 73 | self.action_dim = action_dim 74 | self.final_dim = final_dim 75 | self.device = device 76 | self.fc_dims = fc_dims 77 | self.update_uniform(0, 1) 78 | 79 | self.weights = nn.ParameterList() 80 | self.biases = nn.ParameterList() 81 | 82 | # Instantiate the parameters of each layer in the model of each variable 83 | in_dim = feature_dim + action_dim 84 | for out_dim in fc_dims + [final_dim]: 85 | self.weights.append(nn.Parameter(torch.zeros(feature_dim, out_dim, in_dim))) 86 | self.biases.append(nn.Parameter(torch.zeros(feature_dim, out_dim))) 87 | in_dim = out_dim 88 | self.reset_params() 89 | 90 | self.causal_feature_idxes = None 91 | self.causal_weights = None 92 | self.causal_biases = None 93 | 94 | def reset_params(self): 95 | in_dim = self.feature_dim + self.action_dim 96 | for w, b, fan_in in zip(self.weights, self.biases, [in_dim] + self.fc_dims): 97 | nn.init.kaiming_uniform_(w, a=np.sqrt(5)) 98 | bound = 1 / np.sqrt(fan_in) 99 | nn.init.uniform_(b, -bound, bound) 100 | 101 | def forward_fcs(self, feature, action, forward_idxes=None): 102 | """ 103 | :param feature: (bs, feature_dim) 104 | :param action: (bs, action_dim) 105 | """ 106 | out = torch.cat([feature, action], dim=-1) # (bs, num_forward_idxes, feature_dim + action_dim) 107 | 108 | weights = self.causal_weights if forward_idxes else self.weights 109 | biases = self.causal_biases if forward_idxes else self.biases 110 | for i, (w, b) in enumerate(zip(weights, biases)): 111 | out = out[:, :, None] # (bs, num_forward_idxes, 1, in_dim) 112 | out = torch.sum(w * out, dim=-1) + b # (bs, num_forward_idxes, out_dim) 113 | if i < len(weights) - 1: 114 | out = F.leaky_relu(out) 115 | 116 | return out 117 | 118 | def forward(self, feature, action, forward_idxes, tau=1, drawhard=True): 119 | log_alpha = self.forward_fcs(feature, action, forward_idxes) 120 | if self.training: 121 | sample = gumbel_sigmoid(log_alpha, self.uniform, self.device, bs=None, tau=tau, hard=drawhard) 122 | else: 123 | sample = torch.sigmoid(log_alpha) 124 | prob = torch.sigmoid(log_alpha) 125 | return sample, prob 126 | 127 | def setup_causal_feature_idxes(self, causal_feature_idxes): 128 | self.causal_feature_idxes = causal_feature_idxes 129 | self.causal_weights = [w[causal_feature_idxes] for w in self.weights] 130 | self.causal_biases = [b[causal_feature_idxes] for b in self.biases] 131 | 132 | def get_prob(self, feature, action): 133 | """Returns probability of getting one""" 134 | log_alpha = self.forward_fcs(feature, action) 135 | return torch.sigmoid(log_alpha) 136 | 137 | def update_uniform(self, low, high): 138 | low = torch.tensor(low, dtype=torch.float32, device=self.device) 139 | high = torch.tensor(high, dtype=torch.float32, device=self.device) 140 | self.uniform = torch.distributions.uniform.Uniform(low, high) 141 | -------------------------------------------------------------------------------- /utils/multiprocessing_env.py: -------------------------------------------------------------------------------- 1 | # Code is from OpenAI baseline: https://github.com/openai/baselines/tree/master/baselines/common/vec_env 2 | 3 | import numpy as np 4 | from multiprocessing import Process, Pipe 5 | 6 | 7 | def worker(remote, parent_remote, env_fn_wrapper): 8 | parent_remote.close() 9 | env = env_fn_wrapper.x() 10 | while True: 11 | cmd, data = remote.recv() 12 | if cmd == 'step': 13 | ob, reward, done, info = env.step(data) 14 | if done: 15 | info["obs"] = ob 16 | ob = env.reset() 17 | remote.send((ob, reward, done, info)) 18 | elif cmd == 'reset': 19 | ob = env.reset() 20 | remote.send(ob) 21 | elif cmd == 'reset_task': 22 | ob = env.reset_task() 23 | remote.send(ob) 24 | elif cmd == 'close': 25 | remote.close() 26 | break 27 | elif cmd == 'observation_spec': 28 | remote.send(env.observation_spec()) 29 | elif cmd == 'obs_delta_range': 30 | remote.send(env.obs_delta_range()) 31 | elif cmd == 'seed': 32 | np.random.seed(data) 33 | elif cmd.startswith("get_attr_"): 34 | attr_name = cmd[len("get_attr_"):] 35 | remote.send(getattr(env, attr_name)) 36 | else: 37 | raise NotImplementedError 38 | 39 | 40 | class VecEnv(object): 41 | """ 42 | An abstract asynchronous, vectorized environment. 43 | """ 44 | def __init__(self, num_envs): 45 | self.num_envs = num_envs 46 | 47 | def reset(self): 48 | """ 49 | Reset all the environments and return an array of 50 | observations, or a tuple of observation arrays. 51 | If step_async is still doing work, that work will 52 | be cancelled and step_wait() should not be called 53 | until step_async() is invoked again. 54 | """ 55 | pass 56 | 57 | def step_async(self, actions): 58 | """ 59 | Tell all the environments to start taking a step 60 | with the given actions. 61 | Call step_wait() to get the results of the step. 62 | You should not call this if a step_async run is 63 | already pending. 64 | """ 65 | pass 66 | 67 | def step_wait(self): 68 | """ 69 | Wait for the step taken with step_async(). 70 | Returns (obs, rews, dones, infos): 71 | - obs: an array of observations, or a tuple of 72 | arrays of observations. 73 | - rews: an array of rewards 74 | - dones: an array of "episode done" booleans 75 | - infos: a sequence of info objects 76 | """ 77 | pass 78 | 79 | def close(self): 80 | """ 81 | Clean up the environments' resources. 82 | """ 83 | pass 84 | 85 | def step(self, actions): 86 | self.step_async(actions) 87 | return self.step_wait() 88 | 89 | 90 | class CloudpickleWrapper(object): 91 | """ 92 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 93 | """ 94 | def __init__(self, x): 95 | self.x = x 96 | 97 | def __getstate__(self): 98 | import cloudpickle 99 | return cloudpickle.dumps(self.x) 100 | 101 | def __setstate__(self, ob): 102 | import pickle 103 | self.x = pickle.loads(ob) 104 | 105 | 106 | class SubprocVecEnv(VecEnv): 107 | def __init__(self, env_fns): 108 | """ 109 | envs: list of gym environments to run in subprocesses 110 | """ 111 | self.waiting = False 112 | self.closed = False 113 | nenvs = len(env_fns) 114 | self.nenvs = nenvs 115 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 116 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 117 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 118 | for p in self.ps: 119 | p.daemon = True # if the main process crashes, we should not cause things to hang 120 | p.start() 121 | for remote in self.work_remotes: 122 | remote.close() 123 | 124 | VecEnv.__init__(self, len(env_fns)) 125 | 126 | self.seed() 127 | 128 | def step_async(self, actions): 129 | for remote, action in zip(self.remotes, actions): 130 | remote.send(('step', action)) 131 | self.waiting = True 132 | 133 | def step_wait(self): 134 | results = [remote.recv() for remote in self.remotes] 135 | self.waiting = False 136 | obs, rews, dones, infos = zip(*results) 137 | obs = {key: np.stack([d[key] for d in obs]) for key in obs[0].keys()} 138 | return obs, np.stack(rews), np.stack(dones), infos 139 | 140 | def reset(self): 141 | for remote in self.remotes: 142 | remote.send(('reset', None)) 143 | obs = [remote.recv() for remote in self.remotes] 144 | obs = {key: np.stack([d[key] for d in obs]) for key in obs[0].keys()} 145 | return obs 146 | 147 | def reset_task(self): 148 | for remote in self.remotes: 149 | remote.send(('reset_task', None)) 150 | return np.stack([remote.recv() for remote in self.remotes]) 151 | 152 | def observation_spec(self): 153 | self.remotes[0].send(('observation_spec', None)) 154 | return self.remotes[0].recv() 155 | 156 | def obs_delta_range(self): 157 | self.remotes[0].send(('obs_delta_range', None)) 158 | return self.remotes[0].recv() 159 | 160 | def __getattr__(self, name): 161 | self.remotes[0].send(('get_attr_{}'.format(name), None)) 162 | return self.remotes[0].recv() 163 | 164 | def close(self): 165 | if self.closed: 166 | return 167 | if self.waiting: 168 | for remote in self.remotes: 169 | remote.recv() 170 | for remote in self.remotes: 171 | remote.send(('close', None)) 172 | for p in self.ps: 173 | p.join() 174 | self.closed = True 175 | 176 | def seed(self): 177 | for i, remote in enumerate(self.remotes): 178 | remote.send(('seed', i)) 179 | 180 | def __len__(self): 181 | return self.nenvs 182 | 183 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | 7 | from utils.utils import to_numpy 8 | 9 | 10 | def set_axes_equal(ax): 11 | """Set 3D plot axes to equal scale. 12 | 13 | Make axes of 3D plot have equal scale so that spheres appear as 14 | spheres and cubes as cubes. Required since `ax.axis('equal')` 15 | and `ax.set_aspect('equal')` don't work on 3D. 16 | """ 17 | limits = np.array([ 18 | ax.get_xlim3d(), 19 | ax.get_ylim3d(), 20 | ax.get_zlim3d(), 21 | ]) 22 | origin = np.mean(limits, axis=1) 23 | radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0])) 24 | _set_axes_radius(ax, origin, radius) 25 | 26 | 27 | def _set_axes_radius(ax, origin, radius): 28 | x, y, z = origin 29 | ax.set_xlim3d([x - radius, x + radius]) 30 | ax.set_ylim3d([y - radius, y + radius]) 31 | ax.set_zlim3d([z - radius, z + radius]) 32 | 33 | 34 | # ---------------------------------- TO CREATE A SERIES OF PICTURES ---------------------------------- # 35 | # from https://zulko.wordpress.com/2012/09/29/animate-your-3d-plots-with-pythons-matplotlib/ 36 | 37 | def make_views(ax, angles, elevation=None, width=4, height=3, 38 | prefix='tmprot_', **kwargs): 39 | """ 40 | Makes jpeg pictures of the given 3d ax, with different angles. 41 | Args: 42 | ax (3D axis): te ax 43 | angles (list): the list of angles (in degree) under which to 44 | take the picture. 45 | width,height (float): size, in inches, of the output images. 46 | prefix (str): prefix for the files created. 47 | 48 | Returns: the list of files created (for later removal) 49 | """ 50 | 51 | files = [] 52 | ax.figure.set_size_inches(width, height) 53 | 54 | for i, angle in enumerate(angles): 55 | ax.view_init(elev=elevation, azim=angle) 56 | fname = '%s%03d.jpeg' % (prefix, i) 57 | ax.figure.savefig(fname) 58 | files.append(fname) 59 | 60 | return files 61 | 62 | 63 | # ----------------------- TO TRANSFORM THE SERIES OF PICTURE INTO AN ANIMATION ----------------------- # 64 | 65 | def make_movie(files, output, fps=10, bitrate=1800, **kwargs): 66 | """ 67 | Uses mencoder, produces a .mp4/.ogv/... movie from a list of 68 | picture files. 69 | """ 70 | 71 | output_name, output_ext = os.path.splitext(output) 72 | command = {'.mp4': 'mencoder "mf://%s" -mf fps=%d -o %s.mp4 -ovc lavc\ 73 | -lavcopts vcodec=msmpeg4v2:vbitrate=%d' 74 | % (",".join(files), fps, output_name, bitrate)} 75 | 76 | command['.ogv'] = command['.mp4'] + '; ffmpeg -i %s.mp4 -r %d %s' % (output_name, fps, output) 77 | 78 | print(command[output_ext]) 79 | output_ext = os.path.splitext(output)[1] 80 | os.system(command[output_ext]) 81 | 82 | 83 | def make_gif(files, output, delay=100, repeat=True, **kwargs): 84 | """ 85 | Uses imageMagick to produce an animated .gif from a list of 86 | picture files. 87 | """ 88 | 89 | loop = -1 if repeat else 0 90 | os.system('convert -delay %d -loop %d %s %s' % (delay, loop, " ".join(files), output)) 91 | 92 | 93 | def make_strip(files, output, **kwargs): 94 | """ 95 | Uses imageMagick to produce a .jpeg strip from a list of 96 | picture files. 97 | """ 98 | 99 | os.system('montage -tile 1x -geometry +0+0 %s %s' % (" ".join(files), output)) 100 | 101 | 102 | # ---------------------------------------------- MAIN FUNCTION ---------------------------------------------- # 103 | 104 | def rotanimate(ax, angles, output, **kwargs): 105 | """ 106 | Produces an animation (.mp4,.ogv,.gif,.jpeg,.png) from a 3D plot on 107 | a 3D ax 108 | 109 | Args: 110 | ax (3D axis): the ax containing the plot of interest 111 | angles (list): the list of angles (in degree) under which to 112 | show the plot. 113 | output : name of the output file. The extension determines the 114 | kind of animation used. 115 | **kwargs: 116 | - width : in inches 117 | - heigth: in inches 118 | - framerate : frames per second 119 | - delay : delay between frames in milliseconds 120 | - repeat : True or False (.gif only) 121 | """ 122 | 123 | output_ext = os.path.splitext(output)[1] 124 | 125 | files = make_views(ax, angles, **kwargs) 126 | 127 | D = {'.mp4': make_movie, 128 | '.ogv': make_movie, 129 | '.gif': make_gif, 130 | '.jpeg': make_strip, 131 | '.png': make_strip} 132 | 133 | D[output_ext](files, output, **kwargs) 134 | 135 | for f in files: 136 | os.remove(f) 137 | 138 | 139 | def plot_adjacency_intervention_mask(model, writer, step, adjacency=None): 140 | if adjacency is None: 141 | adjacency = model.get_adjacency() 142 | if adjacency is None: 143 | return 144 | adjacency_intervention = to_numpy(adjacency) 145 | 146 | feature_dim, action_dim = adjacency_intervention.shape 147 | 148 | fig = plt.figure(figsize=(action_dim * 0.45 + 2, feature_dim * 0.45 + 2)) 149 | 150 | vmax = adjacency[0, -1] 151 | while vmax < 0.1: 152 | vmax = vmax * 10 153 | adjacency_intervention = adjacency_intervention * 10 154 | sns.heatmap(adjacency_intervention, linewidths=3, vmin=0, vmax=vmax, square=True, annot=True, fmt='.2f', cbar=False) 155 | 156 | ax = plt.gca() 157 | ax.tick_params(axis="x", bottom=True, top=True, labelbottom=True, labeltop=True) 158 | 159 | fig.tight_layout() 160 | if writer: 161 | writer.add_figure("adjacency", fig, step + 1) 162 | else: 163 | plt.show() 164 | plt.close("all") 165 | 166 | 167 | def plot_adjacency(adjacency): 168 | 169 | adjacency_intervention = to_numpy(adjacency) 170 | 171 | feature_dim, action_dim = adjacency_intervention.shape 172 | 173 | fig = plt.figure(figsize=(action_dim * 0.45 + 2, feature_dim * 0.45 + 2)) 174 | 175 | vmax = 1 176 | sns.heatmap(adjacency_intervention, linewidths=3, vmin=0, vmax=vmax, square=True, annot=True, cbar=False) 177 | 178 | ax = plt.gca() 179 | ax.tick_params(axis="x", bottom=True, top=True, labelbottom=True, labeltop=True) 180 | 181 | # This is to plot the axis caption 182 | n_action = adjacency.shape[0] 183 | n_reward = adjacency.shape[-1] 184 | 185 | plt.xticks(np.array(range(n_reward)) + 0.5, ["R_up/down", "R_left/right", "R_3", "R_4", "R_5"]) 186 | plt.yticks(np.array(range(n_action)) + 0.5, ["A", "A2", "A3", "A4"], rotation=90) 187 | 188 | fig.tight_layout() 189 | 190 | plt.show() 191 | plt.close("all") 192 | 193 | 194 | if __name__ == '__main__': 195 | adjacency = torch.tensor([ [ 0.871, 0.000, 0.000, 0.000, 0.02], 196 | [ 0.000, 0.868, 0.000, 0.000, 0.02], 197 | [ 0.040, 0.041, 0.000, -0.000, 0.02], 198 | [ 0.000, 0.000, 0.034, 0.000, 0.02], 199 | [ -0.000, 0.000, 0.000, 0.028, 0.02]]) 200 | # |omni ,|head,| arm ,|gr 201 | # adjacency = torch.tensor([[1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], # Reach 202 | # [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0], # EE Local Orientation 203 | # [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], # EE Local Position 204 | # [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # Base Collision 205 | # [1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], # Arm Collision 206 | # [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0], # Self Collision 207 | # [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], # Head Attention 208 | # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],]) # Gripper Grasp 209 | 210 | # adjacency = torch.tensor([ [ 0.871, 0.000, 0.000, 0.000], 211 | # [ 0.000, 0.868, 0.000, 0.000], 212 | # [ 0.040, 0.041, 0.000, -0.000], 213 | # [ 0.000, 0.000, 0.034, 0.000], 214 | # [ -0.000, 0.000, 0.000, 0.028]]) 215 | adjacency = torch.tensor([ [ 1, 0, 0, 0], 216 | [ 0, 1, 0, 0], 217 | [ 1, 1, 0, 0], 218 | [ 0, 0, 1, 0], 219 | [ 0, 0, 0, 1]]).T 220 | adjacency[0, -1] = 1 221 | 222 | # plot_adjacency(adjacency) 223 | plot_adjacency_intervention_mask(None, None, None, adjacency=adjacency) -------------------------------------------------------------------------------- /utils/sum_tree.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rlcode/per/blob/master/SumTree.py 2 | 3 | import numpy as np 4 | 5 | 6 | # SumTree 7 | # a binary tree data structure where the parent’s value is the sum of its children 8 | 9 | 10 | class SumTree: 11 | def __init__(self, capacity): 12 | self.capacity = capacity 13 | self.tree = np.zeros(2 * capacity - 1) 14 | self.write = 0 15 | 16 | # update to the root node 17 | def _propagate(self, idx, change): 18 | parent = (idx - 1) // 2 19 | self.tree[parent] += change 20 | if parent != 0: 21 | self._propagate(parent, change) 22 | 23 | # find sample on leaf node 24 | def _retrieve(self, idx, s): 25 | left = 2 * idx + 1 26 | right = left + 1 27 | 28 | if left >= len(self.tree): 29 | return idx 30 | 31 | if s <= self.tree[left]: 32 | return self._retrieve(left, s) 33 | else: 34 | return self._retrieve(right, s - self.tree[left]) 35 | 36 | def total(self): 37 | return self.tree[0] 38 | 39 | # store priority and sample 40 | def add(self, p): 41 | idx = self.write + self.capacity - 1 42 | 43 | self.update(idx, p) 44 | 45 | self.write += 1 46 | if self.write >= self.capacity: 47 | self.write = 0 48 | 49 | # update priority 50 | def update(self, idx, p): 51 | change = p - self.tree[idx] 52 | self.tree[idx] = p 53 | self._propagate(idx, change) 54 | 55 | # get priority and sample 56 | def get(self, s): 57 | idx = self._retrieve(0, s) 58 | dataIdx = idx - self.capacity + 1 59 | return idx, dataIdx 60 | 61 | def init_tree(self, ps): 62 | assert (self.tree == 0).all() and self.write == 0 63 | assert len(ps) <= self.capacity 64 | self.tree[self.capacity - 1:self.capacity - 1 + len(ps)] = ps 65 | self.write = len(ps) 66 | 67 | last_idx = len(ps) - 1 + self.capacity - 1 68 | last_parent = (last_idx - 1) // 2 69 | for i in reversed(range(last_parent + 1)): 70 | left = 2 * i + 1 71 | right = left + 1 72 | self.tree[i] = self.tree[left] + self.tree[right] 73 | 74 | assert self.total() == ps.sum() 75 | 76 | # 77 | # class ParallelSumTree: 78 | # def __init__(self, num_trees, capacity): 79 | # self.num_trees = num_trees 80 | # self.capacity = capacity 81 | # self.trees = np.zeros((num_trees, 2 * capacity - 1), dtype=np.float64) 82 | # self.write = 0 83 | # self.full = False 84 | # self.tree_idxes = np.arange(num_trees) 85 | # 86 | # self.num_updates = 0 87 | # self.valid_freq = 100000 88 | # 89 | # # update to the root node 90 | # def _propagate(self, idxes, changes): 91 | # # idxes, changes: (num_trees,) 92 | # zeros = np.zeros_like(changes) 93 | # while True: 94 | # idxes = (idxes - 1) // 2 95 | # self.trees[self.tree_idxes, idxes] += changes 96 | # 97 | # finish_props = (idxes <= 0) 98 | # changes = np.where(finish_props, zeros, changes) 99 | # 100 | # if finish_props.all(): 101 | # return 102 | # 103 | # # find sample on leaf node 104 | # def _retrieve(self, idxes, values): 105 | # # idxes, value: (num_trees,) 106 | # tree_len = self.capacity if self.full else self.write 107 | # tree_len += self.capacity - 1 108 | # while True: 109 | # lefts = 2 * idxes + 1 110 | # rights = lefts + 1 111 | # 112 | # found_idxes = lefts >= tree_len 113 | # if found_idxes.all(): 114 | # return idxes 115 | # 116 | # modified_lefts = np.where(found_idxes, idxes, lefts) 117 | # left_values = self.trees[self.tree_idxes, modified_lefts] 118 | # le_lefts = values <= left_values 119 | # idxes = np.where(le_lefts, modified_lefts, rights) 120 | # values = np.where(le_lefts, values, values - left_values) 121 | # 122 | # def total(self): 123 | # # return: (num_trees,) 124 | # return self.trees[:, 0] 125 | # 126 | # # store priority and sample 127 | # def add(self, p): 128 | # raise NotImplementedError 129 | # idx = self.write + self.capacity - 1 130 | # 131 | # self.update(idx, p) 132 | # 133 | # self.write += 1 134 | # if self.write >= self.capacity: 135 | # self.full = True 136 | # self.write = 0 137 | # 138 | # # update priority 139 | # def update(self, idxes, ps): 140 | # # idxes, ps: (num_trees,) 141 | # changes = ps - self.trees[self.tree_idxes, idxes] 142 | # self.trees[self.tree_idxes, idxes] = ps 143 | # self._propagate(idxes, changes) 144 | # 145 | # self.num_updates += 1 146 | # if self.num_updates % self.valid_freq == 0: 147 | # assert np.allclose(self.total, self.trees[:, self.capacity - 1:].sum(axis=-1)) 148 | # 149 | # # get priority and sample 150 | # def get(self, values): 151 | # # values: (num_trees,) 152 | # idxes = self._retrieve(np.zeros(self.num_trees, dtype=np.int32), values) 153 | # dataIdxes = idxes - self.capacity + 1 154 | # return idxes, dataIdxes 155 | # 156 | # def init_trees(self, ps): 157 | # assert (self.trees == 0).all() and self.write == 0 158 | # assert len(ps) <= self.capacity 159 | # self.trees[0, self.capacity - 1:self.capacity - 1 + len(ps)] = ps 160 | # self.write = len(ps) % self.capacity 161 | # self.full = len(ps) == self.capacity 162 | # 163 | # last_idx = len(ps) - 1 + self.capacity - 1 164 | # last_parent = (last_idx - 1) // 2 165 | # for i in reversed(range(last_parent + 1)): 166 | # left = 2 * i + 1 167 | # right = left + 1 168 | # self.trees[0, i] = self.trees[0, left] + self.trees[0, right] 169 | # 170 | # self.trees = np.tile(self.trees[0], (self.num_trees, 1)) 171 | # assert (self.total() == ps.sum()).all() 172 | 173 | 174 | class ParallelBatchSumTree: 175 | def __init__(self, num_trees, capacity, batch_size): 176 | self.num_trees = num_trees 177 | self.capacity = capacity 178 | self.batch_size = batch_size 179 | 180 | self.trees = np.zeros((num_trees, 2 * capacity - 1), dtype=np.float64) 181 | self.write = 0 182 | self.full = False 183 | 184 | self.tree_idxes = np.tile(np.arange(num_trees)[:, None], (1, batch_size)) 185 | 186 | # update to the root node 187 | def _propagate(self, idxes, changes): 188 | # idxes, changes: (num_trees, batch_size) 189 | zeros = np.zeros_like(changes) 190 | while True: 191 | idxes = (idxes - 1) // 2 192 | 193 | # similar to self.trees[self.tree_idxes, idxes] += changes but handles repeated inxes 194 | np.add.at(self.trees, (self.tree_idxes, idxes), changes) 195 | 196 | finish_props = (idxes <= 0) 197 | changes = np.where(finish_props, zeros, changes) 198 | 199 | if finish_props.all(): 200 | return 201 | 202 | # find sample on leaf node 203 | def _retrieve(self, idxes, values, monitor): 204 | # idxes, value: (num_trees, batch_size) 205 | tree_len = self.capacity if self.full else self.write 206 | tree_len += self.capacity - 1 207 | 208 | # print(f"tree_len: {tree_len}") 209 | while True: 210 | lefts = 2 * idxes + 1 211 | rights = lefts + 1 212 | 213 | found_idxes = lefts >= tree_len 214 | if found_idxes.all(): 215 | # if np.any(idxes > tree_len): 216 | # import sys 217 | # sys.stdout = sys.__stdout__ 218 | # import ipdb 219 | # ipdb.set_trace() 220 | return idxes 221 | 222 | modified_lefts = np.where(found_idxes, idxes, lefts) 223 | left_values = self.trees[self.tree_idxes, modified_lefts] 224 | epsilon = 1e-8 225 | le_lefts = values <= left_values + epsilon # maybe remove the =? # This might be hacky 226 | idxes = np.where(le_lefts, modified_lefts, rights) 227 | 228 | if monitor: 229 | print("Printing retrieve results...") 230 | print(idxes) 231 | print(values) 232 | print(left_values) 233 | print(le_lefts) 234 | for i in range(100): 235 | print(self.trees[0, 508000 + i * 500: 508010+i*500]) 236 | exit() 237 | 238 | values = np.where(le_lefts, values, values - left_values) 239 | 240 | def total(self): 241 | return self.trees[:, 0] 242 | 243 | # store priority and sample 244 | def add(self, p): 245 | raise NotImplementedError 246 | idx = self.write + self.capacity - 1 247 | 248 | self.update(idx, p) 249 | 250 | self.write += 1 251 | if self.write >= self.capacity: 252 | self.full = True 253 | self.write = 0 254 | 255 | # update priority 256 | def update(self, idxes, ps): 257 | # idxes, ps: (num_trees, batch_size) 258 | changes = ps - self.trees[self.tree_idxes, idxes] 259 | self.trees[self.tree_idxes, idxes] = ps 260 | self._propagate(idxes, changes) 261 | 262 | # get priority and sample 263 | def get(self, values, monitor=False): 264 | # values: (num_trees, batch_size) 265 | idxes = self._retrieve(np.zeros((self.num_trees, self.batch_size), dtype=np.int32), values, monitor) 266 | dataIdxes = idxes - self.capacity + 1 267 | return idxes, dataIdxes 268 | 269 | def init_trees(self, ps): 270 | assert (self.trees == 0).all() and self.write == 0 271 | assert len(ps) <= self.capacity 272 | self.trees[:, self.capacity - 1:self.capacity - 1 + len(ps)] = ps 273 | self.write = len(ps) % self.capacity 274 | self.full = len(ps) == self.capacity 275 | 276 | last_idx = len(ps) - 1 + self.capacity - 1 277 | last_parent = (last_idx - 1) // 2 278 | for i in reversed(range(last_parent + 1)): 279 | left = 2 * i + 1 280 | right = left + 1 281 | self.trees[:, i] = self.trees[:, left] + self.trees[:, right] 282 | 283 | assert (self.total() == ps.sum()).all() -------------------------------------------------------------------------------- /causal_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Supports minigrid and igibson inference 3 | - sparse minigrid has sparse reward and a single reward term 4 | - full minigrid has 5 reward terms 5 | - Igibson-discrete has 3 reward terms and 11 action dimensions 6 | - Igibson-continuous has 5 reward terms and 11 action dimensions 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | import argparse 15 | np.set_printoptions(precision=3, suppress=True) 16 | torch.set_printoptions(precision=3, sci_mode=False) 17 | 18 | from model.inference_cmi import InferenceCMI 19 | # from model.contrastive_cmi import ContrastiveCMI 20 | 21 | from utils.utils import TrainingParams, Logger, set_seed_everywhere, get_start_step_from_model_loading 22 | from utils.replay_buffer import ReplayBuffer, ParallelPrioritizedReplayBuffer 23 | from utils.plot import plot_adjacency_intervention_mask 24 | 25 | def train(params): 26 | device = torch.device("cuda:{}".format(params.cuda_id) if torch.cuda.is_available() else "cpu") 27 | if torch.cuda.device_count() < 2: 28 | device = "cpu" 29 | if params.domain == "minigrid": 30 | if params.mini_env_name == "sparse": 31 | params.action_part_dim = 4 32 | params.reward_dim = 1 33 | params.action_feature_inner_dim = [3, 3, 3, 3] 34 | params.reward_feature_inner_dim = [2] 35 | params.continuous_action = True 36 | params.continuous_reward = False 37 | params.convert_data_onehot = True 38 | params.obs_dim = 1 39 | elif params.mini_env_name == "full": 40 | params.action_part_dim = 4 41 | params.reward_dim = 5 42 | params.action_feature_inner_dim = [3, 3, 3, 3] 43 | params.reward_feature_inner_dim = [3, 3, 2, 2, 2] 44 | params.continuous_action = True # Manually specify, since not loading from env 45 | params.continuous_reward = False 46 | params.convert_data_onehot = True 47 | params.obs_dim = 1 48 | else: 49 | raise NotImplementedError 50 | elif params.domain == "igibson": 51 | if params.igibson_reward_type == "discrete": 52 | params.action_part_dim = 11 53 | params.reward_dim = 3 # three collisions 54 | params.reward_feature_inner_dim = [2, 2, 2] 55 | params.continuous_action = True 56 | params.continuous_reward = False 57 | params.convert_data_onehot = True 58 | params.obs_dim = 2 # scan & task_obs 59 | params.obs_ind_dim = 20 60 | elif params.igibson_reward_type == "continuous": 61 | params.action_part_dim = 11 62 | params.reward_dim = 5 # except for the three collisions 63 | params.continuous_action = True 64 | params.continuous_reward = True 65 | params.convert_data_onehot = False 66 | params.obs_dim = 2 # scan & task_obs 67 | params.obs_ind_dim = 20 68 | 69 | if not params.continuous_reward: 70 | assert(len(params.reward_feature_inner_dim) == params.reward_dim) 71 | 72 | params.ind_action_dim = 1 # Max number of actions we are grouping together as an individual action 73 | params.ind_reward_dim = 1 74 | rb_path = os.path.join("data", params.rb_path) 75 | 76 | set_seed_everywhere(params.seed) 77 | 78 | params.device = device 79 | training_params = params.training_params 80 | replay_buffer_params = training_params.replay_buffer_params 81 | inference_params = params.inference_params 82 | contrastive_params = params.contrastive_params 83 | 84 | # init replay buffer 85 | use_prioritized_buffer = replay_buffer_params.prioritized_buffer 86 | if use_prioritized_buffer: 87 | replay_buffer = ParallelPrioritizedReplayBuffer(params, rb_path) 88 | else: 89 | replay_buffer = ReplayBuffer(params, rb_path) 90 | 91 | # init model 92 | encoder = None 93 | decoder = None 94 | 95 | inference_algo = params.training_params.inference_algo 96 | # For now, contrastive_cmi 97 | use_contrastive = "contrastive" in inference_algo 98 | if inference_algo == "cmi": 99 | Inference = InferenceCMI 100 | elif inference_algo == "contrastive_cmi": 101 | # Inference = ContrastiveCMI 102 | raise NotImplementedError 103 | else: 104 | raise NotImplementedError 105 | inference = Inference(encoder, decoder, params) 106 | 107 | start_step = get_start_step_from_model_loading(params) 108 | total_step = training_params.total_step 109 | num_inference_opt_steps = training_params.num_inference_opt_steps 110 | 111 | # init saving 112 | writer = None 113 | if num_inference_opt_steps or num_inference_opt_steps: 114 | writer = SummaryWriter(os.path.join(params.rslts_dir, "tensorboard")) 115 | model_dir = os.path.join(params.rslts_dir, "trained_models") 116 | os.makedirs(model_dir, exist_ok=True) 117 | 118 | for step in range(start_step, total_step): 119 | is_init_stage = step < training_params.init_step 120 | 121 | loss_details = {"inference": [], 122 | "inference_eval": [], 123 | "policy": []} 124 | 125 | # training and logging 126 | if is_init_stage: 127 | continue 128 | 129 | if num_inference_opt_steps > 0: 130 | inference_batch_size = contrastive_params.batch_size if use_contrastive else inference_params.batch_size 131 | inference.train() 132 | inference.setup_annealing(step) 133 | for i_grad_step in range(num_inference_opt_steps): 134 | action_batch, obss_batch, rewards_batch, idxes_batch = \ 135 | replay_buffer.sample_inference(inference_batch_size, "train") 136 | loss_detail = inference.update(action_batch, obss_batch, rewards_batch) 137 | 138 | if use_prioritized_buffer: 139 | replay_buffer.update_priorties(idxes_batch, loss_detail["priority"], "inference") 140 | if params.domain == "minigrid" and params.mini_env_name == "full": 141 | if (step+1) % 1000 == 0: 142 | reward_dim_we_care = 3 # only support 3 or 4 143 | extracted_num = 3 144 | priority_dim = loss_detail["priority"][reward_dim_we_care].cpu() 145 | values, top_idx = torch.topk(priority_dim, extracted_num) 146 | print(f"top priority for dim {reward_dim_we_care}: {values}") 147 | top_batch_idx = idxes_batch[reward_dim_we_care][top_idx] - \ 148 | params.training_params.replay_buffer_params.capacity + 1 149 | print(f"center grid") 150 | print(replay_buffer.scans[top_batch_idx][..., 0, 3, 3]) # This is the central grid 151 | non_empty_neighbor = (replay_buffer.scans[top_batch_idx][..., 0, 2:5, 2:5] != 1 152 | ).reshape(3, -1).sum(axis=1) % 3 153 | print(f"ideal action: \n{non_empty_neighbor}") 154 | print("actions:") 155 | print(replay_buffer.actions[top_batch_idx][:, reward_dim_we_care-1]) 156 | print("rewards:") 157 | print(replay_buffer.rewards[top_batch_idx][:, reward_dim_we_care]) 158 | 159 | if (step + 1) % 100 == 0: 160 | reward_dim_we_care = 3 # only support 3 or 4 161 | data_idxes = idxes_batch[reward_dim_we_care] - params.training_params.replay_buffer_params.capacity + 1 162 | special_r_list = replay_buffer.rewards[data_idxes][:, reward_dim_we_care] == -5 163 | print(f"\n Number of special reward datapoint is {np.sum(special_r_list)} \n") 164 | 165 | loss_details["inference"].append(loss_detail) 166 | 167 | inference.eval() 168 | if (step + 1) % training_params.eval_freq == 0: 169 | if params.train_mask: 170 | eval_data_part = "train" 171 | else: 172 | eval_data_part = "eval" 173 | action_batch, obss_batch, rewards_batch, _ = \ 174 | replay_buffer.sample_inference(inference_batch_size, use_part=eval_data_part) 175 | loss_detail = inference.update(action_batch, obss_batch, rewards_batch, eval=True) 176 | loss_details["inference_eval"].append(loss_detail) 177 | print("{}/{}, init_stage: {}".format(step + 1, total_step, is_init_stage)) 178 | cur_adj = inference.get_adjacency()[:, :-params.obs_dim] 179 | max_act = torch.max(cur_adj, dim=1) 180 | print(f"current adjacency: \n{cur_adj}") 181 | print(f"max action: {max_act}") 182 | print(f"current mask: \n{inference.get_mask()}") 183 | 184 | if writer is not None: 185 | for module_name, module_loss_detail in loss_details.items(): 186 | if not module_loss_detail: 187 | continue 188 | # list of dict to dict of list 189 | if isinstance(module_loss_detail, list): 190 | keys = set().union(*[dic.keys() for dic in module_loss_detail]) 191 | module_loss_detail = {k: [dic[k].item() for dic in module_loss_detail if k in dic] 192 | for k in keys if k not in ["priority"]} 193 | for loss_name, loss_values in module_loss_detail.items(): 194 | writer.add_scalar("{}/{}".format(module_name, loss_name), np.mean(loss_values), step) 195 | 196 | if (step + 1) % training_params.plot_freq == 0 and num_inference_opt_steps > 0: 197 | plot_adjacency_intervention_mask(inference, writer, step) 198 | 199 | if (step + 1) % training_params.saving_freq == 0: 200 | if num_inference_opt_steps > 0: 201 | inference.save(os.path.join(model_dir, "inference_{}".format(step + 1))) 202 | 203 | 204 | if __name__ == "__main__": 205 | parser = argparse.ArgumentParser(description='Command line arguments.') 206 | parser.add_argument('--config', type=str, 207 | help='params to load from') 208 | args = parser.parse_args() 209 | params = TrainingParams(training_params_fname=args.config, train=True) 210 | train(params) 211 | -------------------------------------------------------------------------------- /utils/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from joblib import Parallel, delayed 5 | 6 | from utils.utils import to_numpy 7 | from utils.sum_tree import SumTree, ParallelBatchSumTree 8 | 9 | import pickle 10 | import random 11 | import sklearn.preprocessing 12 | 13 | 14 | def take(array, start, end): 15 | """ 16 | get array[start:end] in a circular fashion which turns out to be expensive... 17 | """ 18 | # if start >= end: 19 | # end += len(array) 20 | # idxes = np.arange(start, end) % len(array) 21 | # return array[idxes] 22 | return array[start:end] 23 | 24 | 25 | def assign(array, start, end, value): 26 | if start >= end: 27 | end += len(array) 28 | idxes = np.arange(start, end) % len(array) 29 | array[idxes] = value 30 | 31 | 32 | class ReplayBuffer: 33 | def __init__(self, params, file_path): 34 | # load in the data, split into train and test 35 | with open(file_path, "rb") as fp: # Pickling 36 | data_list = pickle.load(fp) 37 | 38 | # random.shuffle(data_list, lambda: 0.5) 39 | 40 | if params.domain == "minigrid": 41 | scans, actions, rewards = data_list 42 | self.scans = np.array([a["image"] for a in scans]) 43 | self.scans = np.transpose(self.scans, (0, 3, 1, 2)) 44 | self.actions = np.array(actions) 45 | self.rewards = np.array(rewards) 46 | self.data_length = self.rewards.shape[0] 47 | if params.convert_data_onehot: 48 | # a list of tensors, [(bs, num_pred_steps, feature_i_dim)] * feature_dim 49 | rewards_list = [] 50 | for i in range(self.rewards.shape[1]): 51 | label_binarizer = sklearn.preprocessing.LabelBinarizer() 52 | unique_values = np.unique(self.rewards[:, i]) 53 | assert len(unique_values) <= params.reward_feature_inner_dim[i] 54 | label_binarizer.fit(unique_values) 55 | out = label_binarizer.transform(self.rewards[:, i]) 56 | if params.reward_feature_inner_dim[i] == 2: 57 | out = np.hstack((out, 1 - out)) 58 | rewards_list.append(out) 59 | self.rewards = rewards_list 60 | 61 | if params.domain == "igibson": 62 | self.data_length = len(data_list) 63 | self.scans = np.array([a[0]["scan"] for a in data_list]) 64 | self.scans = np.transpose(self.scans, (0, 2, 1)) # (total, 220, 1) - > (total, 1, 220) 65 | print(f"scan shape: {self.scans.shape}") 66 | self.actions = np.array([a[1] for a in data_list]) # (total, action_dim) 67 | self.task_obs = np.array([a[0]["task_obs"] for a in data_list]) 68 | 69 | if params.igibson_reward_type == "discrete": 70 | # shape: (total, 1) 71 | base_collision = np.array([a[2]["base_collision"] for a in data_list]).reshape(self.data_length, -1) 72 | arm_collision = np.array([a[2]["arm_collision"] for a in data_list]).reshape(self.data_length, -1) 73 | self_collision = np.array([a[2]["self_collision"] for a in data_list]).reshape(self.data_length, -1) 74 | 75 | self.rewards = np.concatenate([base_collision, arm_collision, self_collision], axis=1).astype(int) 76 | if params.convert_data_onehot: 77 | # a list of tensors, [(bs, num_pred_steps, feature_i_dim)] * feature_dim 78 | rewards_list = [] 79 | for i in range(self.rewards.shape[1]): 80 | reward_dim = self.rewards[:, i] 81 | n_values = params.reward_feature_inner_dim[i] 82 | rewards_list.append(np.eye(n_values)[reward_dim]) 83 | self.rewards = rewards_list 84 | elif params.igibson_reward_type == "continuous": 85 | self.rewards = np.array([np.append(a[3][:3],a[3][6:]) for a in data_list]) # (total, reward_dim) 86 | 87 | self.train_test_split = int(0.8 * self.data_length) 88 | self.device = params.device 89 | self.one_hot_reward = params.convert_data_onehot 90 | self.domain = params.domain 91 | 92 | 93 | def sample_inference(self, batch_size, use_part="all"): 94 | ''' 95 | return: actions, obss, rewards 96 | - actions: (bs, 1, action_size) 97 | - obss: (bs, 1, 220, 1) 98 | - reward: (bs, 1, r_size) 99 | ''' 100 | if use_part == "all": 101 | st_idx = 0 102 | ed_idx = self.data_length 103 | elif use_part == "train": 104 | st_idx = 0 105 | ed_idx = self.train_test_split 106 | elif use_part == "eval": 107 | st_idx = self.train_test_split 108 | ed_idx = self.data_length 109 | else: 110 | raise NotImplementedError 111 | 112 | idxs = np.random.choice(np.arange(st_idx, ed_idx), size=batch_size, replace=False).astype(int) 113 | actions, observations, rewards = self.sample_with_idx(idxs) 114 | return actions, observations, rewards, idxs 115 | 116 | def sample_with_idx(self, idxs): 117 | 118 | actions = self.actions[idxs] 119 | 120 | scans = self.scans[idxs] 121 | 122 | # Convert them to proper device & format 123 | # assumes: if self.continuous_action else torch.int64 124 | actions = torch.tensor(actions, 125 | dtype=torch.float32 , device=self.device).unsqueeze(dim=1) 126 | scans = torch.tensor(scans, 127 | dtype=torch.float32, device=self.device).unsqueeze(dim=1) 128 | if self.domain == "igibson": 129 | task_obs = self.task_obs[idxs] 130 | task_obs = torch.tensor(task_obs, 131 | dtype=torch.float32, device=self.device).unsqueeze(dim=1) 132 | observations = [scans, task_obs] 133 | elif self.domain == "minigrid": 134 | observations = scans 135 | else: 136 | raise NotImplementedError 137 | 138 | if self.one_hot_reward: 139 | rewards = [] 140 | for i in range(len(self.rewards)): 141 | reward_tensor = torch.tensor(self.rewards[i][idxs], 142 | dtype=torch.float32, device=self.device).unsqueeze(dim=1) 143 | rewards.append(reward_tensor) 144 | else: 145 | rewards = self.rewards[idxs] 146 | rewards = torch.tensor(rewards, 147 | dtype=torch.float32, device=self.device).unsqueeze(dim=1) 148 | 149 | return actions, observations, rewards 150 | 151 | 152 | class ParallelPrioritizedReplayBuffer(ReplayBuffer): 153 | def __init__(self, params, file_path): 154 | self.capacity = capacity = params.training_params.replay_buffer_params.capacity 155 | self.reward_dim = params.reward_dim 156 | self.inference_batch_size = params.contrastive_params.batch_size 157 | self.inference_train_trees = ParallelBatchSumTree(self.reward_dim, capacity, self.inference_batch_size) 158 | 159 | self.alpha = params.training_params.replay_buffer_params.prioritized_alpha 160 | self.max_priority = 1 161 | self.num_observation_steps = params.training_params.num_observation_steps 162 | self.num_inference_pred_steps = params.contrastive_params.num_pred_steps 163 | 164 | super(ParallelPrioritizedReplayBuffer, self).__init__(params, file_path) 165 | 166 | assert (self.capacity >= self.data_length) 167 | 168 | # Only the train data has priority 169 | train_priorities = np.zeros(self.data_length) 170 | train_priorities[:self.train_test_split] = 1 171 | self.inference_train_trees.init_trees(train_priorities) 172 | 173 | def update_priorties(self, idxes, probs, type): 174 | if isinstance(probs, torch.Tensor): 175 | probs = to_numpy(probs) 176 | 177 | # probs = np.minimum(probs ** self.alpha, self.max_priority) 178 | probs = np.clip(probs ** self.alpha, 1e-4, self.max_priority) 179 | 180 | if type == "inference": 181 | trees = self.inference_train_trees 182 | # idxes, probs: (feature_dim, bs) 183 | trees.update(idxes, probs) 184 | else: 185 | raise NotImplementedError 186 | 187 | def sample_idxes_from_parallel_trees(self, batch_size, num_steps): 188 | # - self.max_priority * num_steps to avoid infinite loop of sampling the newly added sample 189 | trees = self.inference_train_trees 190 | segment = trees.total() / batch_size # (feature_dim,) # removed - self.max_priority * num_steps) 191 | 192 | s = np.random.uniform(size=(self.reward_dim, batch_size)) + np.arange(batch_size) 193 | s = s * segment[:, None] # (feature_dim, batch_size) 194 | 195 | # no need to validate idxes because we pre-set priorities of non-valid idxes to 0 196 | # tree_idxes, data_idxes: (feature_dim, batch_size) 197 | tree_idxes, data_idxes = trees.get(s) 198 | 199 | if np.any(data_idxes > 507999): 200 | tree_idxes, data_idxes = trees.get(s, monitor=True) 201 | exit() 202 | import sys 203 | sys.stdout = sys.__stdout__ 204 | import ipdb 205 | ipdb.set_trace() 206 | 207 | data_idxes = np.array(data_idxes).flatten() # (feature_dim * batch_size) 208 | return tree_idxes, data_idxes 209 | 210 | 211 | def sample_inference(self, batch_size, use_part="all"): 212 | num_steps = self.num_inference_pred_steps 213 | batch_size = self.inference_batch_size 214 | reward_dim = self.reward_dim 215 | 216 | if use_part == "train": 217 | # size: (reward_dim * batch_size) 218 | tree_idxes, data_idxes = self.sample_idxes_from_parallel_trees(batch_size, num_steps) 219 | actions, scans, rewards = self.sample_with_idx(data_idxes) 220 | actions_shape = actions.shape[2:] 221 | scans_shape = scans.shape[2:] 222 | rewards_shape = rewards.shape[2:] 223 | # TODO: reshape? desired out: (reward_dim, bs, 1, n) 224 | actions = actions.reshape([reward_dim, batch_size, self.num_inference_pred_steps, *actions_shape]) 225 | scans = scans.reshape([reward_dim, batch_size, self.num_inference_pred_steps, *scans_shape]) 226 | rewards = rewards.reshape([reward_dim, batch_size, self.num_inference_pred_steps, *rewards_shape]) 227 | 228 | else: 229 | actions, scans, rewards, tree_idxes = super(ParallelPrioritizedReplayBuffer, self).sample_inference(batch_size, use_part) 230 | 231 | return actions, scans, rewards, tree_idxes 232 | 233 | 234 | -------------------------------------------------------------------------------- /model/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | 9 | from torch.distributions.normal import Normal 10 | from torch.distributions.distribution import Distribution 11 | from torch.distributions.one_hot_categorical import OneHotCategorical 12 | from torch.distributions.kl import kl_divergence 13 | 14 | from utils.utils import to_numpy 15 | 16 | 17 | class Inference(nn.Module): 18 | def __init__(self, encoder, decoder, params): 19 | super(Inference, self).__init__() 20 | 21 | self.encoder = encoder 22 | self.decoder = decoder 23 | 24 | self.params = params 25 | self.device = device = params.device 26 | self.inference_params = inference_params = params.inference_params 27 | self.training_params = training_params = params.training_params 28 | 29 | self.residual = inference_params.residual 30 | self.log_std_min = inference_params.log_std_min 31 | self.log_std_max = inference_params.log_std_max 32 | self.continuous_action = params.continuous_action 33 | self.continuous_reward = params.continuous_reward 34 | 35 | self.object_level_obs = training_params.object_level_obs 36 | self.num_observation_steps = training_params.num_observation_steps 37 | 38 | self.init_model() 39 | self.reset_params() 40 | 41 | self.abstraction_quested = False 42 | 43 | self.to(device) 44 | self.optimizer = optim.Adam(self.parameters(), lr=inference_params.lr) 45 | 46 | self.load(params.training_params.load_inference, device) 47 | self.train() 48 | 49 | def init_model(self): 50 | raise NotImplementedError 51 | 52 | def reset_params(self): 53 | pass 54 | 55 | def forward_step(self, features, action): 56 | """ 57 | :param features: 58 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 59 | else: [(bs, num_observation_steps, feature_i_dim)] * feature_dim 60 | notice that bs can be a multi-dimensional batch size 61 | :param action: (bs, action_dim) 62 | :return: next step value for all state variables in the format of distribution, 63 | if observation space is continuous: a Normal distribution of shape (bs, feature_dim) 64 | else: a list of distributions, [OneHotCategorical / Normal] * feature_dim, each of shape (bs, feature_i_dim) 65 | """ 66 | raise NotImplementedError 67 | 68 | def forward_step_abstraction(self, abstraction_features, action): 69 | """ 70 | :param abstraction_features: 71 | if observation space is continuous: (bs, num_observation_steps, abstraction_feature_dim). 72 | else: [(bs, num_observation_steps, feature_i_dim)] * abstraction_feature_dim 73 | notice that bs can be a multi-dimensional batch size 74 | :param action: (bs, action_dim) 75 | :return: next step value for all abstracted state variables in the format of distribution, 76 | if observation space is continuous: a Normal distribution of shape (bs, abstraction_feature_dim) 77 | else: a list of distributions, [OneHotCategorical / Normal] * abstraction_feature_dim, 78 | each of shape (bs, feature_i_dim) 79 | """ 80 | raise NotImplementedError 81 | 82 | def cat_features(self, features, next_feature): 83 | """ 84 | :param features: 85 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 86 | else: [(bs, num_observation_steps, feature_i_dim)] * feature_dim 87 | notice that bs can be a multi-dimensional batch size 88 | :param next_feature: 89 | if observation space is continuous: (bs, feature_dim). 90 | else: [(bs, feature_i_dim)] * feature_dim 91 | :return: 92 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 93 | else: [(bs, num_observation_steps, feature_i_dim)] * feature_dim 94 | """ 95 | raise NotImplementedError 96 | if features is None and next_feature is None: 97 | # for cmi 98 | return None 99 | 100 | if self.continuous_action: 101 | features = torch.cat([features[..., 1:, :], next_feature.unsqueeze(dim=-2)], dim=-2) 102 | else: 103 | features = [torch.cat([features_i[..., 1:, :], next_feature_i.unsqueeze(dim=-2)], dim=-2) 104 | for features_i, next_feature_i in zip(features, next_feature)] 105 | return features 106 | 107 | def stack_dist(self, dist_list): 108 | """ 109 | list of distribution at different time steps to a single distribution stacked at dim=-2 110 | :param dist_list: 111 | if observation space is continuous: [Normal] * num_pred_steps, each of shape (bs, feature_dim) 112 | else: [[OneHotCategorical / Normal] * feature_dim] * num_pred_steps, each of shape (bs, feature_i_dim) 113 | notice that bs can be a multi-dimensional batch size 114 | :return: 115 | if observation space is continuous: Normal distribution of shape (bs, num_pred_steps, feature_dim) 116 | else: [OneHotCategorical / Normal] * feature_dim, each of shape (bs, num_pred_steps, feature_i_dim) 117 | """ 118 | if self.continuous_reward: 119 | mu = torch.stack([dist.mean for dist in dist_list], dim=-2) # (bs, num_pred_steps, feature_dim) 120 | std = torch.stack([dist.stddev for dist in dist_list], dim=-2) # (bs, num_pred_steps, feature_dim) 121 | return Normal(mu, std) 122 | else: 123 | # [(bs, num_pred_steps, feature_i_dim)] 124 | stacked_dist_list = [] 125 | for i, dist_i in enumerate(dist_list[0]): 126 | if isinstance(dist_i, Normal): 127 | # (bs, num_pred_steps, feature_i_dim) 128 | mu = torch.stack([dist[i].mean for dist in dist_list], dim=-2) 129 | std = torch.stack([dist[i].stddev for dist in dist_list], dim=-2) 130 | stacked_dist_i = Normal(mu, std) 131 | elif isinstance(dist_i, OneHotCategorical): 132 | # (bs, num_pred_steps, feature_i_dim) 133 | logits = torch.stack([dist[i].logits for dist in dist_list], dim=-2) 134 | stacked_dist_i = OneHotCategorical(logits=logits) 135 | else: 136 | raise NotImplementedError 137 | stacked_dist_list.append(stacked_dist_i) 138 | 139 | return stacked_dist_list 140 | 141 | def sample_from_distribution(self, dist): 142 | """ 143 | sample from the distribution 144 | :param dist: 145 | if observation space is continuous: Normal distribution of shape (bs, feature_dim). 146 | else: [OneHotCategorical / Normal] * feature_dim, each of shape (bs, feature_i_dim) 147 | notice that bs can be a multi-dimensional batch size 148 | :return: 149 | if observation space is continuous: (bs, feature_dim) 150 | else: [(bs, feature_i_dim)] * feature_dim 151 | """ 152 | if dist is None: 153 | # for cmi 154 | return None 155 | 156 | if self.continuous_reward: 157 | return dist.rsample() if self.training else dist.mean 158 | else: 159 | sample = [] 160 | for dist_i in dist: 161 | if isinstance(dist_i, Normal): 162 | sample_i = dist_i.rsample() if self.training else dist_i.mean 163 | elif isinstance(dist_i, OneHotCategorical): 164 | logits = dist_i.logits 165 | if self.training: 166 | sample_i = F.gumbel_softmax(logits, hard=True) 167 | else: 168 | sample_i = F.one_hot(torch.argmax(logits, dim=-1), logits.size(-1)).float() 169 | else: 170 | raise NotImplementedError 171 | sample.append(sample_i) 172 | return sample 173 | 174 | def log_prob_from_distribution(self, dist, value): 175 | """ 176 | calculate log_prob of value from the distribution 177 | :param dist: 178 | if observation space is continuous: Normal distribution of shape (bs, feature_dim). 179 | else: [OneHotCategorical / Normal] * feature_dim, each of shape (bs, feature_i_dim) 180 | notice that bs can be a multi-dimensional batch size 181 | :param value: 182 | if observation space is continuous: (bs, feature_dim). 183 | else: [(bs, feature_i_dim)] * feature_dim 184 | :return: (bs, feature_dim) 185 | """ 186 | if self.continuous_reward: 187 | return dist.log_prob(value) 188 | else: 189 | log_prob = [] 190 | for dist_i, val_i in zip(dist, value): 191 | log_prob_i = dist_i.log_prob(val_i) 192 | if isinstance(dist_i, Normal) and not self.object_level_obs: 193 | log_prob_i = log_prob_i.squeeze(dim=-1) 194 | log_prob.append(log_prob_i) 195 | return torch.cat(log_prob, dim=-1) if self.object_level_obs else torch.stack(log_prob, dim=-1) 196 | 197 | def forward_with_feature(self, features, actions, abstraction_mode=False): 198 | """ 199 | :param features: 200 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 201 | else: [(bs, num_observation_steps, feature_i_dim)] * feature_dim 202 | notice that bs can be a multi-dimensional batch size 203 | :param actions: 204 | if observation space is continuous: (bs, num_pred_steps, action_dim) 205 | else: (bs, num_pred_steps, 1) 206 | :param abstraction_mode: whether to only forward controllable & action-relevant state variables, 207 | used for model-based roll-out 208 | :return: next step value for all (abstracted) state variables in the format of distribution, 209 | if observation space is continuous: a Normal distribution of shape (bs, feature_dim) 210 | else: a list of distributions, [OneHotCategorical / Normal] * feature_dim, each of shape (bs, feature_i_dim) 211 | """ 212 | 213 | if abstraction_mode and self.abstraction_quested: 214 | if self.continuous_action: 215 | features = features[..., self.abstraction_idxes] 216 | else: 217 | features = [features[idx] for idx in self.abstraction_idxes] 218 | 219 | if not self.continuous_action: 220 | actions = F.one_hot(actions.squeeze(dim=-1), self.action_dim).float() # (bs, num_pred_steps, action_dim) 221 | actions = torch.unbind(actions, dim=-2) # [(bs, action_dim)] * num_pred_steps 222 | 223 | dists = [] 224 | for action in actions: 225 | if abstraction_mode and self.abstraction_quested: 226 | dist = self.forward_step_abstraction(features, action) 227 | else: 228 | dist = self.forward_step(features, action) 229 | 230 | next_feature = self.sample_from_distribution(dist) 231 | features = self.cat_features(features, next_feature) 232 | dists.append(dist) 233 | dists = self.stack_dist(dists) 234 | 235 | return dists 236 | 237 | def get_feature(self, obs): 238 | feature = self.encoder(obs) 239 | return feature 240 | 241 | def forward(self, obses, actions, abstraction_mode=False): 242 | features = self.get_feature(obses) 243 | return self.forward_with_feature(features, actions, abstraction_mode) 244 | 245 | def setup_annealing(self, step): 246 | pass 247 | 248 | def prediction_loss_from_dist(self, pred_dist, next_feature, keep_variable_dim=False): 249 | """ 250 | calculate prediction loss from the prediction distribution 251 | if use a CNN encoder: prediction loss = KL divergence 252 | else: prediction loss = -log_prob 253 | :param pred_dist: next step value for all state variables in the format of distribution, 254 | if observation space is continuous: 255 | a Normal distribution of shape (bs, num_pred_steps, feature_dim) 256 | else: 257 | a list of distributions, [OneHotCategorical / Normal] * feature_dim, 258 | each of shape (bs, num_pred_steps, feature_i_dim) 259 | :param next_feature: 260 | if use a CNN encoder: 261 | a Normal distribution of shape (bs, num_pred_steps, feature_dim) 262 | elif observation space is continuous: 263 | a tensor of shape (bs, num_pred_steps, feature_dim) 264 | else: 265 | a list of tensors, [(bs, num_pred_steps, feature_i_dim)] * feature_dim 266 | :param keep_variable_dim: whether to keep the dimension of state variables which is dim=-1 267 | :return: (bs, num_pred_steps, feature_dim) if keep_variable_dim else (bs, num_pred_steps) 268 | """ 269 | if isinstance(next_feature, Distribution): 270 | assert isinstance(next_feature, Normal) 271 | next_feature = Normal(next_feature.mean.detach(), next_feature.stddev.detach()) 272 | pred_loss = kl_divergence(next_feature, pred_dist) # (bs, num_pred_steps, feature_dim) 273 | else: 274 | if self.continuous_reward: 275 | next_feature = next_feature.detach() 276 | else: 277 | next_feature = [next_feature_i.detach() for next_feature_i in next_feature] 278 | pred_loss = -self.log_prob_from_distribution(pred_dist, next_feature) # (bs, num_pred_steps, feature_dim) 279 | 280 | if not keep_variable_dim: 281 | pred_loss = pred_loss.sum(dim=-1) # (bs, num_pred_steps) 282 | 283 | return pred_loss 284 | 285 | def backprop(self, loss, loss_detail): 286 | self.optimizer.zero_grad() 287 | loss.backward() 288 | 289 | grad_clip_norm = self.inference_params.grad_clip_norm 290 | if not grad_clip_norm: 291 | grad_clip_norm = np.inf 292 | loss_detail["grad_norm"] = torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip_norm) 293 | 294 | self.optimizer.step() 295 | return loss_detail 296 | 297 | def update(self, obses, actions, next_obses, eval=False): 298 | """ 299 | :param obs: {obs_i_key: (bs, num_observation_steps, obs_i_shape)} 300 | :param actions: (bs, num_pred_steps, action_dim) 301 | :param next_obses: ({obs_i_key: (bs, num_pred_steps, obs_i_shape)} 302 | :return: {"loss_name": loss_value} 303 | """ 304 | features = self.encoder(obses) 305 | next_features = self.encoder(next_obses) 306 | pred_next_dist = self.forward_with_feature(features, actions) 307 | 308 | # prediction loss in the state / latent space 309 | pred_loss = self.prediction_loss_from_dist(pred_next_dist, next_features) # (bs, num_pred_steps) 310 | loss = pred_loss = pred_loss.sum(dim=-1).mean() 311 | loss_detail = {"pred_loss": pred_loss} 312 | 313 | if not eval: 314 | self.backprop(loss, loss_detail) 315 | 316 | return loss_detail 317 | 318 | def update_mask(self, obs, actions, next_obses): 319 | raise NotImplementedError 320 | 321 | def get_state_abstraction(self): 322 | raise NotImplementedError 323 | 324 | def get_adjacency(self): 325 | return None 326 | 327 | def get_intervention_mask(self): 328 | return None 329 | 330 | def get_mask(self): 331 | return torch.cat([self.get_adjacency(), self.get_intervention_mask()], dim=-1) 332 | 333 | def train(self, training=True): 334 | self.training = training 335 | 336 | def eval(self): 337 | self.train(False) 338 | 339 | def save(self, path): 340 | torch.save({"model": self.state_dict(), 341 | "optimizer": self.optimizer.state_dict() 342 | }, path) 343 | 344 | def load(self, path, device): 345 | if path is not None and os.path.exists(path): 346 | print("inference loaded", path) 347 | checkpoint = torch.load(path, map_location=device) 348 | self.load_state_dict(checkpoint["model"]) 349 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 350 | -------------------------------------------------------------------------------- /model/contrastive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | 9 | from torch.distributions.normal import Normal 10 | from torch.distributions.distribution import Distribution 11 | from torch.distributions.categorical import Categorical 12 | 13 | from utils.utils import to_numpy 14 | 15 | 16 | class Contrastive(nn.Module): 17 | def __init__(self, encoder, decoder, params): 18 | super(Contrastive, self).__init__() 19 | 20 | self.encoder = encoder 21 | self.decoder = decoder 22 | 23 | self.params = params 24 | self.device = device = params.device 25 | self.contrastive_params = contrastive_params = params.contrastive_params 26 | 27 | self.continuous_state = params.continuous_action 28 | self.continuous_action = params.continuous_action 29 | self.num_observation_steps = params.training_params.num_observation_steps 30 | self.use_prioritized_buffer = params.training_params.replay_buffer_params.prioritized_buffer 31 | 32 | self.loss_type = contrastive_params.loss_type 33 | self.l2_reg_coef = contrastive_params.l2_reg_coef 34 | self.num_pred_steps = contrastive_params.num_pred_steps 35 | self.gradient_through_all_samples = contrastive_params.gradient_through_all_samples 36 | 37 | self.num_negative_samples = contrastive_params.num_negative_samples 38 | # # (feature_dim,) 39 | # self.delta_feature_min = self.encoder({key: val[0] for key, val in self.params.obs_delta_range.items()}) 40 | # self.delta_feature_max = self.encoder({key: val[1] for key, val in self.params.obs_delta_range.items()}) 41 | 42 | self.num_pred_samples = contrastive_params.num_pred_samples 43 | self.num_pred_iters = contrastive_params.num_pred_iters 44 | self.pred_sigma_init = contrastive_params.pred_sigma_init 45 | self.pred_sigma_shrink = contrastive_params.pred_sigma_shrink 46 | 47 | self.init_model() 48 | self.reset_params() 49 | 50 | self.to(device) 51 | self.optimizer = optim.Adam(self.parameters(), lr=contrastive_params.lr) 52 | 53 | self.load(params.training_params.load_inference, device) 54 | self.train() 55 | 56 | def init_model(self): 57 | raise NotImplementedError 58 | 59 | def reset_params(self): 60 | pass 61 | 62 | def setup_annealing(self, step): 63 | pass 64 | 65 | def cat_features(self, features, next_feature): 66 | """ 67 | :param features: 68 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 69 | else: [(bs, num_observation_steps, feature_i_dim)] * feature_dim 70 | notice that bs can be a multi-dimensional batch size 71 | :param next_feature: 72 | if observation space is continuous: (bs, feature_dim). 73 | else: [(bs, feature_i_dim)] * feature_dim 74 | :return: 75 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 76 | else: [(bs, num_observation_steps, feature_i_dim)] * feature_dim 77 | """ 78 | if self.continuous_state: 79 | features = torch.cat([features[..., 1:, :], next_feature.unsqueeze(dim=-2)], dim=-2) 80 | else: 81 | features = [torch.cat([features_i[..., 1:, :], next_feature_i.unsqueeze(dim=-2)], dim=-2) 82 | for features_i, next_feature_i in zip(features, next_feature)] 83 | return features 84 | 85 | def sample_delta_feature(self, shape, num_samples): 86 | # (bs, num_pred_samples, feature_dim) 87 | uniform_noise = torch.rand(*shape, num_samples, self.feature_dim, dtype=torch.float32, device=self.device) 88 | delta_feature = uniform_noise * (self.delta_feature_max - self.delta_feature_min) + self.delta_feature_min 89 | return delta_feature 90 | 91 | # This is more like a temporary function that only works for collision reward 92 | # return # (bs, num_negative_samples, feature_dim) 93 | def sample_boolean_neg_feature(self, shape, reward): 94 | bs = shape[0] 95 | params = self.params 96 | if params.domain == "igibson": 97 | base = torch.ones(*shape, 1, device=self.device) 98 | neg = base - reward 99 | 100 | # # this will give collision 101 | # num_samples = self.num_negative_samples 102 | # reward_ranges = [[0, 1]] 103 | # assert (len(reward_ranges) == self.reward_dim) 104 | # neg_r_list = [] 105 | # for r_range in reward_ranges: 106 | # neg_r_list.append(np.random.choice(r_range, bs * num_samples)) 107 | # # r_dim * (bs*mnum_samples) 108 | # neg_r_list = np.asarray([neg_r_list]) 109 | # reshaped_neg_r_list = neg_r_list.T.reshape([bs, num_samples, self.reward_dim]) 110 | # neg = torch.tensor(reshaped_neg_r_list, device=self.device) 111 | 112 | elif params.domain == "minigrid": 113 | assert(self.reward_dim == 5) 114 | 115 | # TODO: figure out the most efficient way 116 | # OK here is a hacky implementation 117 | r_collision_range = [[0, 1], [-1, 1], [-1, 0]] 118 | idxes = reward[:,0, :2].type(torch.int) + 1 119 | 120 | with torch.no_grad(): 121 | neg1 = [r_collision_range[ind] for ind in idxes[:, 0]] 122 | neg2 = [r_collision_range[ind] for ind in idxes[:, 1]] 123 | base = torch.ones(*shape, 1, device=self.device) * -5 124 | neg3 = base - reward[:, :, 2:] 125 | 126 | tf_neg1 = torch.tensor(neg1, device=self.device).unsqueeze(-1) 127 | tf_neg2 = torch.tensor(neg2, device=self.device).unsqueeze(-1) 128 | tf_neg3 = neg3.expand([-1, 2, -1]) 129 | 130 | neg = torch.cat((tf_neg1, tf_neg2, tf_neg3), dim=2) 131 | 132 | # # # Previous approach to generate negative samples 133 | # num_samples = self.num_negative_samples 134 | # reward_ranges = [[-1, 0, 1], [-1, 0, 1], [0, -5], [0, -5], [0, -5]] 135 | # assert(len(reward_ranges) == self.reward_dim) 136 | # neg_r_list = [] 137 | # for r_range in reward_ranges: 138 | # neg_r_list.append(np.random.choice(r_range, bs * num_samples)) 139 | # # r_dim * (bs*mnum_samples) 140 | # neg_r_list = np.asarray([neg_r_list]) 141 | # reshaped_neg_r_list = neg_r_list.T.reshape([bs, num_samples, self.reward_dim]) 142 | # neg = torch.tensor(reshaped_neg_r_list, device=self.device) 143 | return neg 144 | 145 | 146 | def get_feature(self, obs): 147 | feature = self.encoder(obs) 148 | return feature 149 | 150 | def forward_step(self, features, action, delta_features): 151 | """ 152 | compute energy 153 | :param features: 154 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 155 | else: NotImplementedError 156 | notice that bs can be a multi-dimensional batch size 157 | :param action: (bs, action_dim) 158 | :param delta_features: 159 | if observation space is continuous: (bs, num_samples, feature_dim). 160 | else: NotImplementedError 161 | :return: energy: (bs, num_samples, feature_dim) 162 | """ 163 | raise NotImplementedError 164 | 165 | def forward_with_feature(self, features, actions, next_features, neg_delta_features=None): 166 | """ 167 | :param features: 168 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 169 | else: NotImplementedError 170 | notice that bs can be a multi-dimensional batch size 171 | :param actions: (bs, num_pred_steps, action_dim) 172 | :param next_features: 173 | if observation space is continuous: (bs, num_pred_steps, feature_dim). 174 | else: NotImplementedError 175 | :param neg_delta_features: 176 | if observation space is continuous: (bs, num_pred_steps, num_negative_samples, feature_dim). 177 | else: NotImplementedError 178 | :return: energy: 179 | if observation space is continuous: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) 180 | else: NotImplementedError 181 | """ 182 | energies = [] 183 | actions = torch.unbind(actions, dim=-2) 184 | next_features = torch.unbind(next_features, dim=-2) 185 | neg_delta_features = torch.unbind(neg_delta_features, dim=-3) 186 | for i, (action, next_feature, neg_delta_features) in enumerate(zip(actions, next_features, neg_delta_features)): 187 | delta_feature = next_feature - features[..., -1, :] # (bs, feature_dim) 188 | delta_feature = delta_feature.unsqueeze(dim=-2) # (bs, 1, feature_dim) 189 | # (bs, 1 + num_negative_samples, feature_dim) 190 | delta_features = torch.cat([delta_feature, neg_delta_features], dim=-2) 191 | energy = self.forward_step(features, action, delta_features) 192 | energies.append(energy) 193 | 194 | if i == len(actions) - 1: 195 | break 196 | 197 | # (bs, num_negative_samples, feature_dim) 198 | neg_energy = energy[..., 1:, :] 199 | if self.gradient_through_all_samples: 200 | # (bs, num_negative_samples, feature_dim) 201 | delta_feature_select = F.gumbel_softmax(neg_energy, dim=-2, hard=True) 202 | delta_feature = (neg_delta_features * delta_feature_select).sum(dim=-2) # (bs, feature_dim) 203 | else: 204 | delta_feature_select = neg_energy.argmax(dim=-2, keepdim=True) # (bs, 1, feature_dim) 205 | delta_feature = torch.gather(neg_delta_features, -2, delta_feature_select) # (bs, 1, feature_dim) 206 | delta_feature = delta_feature[..., 0, :] # (bs, feature_dim) 207 | 208 | pred_next_feature = features[..., -1, :] + delta_feature 209 | features = self.cat_features(features, pred_next_feature) 210 | 211 | # (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) 212 | energies = torch.stack(energies, dim=-3) 213 | return energies 214 | 215 | def forward(self, obses, actions, next_obses, neg_delta_feature=None): 216 | features = self.get_feature(obses) 217 | next_features = self.get_feature(next_obses) 218 | return self.forward_with_feature(features, actions, next_features, neg_delta_feature) 219 | 220 | @staticmethod 221 | def nce_loss(energy): 222 | """ 223 | :param energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) or 224 | (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 225 | :return: 226 | loss: scalar 227 | """ 228 | if energy.ndim == 4: 229 | return -F.log_softmax(energy, dim=-2)[..., 0, :].sum(dim=(-2, -1)).mean() 230 | elif energy.ndim == 5: 231 | return -F.log_softmax(energy, dim=-3)[..., 0, :, :].sum(dim=(-3, -2, -1)).mean() 232 | else: 233 | raise NotImplementedError 234 | 235 | def backprop(self, loss, loss_detail): 236 | self.optimizer.zero_grad() 237 | loss.backward() 238 | 239 | grad_clip_norm = self.contrastive_params.grad_clip_norm 240 | if not grad_clip_norm: 241 | grad_clip_norm = np.inf 242 | loss_detail["grad_norm"] = torch.nn.utils.clip_grad_norm_(self.parameters(), grad_clip_norm) 243 | 244 | self.optimizer.step() 245 | return loss_detail 246 | 247 | def update(self, obses, actions, next_obses, eval=False): 248 | """ 249 | :param obses: {obs_i_key: (bs, num_observation_steps, obs_i_shape)} 250 | :param actions: (bs, num_pred_steps, action_dim) 251 | :param next_obs: ({obs_i_key: (bs, num_pred_steps, obs_i_shape)} 252 | :return: {"loss_name": loss_value} 253 | """ 254 | bs, num_pred_steps = actions.shape[:-2], actions.shape[-2] 255 | # (bs, num_pred_steps, num_negative_samples, feature_dim) 256 | neg_delta_feature = self.sample_delta_feature(bs + (num_pred_steps,), self.num_negative_samples) 257 | # (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) 258 | energy = self.forward(obses, actions, next_obses, neg_delta_feature) 259 | 260 | loss_detail = {} 261 | 262 | if self.loss_type == "contrastive": 263 | loss = self.nce_loss(energy) 264 | elif self.loss_type == "mle": 265 | # (bs, num_pred_steps, feature_dim), (bs, num_pred_steps, num_negative_samples, feature_dim) 266 | pos_energy, neg_energy = energy[..., 0, :], energy[..., 1:, :] 267 | # (bs, num_pred_steps, num_negative_samples, feature_dim) 268 | neg_weight = torch.softmax(neg_energy.detach(), dim=-2) 269 | mle_loss = (pos_energy - (neg_weight * neg_energy).sum(dim=-2)) # (bs, num_pred_steps, feature_dim) 270 | 271 | energy_norm = (pos_energy ** 2 + (neg_energy ** 2).sum(dim=-2)).sum(dim=(-2, -1)).mean() 272 | regularization = self.l2_reg_coef * energy_norm 273 | 274 | loss = -mle_loss.sum(dim=(-2, -1)).mean() + regularization 275 | else: 276 | raise NotImplementedError 277 | 278 | loss_detail["contrastive_loss"] = loss 279 | 280 | if not eval: 281 | self.backprop(loss, loss_detail) 282 | 283 | return loss_detail 284 | 285 | def predict_step_with_feature(self, features, action): 286 | """ 287 | :param features: 288 | if observation space is continuous: (bs, feature_dim). 289 | else: NotImplementedError 290 | notice that bs can be a multi-dimensional batch size 291 | :param action: (bs, action_dim) 292 | :return: pred_next_feature: 293 | if observation space is continuous: (bs, feature_dim). 294 | else: NotImplementedError 295 | """ 296 | bs = action.shape[:-1] 297 | num_pred_samples = self.num_pred_samples 298 | delta_feature_max = self.delta_feature_max 299 | delta_feature_min = self.delta_feature_min 300 | sigma = self.pred_sigma_init 301 | 302 | delta_feature_candidates = self.sample_delta_feature(bs, num_pred_samples) 303 | delta_feature_candidates = torch.sort(delta_feature_candidates, dim=1)[0] 304 | 305 | for i in range(self.num_pred_iters): 306 | # (bs, num_pred_samples, feature_dim) 307 | if self.params.training_params.inference_algo == "contrastive_cmi": 308 | mask = self.get_eval_mask(bs, self.feature_dim - 1) 309 | forward_mode = ("causal",) 310 | 311 | full_energy, mask_energy, mask_cond_energy, causal_energy, causal_cond_energy = \ 312 | self.forward_step(features, action, delta_feature_candidates, forward_mode) 313 | energy = causal_energy 314 | else: 315 | energy = self.forward_step(features, action, delta_feature_candidates) 316 | 317 | if i != self.num_pred_iters - 1: 318 | energy = energy.transpose(-2, -1) # (bs, feature_dim, num_pred_samples) 319 | dist = Categorical(logits=energy) 320 | idxes = dist.sample([num_pred_samples]) # (num_pred_samples, bs, feature_dim) 321 | idxes = idxes.permute(*(np.arange(len(bs)) + 1), 0, -1) # (bs, num_pred_samples, feature_dim) 322 | 323 | # (bs, num_pred_samples, feature_dim) 324 | delta_feature_candidates = torch.gather(delta_feature_candidates, -2, idxes) 325 | noise = torch.randn_like(delta_feature_candidates) * sigma * (delta_feature_max - delta_feature_min) 326 | delta_feature_candidates += noise 327 | delta_feature_candidates = torch.clip(delta_feature_candidates, delta_feature_min, delta_feature_max) 328 | 329 | sigma *= self.pred_sigma_shrink 330 | 331 | argmax_idx = torch.argmax(energy, dim=-2, keepdim=True) # (bs, 1, feature_dim) 332 | # (bs, feature_dim) 333 | delta_feature = torch.gather(delta_feature_candidates, -2, argmax_idx)[..., 0, :] 334 | pred_next_feature = features[..., -1, :] + delta_feature 335 | 336 | return pred_next_feature 337 | 338 | def predict_with_feature(self, features, actions): 339 | pred_next_features = [] 340 | for action in torch.unbind(actions, dim=-2): 341 | pred_next_feature = self.predict_step_with_feature(features, action) 342 | pred_next_features.append(pred_next_feature) 343 | features = self.cat_features(features, pred_next_feature) 344 | return torch.stack(pred_next_features, dim=-2) 345 | 346 | def get_adjacency(self): 347 | return None 348 | 349 | def get_intervention_mask(self): 350 | return None 351 | 352 | def train(self, training=True): 353 | self.training = training 354 | 355 | def eval(self): 356 | self.train(False) 357 | 358 | def save(self, path): 359 | torch.save({"model": self.state_dict(), 360 | "optimizer": self.optimizer.state_dict() 361 | }, path) 362 | 363 | def load(self, path, device): 364 | if path is not None and os.path.exists(path): 365 | print("contrastive loaded", path) 366 | checkpoint = torch.load(path, map_location=device) 367 | self.load_state_dict(checkpoint["model"]) 368 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 369 | 370 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Callable 4 | 5 | import igibson 6 | from igibson.envs.igibson_env import iGibsonEnv 7 | 8 | try: 9 | import gym 10 | import torch as th 11 | import torch.nn as nn 12 | from stable_baselines3 import PPO, FPPO, A2C, SAC, TD3 13 | from stable_baselines3.common.evaluation import evaluate_policy 14 | from stable_baselines3.common.preprocessing import maybe_transpose 15 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 16 | from stable_baselines3.common.utils import set_random_seed 17 | from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor 18 | from stable_baselines3.common.save_util import load_from_zip_file 19 | from stable_baselines3.common.callbacks import CheckpointCallback 20 | 21 | 22 | except ModuleNotFoundError: 23 | print("stable-baselines3 is not installed. ") 24 | exit(1) 25 | 26 | import argparse 27 | import uuid 28 | import yaml 29 | import numpy as np 30 | 31 | """ 32 | Main training code 33 | """ 34 | 35 | 36 | class CustomCombinedExtractor(BaseFeaturesExtractor): 37 | def __init__(self, observation_space: gym.spaces.Dict): 38 | # We do not know features-dim here before going over all the items, 39 | # so put something dummy for now. PyTorch requires calling 40 | # nn.Module.__init__ before adding modules 41 | super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1) 42 | 43 | extractors = {} 44 | 45 | total_concat_size = 0 46 | feature_size = 128 47 | for key, subspace in observation_space.spaces.items(): 48 | if key in ["proprioception", "task_obs"]: 49 | extractors[key] = nn.Sequential(nn.Linear(subspace.shape[0], feature_size), nn.ReLU()) 50 | elif key in ["rgb", "highlight", "depth", "seg", "ins_seg"]: 51 | n_input_channels = subspace.shape[2] # channel last 52 | cnn = nn.Sequential( 53 | nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), 54 | nn.ReLU(), 55 | nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), 56 | nn.ReLU(), 57 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), 58 | nn.ReLU(), 59 | nn.Flatten(), 60 | ) 61 | test_tensor = th.zeros([subspace.shape[2], subspace.shape[0], subspace.shape[1]]) 62 | with th.no_grad(): 63 | n_flatten = cnn(test_tensor[None]).shape[1] 64 | fc = nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU()) 65 | extractors[key] = nn.Sequential(cnn, fc) 66 | elif key in ["scan"]: 67 | n_input_channels = subspace.shape[1] # channel last 68 | cnn = nn.Sequential( 69 | nn.Conv1d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), 70 | nn.ReLU(), 71 | nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=0), 72 | nn.ReLU(), 73 | nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=0), 74 | nn.ReLU(), 75 | nn.Flatten(), 76 | ) 77 | test_tensor = th.zeros([subspace.shape[1], subspace.shape[0]]) 78 | with th.no_grad(): 79 | n_flatten = cnn(test_tensor[None]).shape[1] 80 | fc = nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU()) 81 | extractors[key] = nn.Sequential(cnn, fc) 82 | else: 83 | raise ValueError("Unknown observation key: %s" % key) 84 | total_concat_size += feature_size 85 | 86 | self.extractors = nn.ModuleDict(extractors) 87 | 88 | # Update the features dim manually 89 | self._features_dim = total_concat_size 90 | 91 | def forward(self, observations) -> th.Tensor: 92 | encoded_tensor_list = [] 93 | 94 | # self.extractors contain nn.Modules that do all the processing. 95 | for key, extractor in self.extractors.items(): 96 | if key in ["rgb", "highlight", "depth", "seg", "ins_seg"]: 97 | observations[key] = observations[key].permute((0, 3, 1, 2)) 98 | elif key in ["scan"]: 99 | observations[key] = observations[key].permute((0, 2, 1)) 100 | encoded_tensor_list.append(extractor(observations[key])) 101 | # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension. 102 | return th.cat(encoded_tensor_list, dim=1) 103 | 104 | def get_causal_matrix(reward_channels_dim, env, robot="fetch", fc_causal=False, sparse_causal=False): 105 | if fc_causal: 106 | causal_matrix = th.ones(reward_channels_dim, env.action_space.shape[0]) 107 | elif sparse_causal: 108 | # Sparse causal is the causal matrix we derived from CMI 109 | if robot == "fetch": 110 | causal_matrix = th.tensor([ 111 | [1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0], # Reach 112 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0], # EE Local Orientation 113 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # EE Local Position 114 | [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # Base Collision 115 | [1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0], # Arm Collision 116 | [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0], # Self Collision 117 | [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], # Head Attention 118 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], # Gripper Grasp 119 | ], dtype=th.float32) 120 | elif robot == "hsr": 121 | # base (3), head (2), arm (5), gripper (1) 122 | causal_matrix = th.tensor([ 123 | #|omni ,|head,| arm ,|gr 124 | [1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], # Reach 125 | [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0], # EE Local Orientation 126 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], # EE Local Position 127 | [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # Base Collision 128 | [1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], # Arm Collision 129 | [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0], # Self Collision 130 | [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], # Head Attention 131 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], # Gripper Grasp 132 | ], dtype=th.float32) 133 | else: 134 | # We test assigning causality based on base / arm separation 135 | if robot == "fetch": 136 | # base (3), head (2), arm (5), gripper (1) 137 | causal_matrix = th.tensor([ 138 | #|omni|head | arm , |gr 139 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # Reach 140 | [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], # EE Local Orientation 141 | [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], # EE Local Position 142 | [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], # Base Collision 143 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # Arm Collision 144 | [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], # Self Collision 145 | [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], # Head Attention 146 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # Gripper Grasp 147 | ], dtype=th.float32) 148 | elif robot == "hsr": 149 | # base (3), head (2), arm (5), gripper (1) 150 | causal_matrix = th.tensor([ 151 | # |omni ,|head,| arm ,|gr 152 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # Reach 153 | [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], # EE Local Orientation 154 | [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], # EE Local Position 155 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], # Base Collision 156 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # Arm Collision 157 | [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], # Self Collision 158 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], # Head Attention 159 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # Gripper Grasp 160 | ], dtype=th.float32) 161 | 162 | print(f"causal_matrix: {causal_matrix}") 163 | assert (causal_matrix.shape == (reward_channels_dim, env.action_space.shape[0])) 164 | return causal_matrix 165 | 166 | def main(args, simple_test=False): 167 | """ 168 | Train a RL agent using selected algorithm 169 | on a multi-step reaching task 170 | """ 171 | print("*" * 80 + "\nDescription:" + main.__doc__ + '\n' + "*" * 80) 172 | 173 | # Check string argument 174 | assert args.robot in ["fetch", "hsr"] 175 | assert args.algo_name in ["ppo", "sac", "a2c", "td3", "fppo"] 176 | 177 | if args.robot == "fetch": 178 | config_name = "fetch_reaching.yaml" 179 | elif args.robot == "hsr": 180 | config_name = "hsr_reaching.yaml" 181 | 182 | config_filename = os.path.join(igibson.configs_path, config_name) 183 | config_data = yaml.load(open(config_filename, "r"), Loader=yaml.FullLoader) 184 | device = th.device("cuda:" + str(args.cuda_id) if th.cuda.is_available() else "cpu") 185 | 186 | # Parse the arguments 187 | short_exec = args.short 188 | factored = args.factored 189 | sep_value = args.seperate_v_net 190 | rd_target = args.random_target 191 | load_obstacles = args.load_obstacles 192 | 193 | if factored: 194 | args.algo_name = "fppo" 195 | if args.multi_step: 196 | config_data["task"] = "factored_multistep_reaching_random" 197 | else: 198 | config_data["task"] = "factored_reaching_random" 199 | print("using factored environment") 200 | else: 201 | if args.multi_step: 202 | config_data["task"] = "multistep_reaching_random" 203 | 204 | # We only log the important entrees, for simplicity 205 | log_names = {} 206 | log_parm_names = { 207 | "algoName": args.algo_name, 208 | "robotName": args.robot, 209 | } 210 | if "ppo" in args.algo_name: 211 | log_names = { 212 | "NormAD": args.normalize_advantage, 213 | "FcCausal": args.fc_causal, 214 | "SparseCausal": args.sparse_causal, 215 | } 216 | 217 | log_names["randomEnv"] = args.rand_env 218 | log_names["complexOri"] = not args.simple_orientation 219 | 220 | # Task specific arguments 221 | if not load_obstacles: 222 | config_data["load_room_types"] = "kitchen" 223 | config_data["rd_target"] = rd_target 224 | config_data["simple_orientation"] = args.simple_orientation 225 | config_data["enum_orientation"] = args.enum_orientation 226 | config_data["position_reward"] = not args.no_local_pos_reward 227 | config_data["proportional_local_reward"] = args.proportional_local_reward 228 | 229 | tensorboard_log_dir = "log_dir" 230 | log_base = uuid.uuid4().hex.upper()[:5] # Each log has a unique identifier 231 | for k, v in log_parm_names.items(): 232 | if v is not None: 233 | log_base = log_base + "_" + k + ":" + str(v) 234 | for k, v in log_names.items(): 235 | if v: 236 | log_base = log_base + "_" + k 237 | weight_dir = os.path.join("weight_dir", log_base) 238 | num_environments = 8 if not short_exec else 2 239 | 240 | scene_id_list = ["Rs_int", "Beechwood_0_int", "Merom_0_int", "Wainscott_0_int", 241 | "Ihlen_0_int", "Benevolence_1_int", "Pomaria_1_int", "Ihlen_1_int", ] 242 | 243 | # Function callback to create environments 244 | def make_env(rank: int, seed: int = 0, rand_env=False) -> Callable: 245 | if rand_env: 246 | cur_config = config_data.copy() 247 | cur_config["scene_id"] = scene_id_list[rank] 248 | else: 249 | cur_config = config_data 250 | 251 | def _init() -> iGibsonEnv: 252 | env = iGibsonEnv( 253 | config_file=cur_config, 254 | mode="headless", 255 | action_timestep=1 / 10.0, 256 | physics_timestep=1 / 120.0, 257 | ) 258 | env.seed(seed + rank) 259 | return env 260 | 261 | set_random_seed(seed) 262 | return _init 263 | 264 | # generate a random seed 265 | seed = np.random.randint(200000000) 266 | 267 | # Multiprocess 268 | env = SubprocVecEnv([make_env(i, seed, args.rand_env) for i in range(num_environments)]) 269 | env = VecMonitor(env) 270 | 271 | # Create a new environment for evaluation 272 | eval_env = iGibsonEnv( 273 | config_file=config_data, 274 | mode="headless", 275 | action_timestep=1 / 10.0, 276 | physics_timestep=1 / 120.0, 277 | ) 278 | 279 | # Obtain the arguments/parameters for the policy and create the PPO model 280 | policy_kwargs = dict( 281 | features_extractor_class=CustomCombinedExtractor, 282 | ) 283 | os.makedirs(tensorboard_log_dir, exist_ok=True) 284 | os.makedirs(weight_dir, exist_ok=True) 285 | if simple_test: 286 | print("Perfoming simple testing") 287 | 288 | start_n_eval = 0 289 | final_n_eval = 1 290 | if "ppo" in args.algo_name: 291 | kwargs = {'n_steps': 16} 292 | elif args.algo_name == "a2c": 293 | kwargs = {'n_steps': 16} 294 | else: 295 | kwargs = {} 296 | else: 297 | kwargs = {} 298 | start_n_eval = 10 299 | final_n_eval = 20 300 | 301 | if "ppo" in args.algo_name: 302 | kwargs["clip_range"] = args.clip_range 303 | kwargs["target_kl"] = args.target_kl 304 | if not args.normalize_advantage: 305 | kwargs["normalize_advantage"] = False 306 | kwargs["gae_lambda"] = args.gae_lambda 307 | else: 308 | print("PPO parameters are ignore due to not applicable") 309 | 310 | # learning rate can either be a constant or a schedule 311 | if args.scheduled_lr: 312 | def linear_schedule(initial_value: float) -> Callable[[float], float]: 313 | """ 314 | Linear learning rate schedule. 315 | 316 | :param initial_value: Initial learning rate. 317 | :return: schedule that computes 318 | current learning rate depending on remaining progress 319 | """ 320 | 321 | def func(progress_remaining: float) -> float: 322 | """ 323 | Progress will decrease from 1 (beginning) to 0. 324 | 325 | :param progress_remaining: 326 | :return: current learning rate 327 | """ 328 | return progress_remaining * initial_value 329 | 330 | return func 331 | kwargs["learning_rate"] = linear_schedule(args.learning_rate) 332 | else: 333 | kwargs["learning_rate"] = args.learning_rate 334 | 335 | if factored: 336 | # Factored default to FPPO 337 | reward_channels_dim = 8 338 | causal_matrix = get_causal_matrix(reward_channels_dim, env, robot=args.robot, 339 | fc_causal=args.fc_causal, 340 | sparse_causal=args.sparse_causal) 341 | kwargs["sep_vnet"] = sep_value 342 | kwargs["value_loss_normalization"] = args.normalize_vnet_error 343 | kwargs["value_grad_rescale"] = args.rescale_vnet_grad 344 | kwargs["approx_var_gamma"] = args.normalize_vnet_error 345 | kwargs["episode_length"] = config_data["max_step"] 346 | model = FPPO("MultiInputPolicy", env, reward_channels_dim, causal_matrix, verbose=1, 347 | tensorboard_log=tensorboard_log_dir, policy_kwargs=policy_kwargs, 348 | device=device, **kwargs) 349 | elif args.algo_name == "ppo": 350 | model = PPO("MultiInputPolicy", env, verbose=1, tensorboard_log=tensorboard_log_dir, 351 | device=device, policy_kwargs=policy_kwargs, **kwargs) 352 | elif args.algo_name == "sac": 353 | model = SAC("MultiInputPolicy", env, verbose=1, tensorboard_log=tensorboard_log_dir, 354 | device=device, policy_kwargs=policy_kwargs, **kwargs) 355 | elif args.algo_name == "a2c": 356 | model = A2C("MultiInputPolicy", env, verbose=1, tensorboard_log=tensorboard_log_dir, 357 | device=device, policy_kwargs=policy_kwargs, **kwargs) 358 | elif args.algo_name == "td3": 359 | model = TD3("MultiInputPolicy", env, verbose=1, tensorboard_log=tensorboard_log_dir, 360 | device=device, policy_kwargs=policy_kwargs, **kwargs) 361 | 362 | print(model.policy) 363 | 364 | if start_n_eval > 0: 365 | # Random Agent, evaluation before training 366 | mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=start_n_eval) 367 | print(f"Before Training: Mean reward: {mean_reward} +/- {std_reward:.2f}") 368 | 369 | # Information related to storing weights 370 | # Save a checkpoint every 1000 steps 371 | model_path = os.path.join(weight_dir, "ckpt") 372 | checkpoint_callback = CheckpointCallback( 373 | save_freq=10000, 374 | save_path=weight_dir, 375 | name_prefix="ckpt", 376 | ) 377 | 378 | # Train the model for the given number of steps 379 | total_timesteps = 100 if short_exec else args.training_length 380 | model.learn(total_timesteps, tb_log_name=log_base, callback=checkpoint_callback) 381 | 382 | # Evaluate the policy after training 383 | mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=final_n_eval) 384 | print(f"After Training: Mean reward: {mean_reward} +/- {std_reward:.2f}") 385 | 386 | # Save the trained model and delete it 387 | model.save(model_path) 388 | del model 389 | 390 | 391 | if __name__ == "__main__": 392 | parser = argparse.ArgumentParser(description='Train RL agent for iGibson.') 393 | 394 | parser.add_argument('--short', '-s', action='store_true', 395 | help='whether to execute short') 396 | parser.add_argument('--seperate_v_net', '-sv', action='store_true', 397 | help='whether to seperate the value network') 398 | parser.add_argument('--fc_causal', '-fc', action='store_true') 399 | parser.add_argument('--sparse_causal', '-sc', action='store_true') 400 | parser.add_argument('--cuda_id', '-cid', type=int, default=0) 401 | parser.add_argument('--clip_range', '-cr', type=float, default=0.2) 402 | parser.add_argument('--target_kl', '-tkl', type=float, default=0.15) 403 | parser.add_argument('--scheduled_lr', '-slr', action='store_true') 404 | parser.add_argument('--learning_rate', '-lr', type=float, default=5e-5) 405 | parser.add_argument('--training_length', '-tr', type=int, default=5000000) 406 | parser.add_argument('--gae_lambda', '-gl', type=float, default=0.95) 407 | parser.add_argument('--robot', type=str) 408 | parser.add_argument('--normalize_vnet_error', '-nve', action='store_true') 409 | parser.add_argument('--rescale_vnet_grad', '-rvg', action='store_true') 410 | parser.add_argument('--no_local_pos_reward', '-npr', action='store_true', 411 | help='whether to disable the local position reward') 412 | parser.add_argument('--algo_name', type=str, default="ppo", 413 | help='which baseline algorithm to use, if we are not using factored') 414 | parser.add_argument('--rand_env', action='store_true') 415 | 416 | # For now, these arguments are always true 417 | parser.add_argument('--factored', '-f', action='store_true', 418 | help='whether to factorize action space') 419 | parser.add_argument('--load_obstacles', '-obst', action='store_true') 420 | parser.add_argument('--random_target', '-rdt', action='store_true', 421 | help='whether to have random target for the local eef') 422 | parser.add_argument('--normalize_advantage', '-nad', action='store_true') 423 | parser.add_argument('--simple_orientation', '-sor', action='store_true', 424 | help='whether to simplify the orientation penalty') 425 | parser.add_argument('--multi_step', '-mts', action='store_true', 426 | help='whether to use the multi-step environment') 427 | parser.add_argument('--enum_orientation', '-eor', action='store_true', 428 | help='whether to limit orientation range') 429 | parser.add_argument('--proportional_local_reward', '-plr', action='store_true',) 430 | 431 | args = parser.parse_args() 432 | 433 | # We make these arguments always True 434 | args.load_obstacles = args.random_target = args.normalize_advantage = args.simple_orientation = args.multi_step = \ 435 | args.enum_orientation = args.proportional_local_reward = args.factored = True 436 | 437 | logging.basicConfig(level=logging.INFO) 438 | 439 | simple_test = th.cuda.device_count() < 2 440 | main(args, simple_test) 441 | -------------------------------------------------------------------------------- /model/contrastive_cmi.py: -------------------------------------------------------------------------------- 1 | raise NotImplementedError 2 | # need to fix the obs dimension error 3 | # grep -nrw "action_part_dim + 1" 4 | 5 | import os 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from model.contrastive import Contrastive 13 | from model.inference_utils import reset_layer, forward_network, forward_network_batch 14 | 15 | import ipdb 16 | 17 | 18 | class ContrastiveCMI(Contrastive): 19 | def __init__(self, encoder, decoder, params): 20 | 21 | #initialize hard-coded variable 22 | self.action_part_dim = params.action_part_dim 23 | self.reward_dim = params.reward_dim 24 | 25 | self.cmi_params = params.contrastive_params.cmi_params 26 | self.init_graph(params, encoder) 27 | # TODO: change the init_graph 28 | super(ContrastiveCMI, self).__init__(encoder, decoder, params) 29 | self.aggregation = self.cmi_params.aggregation 30 | self.train_all_masks = self.cmi_params.train_all_masks 31 | self.mask_opt_freq = self.cmi_params.mask_opt_freq 32 | self.full_opt_freq = self.cmi_params.full_opt_freq 33 | self.causal_opt_freq = self.cmi_params.causal_opt_freq 34 | 35 | replay_buffer_params = params.training_params.replay_buffer_params 36 | self.parallel_sample = replay_buffer_params.prioritized_buffer and replay_buffer_params.parallel_sample 37 | 38 | self.update_num = 0 39 | 40 | 41 | 42 | def init_model(self): 43 | params = self.params 44 | cmi_params = self.cmi_params 45 | self.learn_bo = learn_bo = cmi_params.learn_bo 46 | self.dot_product_energy = dot_product_energy = cmi_params.dot_product_energy 47 | 48 | # model params 49 | continuous_state = self.continuous_state 50 | if not continuous_state: 51 | raise NotImplementedError 52 | 53 | # TODO 54 | # This is a bit hacky: include a network. 55 | # Another thing is that there should actually be n copies of these 56 | if params.domain == "igibson": 57 | lidar_shape = [220, 1] 58 | lidar_out = 128 59 | n_input_channels = lidar_shape[1] # channel last 60 | self.obs_extractor = [] 61 | for _ in range(params.reward_dim): 62 | cnn = nn.Sequential( 63 | nn.Conv1d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), 64 | nn.ReLU(), 65 | nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=0), 66 | nn.ReLU(), 67 | nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=0), 68 | nn.ReLU(), 69 | nn.Flatten(), 70 | ) 71 | test_tensor = torch.zeros([lidar_shape[1], lidar_shape[0]]) 72 | with torch.no_grad(): 73 | n_flatten = cnn(test_tensor[None]).shape[1] 74 | fc = nn.Sequential(nn.Linear(n_flatten, lidar_out), nn.ReLU()) 75 | self.obs_extractor.append(nn.Sequential(cnn, fc).to(params.device)) 76 | elif params.domain == "minigrid": 77 | # Define image embedding 78 | img_shape = [3, 7, 7] 79 | img_out = 128 80 | self.obs_extractor = [] 81 | for _ in range(params.reward_dim): 82 | cnn = nn.Sequential( 83 | nn.Conv2d(3, 16, (2, 2)), 84 | nn.ReLU(), 85 | nn.MaxPool2d((2, 2)), 86 | nn.Conv2d(16, 32, (2, 2)), 87 | nn.ReLU(), 88 | nn.Conv2d(32, 64, (2, 2)), 89 | nn.ReLU(), 90 | nn.Flatten(), 91 | ) 92 | test_tensor = torch.zeros(img_shape) 93 | with torch.no_grad(): 94 | n_flatten = cnn(test_tensor[None]).shape[1] 95 | fc = nn.Sequential(nn.Linear(n_flatten, img_out), nn.ReLU()) 96 | self.obs_extractor.append(nn.Sequential(cnn, fc).to(params.device)) 97 | else: 98 | raise NotImplementedError 99 | 100 | 101 | action_part_dim = self.action_part_dim 102 | reward_dim = self.reward_dim 103 | ar_dim = action_part_dim * reward_dim 104 | 105 | ### from here 106 | 107 | self.action_part_feature_weights = nn.ParameterList() 108 | self.action_part_feature_biases = nn.ParameterList() 109 | self.reward_feature_feature_weights = nn.ParameterList() 110 | self.reward_feature_feature_biases = nn.ParameterList() 111 | # self.delta_state_feature_weights = nn.ParameterList() 112 | # self.delta_state_feature_biases = nn.ParameterList() 113 | 114 | self.energy_weights = nn.ParameterList() 115 | self.energy_biases = nn.ParameterList() 116 | self.cond_energy_weights = nn.ParameterList() 117 | self.cond_energy_biases = nn.ParameterList() 118 | 119 | self.sa_encoder_weights = nn.ParameterList() 120 | self.sa_encoder_biases = nn.ParameterList() 121 | self.d_encoder_weights = nn.ParameterList() 122 | self.d_encoder_biases = nn.ParameterList() 123 | self.cond_sa_encoder_weights = nn.ParameterList() 124 | self.cond_sa_encoder_biases = nn.ParameterList() 125 | 126 | # state feature extractor 127 | in_dim = params.ind_action_dim * self.num_observation_steps 128 | for out_dim in cmi_params.feature_fc_dims: 129 | self.action_part_feature_weights.append(nn.Parameter(torch.zeros(ar_dim, in_dim, out_dim))) 130 | self.action_part_feature_biases.append(nn.Parameter(torch.zeros(ar_dim, 1, out_dim))) 131 | in_dim = out_dim 132 | 133 | # delta state feature extractor 134 | in_dim = params.ind_reward_dim 135 | for out_dim in cmi_params.feature_fc_dims: 136 | self.reward_feature_feature_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, out_dim))) 137 | self.reward_feature_feature_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, out_dim))) 138 | in_dim = out_dim 139 | 140 | if dot_product_energy: 141 | # sa_feature encoder 142 | in_dim = cmi_params.feature_fc_dims[-1] 143 | for out_dim in cmi_params.enery_fc_dims: 144 | self.sa_encoder_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, out_dim))) 145 | self.sa_encoder_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, out_dim))) 146 | in_dim = out_dim 147 | 148 | # delta feature encoder 149 | in_dim = cmi_params.feature_fc_dims[-1] 150 | for out_dim in cmi_params.enery_fc_dims: 151 | self.d_encoder_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, out_dim))) 152 | self.d_encoder_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, out_dim))) 153 | in_dim = out_dim 154 | 155 | # conditional sa_feature encoder 156 | if learn_bo: 157 | in_dim = 2 * cmi_params.feature_fc_dims[-1] 158 | for out_dim in cmi_params.enery_fc_dims: 159 | self.cond_sa_encoder_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, out_dim))) 160 | self.cond_sa_encoder_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, out_dim))) 161 | in_dim = out_dim 162 | else: 163 | # energy 164 | in_dim = 2 * cmi_params.feature_fc_dims[-1] 165 | for out_dim in cmi_params.enery_fc_dims: 166 | self.energy_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, out_dim))) 167 | self.energy_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, out_dim))) 168 | in_dim = out_dim 169 | self.energy_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, 1))) 170 | self.energy_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, 1))) 171 | 172 | if learn_bo: 173 | # conditional energy 174 | in_dim = 3 * cmi_params.feature_fc_dims[-1] 175 | for out_dim in cmi_params.enery_fc_dims: 176 | self.cond_energy_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, out_dim))) 177 | self.cond_energy_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, out_dim))) 178 | in_dim = out_dim 179 | self.cond_energy_weights.append(nn.Parameter(torch.zeros(reward_dim, in_dim, 1))) 180 | self.cond_energy_biases.append(nn.Parameter(torch.zeros(reward_dim, 1, 1))) 181 | 182 | # TODO 183 | training_masks = [] 184 | for i in range(action_part_dim): 185 | training_masks.append(self.get_eval_mask((1,), i)) 186 | 187 | # 1st feature_dim: variable to predict, 2nd feature_dim: input variable to ignore 188 | training_masks = torch.stack(training_masks, dim=2) # (1, feature_dim, feature_dim, feature_dim + 1) 189 | self.training_masks = training_masks.view(reward_dim, action_part_dim, action_part_dim + 1, 1, 1) 190 | 191 | def reset_params(self): 192 | # feature_dim = self.feature_dim 193 | module_weights = [self.action_part_feature_weights, 194 | self.reward_feature_feature_weights, 195 | # self.delta_state_feature_weights, 196 | self.energy_weights, 197 | self.cond_energy_weights, 198 | self.sa_encoder_weights, 199 | self.d_encoder_weights, 200 | self.cond_sa_encoder_weights] 201 | module_biases = [self.action_part_feature_biases, 202 | self.reward_feature_feature_biases, 203 | # self.delta_state_feature_biases, 204 | self.energy_biases, 205 | self.cond_energy_biases, 206 | self.sa_encoder_biases, 207 | self.d_encoder_biases, 208 | self.cond_sa_encoder_biases] 209 | for weights, biases in zip(module_weights, module_biases): 210 | for w, b in zip(weights, biases): 211 | assert w.ndim == b.ndim == 3 212 | for i in range(w.shape[0]): 213 | reset_layer(w[i], b[i]) 214 | 215 | def init_graph(self, params, encoder): 216 | 217 | device = params.device 218 | self.CMI_threshold = self.cmi_params.CMI_threshold 219 | 220 | # feature_dim = encoder.feature_dim 221 | action_part_dim = self.action_part_dim 222 | reward_dim = self.reward_dim 223 | 224 | # TODO: Figure out diag_mask. and what happened in our experiment 225 | # # used for masking diagonal elements 226 | # self.diag_mask = torch.eye(reward_dim, action_part_dim + 1, dtype=torch.bool, device=device) 227 | self.diag_mask = torch.zeros(reward_dim, action_part_dim + 1, dtype=torch.bool, device=device) 228 | self.diag_mask[:, action_part_dim:] = True 229 | self.mask_CMI = torch.ones(reward_dim, action_part_dim + 1, device=device) * self.CMI_threshold 230 | self.mask = torch.ones(reward_dim, action_part_dim + 1, dtype=torch.bool, device=device) 231 | 232 | # modified from extract_action_feature 233 | def extract_observation_feature(self, obs): 234 | """ 235 | Modified: takes in a ladar scan and output an obs feature 236 | (feature_dim) * bs * obs_dim 237 | """ 238 | 239 | multi_bs = self.parallel_sample and self.training 240 | if not multi_bs: 241 | obs = obs.unsqueeze(dim=0) 242 | obs_shape = obs.shape[2:] 243 | obs_dim = len(obs_shape) 244 | dummy_axis = obs_dim * [-1] 245 | obs = obs.expand(self.reward_dim, -1, *dummy_axis) 246 | obs_features = [] 247 | for i in range(self.reward_dim): 248 | obs_features.append(self.obs_extractor[i](obs[i])) 249 | 250 | obs_features = torch.stack(obs_features) 251 | 252 | dim_out = obs_features.shape[-1] 253 | obs_features = obs_features.reshape([self.reward_dim, 1, -1, dim_out]) 254 | return obs_features 255 | 256 | # modified from extract_delta_state_feature 257 | def extract_reward_feature(self, rewards): 258 | """ 259 | :param rewards: 260 | if state space is continuous: (bs, num_samples, reward_dim). 261 | else: [(bs, num_samples, feature_i_dim)] * feature_dim 262 | notice that bs must be 1D 263 | :return: (reward_dim, bs, num_samples, out_dim) 264 | """ 265 | 266 | reward_dim = self.reward_dim 267 | if self.continuous_state: 268 | bs, num_samples, _ = rewards.shape 269 | x = rewards.view(-1, reward_dim).T # (feature_dim, bs * num_samples) 270 | x = x.unsqueeze(dim=-1) # (feature_dim, bs * num_samples, 1) 271 | else: 272 | raise NotImplementedError 273 | 274 | reward_feature = forward_network(x, self.reward_feature_feature_weights, self.reward_feature_feature_biases) 275 | reward_feature = reward_feature.view(reward_dim, bs, num_samples, -1) 276 | return reward_feature 277 | 278 | # modified from extract_state_feature 279 | def extract_action_part_feature(self, actions): 280 | """ 281 | :param actions: 282 | if state space is continuous: (bs, num_observation_steps, feature_dim). 283 | else: [(bs, num_observation_steps, feature_i_dim)] * feature_dim 284 | notice that bs must be 1D 285 | :return: (feature_dim, feature_dim, bs, out_dim), 286 | the first feature_dim is each state variable at next time step to predict, the second feature_dim are 287 | inputs (all current state variables) for the prediction 288 | """ 289 | action_part_dim = self.action_part_dim 290 | reward_dim = self.reward_dim 291 | 292 | if self.continuous_state: 293 | if self.parallel_sample and self.training: 294 | bs = actions.shape[1] 295 | x = actions.permute(0, 3, 1, 2) # (feature_dim, feature_dim, bs, num_observation_steps) 296 | else: 297 | bs = actions.shape[0] 298 | x = actions.permute(2, 0, 1).unsqueeze(dim=0) # (1, action_part_dim, bs, num_observation_steps) 299 | x = x.repeat(reward_dim, 1, 1, 1) # (reward_dim, action_part_dim, bs, num_observation_steps) 300 | x = x.reshape(action_part_dim * reward_dim, bs, -1) # (feature_dim * feature_dim, bs, 1) 301 | else: 302 | raise NotImplementedError 303 | 304 | actions_feature = forward_network(x, self.action_part_feature_weights, self.action_part_feature_biases) 305 | actions_feature = actions_feature.view(reward_dim, action_part_dim, bs, -1) # TODO 306 | return actions_feature 307 | 308 | @staticmethod 309 | def dot_product(sa_encoding, delta_encoding): 310 | """ 311 | compute the dot product between sa_encoding and delta_encoding 312 | :param sa_encoding: (feature_dim, bs, encoding_dim) or (feature_dim, feature_dim, bs, encoding_dim), 313 | notice that bs must be 1D 314 | :param delta_encoding: (feature_dim, bs, num_samples, encoding_dim), global feature used for prediction, 315 | notice that bs must be 1D 316 | :return: energy: (bs, num_samples, feature_dim) 317 | """ 318 | # (feature_dim, bs, 1, out_dim) or (feature_dim, feature_dim, bs, 1, out_dim) 319 | sa_encoding = sa_encoding.unsqueeze(dim=-2) 320 | 321 | if sa_encoding.ndim == 5: 322 | num_samples = delta_encoding.shape[-2] 323 | if num_samples < 5000: 324 | delta_encoding = delta_encoding.unsqueeze(dim=1) # (feature_dim, 1, bs, num_samples, out_dim) 325 | energy = (sa_encoding * delta_encoding).sum(dim=-1) # (feature_dim, feature_dim, bs, num_samples) 326 | else: 327 | # likely to have out of memory issue, so need to compute energy in batch 328 | energy = [] 329 | for sa_encoding in torch.unbind(sa_encoding, dim=1): 330 | energy.append((sa_encoding * delta_encoding).sum(dim=-1)) 331 | energy = torch.stack(energy, dim=1) 332 | energy = energy.permute(2, 3, 0, 1) # (bs, num_samples, feature_dim, feature_dim) 333 | else: 334 | energy = (sa_encoding * delta_encoding).sum(dim=-1) # (feature_dim, bs, num_samples) 335 | energy = energy.permute(1, 2, 0) # (bs, num_samples, feature_dim) 336 | 337 | return energy 338 | 339 | def compute_energy_dot(self, sa_feature, delta_feature, full_sa_feature=None): 340 | """ 341 | compute the conditional energy from the conditional and total state-action-delta features 342 | :param sa_feature: (feature_dim, bs, sa_feature_dim) or (feature_dim, feature_dim, bs, sa_feature_dim), 343 | notice that bs must be 1D 344 | :param delta_feature: (feature_dim, bs, num_samples, delta_feature_dim), global feature used for prediction, 345 | notice that bs must be 1D 346 | :param full_sa_feature: (feature_dim, bs, sa_feature_dim), 347 | notice that bs must be 1D 348 | :return: energy: (bs, num_samples, feature_dim) 349 | """ 350 | reward_dim, bs, num_samples, delta_feature_dim = delta_feature.shape 351 | action_part_dim = self.action_part_dim 352 | # (feature_dim, bs * num_samples, delta_feature_dim) 353 | delta_feature = delta_feature.view(reward_dim, bs * num_samples, -1) 354 | assert sa_feature.ndim in [3, 4] 355 | is_mask_feature = sa_feature.ndim == 4 356 | 357 | # (feature_dim, bs * num_samples, out_dim) 358 | delta_encoding = forward_network(delta_feature, self.d_encoder_weights, self.d_encoder_biases) 359 | # (feature_dim, bs, num_samples, out_dim) 360 | delta_encoding = delta_encoding.view(reward_dim, bs, num_samples, -1) 361 | 362 | if is_mask_feature: 363 | 364 | # (feature_dim, feature_dim * bs, sa_feature_dim) 365 | sa_feature = sa_feature.view(reward_dim, action_part_dim * bs, -1) 366 | 367 | # (feature_dim, bs, out_dim) or (feature_dim, feature_dim * bs, out_dim) 368 | sa_encoding = forward_network(sa_feature, self.sa_encoder_weights, self.sa_encoder_biases) 369 | 370 | if is_mask_feature: 371 | # (feature_dim, feature_dim, bs, out_dim) 372 | sa_encoding = sa_encoding.view(reward_dim, action_part_dim, bs, -1) 373 | 374 | # (bs, num_samples, feature_dim) or (bs, num_samples, feature_dim, feature_dim) 375 | energy = self.dot_product(sa_encoding, delta_encoding) 376 | 377 | if full_sa_feature is None: 378 | return energy 379 | 380 | if not self.learn_bo: 381 | return energy, torch.zeros_like(energy) 382 | 383 | if is_mask_feature: 384 | # (feature_dim, feature_dim, bs, out_dim) 385 | ipdb.set_trace() 386 | # TODO: this part needs to be double checked 387 | raise NotImplementedError 388 | full_sa_feature = full_sa_feature.unsqueeze(dim=1).expand(-1, action_part_dim, -1, -1) 389 | # (feature_dim, feature_dim * bs, out_dim) 390 | full_sa_feature = full_sa_feature.reshape(reward_dim, action_part_dim * bs, -1) 391 | # (feature_dim, feature_dim * bs, 2 * out_dim) 392 | cond_sa_feature = torch.cat([sa_feature, full_sa_feature], dim=-1) 393 | else: 394 | # (feature_dim, bs, 2 * out_dim) 395 | cond_sa_feature = torch.cat([sa_feature, full_sa_feature], dim=-1) 396 | 397 | # (feature_dim, bs, out_dim) or (feature_dim, feature_dim * bs, out_dim) 398 | cond_sa_encoding = forward_network(cond_sa_feature, self.cond_sa_encoder_weights, self.sa_encoder_biases) 399 | 400 | if is_mask_feature: 401 | # (feature_dim, feature_dim, bs, out_dim) 402 | cond_sa_encoding = cond_sa_encoding.view(reward_dim, action_part_dim, bs, -1) 403 | 404 | # (bs, num_samples, feature_dim) or (bs, num_samples, feature_dim, feature_dim) 405 | cond_energy = self.dot_product(cond_sa_encoding, delta_encoding) 406 | 407 | return energy, cond_energy 408 | 409 | @staticmethod 410 | def unsqueeze_expand_tensor(tensor, dim, expand_size): 411 | tensor = tensor.unsqueeze(dim=dim) 412 | expand_sizes = [-1] * tensor.ndim 413 | expand_sizes[dim] = expand_size 414 | tensor = tensor.expand(*expand_sizes) 415 | return tensor 416 | 417 | def net(self, sa_feature, delta_feature, weights, biases, full_sa_feature=None): 418 | is_mask_feature = sa_feature.ndim == 5 419 | sa_feature_dim = sa_feature.shape[-1] 420 | feature_dim, bs, num_samples, delta_feature_dim = delta_feature.shape 421 | 422 | if is_mask_feature and num_samples >= 1024: 423 | energy = [] 424 | for sa_feature_i in torch.unbind(sa_feature, dim=1): 425 | if full_sa_feature is None: 426 | sad_feature = torch.cat([sa_feature_i, delta_feature], dim=-1) 427 | else: 428 | sad_feature = torch.cat([sa_feature_i, full_sa_feature, delta_feature], dim=-1) 429 | sad_feature_dim = sad_feature.shape[-1] 430 | 431 | # (feature_dim, bs * num_samples, sad_feature_dim) 432 | sad_feature = sad_feature.view(feature_dim, -1, sad_feature_dim) 433 | 434 | # (feature_dim, bs * num_samples, 1) 435 | energy_i = forward_network(sad_feature, weights, biases) 436 | energy.append(energy_i) 437 | energy = torch.stack(energy, dim=1) # (feature_dim, feature_dim, bs * num_samples, 1) 438 | energy = energy.view(feature_dim, -1, 1) 439 | else: 440 | if is_mask_feature: 441 | # (feature_dim, feature_dim, bs, num_samples, delta_feature_dim) 442 | delta_feature = self.unsqueeze_expand_tensor(delta_feature, 1, feature_dim) 443 | 444 | if full_sa_feature is None: 445 | sad_feature = torch.cat([sa_feature, delta_feature], dim=-1) 446 | else: 447 | if is_mask_feature: 448 | full_sa_feature = self.unsqueeze_expand_tensor(full_sa_feature, 1, feature_dim) 449 | sad_feature = torch.cat([sa_feature, full_sa_feature, delta_feature], dim=-1) 450 | 451 | sad_feature_dim = sad_feature.shape[-1] 452 | 453 | # (feature_dim, bs * num_samples, sad_feature_dim) or 454 | # (feature_dim, feature_dim * bs * num_samples, sad_feature_dim) 455 | sad_feature = sad_feature.view(feature_dim, -1, sad_feature_dim) 456 | 457 | # (feature_dim, bs * num_samples, 1) or (feature_dim, feature_dim * bs * num_samples, 1) 458 | energy = forward_network(sad_feature, weights, biases) 459 | 460 | if is_mask_feature: 461 | energy = energy.view(feature_dim, feature_dim, bs, num_samples) 462 | energy = energy.permute(2, 3, 0, 1) # (bs, num_samples, feature_dim, feature_dim) 463 | else: 464 | energy = energy.view(feature_dim, bs, num_samples) # (feature_dim, bs, num_samples) 465 | energy = energy.permute(1, 2, 0) # (bs, num_samples, feature_dim) 466 | 467 | return energy 468 | 469 | def compute_energy_net(self, sa_feature, delta_feature, full_sa_feature=None): 470 | """ 471 | compute the conditional energy from the conditional and total state-action-delta features 472 | :param sa_feature: (feature_dim, bs, sa_feature_dim) or (feature_dim, feature_dim, bs, sa_feature_dim), 473 | notice that bs must be 1D 474 | :param delta_feature: (feature_dim, bs, num_samples, delta_feature_dim), global feature used for prediction, 475 | notice that bs must be 1D 476 | :param full_sa_feature: (feature_dim, bs, sa_feature_dim), 477 | notice that bs must be 1D 478 | :return: energy: (bs, num_samples, feature_dim) or (bs, num_samples, feature_dim, feature_dim) 479 | """ 480 | 481 | assert sa_feature.ndim in [3, 4] 482 | is_mask_feature = sa_feature.ndim == 4 483 | 484 | sa_feature_dim = sa_feature.shape[-1] 485 | feature_dim, bs, num_samples, delta_feature_dim = delta_feature.shape 486 | 487 | # (feature_dim, bs, num_samples, sa_feature_dim) or (feature_dim, feature_dim, bs, num_samples, sa_feature_dim) 488 | sa_feature = self.unsqueeze_expand_tensor(sa_feature, -2, num_samples) 489 | 490 | # (bs, num_samples, feature_dim) or (bs, num_samples, feature_dim, feature_dim) 491 | energy = self.net(sa_feature, delta_feature, self.energy_weights, self.energy_biases) 492 | 493 | if full_sa_feature is None: 494 | return energy 495 | 496 | if not self.learn_bo: 497 | return energy, torch.zeros_like(energy) 498 | 499 | # (feature_dim, bs, num_samples, sa_feature_dim) or (feature_dim, feature_dim, bs, num_samples, sa_feature_dim) 500 | full_sa_feature = self.unsqueeze_expand_tensor(full_sa_feature, -2, num_samples) 501 | 502 | # (bs, num_samples, feature_dim) or (bs, num_samples, feature_dim, feature_dim) 503 | cond_energy = self.net(sa_feature, delta_feature, self.cond_energy_weights, self.cond_energy_biases, full_sa_feature) 504 | 505 | return energy, cond_energy 506 | 507 | def compute_energy(self, sa_feature, delta_feature, full_sa_feature=None): 508 | """ 509 | compute the conditional energy from the conditional and total state-action-delta features 510 | :param sa_feature: (feature_dim, bs, sa_feature_dim) or (feature_dim, feature_dim, bs, sa_feature_dim), 511 | notice that bs must be 1D 512 | :param delta_feature: (feature_dim, bs, num_samples, delta_feature_dim), global feature used for prediction, 513 | notice that bs must be 1D 514 | :param full_sa_feature: (feature_dim, bs, sa_feature_dim), 515 | notice that bs must be 1D 516 | :return: energy: (bs, num_samples, feature_dim) or (bs, num_samples, feature_dim, feature_dim) 517 | """ 518 | if self.dot_product_energy: 519 | return self.compute_energy_dot(sa_feature, delta_feature, full_sa_feature) 520 | else: 521 | return self.compute_energy_net(sa_feature, delta_feature, full_sa_feature) 522 | 523 | def forward_step(self, actions, obss, rewards, forward_mode=("full", "mask", "causal")): 524 | """ 525 | # TODO: change these extract features function. 526 | # Figure our full energy first. 527 | # Replace delta_feature by reward, replace sa features by oa features 528 | compute energy for the following combinations 529 | if using (1) next_feature + neg_delta_features for training 530 | a. feature from randomly masking one state variable + conditional feature from all variables 531 | b. feature from all variables 532 | c. feature from causal parents + conditional feature from all variable (? probably for eval only) 533 | elif using (2) pred_delta_features for evaluation 534 | a. feature from causal parents + conditional feature from all variable (? probably for eval only) 535 | :param features: 536 | if state space is continuous: (bs, num_observation_steps, feature_dim). 537 | else: NotImplementedError 538 | notice that bs must be 1D 539 | :param action: (bs, action_dim) 540 | :param delta_features: 541 | if observation space is continuous: (bs, num_samples, feature_dim). 542 | else: NotImplementedError 543 | :param forward_mode: which energy to compute 544 | :return: 545 | energy 546 | for training, (bs, 1 + num_negative_samples, feature_dim) 547 | for evaluation, (bs, num_pred_samples, feature_dim) 548 | """ 549 | bs, _, feature_dim = rewards.shape 550 | reward_dim = self.reward_dim 551 | action_part_dim = self.action_part_dim 552 | 553 | obs_feature = self.extract_observation_feature(obss) 554 | actions_feature = self.extract_action_part_feature(actions) 555 | reward_feature = self.extract_reward_feature(rewards) 556 | ao_feature = torch.cat([actions_feature, obs_feature], dim=1) # (reward_dim, action_dim + 1, bs, out_dim) 557 | 558 | if self.aggregation == "max": 559 | full_ao_feature, _ = ao_feature.max(dim=1) # (feature_dim, bs, out_dim) 560 | elif self.aggregation == "mean": 561 | full_ao_feature = ao_feature.mean(dim=1) # (feature_dim, bs, out_dim) 562 | else: 563 | raise NotImplementedError 564 | 565 | # (bs, num_samples, feature_dim) 566 | full_energy = mask_energy = mask_cond_energy = causal_energy = causal_cond_energy = None 567 | 568 | if "full" in forward_mode: 569 | full_energy = self.compute_energy(full_ao_feature, reward_feature) # (bs, num_samples, feature_dim) 570 | 571 | if "mask" in forward_mode: 572 | mask_sa_feature = ao_feature.clone() # (feature_dim, feature_dim + 1, bs, out_dim) 573 | if self.train_all_masks or not self.training: 574 | mask = self.training_masks # (feature_dim, feature_dim, feature_dim + 1, 1, 1) 575 | mask_sa_feature = mask_sa_feature.unsqueeze(dim=1) # (feature_dim, 1, feature_dim + 1, bs, out_dim) 576 | else: 577 | mask = self.get_training_mask(bs) # (bs, feature_dim, feature_dim + 1) 578 | mask = torch.permute(mask, (1, 2, 0)) # (feature_dim, feature_dim + 1, bs) 579 | mask = mask.unsqueeze(dim=-1) # (feature_dim, feature_dim + 1, bs, 1) 580 | 581 | # (feature_dim, feature_dim, feature_dim + 1, bs, out_dim) or (feature_dim, feature_dim + 1, bs, out_dim) 582 | mask_sa_feature = mask_sa_feature * mask 583 | 584 | # (feature_dim, feature_dim, bs, out_dim) or (feature_dim, bs, out_dim) 585 | if self.aggregation == "max": 586 | mask_sa_feature, _ = mask_sa_feature.max(dim=-3) 587 | elif self.aggregation == "mean": 588 | raise NotImplementedError 589 | mask_sa_feature = mask_sa_feature.sum(dim=-3) / feature_dim 590 | else: 591 | raise NotImplementedError 592 | 593 | # (bs, num_samples, feature_dim) or (bs, num_samples, feature_dim, feature_dim) 594 | mask_energy, mask_cond_energy = self.compute_energy(mask_sa_feature, reward_feature, full_ao_feature) 595 | 596 | if "causal" in forward_mode: 597 | causal_sa_feature = ao_feature.clone() # (feature_dim, feature_dim + 1, bs, out_dim) 598 | causal_mask = self.mask.detach().view(reward_dim, action_part_dim + 1, 1, 1) 599 | causal_sa_feature = causal_sa_feature * causal_mask # (feature_dim, feature_dim + 1, bs, out_dim) 600 | 601 | if self.aggregation == "max": 602 | causal_sa_feature, _ = causal_sa_feature.max(dim=1) # (feature_dim, bs, out_dim) 603 | elif self.aggregation == "mean": 604 | num_parents = causal_mask.sum(dim=1) 605 | causal_sa_feature = causal_sa_feature.sum(dim=1) / num_parents # (feature_dim, bs, out_dim) 606 | else: 607 | raise NotImplementedError 608 | 609 | # (bs, num_samples, feature_dim) 610 | causal_energy, causal_cond_energy = self.compute_energy(causal_sa_feature, reward_feature, 611 | full_ao_feature) 612 | 613 | return full_energy, mask_energy, mask_cond_energy, causal_energy, causal_cond_energy 614 | 615 | 616 | def forward_with_feature(self, actions, obss, rewards, 617 | forward_mode=("full", "mask", "causal")): 618 | """ 619 | :param actions: 620 | if observation space is continuous: (bs, num_observation_steps, feature_dim). 621 | else: NotImplementedError 622 | notice that bs can be a multi-dimensional batch size 623 | :param obss: (bs, num_pred_steps, action_dim) 624 | :param rewards: 625 | if observation space is continuous: (bs, num_pred_steps, feature_dim). 626 | else: NotImplementedError 627 | :param neg_delta_features: 628 | if observation space is continuous: (bs, num_pred_steps, num_negative_samples, feature_dim). 629 | else: NotImplementedError 630 | :param forward_mode: which energy to compute 631 | :return: energy: 632 | if observation space is continuous: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) 633 | else: NotImplementedError 634 | """ 635 | num_observation_steps, _ = actions.shape[-2:] 636 | bs = actions.shape[:-2] 637 | if self.parallel_sample and self.training: 638 | bs = bs[1:] 639 | reward_dim = self.reward_dim 640 | action_part_dim = self.action_part_dim 641 | 642 | # This is how we should proceed 643 | assert (num_observation_steps == 1) 644 | 645 | num_pred_steps = num_observation_steps 646 | 647 | 648 | 649 | if num_pred_steps > 1: 650 | raise NotImplementedError 651 | 652 | # This never gets executed?... Ok still need this during inference 653 | flatten_bs = len(bs) > 1 654 | if flatten_bs: 655 | import ipdb 656 | ipdb.set_trace() 657 | 658 | if self.parallel_sample and self.training: 659 | obss = obss[:, :, 0, ...] # (bs, ladar, n_channels) 660 | else: 661 | obss = obss[:, 0, ...] # (bs, ladar, n_channels) 662 | obss.requires_grad = True 663 | 664 | actions.requires_grad = True 665 | delta_feature = rewards.detach() # (bs, 1, feature_dim) 666 | 667 | delta_feature = delta_feature[..., 0, :] 668 | 669 | if self.parallel_sample and self.training: 670 | eye = torch.eye(reward_dim, device=self.device).unsqueeze(dim=-2) 671 | delta_feature = (delta_feature * eye).sum(dim=-1).T 672 | 673 | delta_feature = delta_feature.unsqueeze(dim=-2) # (bs, 1, feature_dim) 674 | 675 | 676 | # sample negative delta features based on current delta fetaures 677 | # For now, treat input as array rather than dictionaries (we can fix this later) 678 | # (bs, num_pred_steps, num_negative_samples, feature_dim) 679 | neg_delta_features = self.sample_boolean_neg_feature(bs + (num_pred_steps,), delta_feature.detach()) 680 | num_negative_samples = neg_delta_features.shape[-2] 681 | 682 | # TODO: add this if using standard neg sample 683 | # neg_delta_features = neg_delta_features[:, 0] # (bs, num_negative_samples, feature_dim) 684 | delta_features = torch.cat([delta_feature, neg_delta_features], dim=-2) # (bs, 1 + num_negative_samples, feature_dim) 685 | 686 | delta_features.requires_grad = True 687 | grad_tensors = (actions, obss, delta_features) 688 | 689 | full_energy, mask_energy, mask_cond_energy, causal_energy, causal_cond_energy = \ 690 | self.forward_step(actions, obss, delta_features, forward_mode) 691 | 692 | # TODO: figure out the exact dimensions 693 | 694 | # (bs, 1, 1 + num_negative_samples, feature_dim) 695 | if "full" in forward_mode: 696 | full_energy = full_energy.view(*bs, num_pred_steps, 1 + num_negative_samples, reward_dim) 697 | if "mask" in forward_mode: 698 | if mask_energy.ndim == 4: 699 | mask_energy = mask_energy.view(*bs, num_pred_steps, 1 + num_negative_samples, reward_dim, action_part_dim) 700 | mask_cond_energy = mask_cond_energy.view(*bs, num_pred_steps, 1 + num_negative_samples, reward_dim, action_part_dim) 701 | elif mask_energy.ndim == 3: 702 | mask_energy = mask_energy.view(*bs, num_pred_steps, 1 + num_negative_samples, reward_dim) 703 | mask_cond_energy = mask_cond_energy.view(*bs, num_pred_steps, 1 + num_negative_samples, reward_dim) 704 | else: 705 | raise NotImplementedError 706 | if "causal" in forward_mode: 707 | causal_energy = causal_energy.view(*bs, num_pred_steps, 1 + num_negative_samples, reward_dim) 708 | causal_cond_energy = causal_cond_energy.view(*bs, num_pred_steps, 1 + num_negative_samples, reward_dim) 709 | 710 | return full_energy, mask_energy, mask_cond_energy, causal_energy, causal_cond_energy, grad_tensors 711 | 712 | # We assume that args are passed in as tensor instead of dictionary 713 | def forward(self, actions, obss, rewards, forward_mode=("full", "mask", "causal")): 714 | # features = self.get_feature(actions) 715 | # next_features = self.get_feature(rewards) 716 | return self.forward_with_feature(actions, obss, rewards, forward_mode) 717 | 718 | # modified to fit our scheme 719 | # We should never mask out the last dimension 720 | def get_mask_by_id(self, mask_ids): 721 | """ 722 | :param mask_ids: (bs feature_dim), idxes of state variable to drop 723 | notice that bs can be a multi-dimensional batch size 724 | :return: (bs, feature_dim, feature_dim + 1), bool mask of state variables to use 725 | """ 726 | int_mask = F.one_hot(mask_ids, self.action_part_dim + 1) 727 | bool_mask = int_mask < 1 728 | return bool_mask 729 | 730 | def get_training_mask(self, bs): 731 | # uniformly select one state variable to omit when predicting the next time step value 732 | if isinstance(bs, int): 733 | bs = (bs,) 734 | 735 | idxes = torch.randint(self.action_part_dim, bs + (self.reward_dim,), device=self.device) 736 | return self.get_mask_by_id(idxes) # (bs, feature_dim, feature_dim + 1) 737 | 738 | # Modified so that matches our desired dimension 739 | def get_eval_mask(self, bs, i): 740 | # omit i-th state variable or the action when predicting the next time step value 741 | 742 | if isinstance(bs, int): 743 | bs = (bs,) 744 | 745 | feature_dim = self.reward_dim 746 | idxes = torch.full(size=bs + (feature_dim,), fill_value=i, dtype=torch.int64, device=self.device) 747 | 748 | # # We don't need any of these: we simply mask out the given i 749 | # # this is quite hacky tbh 750 | # # each state variable must depend on itself when predicting the next time step value 751 | # self_mask = torch.arange(feature_dim, device=self.device) 752 | # idxes[idxes >= self_mask] += 1 753 | 754 | return self.get_mask_by_id(idxes) # (bs, feature_dim, feature_dim + 1) 755 | 756 | def bo_loss(self, energy, cond_energy): 757 | """ 758 | :param energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) or 759 | (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 760 | :param cond_energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) or 761 | (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 762 | :return: 763 | loss: scalar 764 | """ 765 | return self.nce_loss(energy.detach() + cond_energy) 766 | 767 | @staticmethod 768 | def energy_norm_loss(energy): 769 | """ 770 | :param energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) or 771 | (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 772 | :return: 773 | loss: scalar 774 | """ 775 | if energy.ndim == 4: 776 | energy_sq = (energy ** 2).sum(dim=(-3, -1)).mean() 777 | energy_abs = energy.abs().sum(dim=(-3, -1)).mean() 778 | elif energy.ndim == 5: 779 | energy_sq = (energy ** 2).sum(dim=(-4, -2, -1)).mean() 780 | energy_abs = energy.abs().sum(dim=(-4, -2, -1)).mean() 781 | else: 782 | raise NotImplementedError 783 | 784 | norm_reg_coef = 1e-6 785 | return energy_sq * norm_reg_coef, energy_abs 786 | 787 | @staticmethod 788 | def energy_grad_loss(energy, tensors): 789 | """ 790 | :param energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) or 791 | (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 792 | :param tensors: a list of tensors 793 | :return: 794 | gradient: (bs, 1 + num_negative_samples, feature_dim) 795 | """ 796 | grad_reg_coef = 1e-6 797 | grad_thre = 0 798 | 799 | if energy.ndim == 4: 800 | grads = torch.autograd.grad(energy.sum(), tensors, create_graph=True) 801 | elif energy.ndim == 5: 802 | feature_dim = energy.shape[-1] 803 | grads = [torch.autograd.grad(energy[..., i].sum(), tensors, create_graph=True) 804 | for i in range(feature_dim)] 805 | grads = list(map(list, zip(*grads))) 806 | grads = [torch.stack(grad, dim=-1) for grad in grads] 807 | else: 808 | raise NotImplementedError 809 | grads_abs = 0 810 | grads_penalty = 0 811 | 812 | for tensor, grad in zip(tensors, grads): 813 | if tensor.ndim == 2: 814 | bs = tensor.shape[0] 815 | grad = grad.view(bs, -1) 816 | elif tensor.ndim == 3: 817 | bs, num_samples_or_steps, _ = tensor.shape 818 | grad = grad.view(bs * num_samples_or_steps, -1) 819 | else: 820 | raise NotImplementedError 821 | grads_abs += grad.abs().mean() 822 | grads_penalty += (F.relu(grad.abs() - grad_thre) ** 2).sum(dim=-1).mean() 823 | 824 | return grads_penalty * grad_reg_coef, grads_abs 825 | 826 | def update(self, actions, obss, rewards, eval=False): 827 | """ 828 | :param actions: {obs_i_key: (bs, num_observation_steps, obs_i_shape)} 829 | notice that bs can be a multi-dimensional batch size 830 | :param obss: (bs, num_pred_steps, obs_dim, obs_channel) 831 | :param rewards: ({obs_i_key: (bs, num_pred_steps, obs_i_shape)} 832 | :return: {"loss_name": loss_value} 833 | """ 834 | if eval: 835 | return self.update_mask(actions, obss, rewards) 836 | 837 | self.update_num += 1 838 | 839 | bs, num_pred_steps = actions.shape[:-2], actions.shape[-2] 840 | 841 | if self.parallel_sample: 842 | bs = bs[1:] 843 | 844 | forward_mode = [] 845 | opt_mask = self.mask_opt_freq > 0 and self.update_num % self.mask_opt_freq == 0 846 | opt_full = self.full_opt_freq > 0 and self.update_num % self.full_opt_freq == 0 847 | opt_causal = self.causal_opt_freq > 0 and self.update_num % self.causal_opt_freq == 0 848 | imit_full = False 849 | if opt_mask: 850 | forward_mode.append("mask") 851 | if opt_full or (opt_mask and imit_full): 852 | forward_mode.append("full") 853 | if self.use_prioritized_buffer or opt_causal: 854 | forward_mode.append("causal") 855 | 856 | # (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) 857 | full_energy, mask_energy, mask_cond_energy, causal_energy, causal_cond_energy, grad_tensors = \ 858 | self.forward(actions, obss, rewards, forward_mode) 859 | features, action, delta_features = grad_tensors 860 | 861 | grad_tensors = (delta_features,) 862 | 863 | loss = 0 864 | loss_detail = {} 865 | if "mask" in forward_mode: 866 | # mask_energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim) or 867 | # (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 868 | mask_nce_loss = self.nce_loss(mask_energy) 869 | loss_detail["mask_nce_loss"] = mask_nce_loss 870 | if opt_mask: 871 | energy_norm_loss, energy_norm = self.energy_norm_loss(mask_energy) 872 | energy_grad_loss, energy_grad = self.energy_grad_loss(mask_energy, grad_tensors) 873 | loss += mask_nce_loss + energy_norm_loss + energy_grad_loss 874 | loss_detail["mask_energy_norm"] = energy_norm 875 | loss_detail["mask_energy_grad"] = energy_grad 876 | 877 | if self.learn_bo: 878 | mask_bo_loss = self.bo_loss(mask_energy, mask_cond_energy) 879 | loss += mask_bo_loss 880 | loss_detail["mask_bo_gain"] = mask_nce_loss - mask_bo_loss 881 | 882 | if "full" in forward_mode: 883 | full_nce_loss = self.nce_loss(full_energy) 884 | loss_detail["full_nce_loss"] = full_nce_loss 885 | 886 | if opt_mask and imit_full: 887 | mask_cond_energy = full_energy.detach() - mask_energy 888 | energy_norm_loss, energy_norm = self.energy_norm_loss(mask_cond_energy) 889 | energy_grad_loss, energy_grad = self.energy_grad_loss(mask_cond_energy, grad_tensors) 890 | loss += energy_norm_loss + energy_grad_loss 891 | 892 | if opt_full: 893 | energy_norm_loss, energy_norm = self.energy_norm_loss(full_energy) 894 | energy_grad_loss, energy_grad = self.energy_grad_loss(full_energy, grad_tensors) 895 | loss += full_nce_loss + energy_norm_loss + energy_grad_loss 896 | loss_detail["full_energy_norm"] = energy_norm 897 | loss_detail["full_energy_grad"] = energy_grad 898 | 899 | if "causal" in forward_mode: 900 | causal_nce_loss = self.nce_loss(causal_energy) 901 | loss_detail["causal_nce_loss"] = causal_nce_loss 902 | 903 | if self.use_prioritized_buffer: 904 | priority = 1 - F.softmax(causal_energy, dim=-2)[..., 0, :].mean(dim=-2) # (bs, feature_dim) 905 | 906 | if self.parallel_sample: 907 | priority = priority.T 908 | else: 909 | priority = priority.mean(dim=-1) 910 | loss_detail["priority"] = priority 911 | 912 | if opt_causal: 913 | energy_norm_loss, energy_norm = self.energy_norm_loss(causal_energy) 914 | energy_grad_loss, energy_grad = self.energy_grad_loss(causal_energy, grad_tensors) 915 | loss += causal_nce_loss + energy_norm_loss + energy_grad_loss 916 | loss_detail["causal_energy_norm"] = energy_norm 917 | loss_detail["causal_energy_grad"] = energy_grad 918 | 919 | if self.learn_bo: 920 | causal_bo_loss = self.bo_loss(causal_energy, causal_cond_energy) 921 | loss += causal_bo_loss 922 | loss_detail["causal_bo_gain"] = causal_nce_loss - causal_bo_loss 923 | 924 | self.backprop(loss, loss_detail) 925 | 926 | return loss_detail 927 | 928 | @staticmethod 929 | def compute_cmi(energy, cond_energy, unbiased=True): 930 | """ 931 | https://arxiv.org/pdf/2106.13401, proposition 3 932 | :param energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 933 | notice that bs can be a multi-dimensional batch size 934 | :param cond_energy: (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 935 | :return: cmi: (feature_dim,feature_dim) (previous documentation is wrong) 936 | """ 937 | pos_cond_energy = cond_energy[..., 0, :, :] # (bs, num_pred_steps, feature_dim, feature_dim) 938 | 939 | if unbiased: 940 | K = energy.shape[-3] # num_negative_samples 941 | neg_energy = energy[..., 1:, :, :] # (bs, num_pred_steps, num_negative_samples, feature_dim, feature_dim) 942 | neg_cond_energy = cond_energy[..., 1:, :, :] # (bs, num_pred_steps, num_negative_samples, feature_dim, feature_dim) 943 | else: 944 | K = energy.shape[-3] + 1 # num_negative_samples 945 | neg_energy = energy # (bs, num_pred_steps, num_negative_samples, feature_dim, feature_dim) 946 | neg_cond_energy = cond_energy # (bs, num_pred_steps, num_negative_samples, feature_dim, feature_dim) 947 | 948 | log_w_neg = F.log_softmax(neg_energy, dim=-3) # (bs, num_pred_steps, num_negative_samples, feature_dim, feature_dim) 949 | # (bs, num_pred_steps, num_negative_samples, feature_dim, feature_dim) 950 | weighted_neg_cond_energy = np.log(K - 1) + log_w_neg + neg_cond_energy 951 | # (bs, num_pred_steps, 1 + num_negative_samples, feature_dim, feature_dim) 952 | cond_energy = torch.cat([pos_cond_energy.unsqueeze(dim=-3), weighted_neg_cond_energy], dim=-3) 953 | log_denominator = -np.log(K) + torch.logsumexp(cond_energy, dim=-3) # (bs, num_pred_steps, feature_dim, feature_dim) 954 | cmi = pos_cond_energy - log_denominator # (bs, num_pred_steps, feature_dim, feature_dim) 955 | 956 | cmi_dim = cmi.shape[-2:] 957 | cmi = cmi.sum(dim=-3).view(-1, *cmi_dim).mean(dim=0) 958 | return cmi 959 | 960 | def update_mask(self, actions, obss, rewards): 961 | """ 962 | :param actions: {obs_i_key: (bs, num_observation_steps, obs_i_shape)} 963 | notice that bs can be a multi-dimensional batch size 964 | :param obss: (bs, num_pred_steps, action_dim) 965 | :param rewards: ({obs_i_key: (bs, num_pred_steps, obs_i_shape)} 966 | :return: {"loss_name": loss_value} 967 | """ 968 | bs, num_pred_steps = actions.shape[:-2], actions.shape[-2] 969 | # feature_dim = self.feature_dim 970 | 971 | with torch.no_grad(): 972 | # features = self.encoder(actions) 973 | # next_features = self.encoder(rewards) 974 | features = actions 975 | next_features = rewards 976 | 977 | full_energy, mask_energy, mask_cond_energy, causal_energy, causal_cond_energy, _ = \ 978 | self.forward_with_feature(features, obss, next_features) 979 | 980 | mask_nce_loss = self.nce_loss(mask_energy) 981 | mask_bo_loss = self.bo_loss(mask_energy, mask_cond_energy) 982 | full_nce_loss = self.nce_loss(full_energy) 983 | causal_nce_loss = self.nce_loss(causal_energy) 984 | causal_bo_loss = self.bo_loss(causal_energy, causal_cond_energy) 985 | 986 | eval_details = {"mask_nce_loss": mask_nce_loss, 987 | "mask_bo_gain": mask_nce_loss - mask_bo_loss, 988 | "full_nce_loss": full_nce_loss, 989 | "causal_nce_loss": causal_nce_loss, 990 | "causal_bo_gain": causal_nce_loss - causal_bo_loss} 991 | 992 | if not self.learn_bo: 993 | mask_cond_energy = full_energy.unsqueeze(dim=-1) - mask_energy 994 | 995 | # energy: (bs, num_samples, reward_dim) 996 | cmi = self.compute_cmi(mask_energy, mask_cond_energy) # (reward_dim, action_dim) 997 | 998 | 999 | reward_dim = self.reward_dim 1000 | action_part_dim = self.action_part_dim 1001 | 1002 | diag = torch.ones(reward_dim, action_part_dim + 1, dtype=torch.float32, device=self.device) 1003 | diag *= self.CMI_threshold 1004 | 1005 | # # (feature_dim, feature_dim), (feature_dim, feature_dim) 1006 | # upper_tri, lower_tri = torch.triu(cmi), torch.tril(cmi, diagonal=-1) 1007 | # diag[:, 1:] += upper_tri 1008 | diag[:, :-1] = cmi 1009 | 1010 | eval_tau = self.cmi_params.eval_tau 1011 | self.mask_CMI = self.mask_CMI * eval_tau + diag * (1 - eval_tau) 1012 | self.mask = self.mask_CMI >= self.CMI_threshold 1013 | self.mask[self.diag_mask] = True 1014 | 1015 | return eval_details 1016 | 1017 | def get_mask(self): 1018 | return self.mask 1019 | 1020 | def get_adjacency(self): 1021 | return self.mask_CMI 1022 | 1023 | def save(self, path): 1024 | torch.save({"model": self.state_dict(), 1025 | "optimizer": self.optimizer.state_dict(), 1026 | "mask_CMI": self.mask_CMI, 1027 | }, path) 1028 | 1029 | def load(self, path, device): 1030 | if path is not None and os.path.exists(path): 1031 | print("contrastive loaded", path) 1032 | checkpoint = torch.load(path, map_location=device) 1033 | self.load_state_dict(checkpoint["model"]) 1034 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 1035 | self.mask_CMI = checkpoint["mask_CMI"] 1036 | self.mask = self.mask_CMI >= self.CMI_threshold 1037 | self.mask_CMI[self.diag_mask] = self.CMI_threshold 1038 | self.mask[self.diag_mask] = True 1039 | --------------------------------------------------------------------------------