├── README.md ├── environment ├── __init__.py ├── env_pomdp.py └── env_wrappers.py ├── experiments ├── toy1 │ ├── 01_train_models.py │ ├── 02_eval_models.py │ ├── 03_kallus_et_al.py │ ├── 04_plots.py │ └── config.json ├── toy2 │ ├── 01_train_models.py │ ├── 02_eval_models.py │ ├── 03_train_agents.py │ ├── 04_eval_agents.py │ ├── 05_plots.py │ └── config.json └── toy3 │ ├── 01_train_models.py │ ├── 02_eval_models.py │ ├── 03_train_agents.py │ ├── 04_eval_agents.py │ ├── 05_plots.py │ └── config.json ├── learning.py ├── models.py ├── policies.py ├── rl_agents ├── ac.py └── reinforce.py ├── stat_tests.py ├── utils.py └── utils_kallus.py /README.md: -------------------------------------------------------------------------------- 1 | # Causal Reinforcement Learning using Observational and Interventional Data 2 | 3 | ## Requirements 4 | 5 | You must have Python 3 and the following packages installed: 6 | ``` 7 | pytorch 8 | gym 9 | scipy 10 | matplotlib 11 | ``` 12 | 13 | # Toy problem 1 (door) 14 | 15 | This toy problem is configured in the following file: 16 | ``` 17 | experiments/toy1/config.json 18 | ``` 19 | 20 | To run this experiment execute the following commands: 21 | ```shell 22 | GPU = 0 # -1 for CPU 23 | 24 | for EXPERT in noisy_good perfect_good perfect_bad random strong_bad_bias strong_good_bias; do 25 | for SEED in {0..19}; do 26 | python experiments/toy1/01_train_models.py $EXPERT -s $SEED -g $GPU 27 | python experiments/toy1/02_eval_models.py $EXPERT -s $SEED -g $GPU 28 | done 29 | (Optional) python experiments/toy1/03_kallus_et_al.py $EXPERT 30 | python experiments/toy1/04_plots.py $EXPERT (Optional) --kallus=True 31 | done 32 | ``` 33 | 34 | Results are stored in the following folders: 35 | ``` 36 | experiments/ 37 | toy1/ 38 | plots/ 39 | results/ 40 | trained_models/ 41 | ``` 42 | 43 | # Toy problem 2 (tiger) 44 | 45 | This toy problem is configured in the following file: 46 | ``` 47 | experiments/toy2/config.json 48 | ``` 49 | 50 | To run this experiment execute the following commands: 51 | ```shell 52 | GPU = 0 # -1 for CPU 53 | 54 | for EXPERT in noisy_good very_good very_bad random strong_bad_bias strong_good_bias; do 55 | for SEED in {0..19}; do 56 | python experiments/toy2/01_train_models.py $EXPERT -s $SEED -g $GPU 57 | python experiments/toy2/02_eval_models.py $EXPERT -s $SEED -g $GPU 58 | python experiments/toy2/03_train_agents.py $EXPERT -s $SEED -g $GPU 59 | python experiments/toy2/04_eval_agents.py $EXPERT -s $SEED -g $GPU 60 | done 61 | python experiments/toy2/05_plots.py $EXPERT 62 | done 63 | ``` 64 | 65 | # Toy problem 3 (gridworld) 66 | 67 | This toy problem is configured in the following file: 68 | ``` 69 | experiments/toy3/config.json 70 | ``` 71 | 72 | To run this experiment execute the following commands: 73 | ```shell 74 | GPU = 0 # -1 for CPU 75 | 76 | for EXPERT in noisy_good very_good very_bad random strong_bad_bias strong_good_bias; do 77 | for SEED in {0..19}; do 78 | python experiments/toy3/01_train_models.py $EXPERT -s $SEED -g $GPU 79 | python experiments/toy3/02_eval_models.py $EXPERT -s $SEED -g $GPU 80 | python experiments/toy3/03_train_agents.py $EXPERT -s $SEED -g $GPU 81 | python experiments/toy3/04_eval_agents.py $EXPERT -s $SEED -g $GPU 82 | done 83 | python experiments/toy3/05_plots.py $EXPERT 84 | done 85 | ``` 86 | -------------------------------------------------------------------------------- /environment/__init__.py: -------------------------------------------------------------------------------- 1 | from environment.env_pomdp import PomdpEnv 2 | -------------------------------------------------------------------------------- /environment/env_pomdp.py: -------------------------------------------------------------------------------- 1 | ################################################ 2 | #################### GYM POMDP ################# 3 | ################################################ 4 | 5 | # Packages 6 | import torch 7 | import numpy as np 8 | 9 | # Gym 10 | import gym 11 | from gym import spaces 12 | 13 | class PomdpEnv(gym.Env): 14 | 15 | def __init__(self, p_s, p_or_s, p_s_sa, max_length, categorical_obs=False): 16 | 17 | """ Environment for POMDP. Needs : 18 | - p(s) initial distribution for latent variable. 19 | - p(o|s) distribution of noisy observation given latent state. 20 | - p(s|s,a) transition distribution of next latent state given current action and state. 21 | - the maximum length of a rollout. 22 | """ 23 | 24 | self.episode_length = max_length 25 | self.categorical_obs = categorical_obs 26 | 27 | # Distribution 28 | self.p_s_sa = p_s_sa 29 | self.p_s = p_s 30 | self.p_or_s = p_or_s 31 | 32 | # Initialize game indicators 33 | self.initialize_on_reset() 34 | self.action = 0 35 | self.n_rewards = p_or_s.shape[2] 36 | 37 | # Action and Obersvation Space 38 | self.action_space = spaces.Discrete(p_s_sa.shape[1]) 39 | if categorical_obs : 40 | self.observation_space = spaces.Box(low=0, high=1, shape=(p_or_s.shape[1],), dtype=np.uint8) 41 | 42 | else : 43 | self.observation_space = spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8) 44 | 45 | def initialize_on_reset(self): 46 | 47 | """ Reset the state of the environment to an initial state. """ 48 | 49 | self.current_step = 0 50 | self.score = 0 51 | self.done = False 52 | 53 | def reset(self): 54 | 55 | """ Reset Environement : 56 | 1. Draw initial hidden state from (s). 57 | 2. Get observation from joint p(o, r|s). 58 | 3. Return observation. 59 | """ 60 | 61 | # Reset the state of the environment to an initial state 62 | self.initialize_on_reset() 63 | 64 | # Draw Initial latent state S 65 | if self.categorical_obs : 66 | self.s = torch.distributions.one_hot_categorical.OneHotCategorical(probs=self.p_s,).sample() 67 | else : 68 | self.s = torch.multinomial(self.p_s, 1) 69 | 70 | # Sample o, r from p(o,r|s) 71 | self.o, self.r = self.sample_ro_s() 72 | #self.r = torch.tensor(int(r), dtype=torch.float) 73 | 74 | # Return only observation (gym like) 75 | return self.o 76 | 77 | 78 | def sample_s_sa(self): 79 | 80 | """ Sample next hidden state from current state and action p(s|s,a) """ 81 | 82 | if self.categorical_obs: 83 | s = self.s.argmax() 84 | else : 85 | s = self.s[0] 86 | 87 | probs = self.p_s_sa[s,self.action,:] 88 | if self.categorical_obs : 89 | return torch.distributions.one_hot_categorical.OneHotCategorical(probs=probs,).sample() 90 | return torch.multinomial(probs, 1) 91 | 92 | def sample_ro_s(self): 93 | 94 | """ Sample reward and observation from joint (conditional) distribution p(o, r|s)""" 95 | 96 | # If categorical var, find corresponding value to index tables 97 | if self.categorical_obs: 98 | s = self.s.argmax() 99 | 100 | else : 101 | s = self.s.clone() 102 | 103 | # Sample from joint multinomial. 104 | ind = torch.multinomial(self.p_or_s[s.reshape(-1),:].reshape(s.reshape(-1).size(0), -1), 1) 105 | size = torch.tensor(self.p_or_s[0].size(), dtype=torch.float) 106 | ro = torch.cat([ind//size[1], ind%size[1]], dim = -1).reshape(-1) 107 | 108 | if self.categorical_obs : 109 | o = torch.zeros(self.p_or_s.shape[1]) 110 | o[int(ro[0])] = 1. 111 | 112 | r = torch.zeros(self.p_or_s.shape[2]) 113 | r[int(ro[1])] = 1. 114 | else : 115 | o = ro[:1] 116 | r = ro[1:] 117 | 118 | return o, r 119 | 120 | 121 | def step(self, action): 122 | 123 | """ Take a step in the env : 124 | 1. Generate new hidden state p(s|s,a) 125 | 2. Get Obeservation, Reward, Flag Done from p(o,r|s) 126 | 3. return obs, reward, flag_done, info (s) 127 | """ 128 | 129 | # Increase current steptime 130 | self.current_step += 1 131 | # Save picked action 132 | self.action = action 133 | # Generate new latent state 134 | self.s = self.sample_s_sa() 135 | 136 | # Sample obs and reward from hidden state 137 | self.o, self.r = self.sample_ro_s() 138 | 139 | observation = self.o 140 | done = int(self.current_step >= self.episode_length) 141 | reward = self.r 142 | 143 | #Increment total score 144 | self.score += reward 145 | return observation, reward, torch.tensor(done, dtype=torch.float), {"s" : self.s} 146 | 147 | -------------------------------------------------------------------------------- /environment/env_wrappers.py: -------------------------------------------------------------------------------- 1 | 2 | ############################################################## 3 | ##################### WRAPPERS ############################### 4 | ############################################################## 5 | 6 | # Packages 7 | # import cv2 8 | import gym 9 | import time 10 | import torch 11 | 12 | import numpy as np 13 | 14 | ####################### To MDP ############################### 15 | 16 | class SqueezeEnv(gym.Wrapper): 17 | def __init__(self, env): 18 | super(SqueezeEnv, self).__init__(env) 19 | self.env = env 20 | 21 | def reset(self): 22 | o = self.env.reset() 23 | return o.unsqueeze(0) 24 | 25 | def step(self, action): 26 | 27 | o, r, done, info = self.env.step(action) 28 | return o.unsqueeze(0), r.unsqueeze(0), done.unsqueeze(0), info 29 | 30 | class RewardWrapper(gym.Wrapper): 31 | 32 | def __init__(self, env, reward_dic): 33 | super(RewardWrapper, self).__init__(env) 34 | self.env = env 35 | self.reward_dic = reward_dic 36 | 37 | def reset(self): 38 | o = self.env.reset() 39 | self.r = self.reward_dic[int(self.env.r.argmax())] 40 | return o 41 | 42 | def step(self, action): 43 | 44 | o, r, done, info = self.env.step(action) 45 | r = self.reward_dic[int(r.argmax())] 46 | return o, r, done, info 47 | 48 | 49 | ################## Augmented POMDP ############################# 50 | 51 | class BeliefStateRepresentation(gym.Wrapper): 52 | 53 | """ Return the same POMDP env with a belief state representation as attribute. 54 | Compute p(s|h), ie the proba distribution of the hidden state given trajectory history, 55 | using the model estimation of p_s_h, and store the computed vector 56 | as a belief state representation attributes """ 57 | 58 | def __init__(self, env, belief_state_model, with_done=False): 59 | super(BeliefStateRepresentation, self).__init__(env) 60 | 61 | self.env = env 62 | self.internal_model = belief_state_model 63 | self.with_done = with_done 64 | 65 | self.observation_space = gym.spaces.Box( 66 | low=0, high=1, 67 | shape=(belief_state_model.s_nvals,), dtype=np.float) 68 | 69 | def update_belief_state(self, a, o, r, d): 70 | 71 | ''' Update the proba distribution of hidden state representation p(s|h). 72 | Ie the distribution of the hidden state estimated s given the whole history ''' 73 | 74 | with torch.no_grad(): 75 | self.log_q_s_h = self.internal_model.log_q_s_h( 76 | regime=torch.tensor([1.]), # interventional regime 77 | loq_q_sprev_hprev=self.log_q_s_h, 78 | a=a, o=o, r=r, d=d, 79 | with_done=self.with_done) 80 | 81 | def reset(self): 82 | 83 | ''' Reset hidden state beliefs and last action ''' 84 | 85 | self.log_q_s_h = None # reset belief state 86 | 87 | o = self.env.reset() 88 | r = self.env.r.unsqueeze(0) 89 | d = torch.tensor([0]) 90 | 91 | self.update_belief_state(None, o, r, d) 92 | 93 | return torch.exp(self.log_q_s_h) 94 | 95 | def step(self, action): 96 | 97 | ''' Take a step, update hidden state beliefs and 98 | return the hidden state p(s|h) as observation ''' 99 | 100 | o, r, d, info = self.env.step(action) 101 | a = torch.tensor([1. if action==i else 0. for i in range(self.action_space.n)]).unsqueeze(0) 102 | 103 | self.update_belief_state(a, o, r, d) 104 | 105 | return torch.exp(self.log_q_s_h), r, d, info 106 | 107 | -------------------------------------------------------------------------------- /experiments/toy1/01_train_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy1/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | 'privileged_policy', 31 | type=str, 32 | choices=cfg['privileged_policies'].keys(), 33 | ) 34 | args = parser.parse_args() 35 | 36 | # process command-line arguments 37 | if args.gpu == -1: 38 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 39 | device = "cpu" 40 | else: 41 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 42 | device = f"cuda:{args.gpu}" 43 | 44 | seed = args.seed 45 | privileged_policy = args.privileged_policy 46 | 47 | print(f"device: {device}") 48 | print(f"seed: {seed}") 49 | print(f"privileged_policy : {privileged_policy}") 50 | 51 | 52 | import torch 53 | 54 | # Ugly hack 55 | sys.path.insert(0, os.path.abspath(f".")) 56 | 57 | from environment import PomdpEnv 58 | from policies import UniformPolicy, ExpertPolicy 59 | from models import TabularAugmentedModel 60 | 61 | from utils import construct_dataset 62 | from learning import fit_model 63 | 64 | 65 | ## SET UP THE ENVIRONMENT ## 66 | 67 | p_s = torch.tensor(cfg['p_s']) 68 | p_r_s = torch.tensor(cfg['p_r_s']) 69 | p_o_s = torch.tensor(cfg['p_o_s']) 70 | p_s_sa = torch.tensor(cfg['p_s_sa']) 71 | 72 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 73 | 74 | o_nvals=p_o_s.shape[1] 75 | a_nvals=p_s_sa.shape[1] 76 | r_nvals=p_r_s.shape[1] 77 | s_nvals = cfg["latent_space_size"] 78 | 79 | episode_length = cfg["episode_length"] 80 | 81 | # POMDP dynamics 82 | env = PomdpEnv(p_s=p_s, 83 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 84 | p_s_sa=p_s_sa, 85 | categorical_obs=True, 86 | max_length=episode_length) 87 | 88 | # Policy in the observational regime (priviledged) 89 | obs_policy = ExpertPolicy(p_a_s) 90 | 91 | # Policy in the interventional regime 92 | int_policy = UniformPolicy(a_nvals) 93 | 94 | 95 | ## SET UP THE SEEDS ## 96 | 97 | rng = np.random.RandomState(seed) 98 | seed_data_obs = rng.randint(0, 2**10) 99 | seed_data_int = rng.randint(0, 2**10) 100 | seed_training = rng.randint(0, 2**10) 101 | 102 | 103 | ## GENERATE THE DATASETS ## 104 | 105 | nsamples_obs_subsets = cfg['nsamples_obs'] 106 | nsamples_int_subsets = cfg['nsamples_int'] 107 | 108 | # we perform experiments on subsets of the same 109 | # dataset, so that each sequentially growing experiment 110 | # reuses the same samples, complemented with new ones 111 | nsamples_obs_total = np.max(nsamples_obs_subsets) 112 | nsamples_int_total = np.max(nsamples_int_subsets) 113 | 114 | torch.manual_seed(seed_data_obs) 115 | data_obs_all = construct_dataset(env=env, 116 | policy=obs_policy, 117 | n_samples=nsamples_obs_total, 118 | regime=torch.tensor(0)) 119 | 120 | torch.manual_seed(seed_data_int) 121 | data_int_all = construct_dataset(env=env, 122 | policy=int_policy, 123 | n_samples=nsamples_int_total, 124 | regime=torch.tensor(1)) 125 | 126 | 127 | ## LEARN THE TRANSITION MODELS ## 128 | 129 | loss_type = 'nll' 130 | with_done = False 131 | 132 | n_epochs = 500 133 | epoch_size = 50 134 | batch_size = 32 135 | lr = 1e-2 136 | patience = 10 137 | 138 | device = torch.device(device) 139 | 140 | training_schemes = cfg["training_schemes"] 141 | 142 | for nsamples_obs in nsamples_obs_subsets: 143 | for nsamples_int in nsamples_int_subsets: 144 | 145 | data_obs = data_obs_all[:nsamples_obs] 146 | data_int = data_int_all[:nsamples_int] 147 | 148 | modeldir = pathlib.Path(f"experiments/toy1/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 149 | modeldir.mkdir(parents=True, exist_ok=True) 150 | 151 | print(f"saving results to: {modeldir}") 152 | 153 | for training_scheme in training_schemes: 154 | 155 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 156 | 157 | logfile = modeldir / f"{training_scheme}_log.txt" 158 | paramsfile = modeldir / f"{training_scheme}.pt" 159 | 160 | if training_scheme == 'int': 161 | train_data = data_int 162 | elif training_scheme == 'obs+int': 163 | train_data = [(torch.tensor(1), episode) for (_, episode) in data_obs + data_int] 164 | elif training_scheme == 'augmented_obs+int': 165 | train_data = data_obs + data_int 166 | else: 167 | raise NotImplemented 168 | 169 | torch.manual_seed(seed_training) 170 | 171 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 172 | m = m.to(device) 173 | 174 | fit_model(m, 175 | train_data=train_data, 176 | valid_data=train_data, # we want to overfit 177 | loss_type=loss_type, 178 | with_done=with_done, 179 | n_epochs=n_epochs, 180 | epoch_size=epoch_size, 181 | batch_size=batch_size, 182 | lr=lr, 183 | patience=patience, 184 | log=True, 185 | logfile=logfile) 186 | 187 | torch.save(m.state_dict(), paramsfile) 188 | 189 | -------------------------------------------------------------------------------- /experiments/toy1/02_eval_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy1/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | 'privileged_policy', 31 | type=str, 32 | choices=cfg['privileged_policies'].keys(), 33 | ) 34 | args = parser.parse_args() 35 | 36 | # process command-line arguments 37 | if args.gpu == -1: 38 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 39 | device = "cpu" 40 | else: 41 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 42 | device = f"cuda:{args.gpu}" 43 | 44 | seed = args.seed 45 | privileged_policy = args.privileged_policy 46 | 47 | print(f"device: {device}") 48 | print(f"seed: {seed}") 49 | print(f"privileged_policy : {privileged_policy}") 50 | 51 | 52 | import torch 53 | 54 | # Ugly hack 55 | sys.path.insert(0, os.path.abspath(f".")) 56 | 57 | from environment import PomdpEnv 58 | from policies import UniformPolicy, ExpertPolicy 59 | from models import TabularAugmentedModel 60 | 61 | from utils import js_div, kl_div 62 | 63 | 64 | ## SET UP THE ENVIRONMENT ## 65 | 66 | p_s = torch.tensor(cfg['p_s']) 67 | p_r_s = torch.tensor(cfg['p_r_s']) 68 | p_o_s = torch.tensor(cfg['p_o_s']) 69 | p_s_sa = torch.tensor(cfg['p_s_sa']) 70 | 71 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 72 | 73 | o_nvals=p_o_s.shape[1] 74 | a_nvals=p_s_sa.shape[1] 75 | r_nvals=p_r_s.shape[1] 76 | s_nvals = cfg["latent_space_size"] 77 | 78 | # POMDP dynamics 79 | env = PomdpEnv(p_s=p_s, 80 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 81 | p_s_sa=p_s_sa, 82 | categorical_obs=True, 83 | max_length=1) 84 | 85 | # Policy in the observational regime (priviledged) 86 | obs_policy = ExpertPolicy(p_a_s) 87 | 88 | # Policy in the interventional regime 89 | int_policy = UniformPolicy(a_nvals) 90 | 91 | # recovering the true bandit transition model 92 | with torch.no_grad(): 93 | p_ssr_a_int = p_s.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\ 94 | * p_s_sa.permute(1, 0, 2).unsqueeze(-1) \ 95 | * p_r_s.unsqueeze(0).unsqueeze(0) 96 | p_r_a_int = p_ssr_a_int.sum(dim=(1, 2)) 97 | 98 | 99 | ## EVALUATE THE TRANSITION MODELS ## 100 | 101 | nsamples_obs_subsets = cfg['nsamples_obs'] 102 | nsamples_int_subsets = cfg['nsamples_int'] 103 | training_schemes = cfg["training_schemes"] 104 | 105 | with_done = False 106 | 107 | device = torch.device(device) 108 | 109 | resultsdir = pathlib.Path(f"experiments/toy1/results/{privileged_policy}/seed_{seed}") 110 | resultsdir.mkdir(parents=True, exist_ok=True) 111 | 112 | results = np.full((len(nsamples_obs_subsets), len(nsamples_int_subsets), len(training_schemes), 3), np.nan) 113 | 114 | p_r_a_int = p_r_a_int.to(device) 115 | 116 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 117 | m = m.to(device) 118 | 119 | for i, nsamples_obs in enumerate(nsamples_obs_subsets): 120 | for j, nsamples_int in enumerate(nsamples_int_subsets): 121 | for k, training_scheme in enumerate(training_schemes): 122 | 123 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 124 | 125 | modeldir = pathlib.Path(f"experiments/toy1/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 126 | 127 | print(f"reading results from: {modeldir}") 128 | 129 | paramsfile = modeldir / f"{training_scheme}.pt" 130 | m.load_state_dict(torch.load(paramsfile, map_location=device)) 131 | 132 | # recovering the learnt bandit transition model 133 | with torch.no_grad(): 134 | q_s = torch.nn.functional.softmax(m.params_s, dim=-1) 135 | q_r_s = torch.nn.functional.softmax(m.params_r_s, dim=-1) 136 | q_s_sa = torch.nn.functional.softmax(m.params_s_sa, dim=-1) 137 | 138 | # a, s, s, r 139 | q_ssr_a_int = q_s.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\ 140 | * q_s_sa.permute(1, 0, 2).unsqueeze(-1) \ 141 | * q_r_s.unsqueeze(0).unsqueeze(0) 142 | q_r_a_int = q_ssr_a_int.sum(dim=(1,2)) 143 | 144 | # computing the evaluation measures 145 | with torch.no_grad(): 146 | jsd = js_div(p_r_a_int, q_r_a_int).mean(0) # expectation over uniform policy 147 | kld = kl_div(p_r_a_int, q_r_a_int).mean(0) # expectation over uniform policy 148 | reward = p_r_a_int[q_r_a_int[:, 1].argmax(dim=0), 1] 149 | 150 | jsd = jsd.item() 151 | kld = kld.item() 152 | reward = reward.item() 153 | 154 | print(f"jsd: {jsd}") 155 | print(f"kld: {kld}") 156 | print(f"reward: {reward}") 157 | 158 | results[i, j, k] = (jsd, kld, reward) 159 | 160 | with open(resultsdir / "results.npy", 'wb') as f: 161 | np.save(f, results) 162 | -------------------------------------------------------------------------------- /experiments/toy1/03_kallus_et_al.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy1/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--nseeds', 19 | type=int, 20 | help = 'Number of random seed.', 21 | default=20, 22 | ) 23 | parser.add_argument( 24 | 'privileged_policy', 25 | type=str, 26 | choices=cfg['privileged_policies'].keys(), 27 | ) 28 | args = parser.parse_args() 29 | 30 | nseeds = args.nseeds 31 | privileged_policy = args.privileged_policy 32 | 33 | print(f"nseeds: {nseeds}") 34 | print(f"privileged_policy : {privileged_policy}") 35 | 36 | # Ugly hack 37 | sys.path.insert(0, os.path.abspath(f".")) 38 | 39 | import torch 40 | from utils_kallus import run_method, get_best_for_data, regs 41 | 42 | ## SET UP RANGE OF EXPERIMENTS ## 43 | nsample_obs = 512 44 | nsamples_int = [4, 8, 16, 32, 64, 128] 45 | 46 | ## SET UP THE SEEDS ## 47 | seed = 0 48 | rng = np.random.RandomState(seed) 49 | seed_data_obs = rng.randint(0, 2**10) 50 | seed_data_int = rng.randint(0, 2**10) 51 | seed_training = rng.randint(0, 2**10) 52 | 53 | ## SET UP THE ENVIRONMENT ## 54 | 55 | p_s = torch.tensor(cfg['p_s']) 56 | p_r_s = torch.tensor(cfg['p_r_s']) 57 | p_o_s = torch.tensor(cfg['p_o_s']) 58 | p_s_sa = torch.tensor(cfg['p_s_sa']) 59 | 60 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 61 | 62 | nsamples_obs = 512 63 | nsamples_int_subsets = cfg['nsamples_int'] 64 | training_schemes = cfg["training_schemes"] 65 | 66 | resultsdir = pathlib.Path(f"experiments/toy1/results/kallus") 67 | resultsdir.mkdir(parents=True, exist_ok=True) 68 | 69 | results_kallus = np.full((nseeds, len(nsamples_int_subsets)),np.nan) 70 | 71 | best_eta_regs = np.zeros((nseeds, len(nsamples_int_subsets))) 72 | predicted_taus = np.zeros((nseeds, len(nsamples_int_subsets))) 73 | predicted_omegas = np.zeros((nseeds, len(nsamples_int_subsets))) 74 | 75 | ## Calcul du true tau 76 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 77 | 78 | for seed in range(nseeds): 79 | 80 | #ugly hack for unexpected errors 81 | if seed in [13, 19] : 82 | seed += 8 83 | 84 | ## SET UP THE SEEDS ## 85 | rng = np.random.RandomState(seed) 86 | seed_data_obs = rng.randint(0, 2**10) 87 | seed_data_int = rng.randint(0, 2**10) 88 | seed_training = rng.randint(0, 2**10) 89 | 90 | #ugly hack for unexpected errors 91 | if seed in [21, 27] : 92 | seed -= 8 93 | 94 | np.random.seed(seed_data_obs) 95 | n_obs = nsample_obs 96 | X_obs = np.array([1 for i in range(n_obs)]) 97 | U_obs = np.random.choice([0,1,2], p=np.array(p_s), size=n_obs) 98 | T_obs = np.array([np.random.choice([0,1], p=np.array(p_a_s[u]), size=1)[0] for u in U_obs]) 99 | Y_obs = np.array([p_r_s[p_s_sa[U_obs[i], T_obs[i]].argmax()].argmax() for i in range(n_obs)], dtype=int) 100 | 101 | np.random.seed(seed_data_int) 102 | X_int_ = np.array([1 for i in range(nsamples_int_subsets[-1])]) 103 | U_int_ = np.random.choice([0,1,2], p=np.array(p_s), size=nsamples_int_subsets[-1]) 104 | T_int_ = np.array([np.random.choice([0,1], p=[0.5, 0.5], size=1)[0] for u in U_int_]) 105 | Y_int_ = np.array([p_r_s[p_s_sa[U_int_[i], T_int_[i]].argmax()].argmax() for i in range(nsamples_int_subsets[-1])], dtype=int) 106 | 107 | for k, nsamples_int in enumerate(nsamples_int_subsets): 108 | 109 | X_int = X_int_[:nsamples_int] 110 | U_int = U_int_[:nsamples_int] 111 | T_int = T_int_[:nsamples_int] 112 | Y_int = Y_int_[:nsamples_int] 113 | 114 | np.random.seed(seed_training) 115 | f1pred_obs = get_best_for_data(X_obs[T_obs>0].reshape(-1,1), Y_obs[T_obs>0], regs) 116 | f0pred_obs = get_best_for_data(X_obs[T_obs==0].reshape(-1,1), Y_obs[T_obs==0], regs) 117 | omega_est_int = f1pred_obs.predict(X_int.mean().reshape(-1,1)) - f0pred_obs.predict(X_int.mean().reshape(-1,1)) 118 | 119 | best_eta_reg , eta_est_cf, _ = run_method(X_int, Y_int, T_int, X_obs, Y_obs, T_obs) 120 | 121 | best_eta_regs[seed, k] = best_eta_reg.predict(X_obs.mean().reshape(-1, 1))[0] 122 | predicted_taus[seed, k] = omega_est_int[0] + best_eta_reg.predict(X_obs.mean().reshape(-1, 1))[0] 123 | predicted_omegas[seed, k] = omega_est_int[0] 124 | 125 | pred_tau = predicted_taus[seed, k] 126 | optimal_policy = torch.tensor(np.repeat(np.array([pred_tau<0, pred_tau>0], dtype=int).reshape(1,-1),3, axis=0)) 127 | if pred_tau == 0: 128 | optimal_policy = torch.tensor(np.repeat(np.array([0.5, 0.5]).reshape(1,-1),3, axis=0)) 129 | reward = (p_r_s * (p_s_sa * (optimal_policy * p_s.unsqueeze(-1)).unsqueeze(-1)).sum(1).sum(0).unsqueeze(1)).sum(0) 130 | results_kallus[seed, k] = reward[1] 131 | 132 | with open(f"experiments/toy1/results/kallus/{privileged_policy}.npy", 'wb') as f: 133 | np.save(f, results_kallus) 134 | 135 | -------------------------------------------------------------------------------- /experiments/toy1/04_plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | from xmlrpc.client import boolean 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | # Ugly hack 12 | sys.path.insert(0, os.path.abspath(f".")) 13 | 14 | from stat_tests import run_test 15 | 16 | 17 | if __name__ == '__main__': 18 | 19 | # read experiment config 20 | with open("experiments/toy1/config.json", "r") as json_data_file: 21 | cfg = json.load(json_data_file) 22 | 23 | # read command-line arguments 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | 'privileged_policy', 27 | type=str, 28 | choices=cfg['privileged_policies'].keys(), 29 | ) 30 | 31 | parser.add_argument( 32 | '-k', '--kallus', 33 | type=bool, 34 | help = 'Add Kallus et al experiment to the plot', 35 | default=False, 36 | ) 37 | args = parser.parse_args() 38 | 39 | privileged_policy = args.privileged_policy 40 | 41 | print(f"privileged_policy : {privileged_policy }") 42 | 43 | 44 | ## COLLECT THE RESULTS ## 45 | 46 | nobss = cfg['nsamples_obs'] 47 | nints = cfg['nsamples_int'] 48 | training_schemes = cfg["training_schemes"] 49 | 50 | nseeds = 20 51 | 52 | results = [] 53 | for seed in range(nseeds): 54 | with open(f"experiments/toy1/results/{privileged_policy}/seed_{seed}/results.npy", 'rb') as f: 55 | results.append(np.load(f)) 56 | 57 | results = np.asarray(results) 58 | 59 | jss = results[..., 0] 60 | # kls = results[..., 1] 61 | rewards = results[..., 2] 62 | 63 | 64 | ## CREATE AND SAVE THE PLOTS ## 65 | 66 | plotsdir = pathlib.Path(f"experiments/toy1/plots") 67 | plotsdir.mkdir(parents=True, exist_ok=True) 68 | 69 | rmin = np.min(rewards) 70 | rmax = np.max(rewards) 71 | 72 | jsmin = np.min(jss) 73 | jsmax = np.max(jss) 74 | 75 | r_int = rewards[..., 0] 76 | r_naiv = rewards[..., 1] 77 | r_augm = rewards[..., 2] 78 | 79 | js_int = jss[..., 0] 80 | js_naiv = jss[..., 1] 81 | js_augm = jss[..., 2] 82 | 83 | fig, axes = plt.subplots(2, 5, figsize=(20, 6), dpi=300) 84 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 85 | 86 | ax = axes[0, 0] 87 | cf = ax.pcolormesh(r_int.mean(0), vmin=rmin, vmax=rmax) 88 | fig.colorbar(cf, ax=ax) 89 | ax.set_title(f"no obs") 90 | ax.set_ylabel('nobs') 91 | ax.set_xlabel('nints') 92 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 93 | ax.set_xticklabels(nints) 94 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 95 | ax.set_yticklabels(nobss) 96 | 97 | ax = axes[0, 1] 98 | cf = ax.pcolormesh(r_naiv.mean(0), vmin=rmin, vmax=rmax) 99 | fig.colorbar(cf, ax=ax) 100 | ax.set_title(f"naive obs+int") 101 | ax.set_ylabel('nobs') 102 | ax.set_xlabel('nints') 103 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 104 | ax.set_xticklabels(nints) 105 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 106 | ax.set_yticklabels(nobss) 107 | 108 | ax = axes[0, 2] 109 | cf = ax.pcolormesh(r_augm.mean(0), vmin=rmin, vmax=rmax) 110 | fig.colorbar(cf, ax=ax) 111 | ax.set_title(f"augmented obs+int") 112 | ax.set_ylabel('nobs') 113 | ax.set_xlabel('nints') 114 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 115 | ax.set_xticklabels(nints) 116 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 117 | ax.set_yticklabels(nobss) 118 | 119 | r_gain_int = (r_augm - r_int).mean(0) 120 | r_gain_naiv = (r_augm - r_naiv).mean(0) 121 | r_gain_max = np.max([np.abs(r_gain_int), np.abs(r_gain_naiv)]) 122 | r_gain_min = -r_gain_max 123 | 124 | ax = axes[0, 3] 125 | cf = ax.pcolormesh(r_gain_int, cmap=plt.get_cmap('PiYG'), vmin=r_gain_min, vmax=r_gain_max) 126 | fig.colorbar(cf, ax=ax) 127 | ax.set_title(f"augmented - no obs") 128 | ax.set_ylabel('nobs') 129 | ax.set_xlabel('nints') 130 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 131 | ax.set_xticklabels(nints) 132 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 133 | ax.set_yticklabels(nobss) 134 | 135 | ax = axes[0, 4] 136 | cf = ax.pcolormesh(r_gain_naiv, cmap=plt.get_cmap('PiYG'), vmin=r_gain_min, vmax=r_gain_max) 137 | fig.colorbar(cf, ax=ax) 138 | ax.set_title(f"augmented - naive") 139 | ax.set_ylabel('nobs') 140 | ax.set_xlabel('nints') 141 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 142 | ax.set_xticklabels(nints) 143 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 144 | ax.set_yticklabels(nobss) 145 | 146 | ax = axes[1, 0] 147 | cf = ax.pcolormesh(js_int.mean(0), vmin=jsmin, vmax=jsmax) 148 | fig.colorbar(cf, ax=ax) 149 | ax.set_title(f"no obs") 150 | ax.set_ylabel('nobs') 151 | ax.set_xlabel('nints') 152 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 153 | ax.set_xticklabels(nints) 154 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 155 | ax.set_yticklabels(nobss) 156 | 157 | ax = axes[1, 1] 158 | cf = ax.pcolormesh(js_naiv.mean(0), vmin=jsmin, vmax=jsmax) 159 | fig.colorbar(cf, ax=ax) 160 | ax.set_title(f"naive obs+int") 161 | ax.set_ylabel('nobs') 162 | ax.set_xlabel('nints') 163 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 164 | ax.set_xticklabels(nints) 165 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 166 | ax.set_yticklabels(nobss) 167 | 168 | ax = axes[1, 2] 169 | cf = ax.pcolormesh(js_augm.mean(0), vmin=jsmin, vmax=jsmax) 170 | fig.colorbar(cf, ax=ax) 171 | ax.set_title(f"augmented obs+int") 172 | ax.set_ylabel('nobs') 173 | ax.set_xlabel('nints') 174 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 175 | ax.set_xticklabels(nints) 176 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 177 | ax.set_yticklabels(nobss) 178 | 179 | js_gain_int = (js_augm - js_int).mean(0) 180 | js_gain_naiv = (js_augm - js_naiv).mean(0) 181 | js_gain_max = np.max([np.abs(js_gain_int), np.abs(js_gain_naiv)]) 182 | js_gain_min = -js_gain_max 183 | 184 | ax = axes[1, 3] 185 | cf = ax.pcolormesh(js_gain_int, cmap=plt.get_cmap('PiYG'), vmin=js_gain_min, vmax=js_gain_max) 186 | fig.colorbar(cf, ax=ax) 187 | ax.set_title(f"augmented - no obs") 188 | ax.set_ylabel('nobs') 189 | ax.set_xlabel('nints') 190 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 191 | ax.set_xticklabels(nints) 192 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 193 | ax.set_yticklabels(nobss) 194 | 195 | ax = axes[1, 4] 196 | cf = ax.pcolormesh(js_gain_naiv, cmap=plt.get_cmap('PiYG'), vmin=js_gain_min, vmax=js_gain_max) 197 | fig.colorbar(cf, ax=ax) 198 | ax.set_title(f"augmented - naive") 199 | ax.set_ylabel('nobs') 200 | ax.set_xlabel('nints') 201 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 202 | ax.set_xticklabels(nints) 203 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 204 | ax.set_yticklabels(nobss) 205 | 206 | fig.savefig(plotsdir / f"{privileged_policy}_reward_js_grids.pdf", bbox_inches='tight', pad_inches=0) 207 | plt.close(fig) 208 | 209 | def plot_mean_std(ax, x, y, label, color): 210 | ax.plot(x, y.mean(0), label=label, color=color) 211 | ax.fill_between(x, y.mean(0) - y.std(0), y.mean(0) + y.std(0), color=color, alpha=0.2) 212 | 213 | def plot_mean_lowhigh(ax, x, mean, low, high, label, color): 214 | ax.plot(x, mean, label=label, color=color) 215 | ax.fill_between(x, low, high, color=color, alpha=0.2) 216 | 217 | def compute_central_tendency_and_error(id_central, id_error, sample): 218 | if id_central == 'mean': 219 | central = np.nanmean(sample, axis=0) 220 | elif id_central == 'median': 221 | central = np.nanmedian(sample, axis=0) 222 | else: 223 | raise NotImplementedError 224 | 225 | if isinstance(id_error, int): 226 | low = np.nanpercentile(sample, q=int((100 - id_error) / 2), axis=0) 227 | high = np.nanpercentile(sample, q=int(100 - (100 - id_error) / 2), axis=0) 228 | elif id_error == 'std': 229 | low = central - np.nanstd(sample, axis=0) 230 | high = central + np.nanstd(sample, axis=0) 231 | elif id_error == 'sem': 232 | low = central - np.nanstd(sample, axis=0) / np.sqrt(sample.shape[0]) 233 | high = central + np.nanstd(sample, axis=0) / np.sqrt(sample.shape[0]) 234 | else: 235 | raise NotImplementedError 236 | 237 | return central, low, high 238 | 239 | for i, nobs in enumerate(nobss): 240 | 241 | test = 'Wilcoxon' 242 | deviation = 'sem' #'std' 243 | confidence_level = 0.05 244 | 245 | ### Jensen-Shannon ### 246 | 247 | fig, axes = plt.subplots(1, 1, figsize=(3, 2.25), dpi=300) 248 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 249 | 250 | # statistical tests 251 | test_int_augm = [run_test(test, js_augm[:, i, j], js_int[:, i, j], alpha=confidence_level) for j in range(len(nints))] 252 | test_naiv_augm = [run_test(test, js_augm[:, i, j], js_naiv[:, i, j], alpha=confidence_level) for j in range(len(nints))] 253 | 254 | # mean and standard error 255 | mean0, low0, high0 = compute_central_tendency_and_error('mean', deviation, js_int[:, i]) 256 | mean1, low1, high1 = compute_central_tendency_and_error('mean', deviation, js_naiv[:, i]) 257 | mean2, low2, high2 = compute_central_tendency_and_error('mean', deviation, js_augm[:, i]) 258 | 259 | # plot JS curves 260 | ax = axes 261 | plot_mean_lowhigh(ax, nints, mean0, low0, high0, label="no obs", color="tab:blue") 262 | plot_mean_lowhigh(ax, nints, mean1, low1, high1, label="naive", color="tab:orange") 263 | plot_mean_lowhigh(ax, nints, mean2, low2, high2, label="augmented", color="tab:green") 264 | 265 | ymax = np.nanmax([high0, high1, high2]) 266 | ymin = np.nanmin([low0, low1, low2]) 267 | 268 | # plot significative difference as dots 269 | y = ymax + 0.05 * (ymax-ymin) 270 | x = np.asarray(nints)[np.argwhere(test_int_augm)] 271 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:blue', marker='v') 272 | 273 | y = ymax + 0.10 * (ymax-ymin) 274 | x = np.asarray(nints)[np.argwhere(test_naiv_augm)] 275 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:orange', marker='s') 276 | 277 | ax.set_title(f"JS divergence") 278 | ax.set_xlabel('nints (log scale)') 279 | ax.set_xscale('log', base=2) 280 | ax.set_ylim(bottom=0) 281 | ax.legend() 282 | 283 | fig.savefig(plotsdir / f"{privileged_policy}_js_nobs_{nobs}.pdf", bbox_inches='tight', pad_inches=0) 284 | plt.close(fig) 285 | 286 | 287 | ### Reward ### 288 | 289 | fig, axes = plt.subplots(1, 1, figsize=(3, 2.25), dpi=300) 290 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 291 | 292 | # statistical tests 293 | test_int_augm = [run_test(test, r_int[:, i, j], r_augm[:, i, j], alpha=confidence_level) for j in range(len(nints))] 294 | test_naiv_augm = [run_test(test, r_naiv[:, i, j], r_augm[:, i, j], alpha=confidence_level) for j in range(len(nints))] 295 | 296 | # mean and standard error 297 | mean0, low0, high0 = compute_central_tendency_and_error('mean', deviation, r_int[:, i]) 298 | mean1, low1, high1 = compute_central_tendency_and_error('mean', deviation, r_naiv[:, i]) 299 | mean2, low2, high2 = compute_central_tendency_and_error('mean', deviation, r_augm[:, i]) 300 | 301 | # plot reward curves 302 | ax = axes 303 | plot_mean_lowhigh(ax, nints, mean0, low0, high0, label="no obs", color="tab:blue") 304 | plot_mean_lowhigh(ax, nints, mean1, low1, high1, label="naive", color="tab:orange") 305 | plot_mean_lowhigh(ax, nints, mean2, low2, high2, label="augmented", color="tab:green") 306 | 307 | ymax = np.nanmax([high0, high1, high2]) 308 | ymin = np.nanmin([low0, low1, low2]) 309 | 310 | # plot significative difference as dots 311 | y = ymax + 0.05 * (ymax-ymin) 312 | x = np.asarray(nints)[np.argwhere(test_int_augm)] 313 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:blue', marker='v') 314 | 315 | y = ymax + 0.10 * (ymax-ymin) 316 | x = np.asarray(nints)[np.argwhere(test_naiv_augm)] 317 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:orange', marker='s') 318 | 319 | ax.set_title(f"reward") 320 | ax.set_xlabel('nints (log scale)') 321 | ax.set_xscale('log', base=2) 322 | ax.legend() 323 | 324 | 325 | if args.kallus and nobs == 512: 326 | 327 | with open(f"experiments/toy1/results/kallus/{privileged_policy}.npy", 'rb') as f: 328 | rewards_kallus = np.load(f) 329 | 330 | test_kallus_augm = [run_test(test, rewards_kallus[:, j], r_augm[:, i, j], alpha=confidence_level) for j in range(len(nints))] 331 | # plot significative difference as dots 332 | y = ymax + 0.155 * (ymax-ymin) 333 | x = np.asarray(nints)[np.argwhere(test_kallus_augm)] 334 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:red', marker='^') 335 | 336 | # mean and standard error 337 | mean0, low0, high0 = compute_central_tendency_and_error('mean', deviation, rewards_kallus) 338 | 339 | # plot reward curves 340 | plot_mean_lowhigh(ax, nints, mean0, low0, high0, label="Kallus et al", color="tab:red") 341 | 342 | fig.savefig(plotsdir / f"{privileged_policy}_reward_nobs_{nobs}.pdf", bbox_inches='tight', pad_inches=0) 343 | 344 | 345 | plt.close(fig) 346 | -------------------------------------------------------------------------------- /experiments/toy1/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "s_desc": ["red light 0", "green light 0", "red light 1"], 3 | "o_desc": ["none"], 4 | "a_desc": ["button A", "button B"], 5 | "r_desc": [0, 1], 6 | 7 | "episode_length": 1, 8 | 9 | "p_s": [0.6, 0.4, 0.0], 10 | 11 | "p_s_sa": [ 12 | [ 13 | [0.0, 0.0, 1.0], 14 | [1.0, 0.0, 0.0] 15 | ], [ 16 | [1.0, 0.0, 0.0], 17 | [0.0, 0.0, 1.0] 18 | ], [ 19 | [0.0, 0.0, 1.0], 20 | [0.0, 0.0, 1.0] 21 | ] 22 | ], 23 | 24 | "p_r_s": [ 25 | [1.0, 0.0], 26 | [1.0, 0.0], 27 | [0.0, 1.0] 28 | ], 29 | 30 | "p_o_s": [ 31 | [1.0], 32 | [1.0], 33 | [1.0] 34 | ], 35 | 36 | "privileged_policies": { 37 | "noisy_good": [ 38 | [0.9, 0.1], 39 | [0.4, 0.6], 40 | [1.0, 0.0] 41 | ], 42 | "perfect_good": [ 43 | [1.0, 0.0], 44 | [0.0, 1.0], 45 | [1.0, 0.0] 46 | ], 47 | "perfect_bad": [ 48 | [0.0, 1.0], 49 | [1.0, 0.0], 50 | [1.0, 0.0] 51 | ], 52 | "random": [ 53 | [0.5, 0.5], 54 | [0.5, 0.5], 55 | [1.0, 0.0] 56 | ], 57 | "strong_good_bias": [ 58 | [0.8, 0.2], 59 | [1.0, 0.0], 60 | [1.0, 0.0] 61 | ], 62 | "strong_bad_bias": [ 63 | [1.0, 0.0], 64 | [0.8, 0.2], 65 | [1.0, 0.0] 66 | ] 67 | }, 68 | 69 | "nsamples_obs": [4, 8, 16, 32, 64, 128, 256, 512], 70 | "nsamples_int": [4, 8, 16, 32, 64, 128], 71 | "training_schemes": ["int", "obs+int", "augmented_obs+int"], 72 | "latent_space_size": 32 73 | } 74 | -------------------------------------------------------------------------------- /experiments/toy2/01_train_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy2/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | '--nobs', 31 | type=int, 32 | help = 'Number of observational samples.', 33 | default=argparse.SUPPRESS, 34 | ) 35 | parser.add_argument( 36 | '--nint', 37 | type=int, 38 | help = 'Number of interventional samples.', 39 | default=argparse.SUPPRESS, 40 | ) 41 | parser.add_argument( 42 | '--scheme', 43 | type=str, 44 | choices=cfg['training_schemes'], 45 | help='Training scheme.', 46 | default=argparse.SUPPRESS, 47 | ) 48 | parser.add_argument( 49 | 'privileged_policy', 50 | type=str, 51 | choices=cfg['privileged_policies'].keys(), 52 | ) 53 | args = parser.parse_args() 54 | 55 | # process command-line arguments 56 | if args.gpu == -1: 57 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 58 | device = "cpu" 59 | else: 60 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 61 | device = f"cuda:{args.gpu}" 62 | 63 | seed = args.seed 64 | privileged_policy = args.privileged_policy 65 | 66 | print(f"device: {device}") 67 | print(f"seed: {seed}") 68 | print(f"privileged_policy : {privileged_policy}") 69 | 70 | 71 | import torch 72 | 73 | # Ugly hack 74 | sys.path.insert(0, os.path.abspath(f".")) 75 | 76 | from environment import PomdpEnv 77 | from policies import UniformPolicy, ExpertPolicy 78 | from models import TabularAugmentedModel 79 | 80 | from utils import construct_dataset 81 | from learning import fit_model 82 | 83 | 84 | ## SET UP THE ENVIRONMENT ## 85 | 86 | p_s = torch.tensor(cfg['p_s']) 87 | p_r_s = torch.tensor(cfg['p_r_s']) 88 | p_o_s = torch.tensor(cfg['p_o_s']) 89 | p_s_sa = torch.tensor(cfg['p_s_sa']) 90 | 91 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 92 | 93 | o_nvals=p_o_s.shape[1] 94 | a_nvals=p_s_sa.shape[1] 95 | r_nvals=p_r_s.shape[1] 96 | s_nvals = cfg["latent_space_size"] 97 | 98 | # POMDP dynamics 99 | env = PomdpEnv(p_s=p_s, 100 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 101 | p_s_sa=p_s_sa, 102 | categorical_obs=True, 103 | max_length=50) 104 | 105 | # Policy in the observational regime (priviledged) 106 | obs_policy = ExpertPolicy(p_a_s) 107 | 108 | # Policy in the interventional regime 109 | int_policy = UniformPolicy(a_nvals) 110 | 111 | 112 | ## SET UP THE SEEDS ## 113 | 114 | rng = np.random.RandomState(seed) 115 | seed_data_obs = rng.randint(0, 2**10) 116 | seed_data_int = rng.randint(0, 2**10) 117 | seed_training = rng.randint(0, 2**10) 118 | 119 | 120 | ## GENERATE THE DATASETS ## 121 | 122 | # from command-line argument if provided, otherwise from config file 123 | nsamples_obs_subsets = [args.nobs] if "nobs" in args else cfg['nsamples_obs'] 124 | nsamples_int_subsets = [args.nint] if "nint" in args else cfg['nsamples_int'] 125 | training_schemes = [args.scheme] if "scheme" in args else cfg["training_schemes"] 126 | 127 | print(f"nsamples_obs_subsets: {nsamples_obs_subsets}") 128 | print(f"nsamples_int_subsets: {nsamples_int_subsets}") 129 | print(f"training_schemes: {training_schemes}") 130 | 131 | # we perform experiments on subsets of the same 132 | # dataset, so that each sequentially growing experiment 133 | # reuses the same samples, complemented with new ones 134 | nsamples_obs_total = np.max(nsamples_obs_subsets) 135 | nsamples_int_total = np.max(nsamples_int_subsets) 136 | 137 | torch.manual_seed(seed_data_obs) 138 | data_obs_all = construct_dataset(env=env, 139 | policy=obs_policy, 140 | n_samples=nsamples_obs_total, 141 | regime=torch.tensor(0)) 142 | 143 | torch.manual_seed(seed_data_int) 144 | data_int_all = construct_dataset(env=env, 145 | policy=int_policy, 146 | n_samples=nsamples_int_total, 147 | regime=torch.tensor(1)) 148 | 149 | 150 | ## LEARN THE TRANSITION MODELS ## 151 | 152 | loss_type = 'nll' 153 | with_done = False 154 | 155 | n_epochs = 500 156 | epoch_size = 50 157 | batch_size = 32 158 | lr = 1e-2 159 | patience = 10 160 | 161 | device = torch.device(device) 162 | 163 | for nsamples_obs in nsamples_obs_subsets: 164 | for nsamples_int in nsamples_int_subsets: 165 | 166 | data_obs = data_obs_all[:nsamples_obs] 167 | data_int = data_int_all[:nsamples_int] 168 | 169 | modeldir = pathlib.Path(f"experiments/toy2/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 170 | modeldir.mkdir(parents=True, exist_ok=True) 171 | 172 | print(f"saving results to: {modeldir}") 173 | 174 | for training_scheme in training_schemes: 175 | 176 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 177 | 178 | logfile = modeldir / f"{training_scheme}_log.txt" 179 | paramsfile = modeldir / f"{training_scheme}.pt" 180 | 181 | if pathlib.Path(paramsfile).exists(): 182 | print(f"Found trained model {paramsfile}, skip training.") 183 | continue 184 | 185 | if training_scheme == 'int': 186 | train_data = data_int 187 | elif training_scheme == 'obs+int': 188 | train_data = [(torch.tensor(1), episode) for (_, episode) in data_obs + data_int] 189 | elif training_scheme == 'augmented_obs+int': 190 | train_data = data_obs + data_int 191 | else: 192 | raise NotImplemented 193 | 194 | torch.manual_seed(seed_training) 195 | 196 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 197 | m = m.to(device) 198 | 199 | fit_model(m, 200 | train_data=train_data, 201 | valid_data=train_data, # we want to overfit 202 | loss_type=loss_type, 203 | with_done=with_done, 204 | n_epochs=n_epochs, 205 | epoch_size=epoch_size, 206 | batch_size=batch_size, 207 | lr=lr, 208 | patience=patience, 209 | log=True, 210 | logfile=logfile) 211 | 212 | torch.save(m.state_dict(), paramsfile) 213 | 214 | -------------------------------------------------------------------------------- /experiments/toy2/02_eval_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy2/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | 'privileged_policy', 31 | type=str, 32 | choices=cfg['privileged_policies'].keys(), 33 | ) 34 | args = parser.parse_args() 35 | 36 | # process command-line arguments 37 | if args.gpu == -1: 38 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 39 | device = "cpu" 40 | else: 41 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 42 | device = f"cuda:{args.gpu}" 43 | 44 | seed = args.seed 45 | privileged_policy = args.privileged_policy 46 | 47 | print(f"device: {device}") 48 | print(f"seed: {seed}") 49 | print(f"privileged_policy : {privileged_policy}") 50 | 51 | 52 | import torch 53 | 54 | # Ugly hack 55 | sys.path.insert(0, os.path.abspath(f".")) 56 | 57 | from environment import PomdpEnv 58 | from policies import UniformPolicy, ExpertPolicy 59 | from models import TabularAugmentedModel 60 | 61 | from utils import construct_dataset 62 | from utils import js_div_empirical, kl_div_empirical, cross_entropy_empirical 63 | 64 | 65 | ## SET UP THE ENVIRONMENT ## 66 | 67 | p_s = torch.tensor(cfg['p_s']) 68 | p_r_s = torch.tensor(cfg['p_r_s']) 69 | p_o_s = torch.tensor(cfg['p_o_s']) 70 | p_s_sa = torch.tensor(cfg['p_s_sa']) 71 | 72 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 73 | 74 | o_nvals=p_o_s.shape[1] 75 | a_nvals=p_s_sa.shape[1] 76 | r_nvals=p_r_s.shape[1] 77 | s_nvals = cfg["latent_space_size"] 78 | 79 | episode_length = cfg["episode_length"] 80 | 81 | # POMDP dynamics 82 | env = PomdpEnv(p_s=p_s, 83 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 84 | p_s_sa=p_s_sa, 85 | categorical_obs=True, 86 | max_length=episode_length) 87 | 88 | # Policy in the observational regime (priviledged) 89 | obs_policy = ExpertPolicy(p_a_s) 90 | 91 | # Policy in the interventional regime 92 | int_policy = UniformPolicy(a_nvals) 93 | 94 | 95 | ## SET UP THE SEEDS ## 96 | 97 | rng = np.random.RandomState(seed) 98 | seed_data_obs = rng.randint(0, 2**10) 99 | seed_data_int = rng.randint(0, 2**10) 100 | seed_training = rng.randint(0, 2**10) 101 | seed_data_eval = rng.randint(0, 2**10) 102 | seed_eval = rng.randint(0, 2**10) 103 | 104 | 105 | ## GENERATE THE EVALUATION DATASET ## 106 | 107 | nsamples_eval = cfg["nsamples_eval"] 108 | 109 | torch.manual_seed(seed_data_eval) 110 | data_eval_p = construct_dataset(env=env, 111 | policy=int_policy, 112 | n_samples=nsamples_eval, 113 | regime=torch.tensor(1)) 114 | 115 | 116 | ## EVALUATE THE TRANSITION MODELS ## 117 | 118 | nsamples_obs_subsets = cfg['nsamples_obs'] 119 | nsamples_int_subsets = cfg['nsamples_int'] 120 | training_schemes = cfg["training_schemes"] 121 | 122 | with_done = False 123 | batch_size = 32 124 | 125 | device = torch.device(device) 126 | 127 | # true model 128 | m_true = TabularAugmentedModel(s_nvals=p_s.shape[0], o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 129 | m_true.set_probs(p_s=p_s, p_o_s=p_o_s, p_r_s=p_r_s, p_s_sa=p_s_sa, p_a_s=p_a_s) 130 | m_true.to(device) 131 | 132 | # learnt model 133 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 134 | m = m.to(device) 135 | 136 | resultsdir = pathlib.Path(f"experiments/toy2/results/{privileged_policy}/seed_{seed}") 137 | resultsdir.mkdir(parents=True, exist_ok=True) 138 | 139 | results = np.full((len(nsamples_obs_subsets), len(nsamples_int_subsets), len(training_schemes), 3), np.nan) 140 | 141 | for i, nsamples_obs in enumerate(nsamples_obs_subsets): 142 | for j, nsamples_int in enumerate(nsamples_int_subsets): 143 | for k, training_scheme in enumerate(training_schemes): 144 | 145 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 146 | 147 | modeldir = pathlib.Path(f"experiments/toy2/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 148 | 149 | print(f"reading results from: {modeldir}") 150 | 151 | paramsfile = modeldir / f"{training_scheme}.pt" 152 | m.load_state_dict(torch.load(paramsfile, map_location=device)) 153 | 154 | # sample from the learnt model 155 | with torch.no_grad(): 156 | q_s = torch.nn.functional.softmax(m.params_s, dim=-1) 157 | q_r_s = torch.nn.functional.softmax(m.params_r_s, dim=-1) 158 | q_o_s = torch.nn.functional.softmax(m.params_o_s, dim=-1) 159 | q_s_sa = torch.nn.functional.softmax(m.params_s_sa, dim=-1) 160 | 161 | # imaginary POMDP dynamics 162 | env_q = PomdpEnv(p_s=q_s, 163 | p_or_s=q_r_s.unsqueeze(-2) * q_o_s.unsqueeze(-1), 164 | p_s_sa=q_s_sa, 165 | categorical_obs=True, 166 | max_length=episode_length) 167 | 168 | # imaginary data 169 | torch.manual_seed(seed_eval) 170 | data_eval_q = construct_dataset(env=env_q, 171 | policy=int_policy, 172 | n_samples=nsamples_eval, 173 | regime=torch.tensor(1)) 174 | 175 | # compute empirical cross-entropy (NLL) 176 | ce = cross_entropy_empirical(model_q=m, data_p=data_eval_p, 177 | batch_size=batch_size, with_done=with_done) 178 | 179 | # compute empirical KL 180 | kld = kl_div_empirical(model_q=m, model_p=m_true, 181 | data_p=data_eval_p, 182 | batch_size=batch_size, with_done=with_done) 183 | 184 | # compute empirical JS 185 | jsd = js_div_empirical(model_q=m, model_p=m_true, 186 | data_q=data_eval_q, data_p=data_eval_p, 187 | batch_size=batch_size, with_done=with_done) 188 | 189 | ce = ce.item() 190 | kld = kld.item() 191 | jsd = jsd.item() 192 | 193 | print(f"ce: {ce}") 194 | print(f"kld: {kld}") 195 | print(f"jsd: {jsd}") 196 | 197 | results[i, j, k] = (kld, jsd, ce) 198 | 199 | with open(resultsdir / "model_results.npy", 'wb') as f: 200 | np.save(f, results) 201 | -------------------------------------------------------------------------------- /experiments/toy2/03_train_agents.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy2/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | '--nobs', 31 | type=int, 32 | help = 'Number of observational samples.', 33 | default=argparse.SUPPRESS, 34 | ) 35 | parser.add_argument( 36 | '--nint', 37 | type=int, 38 | help = 'Number of interventional samples.', 39 | default=argparse.SUPPRESS, 40 | ) 41 | parser.add_argument( 42 | '--scheme', 43 | type=str, 44 | choices=cfg['training_schemes'], 45 | help='Training scheme.', 46 | default=argparse.SUPPRESS, 47 | ) 48 | parser.add_argument( 49 | 'privileged_policy', 50 | type=str, 51 | choices=cfg['privileged_policies'].keys(), 52 | ) 53 | args = parser.parse_args() 54 | 55 | # process command-line arguments 56 | if args.gpu == -1: 57 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 58 | device = "cpu" 59 | else: 60 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 61 | device = f"cuda:{args.gpu}" 62 | 63 | seed = args.seed 64 | privileged_policy = args.privileged_policy 65 | 66 | print(f"device: {device}") 67 | print(f"seed: {seed}") 68 | print(f"privileged_policy : {privileged_policy}") 69 | 70 | 71 | import torch 72 | 73 | # Ugly hack 74 | sys.path.insert(0, os.path.abspath(f".")) 75 | 76 | from models import TabularAugmentedModel 77 | 78 | # ENVIRONMENT 79 | from environment import PomdpEnv 80 | from environment.env_wrappers import BeliefStateRepresentation, RewardWrapper, SqueezeEnv 81 | 82 | from rl_agents.ac import ActorCritic, run_actorcritic 83 | # from rl_agents.reinforce import Actor, run_reinforce 84 | 85 | ## SET UP THE ENVIRONMENT ## 86 | 87 | p_s = torch.tensor(cfg['p_s']) 88 | p_r_s = torch.tensor(cfg['p_r_s']) 89 | p_o_s = torch.tensor(cfg['p_o_s']) 90 | p_s_sa = torch.tensor(cfg['p_s_sa']) 91 | 92 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 93 | 94 | o_nvals=p_o_s.shape[1] 95 | a_nvals=p_s_sa.shape[1] 96 | r_nvals=p_r_s.shape[1] 97 | s_nvals = cfg["latent_space_size"] 98 | 99 | episode_length = cfg["episode_length"] 100 | 101 | reward_map = cfg["r_desc"] 102 | 103 | 104 | ## SET UP THE SEEDS ## 105 | 106 | rng = np.random.RandomState(seed) 107 | seed_data_obs = rng.randint(0, 2**10) 108 | seed_data_int = rng.randint(0, 2**10) 109 | seed_model_training = rng.randint(0, 2**10) 110 | seed_data_eval = rng.randint(0, 2**10) 111 | seed_eval = rng.randint(0, 2**10) 112 | seed_agent_training = rng.randint(0, 2**10) 113 | 114 | ## EVALUATE THE TRANSITION MODELS ## 115 | 116 | # from command-line argument if provided, otherwise from config file 117 | nsamples_obs_subsets = [args.nobs] if "nobs" in args else cfg['nsamples_obs'] 118 | nsamples_int_subsets = [args.nint] if "nint" in args else cfg['nsamples_int'] 119 | training_schemes = [args.scheme] if "scheme" in args else cfg["training_schemes"] 120 | 121 | print(f"nsamples_obs_subsets: {nsamples_obs_subsets}") 122 | print(f"nsamples_int_subsets: {nsamples_int_subsets}") 123 | print(f"training_schemes: {training_schemes}") 124 | 125 | ## EVALUATE THE TRANSITION MODELS ## 126 | 127 | with_done = False 128 | lr = 1e-2 129 | gamma = 0.9 130 | n_epochs = 1000 131 | log_every = 10 132 | batch_size = 8 133 | 134 | device = torch.device(device) 135 | 136 | # learnt model 137 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 138 | m = m.to(device) 139 | 140 | for nsamples_obs in nsamples_obs_subsets: 141 | for nsamples_int in nsamples_int_subsets: 142 | for training_scheme in training_schemes: 143 | 144 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 145 | 146 | model_dir = pathlib.Path(f"experiments/toy2/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 147 | agent_dir = pathlib.Path(f"experiments/toy2/trained_agents/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 148 | 149 | agent_dir.mkdir(parents=True, exist_ok=True) 150 | 151 | model_paramsfile = model_dir / f"{training_scheme}.pt" 152 | agent_paramsfile = agent_dir / f"{training_scheme}.pt" 153 | logfile = agent_dir / f"{training_scheme}_log.txt" 154 | 155 | if agent_paramsfile.exists(): 156 | print(f"Found trained agent {agent_paramsfile}, skip training.") 157 | continue 158 | 159 | print(f"reading model from: {model_paramsfile}") 160 | 161 | m.load_state_dict(torch.load(model_paramsfile, map_location=device)) 162 | 163 | # recover learned POMDP dynamics 164 | with torch.no_grad(): 165 | q_s = torch.nn.functional.softmax(m.params_s, dim=-1) 166 | q_r_s = torch.nn.functional.softmax(m.params_r_s, dim=-1) 167 | q_o_s = torch.nn.functional.softmax(m.params_o_s, dim=-1) 168 | q_s_sa = torch.nn.functional.softmax(m.params_s_sa, dim=-1) 169 | 170 | # learned POMDP 171 | env_q = PomdpEnv(p_s=q_s, 172 | p_or_s=q_r_s.unsqueeze(-2) * q_o_s.unsqueeze(-1), 173 | p_s_sa=q_s_sa, 174 | categorical_obs=True, 175 | max_length=episode_length) 176 | 177 | # POMDP -> MDP (using the model's belief state) 178 | env_q = BeliefStateRepresentation(SqueezeEnv(env_q), m) 179 | 180 | # map categorical reward to numerical values 181 | env_q = RewardWrapper(env_q, reward_dic=reward_map) 182 | 183 | # agent training (dream) 184 | torch.manual_seed(seed_agent_training) 185 | 186 | # agent = Actor(s_nvals=s_nvals, a_nvals=a_nvals) 187 | # run_reinforce(env=env_q, agent=agent, 188 | # lr=lr, gamma=gamma, 189 | # batch_size=batch_size, 190 | # n_epochs=n_epochs, 191 | # log_every=log_every, 192 | # logfile=logfile) 193 | 194 | agent = ActorCritic(s_nvals=s_nvals, a_nvals=a_nvals) 195 | run_actorcritic(env_q, agent, 196 | lr=lr, gamma=gamma, 197 | batch_size=batch_size, 198 | n_epochs=n_epochs, 199 | log_every=log_every, 200 | logfile=logfile) 201 | 202 | torch.save(agent.state_dict(), agent_paramsfile) 203 | print(f"saving agent to: {agent_paramsfile}") 204 | -------------------------------------------------------------------------------- /experiments/toy2/04_eval_agents.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy2/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | 'privileged_policy', 31 | type=str, 32 | choices=cfg['privileged_policies'].keys(), 33 | ) 34 | args = parser.parse_args() 35 | 36 | # process command-line arguments 37 | if args.gpu == -1: 38 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 39 | device = "cpu" 40 | else: 41 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 42 | device = f"cuda:{args.gpu}" 43 | 44 | seed = args.seed 45 | privileged_policy = args.privileged_policy 46 | 47 | print(f"device: {device}") 48 | print(f"seed: {seed}") 49 | print(f"privileged_policy : {privileged_policy}") 50 | 51 | 52 | import torch 53 | 54 | # Ugly hack 55 | sys.path.insert(0, os.path.abspath(f".")) 56 | 57 | from models import TabularAugmentedModel 58 | 59 | # ENVIRONMENT 60 | from environment import PomdpEnv 61 | from environment.env_wrappers import BeliefStateRepresentation, RewardWrapper, SqueezeEnv 62 | 63 | from rl_agents.ac import ActorCritic, evaluate_agent 64 | 65 | ## SET UP THE ENVIRONMENT ## 66 | 67 | p_s = torch.tensor(cfg['p_s']) 68 | p_r_s = torch.tensor(cfg['p_r_s']) 69 | p_o_s = torch.tensor(cfg['p_o_s']) 70 | p_s_sa = torch.tensor(cfg['p_s_sa']) 71 | 72 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 73 | 74 | o_nvals=p_o_s.shape[1] 75 | a_nvals=p_s_sa.shape[1] 76 | r_nvals=p_r_s.shape[1] 77 | s_nvals = cfg["latent_space_size"] 78 | 79 | episode_length = cfg["episode_length"] 80 | 81 | reward_map = cfg["r_desc"] 82 | 83 | 84 | ## SET UP THE SEEDS ## 85 | 86 | rng = np.random.RandomState(seed) 87 | seed_data_obs = rng.randint(0, 2**10) 88 | seed_data_int = rng.randint(0, 2**10) 89 | seed_model_training = rng.randint(0, 2**10) 90 | seed_data_eval = rng.randint(0, 2**10) 91 | seed_eval = rng.randint(0, 2**10) 92 | seed_agent_training = rng.randint(0, 2**10) 93 | seed_agent_eval = rng.randint(0, 2**10) 94 | 95 | ## EVALUATE THE TRANSITION MODELS ## 96 | 97 | nsamples_obs_subsets = cfg['nsamples_obs'] 98 | nsamples_int_subsets = cfg['nsamples_int'] 99 | training_schemes = cfg["training_schemes"] 100 | 101 | ## EVALUATE THE TRANSITION MODELS ## 102 | 103 | n_episodes = 100 104 | 105 | device = torch.device(device) 106 | 107 | # learnt model 108 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 109 | m = m.to(device) 110 | 111 | # learnt agent 112 | agent = ActorCritic(s_nvals=s_nvals, a_nvals=a_nvals) 113 | agent.to(device) 114 | 115 | # true POMDP 116 | env_p = PomdpEnv(p_s=p_s, 117 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 118 | p_s_sa=p_s_sa, 119 | categorical_obs=True, 120 | max_length=episode_length) 121 | 122 | resultsdir = pathlib.Path(f"experiments/toy2/results/{privileged_policy}/seed_{seed}") 123 | resultsdir.mkdir(parents=True, exist_ok=True) 124 | 125 | results = np.full((len(nsamples_obs_subsets), len(nsamples_int_subsets), len(training_schemes), 1), np.nan) 126 | 127 | for i, nsamples_obs in enumerate(nsamples_obs_subsets): 128 | for j, nsamples_int in enumerate(nsamples_int_subsets): 129 | for k, training_scheme in enumerate(training_schemes): 130 | 131 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 132 | 133 | model_dir = pathlib.Path(f"experiments/toy2/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 134 | model_paramsfile = model_dir / f"{training_scheme}.pt" 135 | 136 | agent_dir = pathlib.Path(f"experiments/toy2/trained_agents/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 137 | agent_paramsfile = agent_dir / f"{training_scheme}.pt" 138 | 139 | print(f"reading model from: {model_paramsfile}") 140 | m.load_state_dict(torch.load(model_paramsfile, map_location=device)) 141 | 142 | print(f"reading agent from: {agent_paramsfile}") 143 | agent.load_state_dict(torch.load(agent_paramsfile, map_location=device)) 144 | 145 | # POMDP -> MDP (using the model's belief state) 146 | env = BeliefStateRepresentation(SqueezeEnv(env_p), m) 147 | 148 | # map categorical reward to its numerical values 149 | env = RewardWrapper(env, reward_dic=reward_map) 150 | 151 | # agent evaluation (true environment) 152 | torch.manual_seed(seed_agent_eval) 153 | 154 | reward = evaluate_agent(env, agent, n_episodes) 155 | 156 | print(f"reward: {reward}") 157 | 158 | results[i, j, k] = reward 159 | 160 | with open(resultsdir / "agent_results.npy", 'wb') as f: 161 | np.save(f, results) 162 | -------------------------------------------------------------------------------- /experiments/toy2/05_plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | # Ugly hack 11 | sys.path.insert(0, os.path.abspath(f".")) 12 | 13 | from stat_tests import run_test 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | # read experiment config 19 | with open("experiments/toy2/config.json", "r") as json_data_file: 20 | cfg = json.load(json_data_file) 21 | 22 | # read command-line arguments 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | 'privileged_policy', 26 | type=str, 27 | choices=cfg['privileged_policies'].keys(), 28 | ) 29 | args = parser.parse_args() 30 | 31 | privileged_policy = args.privileged_policy 32 | 33 | print(f"privileged_policy : {privileged_policy }") 34 | 35 | 36 | ## COLLECT THE RESULTS ## 37 | 38 | nobss = cfg['nsamples_obs'] 39 | nints = cfg['nsamples_int'] 40 | training_schemes = cfg["training_schemes"] 41 | 42 | nseeds = 20 43 | 44 | model_results = [] 45 | agent_results = [] 46 | for seed in range(nseeds): 47 | with open(f"experiments/toy2/results/{privileged_policy}/seed_{seed}/model_results.npy", 'rb') as f: 48 | model_results.append(np.load(f)) 49 | with open(f"experiments/toy2/results/{privileged_policy}/seed_{seed}/agent_results.npy", 'rb') as f: 50 | agent_results.append(np.load(f)) 51 | 52 | model_results = np.asarray(model_results) 53 | agent_results = np.asarray(agent_results) 54 | 55 | # kls = model_results[..., 0] 56 | jss = model_results[..., 1] 57 | # ces = model_results[..., 2] 58 | rewards = agent_results[..., 0] 59 | 60 | 61 | ## CREATE AND SAVE THE PLOTS ## 62 | 63 | plotsdir = pathlib.Path(f"experiments/toy2/plots") 64 | plotsdir.mkdir(parents=True, exist_ok=True) 65 | 66 | rmin = np.min(rewards) 67 | rmax = np.max(rewards) 68 | 69 | jsmin = np.min(jss) 70 | jsmax = np.max(jss) 71 | 72 | r_int = rewards[..., 0] 73 | r_naiv = rewards[..., 1] 74 | r_augm = rewards[..., 2] 75 | 76 | js_int = jss[..., 0] 77 | js_naiv = jss[..., 1] 78 | js_augm = jss[..., 2] 79 | 80 | fig, axes = plt.subplots(2, 5, figsize=(20, 6), dpi=300) 81 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 82 | 83 | ax = axes[0, 0] 84 | cf = ax.pcolormesh(r_int.mean(0), vmin=rmin, vmax=rmax) 85 | fig.colorbar(cf, ax=ax) 86 | ax.set_title(f"no obs") 87 | ax.set_ylabel('nobs') 88 | ax.set_xlabel('nints') 89 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 90 | ax.set_xticklabels(nints) 91 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 92 | ax.set_yticklabels(nobss) 93 | 94 | ax = axes[0, 1] 95 | cf = ax.pcolormesh(r_naiv.mean(0), vmin=rmin, vmax=rmax) 96 | fig.colorbar(cf, ax=ax) 97 | ax.set_title(f"naive obs+int") 98 | ax.set_ylabel('nobs') 99 | ax.set_xlabel('nints') 100 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 101 | ax.set_xticklabels(nints) 102 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 103 | ax.set_yticklabels(nobss) 104 | 105 | ax = axes[0, 2] 106 | cf = ax.pcolormesh(r_augm.mean(0), vmin=rmin, vmax=rmax) 107 | fig.colorbar(cf, ax=ax) 108 | ax.set_title(f"augmented obs+int") 109 | ax.set_ylabel('nobs') 110 | ax.set_xlabel('nints') 111 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 112 | ax.set_xticklabels(nints) 113 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 114 | ax.set_yticklabels(nobss) 115 | 116 | r_gain_int = (r_augm - r_int).mean(0) 117 | r_gain_naiv = (r_augm - r_naiv).mean(0) 118 | r_gain_max = np.max([np.abs(r_gain_int), np.abs(r_gain_naiv)]) 119 | r_gain_min = -r_gain_max 120 | 121 | ax = axes[0, 3] 122 | cf = ax.pcolormesh(r_gain_int, cmap=plt.get_cmap('PiYG'), vmin=r_gain_min, vmax=r_gain_max) 123 | fig.colorbar(cf, ax=ax) 124 | ax.set_title(f"augmented - no obs") 125 | ax.set_ylabel('nobs') 126 | ax.set_xlabel('nints') 127 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 128 | ax.set_xticklabels(nints) 129 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 130 | ax.set_yticklabels(nobss) 131 | 132 | ax = axes[0, 4] 133 | cf = ax.pcolormesh(r_gain_naiv, cmap=plt.get_cmap('PiYG'), vmin=r_gain_min, vmax=r_gain_max) 134 | fig.colorbar(cf, ax=ax) 135 | ax.set_title(f"augmented - naive") 136 | ax.set_ylabel('nobs') 137 | ax.set_xlabel('nints') 138 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 139 | ax.set_xticklabels(nints) 140 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 141 | ax.set_yticklabels(nobss) 142 | 143 | ax = axes[1, 0] 144 | cf = ax.pcolormesh(js_int.mean(0), vmin=jsmin, vmax=jsmax) 145 | fig.colorbar(cf, ax=ax) 146 | ax.set_title(f"no obs") 147 | ax.set_ylabel('nobs') 148 | ax.set_xlabel('nints') 149 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 150 | ax.set_xticklabels(nints) 151 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 152 | ax.set_yticklabels(nobss) 153 | 154 | ax = axes[1, 1] 155 | cf = ax.pcolormesh(js_naiv.mean(0), vmin=jsmin, vmax=jsmax) 156 | fig.colorbar(cf, ax=ax) 157 | ax.set_title(f"naive obs+int") 158 | ax.set_ylabel('nobs') 159 | ax.set_xlabel('nints') 160 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 161 | ax.set_xticklabels(nints) 162 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 163 | ax.set_yticklabels(nobss) 164 | 165 | ax = axes[1, 2] 166 | cf = ax.pcolormesh(js_augm.mean(0), vmin=jsmin, vmax=jsmax) 167 | fig.colorbar(cf, ax=ax) 168 | ax.set_title(f"augmented obs+int") 169 | ax.set_ylabel('nobs') 170 | ax.set_xlabel('nints') 171 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 172 | ax.set_xticklabels(nints) 173 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 174 | ax.set_yticklabels(nobss) 175 | 176 | js_gain_int = (js_augm - js_int).mean(0) 177 | js_gain_naiv = (js_augm - js_naiv).mean(0) 178 | js_gain_max = np.max([np.abs(js_gain_int), np.abs(js_gain_naiv)]) 179 | js_gain_min = -js_gain_max 180 | 181 | ax = axes[1, 3] 182 | cf = ax.pcolormesh(js_gain_int, cmap=plt.get_cmap('PiYG'), vmin=js_gain_min, vmax=js_gain_max) 183 | fig.colorbar(cf, ax=ax) 184 | ax.set_title(f"augmented - no obs") 185 | ax.set_ylabel('nobs') 186 | ax.set_xlabel('nints') 187 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 188 | ax.set_xticklabels(nints) 189 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 190 | ax.set_yticklabels(nobss) 191 | 192 | ax = axes[1, 4] 193 | cf = ax.pcolormesh(js_gain_naiv, cmap=plt.get_cmap('PiYG'), vmin=js_gain_min, vmax=js_gain_max) 194 | fig.colorbar(cf, ax=ax) 195 | ax.set_title(f"augmented - naive") 196 | ax.set_ylabel('nobs') 197 | ax.set_xlabel('nints') 198 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 199 | ax.set_xticklabels(nints) 200 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 201 | ax.set_yticklabels(nobss) 202 | 203 | fig.savefig(plotsdir / f"{privileged_policy}_reward_js_grids.pdf", bbox_inches='tight', pad_inches=0) 204 | plt.close(fig) 205 | 206 | def plot_mean_std(ax, x, y, label, color): 207 | ax.plot(x, y.mean(0), label=label, color=color) 208 | ax.fill_between(x, y.mean(0) - y.std(0), y.mean(0) + y.std(0), color=color, alpha=0.2) 209 | 210 | def plot_mean_lowhigh(ax, x, mean, low, high, label, color): 211 | ax.plot(x, mean, label=label, color=color) 212 | ax.fill_between(x, low, high, color=color, alpha=0.2) 213 | 214 | def compute_central_tendency_and_error(id_central, id_error, sample): 215 | if id_central == 'mean': 216 | central = np.nanmean(sample, axis=0) 217 | elif id_central == 'median': 218 | central = np.nanmedian(sample, axis=0) 219 | else: 220 | raise NotImplementedError 221 | 222 | if isinstance(id_error, int): 223 | low = np.nanpercentile(sample, q=int((100 - id_error) / 2), axis=0) 224 | high = np.nanpercentile(sample, q=int(100 - (100 - id_error) / 2), axis=0) 225 | elif id_error == 'std': 226 | low = central - np.nanstd(sample, axis=0) 227 | high = central + np.nanstd(sample, axis=0) 228 | elif id_error == 'sem': 229 | low = central - np.nanstd(sample, axis=0) / np.sqrt(sample.shape[0]) 230 | high = central + np.nanstd(sample, axis=0) / np.sqrt(sample.shape[0]) 231 | else: 232 | raise NotImplementedError 233 | 234 | return central, low, high 235 | 236 | for i, nobs in enumerate(nobss): 237 | 238 | test = 'Wilcoxon' 239 | deviation = 'std' # 'sem' 240 | confidence_level = 0.05 241 | 242 | ### Jensen-Shannon ### 243 | 244 | fig, axes = plt.subplots(1, 1, figsize=(3, 2.25), dpi=300) 245 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 246 | 247 | # statistical tests 248 | test_int_augm = [run_test(test, js_augm[:, i, j], js_int[:, i, j], alpha=confidence_level) for j in range(len(nints))] 249 | test_naiv_augm = [run_test(test, js_augm[:, i, j], js_naiv[:, i, j], alpha=confidence_level) for j in range(len(nints))] 250 | 251 | # mean and standard error 252 | mean0, low0, high0 = compute_central_tendency_and_error('mean', deviation, js_int[:, i]) 253 | mean1, low1, high1 = compute_central_tendency_and_error('mean', deviation, js_naiv[:, i]) 254 | mean2, low2, high2 = compute_central_tendency_and_error('mean', deviation, js_augm[:, i]) 255 | 256 | # plot JS curves 257 | ax = axes 258 | plot_mean_lowhigh(ax, nints, mean0, low0, high0, label="no obs", color="tab:blue") 259 | plot_mean_lowhigh(ax, nints, mean1, low1, high1, label="naive", color="tab:orange") 260 | plot_mean_lowhigh(ax, nints, mean2, low2, high2, label="augmented", color="tab:green") 261 | 262 | ymax = np.nanmax([high0, high1, high2]) 263 | ymin = np.nanmin([low0, low1, low2]) 264 | 265 | # plot significative difference as dots 266 | y = ymax + 0.05 * (ymax-ymin) 267 | x = np.asarray(nints)[np.argwhere(test_int_augm)] 268 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:blue', marker='v') 269 | 270 | y = ymax + 0.10 * (ymax-ymin) 271 | x = np.asarray(nints)[np.argwhere(test_naiv_augm)] 272 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:orange', marker='s') 273 | 274 | ax.set_title(f"JS divergence") 275 | ax.set_xlabel('nints (log scale)') 276 | ax.set_xscale('log', base=2) 277 | ax.set_ylim(bottom=0) 278 | ax.legend() 279 | 280 | fig.savefig(plotsdir / f"{privileged_policy}_js_nobs_{nobs}.pdf", bbox_inches='tight', pad_inches=0) 281 | plt.close(fig) 282 | 283 | 284 | ### Reward ### 285 | 286 | fig, axes = plt.subplots(1, 1, figsize=(3, 2.25), dpi=300) 287 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 288 | 289 | # statistical tests 290 | test_int_augm = [run_test(test, r_int[:, i, j], r_augm[:, i, j], alpha=confidence_level) for j in range(len(nints))] 291 | test_naiv_augm = [run_test(test, r_naiv[:, i, j], r_augm[:, i, j], alpha=confidence_level) for j in range(len(nints))] 292 | 293 | # mean and standard error 294 | mean0, low0, high0 = compute_central_tendency_and_error('mean', deviation, r_int[:, i]) 295 | mean1, low1, high1 = compute_central_tendency_and_error('mean', deviation, r_naiv[:, i]) 296 | mean2, low2, high2 = compute_central_tendency_and_error('mean', deviation, r_augm[:, i]) 297 | 298 | # y-axis scaling 299 | # offset = 100 300 | # scale_forward = lambda a: -1 * np.log((a - offset) * -1) 301 | # scale_inverse = lambda a: (np.exp(a / -1) / -1) + offset 302 | 303 | # plot reward curves 304 | ax = axes 305 | # mean0, low0, high0 = scale_forward(mean0), scale_forward(low0), scale_forward(high0) 306 | # mean1, low1, high1 = scale_forward(mean1), scale_forward(low1), scale_forward(high1) 307 | # mean2, low2, high2 = scale_forward(mean2), scale_forward(low2), scale_forward(high2) 308 | plot_mean_lowhigh(ax, nints, mean0, low0, high0, label="no obs", color="tab:blue") 309 | plot_mean_lowhigh(ax, nints, mean1, low1, high1, label="naive", color="tab:orange") 310 | plot_mean_lowhigh(ax, nints, mean2, low2, high2, label="augmented", color="tab:green") 311 | 312 | ymax = np.nanmax([high0, high1, high2]) 313 | ymin = np.nanmin([low0, low1, low2]) 314 | 315 | # plot significative difference as dots 316 | y = ymax + 0.05 * (ymax - ymin) 317 | x = np.asarray(nints)[np.argwhere(test_int_augm)] 318 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:blue', marker='v') 319 | 320 | y = ymax + 0.10 * (ymax - ymin) 321 | x = np.asarray(nints)[np.argwhere(test_naiv_augm)] 322 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:orange', marker='s') 323 | 324 | ax.set_title(f"reward") 325 | ax.set_xlabel('nints (log scale)') 326 | ax.set_xscale('log', base=2) 327 | # # ax.set_yscale('function', functions=(scale_forward, scale_inverse)) 328 | # yticks = np.array([50, 0, -150, -600]) # scale_inverse(ax.get_yticks()) 329 | # ax.set_yticks(scale_forward(yticks)) # ax.set_yticks(yticks) 330 | # ax.set_yticklabels(yticks) 331 | 332 | # ax.legend() 333 | 334 | fig.savefig(plotsdir / f"{privileged_policy}_reward_nobs_{nobs}.pdf", bbox_inches='tight', pad_inches=0) 335 | plt.close(fig) 336 | -------------------------------------------------------------------------------- /experiments/toy2/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "s_desc": ["tiger left -1", "tiger left -100", "tiger left +10", "tiger right -1", "tiger right -100", "tiger right +10"], 3 | "o_desc": ["roar left", "roar right"], 4 | "a_desc": ["listen", "open left", "open right"], 5 | "r_desc": [-1, -100, 10], 6 | 7 | "episode_length": 50, 8 | 9 | "p_s": [0.5, 0.0, 0.0, 0.5, 0.0, 0.0], 10 | 11 | "p_s_sa": [ 12 | [ 13 | [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], 14 | [0.0, 0.5, 0.0, 0.0, 0.5, 0.0], 15 | [0.0, 0.0, 0.5, 0.0, 0.0, 0.5] 16 | ], [ 17 | [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], 18 | [0.0, 0.5, 0.0, 0.0, 0.5, 0.0], 19 | [0.0, 0.0, 0.5, 0.0, 0.0, 0.5] 20 | ], [ 21 | [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], 22 | [0.0, 0.5, 0.0, 0.0, 0.5, 0.0], 23 | [0.0, 0.0, 0.5, 0.0, 0.0, 0.5] 24 | ], [ 25 | [0.0, 0.0, 0.0, 1.0, 0.0, 0.0], 26 | [0.0, 0.0, 0.5, 0.0, 0.0, 0.5], 27 | [0.0, 0.5, 0.0, 0.0, 0.5, 0.0] 28 | ], [ 29 | [0.0, 0.0, 0.0, 1.0, 0.0, 0.0], 30 | [0.0, 0.0, 0.5, 0.0, 0.0, 0.5], 31 | [0.0, 0.5, 0.0, 0.0, 0.5, 0.0] 32 | ], [ 33 | [0.0, 0.0, 0.0, 1.0, 0.0, 0.0], 34 | [0.0, 0.0, 0.5, 0.0, 0.0, 0.5], 35 | [0.0, 0.5, 0.0, 0.0, 0.5, 0.0] 36 | ] 37 | ], 38 | 39 | "p_r_s": [ 40 | [1.0, 0.0, 0.0], 41 | [0.0, 1.0, 0.0], 42 | [0.0, 0.0, 1.0], 43 | [1.0, 0.0, 0.0], 44 | [0.0, 1.0, 0.0], 45 | [0.0, 0.0, 1.0] 46 | ], 47 | 48 | "p_o_s": [ 49 | [0.85, 0.15], 50 | [0.85, 0.15], 51 | [0.85, 0.15], 52 | [0.15, 0.85], 53 | [0.15, 0.85], 54 | [0.15, 0.85] 55 | ], 56 | 57 | "privileged_policies": { 58 | "noisy_good": [ 59 | [0.05, 0.3, 0.65], 60 | [0.05, 0.3, 0.65], 61 | [0.05, 0.3, 0.65], 62 | [0.05, 0.8, 0.15], 63 | [0.05, 0.8, 0.15], 64 | [0.05, 0.8, 0.15] 65 | ], 66 | "very_good": [ 67 | [0.05, 0.0, 0.95], 68 | [0.05, 0.0, 0.95], 69 | [0.05, 0.0, 0.95], 70 | [0.05, 0.95, 0.0], 71 | [0.05, 0.95, 0.0], 72 | [0.05, 0.95, 0.0] 73 | ], 74 | "very_bad": [ 75 | [0.05, 0.95, 0.0], 76 | [0.05, 0.95, 0.0], 77 | [0.05, 0.95, 0.0], 78 | [0.05, 0.0, 0.95], 79 | [0.05, 0.0, 0.95], 80 | [0.05, 0.0, 0.95] 81 | ], 82 | "random": [ 83 | [0.5, 0.5, 0.5], 84 | [0.5, 0.5, 0.5], 85 | [0.5, 0.5, 0.5], 86 | [0.5, 0.5, 0.5], 87 | [0.5, 0.5, 0.5], 88 | [0.5, 0.5, 0.5] 89 | ], 90 | "strong_good_bias": [ 91 | [0.05, 0.2, 0.75], 92 | [0.05, 0.2, 0.75], 93 | [0.05, 0.2, 0.75], 94 | [0.05, 0.95, 0.0], 95 | [0.05, 0.95, 0.0], 96 | [0.05, 0.95, 0.0] 97 | ], 98 | "strong_bad_bias": [ 99 | [0.05, 0.95, 0.0], 100 | [0.05, 0.95, 0.0], 101 | [0.05, 0.95, 0.0], 102 | [0.05, 0.2, 0.75], 103 | [0.05, 0.2, 0.75], 104 | [0.05, 0.2, 0.75] 105 | ] 106 | }, 107 | 108 | "nsamples_obs": [8192], 109 | "nsamples_int": [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192], 110 | "training_schemes": ["int", "obs+int", "augmented_obs+int"], 111 | "latent_space_size": 32, 112 | "nsamples_eval": 500 113 | } 114 | -------------------------------------------------------------------------------- /experiments/toy3/01_train_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy3/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | '--nobs', 31 | type=int, 32 | help = 'Number of observational samples.', 33 | default=argparse.SUPPRESS, 34 | ) 35 | parser.add_argument( 36 | '--nint', 37 | type=int, 38 | help = 'Number of interventional samples.', 39 | default=argparse.SUPPRESS, 40 | ) 41 | parser.add_argument( 42 | '--scheme', 43 | type=str, 44 | choices=cfg['training_schemes'], 45 | help='Training scheme.', 46 | default=argparse.SUPPRESS, 47 | ) 48 | parser.add_argument( 49 | 'privileged_policy', 50 | type=str, 51 | choices=cfg['privileged_policies'].keys(), 52 | ) 53 | args = parser.parse_args() 54 | 55 | # process command-line arguments 56 | if args.gpu == -1: 57 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 58 | device = "cpu" 59 | else: 60 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 61 | device = f"cuda:{args.gpu}" 62 | 63 | seed = args.seed 64 | privileged_policy = args.privileged_policy 65 | 66 | print(f"device: {device}") 67 | print(f"seed: {seed}") 68 | print(f"privileged_policy : {privileged_policy}") 69 | 70 | 71 | import torch 72 | 73 | # Ugly hack 74 | sys.path.insert(0, os.path.abspath(f".")) 75 | 76 | from environment import PomdpEnv 77 | from policies import UniformPolicy, ExpertPolicy 78 | from models import TabularAugmentedModel 79 | 80 | from utils import construct_dataset 81 | from learning import fit_model 82 | 83 | 84 | ## SET UP THE ENVIRONMENT ## 85 | 86 | p_s = torch.tensor(cfg['p_s']) 87 | p_r_s = torch.tensor(cfg['p_r_s']) 88 | p_o_s = torch.tensor(cfg['p_o_s']) 89 | p_s_sa = torch.tensor(cfg['p_s_sa']) 90 | 91 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 92 | 93 | o_nvals=p_o_s.shape[1] 94 | a_nvals=p_s_sa.shape[1] 95 | r_nvals=p_r_s.shape[1] 96 | s_nvals = cfg["latent_space_size"] 97 | 98 | # POMDP dynamics 99 | env = PomdpEnv(p_s=p_s, 100 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 101 | p_s_sa=p_s_sa, 102 | categorical_obs=True, 103 | max_length=50) 104 | 105 | # Policy in the observational regime (priviledged) 106 | obs_policy = ExpertPolicy(p_a_s) 107 | 108 | # Policy in the interventional regime 109 | int_policy = UniformPolicy(a_nvals) 110 | 111 | 112 | ## SET UP THE SEEDS ## 113 | 114 | rng = np.random.RandomState(seed) 115 | seed_data_obs = rng.randint(0, 2**10) 116 | seed_data_int = rng.randint(0, 2**10) 117 | seed_training = rng.randint(0, 2**10) 118 | 119 | 120 | ## GENERATE THE DATASETS ## 121 | 122 | # from command-line argument if provided, otherwise from config file 123 | nsamples_obs_subsets = [args.nobs] if "nobs" in args else cfg['nsamples_obs'] 124 | nsamples_int_subsets = [args.nint] if "nint" in args else cfg['nsamples_int'] 125 | training_schemes = [args.scheme] if "scheme" in args else cfg["training_schemes"] 126 | 127 | print(f"nsamples_obs_subsets: {nsamples_obs_subsets}") 128 | print(f"nsamples_int_subsets: {nsamples_int_subsets}") 129 | print(f"training_schemes: {training_schemes}") 130 | 131 | # we perform experiments on subsets of the same 132 | # dataset, so that each sequentially growing experiment 133 | # reuses the same samples, complemented with new ones 134 | nsamples_obs_total = np.max(nsamples_obs_subsets) 135 | nsamples_int_total = np.max(nsamples_int_subsets) 136 | 137 | torch.manual_seed(seed_data_obs) 138 | data_obs_all = construct_dataset(env=env, 139 | policy=obs_policy, 140 | n_samples=nsamples_obs_total, 141 | regime=torch.tensor(0)) 142 | 143 | torch.manual_seed(seed_data_int) 144 | data_int_all = construct_dataset(env=env, 145 | policy=int_policy, 146 | n_samples=nsamples_int_total, 147 | regime=torch.tensor(1)) 148 | 149 | 150 | ## LEARN THE TRANSITION MODELS ## 151 | 152 | loss_type = 'nll' 153 | with_done = False 154 | 155 | n_epochs = 500 156 | epoch_size = 50 157 | batch_size = 32 158 | lr = 1e-2 159 | patience = 10 160 | 161 | device = torch.device(device) 162 | 163 | for nsamples_obs in nsamples_obs_subsets: 164 | for nsamples_int in nsamples_int_subsets: 165 | 166 | data_obs = data_obs_all[:nsamples_obs] 167 | data_int = data_int_all[:nsamples_int] 168 | 169 | modeldir = pathlib.Path(f"experiments/toy3/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 170 | modeldir.mkdir(parents=True, exist_ok=True) 171 | 172 | print(f"saving results to: {modeldir}") 173 | 174 | for training_scheme in training_schemes: 175 | 176 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 177 | 178 | logfile = modeldir / f"{training_scheme}_log.txt" 179 | paramsfile = modeldir / f"{training_scheme}.pt" 180 | 181 | if pathlib.Path(paramsfile).exists(): 182 | print(f"Found trained model {paramsfile}, skip training.") 183 | continue 184 | 185 | if training_scheme == 'int': 186 | train_data = data_int 187 | elif training_scheme == 'obs+int': 188 | train_data = [(torch.tensor(1), episode) for (_, episode) in data_obs + data_int] 189 | elif training_scheme == 'augmented_obs+int': 190 | train_data = data_obs + data_int 191 | else: 192 | raise NotImplemented 193 | 194 | torch.manual_seed(seed_training) 195 | 196 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 197 | m = m.to(device) 198 | 199 | fit_model(m, 200 | train_data=train_data, 201 | valid_data=train_data, # we want to overfit 202 | loss_type=loss_type, 203 | with_done=with_done, 204 | n_epochs=n_epochs, 205 | epoch_size=epoch_size, 206 | batch_size=batch_size, 207 | lr=lr, 208 | patience=patience, 209 | log=True, 210 | logfile=logfile) 211 | 212 | torch.save(m.state_dict(), paramsfile) 213 | 214 | -------------------------------------------------------------------------------- /experiments/toy3/02_eval_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy3/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | 'privileged_policy', 31 | type=str, 32 | choices=cfg['privileged_policies'].keys(), 33 | ) 34 | args = parser.parse_args() 35 | 36 | # process command-line arguments 37 | if args.gpu == -1: 38 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 39 | device = "cpu" 40 | else: 41 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 42 | device = f"cuda:{args.gpu}" 43 | 44 | seed = args.seed 45 | privileged_policy = args.privileged_policy 46 | 47 | print(f"device: {device}") 48 | print(f"seed: {seed}") 49 | print(f"privileged_policy : {privileged_policy}") 50 | 51 | 52 | import torch 53 | 54 | # Ugly hack 55 | sys.path.insert(0, os.path.abspath(f".")) 56 | 57 | from environment import PomdpEnv 58 | from policies import UniformPolicy, ExpertPolicy 59 | from models import TabularAugmentedModel 60 | 61 | from utils import construct_dataset 62 | from utils import js_div_empirical, kl_div_empirical, cross_entropy_empirical 63 | 64 | 65 | ## SET UP THE ENVIRONMENT ## 66 | 67 | p_s = torch.tensor(cfg['p_s']) 68 | p_r_s = torch.tensor(cfg['p_r_s']) 69 | p_o_s = torch.tensor(cfg['p_o_s']) 70 | p_s_sa = torch.tensor(cfg['p_s_sa']) 71 | 72 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 73 | 74 | o_nvals=p_o_s.shape[1] 75 | a_nvals=p_s_sa.shape[1] 76 | r_nvals=p_r_s.shape[1] 77 | s_nvals = cfg["latent_space_size"] 78 | 79 | episode_length = cfg["episode_length"] 80 | 81 | # POMDP dynamics 82 | env = PomdpEnv(p_s=p_s, 83 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 84 | p_s_sa=p_s_sa, 85 | categorical_obs=True, 86 | max_length=episode_length) 87 | 88 | # Policy in the observational regime (priviledged) 89 | obs_policy = ExpertPolicy(p_a_s) 90 | 91 | # Policy in the interventional regime 92 | int_policy = UniformPolicy(a_nvals) 93 | 94 | 95 | ## SET UP THE SEEDS ## 96 | 97 | rng = np.random.RandomState(seed) 98 | seed_data_obs = rng.randint(0, 2**10) 99 | seed_data_int = rng.randint(0, 2**10) 100 | seed_training = rng.randint(0, 2**10) 101 | seed_data_eval = rng.randint(0, 2**10) 102 | seed_eval = rng.randint(0, 2**10) 103 | 104 | 105 | ## GENERATE THE EVALUATION DATASET ## 106 | 107 | nsamples_eval = cfg["nsamples_eval"] 108 | 109 | torch.manual_seed(seed_data_eval) 110 | data_eval_p = construct_dataset(env=env, 111 | policy=int_policy, 112 | n_samples=nsamples_eval, 113 | regime=torch.tensor(1)) 114 | 115 | 116 | ## EVALUATE THE TRANSITION MODELS ## 117 | 118 | nsamples_obs_subsets = cfg['nsamples_obs'] 119 | nsamples_int_subsets = cfg['nsamples_int'] 120 | training_schemes = cfg["training_schemes"] 121 | 122 | with_done = False 123 | batch_size = 32 124 | 125 | device = torch.device(device) 126 | 127 | # true model 128 | m_true = TabularAugmentedModel(s_nvals=p_s.shape[0], o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 129 | m_true.set_probs(p_s=p_s, p_o_s=p_o_s, p_r_s=p_r_s, p_s_sa=p_s_sa, p_a_s=p_a_s) 130 | m_true.to(device) 131 | 132 | # learnt model 133 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 134 | m = m.to(device) 135 | 136 | resultsdir = pathlib.Path(f"experiments/toy3/results/{privileged_policy}/seed_{seed}") 137 | resultsdir.mkdir(parents=True, exist_ok=True) 138 | 139 | results = np.full((len(nsamples_obs_subsets), len(nsamples_int_subsets), len(training_schemes), 3), np.nan) 140 | 141 | for i, nsamples_obs in enumerate(nsamples_obs_subsets): 142 | for j, nsamples_int in enumerate(nsamples_int_subsets): 143 | for k, training_scheme in enumerate(training_schemes): 144 | 145 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 146 | 147 | modeldir = pathlib.Path(f"experiments/toy3/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 148 | 149 | print(f"reading results from: {modeldir}") 150 | 151 | paramsfile = modeldir / f"{training_scheme}.pt" 152 | m.load_state_dict(torch.load(paramsfile, map_location=device)) 153 | 154 | # sample from the learnt model 155 | with torch.no_grad(): 156 | q_s = torch.nn.functional.softmax(m.params_s, dim=-1) 157 | q_r_s = torch.nn.functional.softmax(m.params_r_s, dim=-1) 158 | q_o_s = torch.nn.functional.softmax(m.params_o_s, dim=-1) 159 | q_s_sa = torch.nn.functional.softmax(m.params_s_sa, dim=-1) 160 | 161 | # imaginary POMDP dynamics 162 | env_q = PomdpEnv(p_s=q_s, 163 | p_or_s=q_r_s.unsqueeze(-2) * q_o_s.unsqueeze(-1), 164 | p_s_sa=q_s_sa, 165 | categorical_obs=True, 166 | max_length=episode_length) 167 | 168 | # imaginary data 169 | torch.manual_seed(seed_eval) 170 | data_eval_q = construct_dataset(env=env_q, 171 | policy=int_policy, 172 | n_samples=nsamples_eval, 173 | regime=torch.tensor(1)) 174 | 175 | # compute empirical cross-entropy (NLL) 176 | ce = cross_entropy_empirical(model_q=m, data_p=data_eval_p, 177 | batch_size=batch_size, with_done=with_done) 178 | 179 | # compute empirical KL 180 | kld = kl_div_empirical(model_q=m, model_p=m_true, 181 | data_p=data_eval_p, 182 | batch_size=batch_size, with_done=with_done) 183 | 184 | # compute empirical JS 185 | jsd = js_div_empirical(model_q=m, model_p=m_true, 186 | data_q=data_eval_q, data_p=data_eval_p, 187 | batch_size=batch_size, with_done=with_done) 188 | 189 | ce = ce.item() 190 | kld = kld.item() 191 | jsd = jsd.item() 192 | 193 | print(f"ce: {ce}") 194 | print(f"kld: {kld}") 195 | print(f"jsd: {jsd}") 196 | 197 | results[i, j, k] = (kld, jsd, ce) 198 | 199 | with open(resultsdir / "model_results.npy", 'wb') as f: 200 | np.save(f, results) 201 | -------------------------------------------------------------------------------- /experiments/toy3/03_train_agents.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy3/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | '--nobs', 31 | type=int, 32 | help = 'Number of observational samples.', 33 | default=argparse.SUPPRESS, 34 | ) 35 | parser.add_argument( 36 | '--nint', 37 | type=int, 38 | help = 'Number of interventional samples.', 39 | default=argparse.SUPPRESS, 40 | ) 41 | parser.add_argument( 42 | '--scheme', 43 | type=str, 44 | choices=cfg['training_schemes'], 45 | help='Training scheme.', 46 | default=argparse.SUPPRESS, 47 | ) 48 | parser.add_argument( 49 | 'privileged_policy', 50 | type=str, 51 | choices=cfg['privileged_policies'].keys(), 52 | ) 53 | args = parser.parse_args() 54 | 55 | # process command-line arguments 56 | if args.gpu == -1: 57 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 58 | device = "cpu" 59 | else: 60 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 61 | device = f"cuda:{args.gpu}" 62 | 63 | seed = args.seed 64 | privileged_policy = args.privileged_policy 65 | 66 | print(f"device: {device}") 67 | print(f"seed: {seed}") 68 | print(f"privileged_policy : {privileged_policy}") 69 | 70 | 71 | import torch 72 | 73 | # Ugly hack 74 | sys.path.insert(0, os.path.abspath(f".")) 75 | 76 | from models import TabularAugmentedModel 77 | 78 | # ENVIRONMENT 79 | from environment import PomdpEnv 80 | from environment.env_wrappers import BeliefStateRepresentation, RewardWrapper, SqueezeEnv 81 | 82 | from rl_agents.ac import ActorCritic, run_actorcritic 83 | # from rl_agents.reinforce import Actor, run_reinforce 84 | 85 | ## SET UP THE ENVIRONMENT ## 86 | 87 | p_s = torch.tensor(cfg['p_s']) 88 | p_r_s = torch.tensor(cfg['p_r_s']) 89 | p_o_s = torch.tensor(cfg['p_o_s']) 90 | p_s_sa = torch.tensor(cfg['p_s_sa']) 91 | 92 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 93 | 94 | o_nvals=p_o_s.shape[1] 95 | a_nvals=p_s_sa.shape[1] 96 | r_nvals=p_r_s.shape[1] 97 | s_nvals = cfg["latent_space_size"] 98 | 99 | episode_length = cfg["episode_length"] 100 | 101 | reward_map = cfg["r_desc"] 102 | 103 | 104 | ## SET UP THE SEEDS ## 105 | 106 | rng = np.random.RandomState(seed) 107 | seed_data_obs = rng.randint(0, 2**10) 108 | seed_data_int = rng.randint(0, 2**10) 109 | seed_model_training = rng.randint(0, 2**10) 110 | seed_data_eval = rng.randint(0, 2**10) 111 | seed_eval = rng.randint(0, 2**10) 112 | seed_agent_training = rng.randint(0, 2**10) 113 | 114 | ## EVALUATE THE TRANSITION MODELS ## 115 | 116 | # from command-line argument if provided, otherwise from config file 117 | nsamples_obs_subsets = [args.nobs] if "nobs" in args else cfg['nsamples_obs'] 118 | nsamples_int_subsets = [args.nint] if "nint" in args else cfg['nsamples_int'] 119 | training_schemes = [args.scheme] if "scheme" in args else cfg["training_schemes"] 120 | 121 | print(f"nsamples_obs_subsets: {nsamples_obs_subsets}") 122 | print(f"nsamples_int_subsets: {nsamples_int_subsets}") 123 | print(f"training_schemes: {training_schemes}") 124 | 125 | ## EVALUATE THE TRANSITION MODELS ## 126 | 127 | with_done = False 128 | lr = 1e-2 129 | gamma = 1 130 | n_epochs = 1000 131 | log_every = 10 132 | batch_size = 32 133 | 134 | device = torch.device(device) 135 | 136 | # learnt model 137 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 138 | m = m.to(device) 139 | 140 | for nsamples_obs in nsamples_obs_subsets: 141 | for nsamples_int in nsamples_int_subsets: 142 | for training_scheme in training_schemes: 143 | 144 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 145 | 146 | model_dir = pathlib.Path(f"experiments/toy3/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 147 | agent_dir = pathlib.Path(f"experiments/toy3/trained_agents/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 148 | 149 | agent_dir.mkdir(parents=True, exist_ok=True) 150 | 151 | model_paramsfile = model_dir / f"{training_scheme}.pt" 152 | agent_paramsfile = agent_dir / f"{training_scheme}.pt" 153 | logfile = agent_dir / f"{training_scheme}_log.txt" 154 | 155 | if agent_paramsfile.exists(): 156 | print(f"Found trained agent {agent_paramsfile}, skip training.") 157 | continue 158 | 159 | print(f"reading model from: {model_paramsfile}") 160 | 161 | m.load_state_dict(torch.load(model_paramsfile, map_location=device)) 162 | 163 | # recover learned POMDP dynamics 164 | with torch.no_grad(): 165 | q_s = torch.nn.functional.softmax(m.params_s, dim=-1) 166 | q_r_s = torch.nn.functional.softmax(m.params_r_s, dim=-1) 167 | q_o_s = torch.nn.functional.softmax(m.params_o_s, dim=-1) 168 | q_s_sa = torch.nn.functional.softmax(m.params_s_sa, dim=-1) 169 | 170 | # learned POMDP 171 | env_q = PomdpEnv(p_s=q_s, 172 | p_or_s=q_r_s.unsqueeze(-2) * q_o_s.unsqueeze(-1), 173 | p_s_sa=q_s_sa, 174 | categorical_obs=True, 175 | max_length=episode_length) 176 | 177 | # POMDP -> MDP (using the model's belief state) 178 | env_q = BeliefStateRepresentation(SqueezeEnv(env_q), m) 179 | 180 | # map categorical reward to numerical values 181 | env_q = RewardWrapper(env_q, reward_dic=reward_map) 182 | 183 | # agent training (dream) 184 | torch.manual_seed(seed_agent_training) 185 | 186 | # agent = Actor(s_nvals=s_nvals, a_nvals=a_nvals) 187 | # run_reinforce(env=env_q, agent=agent, 188 | # lr=lr, gamma=gamma, 189 | # batch_size=batch_size, 190 | # n_epochs=n_epochs, 191 | # log_every=log_every, 192 | # logfile=logfile) 193 | 194 | agent = ActorCritic(s_nvals=s_nvals, a_nvals=a_nvals) 195 | run_actorcritic(env_q, agent, 196 | lr=lr, gamma=gamma, 197 | batch_size=batch_size, 198 | n_epochs=n_epochs, 199 | log_every=log_every, 200 | logfile=logfile) 201 | 202 | torch.save(agent.state_dict(), agent_paramsfile) 203 | print(f"saving agent to: {agent_paramsfile}") 204 | -------------------------------------------------------------------------------- /experiments/toy3/04_eval_agents.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | # read experiment config 12 | with open("experiments/toy3/config.json", "r") as json_data_file: 13 | cfg = json.load(json_data_file) 14 | 15 | # read command-line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '-s', '--seed', 19 | type=int, 20 | help = 'Random generator seed.', 21 | default=0, 22 | ) 23 | parser.add_argument( 24 | '-g', '--gpu', 25 | type=int, 26 | help='CUDA GPU id (-1 for CPU).', 27 | default=-1, 28 | ) 29 | parser.add_argument( 30 | 'privileged_policy', 31 | type=str, 32 | choices=cfg['privileged_policies'].keys(), 33 | ) 34 | args = parser.parse_args() 35 | 36 | # process command-line arguments 37 | if args.gpu == -1: 38 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 39 | device = "cpu" 40 | else: 41 | os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}" 42 | device = f"cuda:{args.gpu}" 43 | 44 | seed = args.seed 45 | privileged_policy = args.privileged_policy 46 | 47 | print(f"device: {device}") 48 | print(f"seed: {seed}") 49 | print(f"privileged_policy : {privileged_policy}") 50 | 51 | 52 | import torch 53 | 54 | # Ugly hack 55 | sys.path.insert(0, os.path.abspath(f".")) 56 | 57 | from models import TabularAugmentedModel 58 | 59 | # ENVIRONMENT 60 | from environment import PomdpEnv 61 | from environment.env_wrappers import BeliefStateRepresentation, RewardWrapper, SqueezeEnv 62 | 63 | from rl_agents.ac import ActorCritic, evaluate_agent 64 | 65 | ## SET UP THE ENVIRONMENT ## 66 | 67 | p_s = torch.tensor(cfg['p_s']) 68 | p_r_s = torch.tensor(cfg['p_r_s']) 69 | p_o_s = torch.tensor(cfg['p_o_s']) 70 | p_s_sa = torch.tensor(cfg['p_s_sa']) 71 | 72 | p_a_s = torch.tensor(cfg['privileged_policies'][privileged_policy]) 73 | 74 | o_nvals=p_o_s.shape[1] 75 | a_nvals=p_s_sa.shape[1] 76 | r_nvals=p_r_s.shape[1] 77 | s_nvals = cfg["latent_space_size"] 78 | 79 | episode_length = cfg["episode_length"] 80 | 81 | reward_map = cfg["r_desc"] 82 | 83 | 84 | ## SET UP THE SEEDS ## 85 | 86 | rng = np.random.RandomState(seed) 87 | seed_data_obs = rng.randint(0, 2**10) 88 | seed_data_int = rng.randint(0, 2**10) 89 | seed_model_training = rng.randint(0, 2**10) 90 | seed_data_eval = rng.randint(0, 2**10) 91 | seed_eval = rng.randint(0, 2**10) 92 | seed_agent_training = rng.randint(0, 2**10) 93 | seed_agent_eval = rng.randint(0, 2**10) 94 | 95 | ## EVALUATE THE TRANSITION MODELS ## 96 | 97 | nsamples_obs_subsets = cfg['nsamples_obs'] 98 | nsamples_int_subsets = cfg['nsamples_int'] 99 | training_schemes = cfg["training_schemes"] 100 | 101 | ## EVALUATE THE TRANSITION MODELS ## 102 | 103 | n_episodes = 100 104 | 105 | device = torch.device(device) 106 | 107 | # learnt model 108 | m = TabularAugmentedModel(s_nvals=s_nvals, o_nvals=o_nvals, a_nvals=a_nvals, r_nvals=r_nvals) 109 | m = m.to(device) 110 | 111 | # learnt agent 112 | agent = ActorCritic(s_nvals=s_nvals, a_nvals=a_nvals) 113 | agent.to(device) 114 | 115 | # true POMDP 116 | env_p = PomdpEnv(p_s=p_s, 117 | p_or_s=p_r_s.unsqueeze(-2) * p_o_s.unsqueeze(-1), 118 | p_s_sa=p_s_sa, 119 | categorical_obs=True, 120 | max_length=episode_length) 121 | 122 | resultsdir = pathlib.Path(f"experiments/toy3/results/{privileged_policy}/seed_{seed}") 123 | resultsdir.mkdir(parents=True, exist_ok=True) 124 | 125 | results = np.full((len(nsamples_obs_subsets), len(nsamples_int_subsets), len(training_schemes), 1), np.nan) 126 | 127 | for i, nsamples_obs in enumerate(nsamples_obs_subsets): 128 | for j, nsamples_int in enumerate(nsamples_int_subsets): 129 | for k, training_scheme in enumerate(training_schemes): 130 | 131 | print(f"nsamples_obs: {nsamples_obs} nsamples_int: {nsamples_int} training_scheme: {training_scheme}") 132 | 133 | model_dir = pathlib.Path(f"experiments/toy3/trained_models/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 134 | model_paramsfile = model_dir / f"{training_scheme}.pt" 135 | 136 | agent_dir = pathlib.Path(f"experiments/toy3/trained_agents/{privileged_policy}/seed_{seed}/nobs_{nsamples_obs}/nint_{nsamples_int}") 137 | agent_paramsfile = agent_dir / f"{training_scheme}.pt" 138 | 139 | print(f"reading model from: {model_paramsfile}") 140 | m.load_state_dict(torch.load(model_paramsfile, map_location=device)) 141 | 142 | print(f"reading agent from: {agent_paramsfile}") 143 | agent.load_state_dict(torch.load(agent_paramsfile, map_location=device)) 144 | 145 | # POMDP -> MDP (using the model's belief state) 146 | env = BeliefStateRepresentation(SqueezeEnv(env_p), m) 147 | 148 | # map categorical reward to its numerical values 149 | env = RewardWrapper(env, reward_dic=reward_map) 150 | 151 | # agent evaluation (true environment) 152 | torch.manual_seed(seed_agent_eval) 153 | 154 | reward = evaluate_agent(env, agent, n_episodes) 155 | 156 | print(f"reward: {reward}") 157 | 158 | results[i, j, k] = reward 159 | 160 | with open(resultsdir / "agent_results.npy", 'wb') as f: 161 | np.save(f, results) 162 | -------------------------------------------------------------------------------- /experiments/toy3/05_plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import json 5 | import argparse 6 | import numpy as np 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | # Ugly hack 11 | sys.path.insert(0, os.path.abspath(f".")) 12 | 13 | from stat_tests import run_test 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | # read experiment config 19 | with open("experiments/toy3/config.json", "r") as json_data_file: 20 | cfg = json.load(json_data_file) 21 | 22 | # read command-line arguments 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | 'privileged_policy', 26 | type=str, 27 | choices=cfg['privileged_policies'].keys(), 28 | ) 29 | args = parser.parse_args() 30 | 31 | privileged_policy = args.privileged_policy 32 | 33 | print(f"privileged_policy : {privileged_policy }") 34 | 35 | 36 | ## COLLECT THE RESULTS ## 37 | 38 | nobss = cfg['nsamples_obs'] 39 | nints = cfg['nsamples_int'] 40 | training_schemes = cfg["training_schemes"] 41 | 42 | nseeds = 20 43 | 44 | model_results = [] 45 | agent_results = [] 46 | for seed in range(nseeds): 47 | with open(f"experiments/toy3/results/{privileged_policy}/seed_{seed}/model_results.npy", 'rb') as f: 48 | model_results.append(np.load(f)) 49 | with open(f"experiments/toy3/results/{privileged_policy}/seed_{seed}/agent_results.npy", 'rb') as f: 50 | agent_results.append(np.load(f)) 51 | 52 | model_results = np.asarray(model_results) 53 | agent_results = np.asarray(agent_results) 54 | 55 | # kls = model_results[..., 0] 56 | jss = model_results[..., 1] 57 | # ces = model_results[..., 2] 58 | rewards = agent_results[..., 0] 59 | 60 | 61 | ## CREATE AND SAVE THE PLOTS ## 62 | 63 | plotsdir = pathlib.Path(f"experiments/toy3/plots") 64 | plotsdir.mkdir(parents=True, exist_ok=True) 65 | 66 | rmin = np.min(rewards) 67 | rmax = np.max(rewards) 68 | 69 | jsmin = np.min(jss) 70 | jsmax = np.max(jss) 71 | 72 | r_int = rewards[..., 0] 73 | r_naiv = rewards[..., 1] 74 | r_augm = rewards[..., 2] 75 | 76 | js_int = jss[..., 0] 77 | js_naiv = jss[..., 1] 78 | js_augm = jss[..., 2] 79 | 80 | fig, axes = plt.subplots(2, 5, figsize=(20, 6), dpi=300) 81 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 82 | 83 | ax = axes[0, 0] 84 | cf = ax.pcolormesh(r_int.mean(0), vmin=rmin, vmax=rmax) 85 | fig.colorbar(cf, ax=ax) 86 | ax.set_title(f"no obs") 87 | ax.set_ylabel('nobs') 88 | ax.set_xlabel('nints') 89 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 90 | ax.set_xticklabels(nints) 91 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 92 | ax.set_yticklabels(nobss) 93 | 94 | ax = axes[0, 1] 95 | cf = ax.pcolormesh(r_naiv.mean(0), vmin=rmin, vmax=rmax) 96 | fig.colorbar(cf, ax=ax) 97 | ax.set_title(f"naive obs+int") 98 | ax.set_ylabel('nobs') 99 | ax.set_xlabel('nints') 100 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 101 | ax.set_xticklabels(nints) 102 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 103 | ax.set_yticklabels(nobss) 104 | 105 | ax = axes[0, 2] 106 | cf = ax.pcolormesh(r_augm.mean(0), vmin=rmin, vmax=rmax) 107 | fig.colorbar(cf, ax=ax) 108 | ax.set_title(f"augmented obs+int") 109 | ax.set_ylabel('nobs') 110 | ax.set_xlabel('nints') 111 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 112 | ax.set_xticklabels(nints) 113 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 114 | ax.set_yticklabels(nobss) 115 | 116 | r_gain_int = (r_augm - r_int).mean(0) 117 | r_gain_naiv = (r_augm - r_naiv).mean(0) 118 | r_gain_max = np.max([np.abs(r_gain_int), np.abs(r_gain_naiv)]) 119 | r_gain_min = -r_gain_max 120 | 121 | ax = axes[0, 3] 122 | cf = ax.pcolormesh(r_gain_int, cmap=plt.get_cmap('PiYG'), vmin=r_gain_min, vmax=r_gain_max) 123 | fig.colorbar(cf, ax=ax) 124 | ax.set_title(f"augmented - no obs") 125 | ax.set_ylabel('nobs') 126 | ax.set_xlabel('nints') 127 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 128 | ax.set_xticklabels(nints) 129 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 130 | ax.set_yticklabels(nobss) 131 | 132 | ax = axes[0, 4] 133 | cf = ax.pcolormesh(r_gain_naiv, cmap=plt.get_cmap('PiYG'), vmin=r_gain_min, vmax=r_gain_max) 134 | fig.colorbar(cf, ax=ax) 135 | ax.set_title(f"augmented - naive") 136 | ax.set_ylabel('nobs') 137 | ax.set_xlabel('nints') 138 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 139 | ax.set_xticklabels(nints) 140 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 141 | ax.set_yticklabels(nobss) 142 | 143 | ax = axes[1, 0] 144 | cf = ax.pcolormesh(js_int.mean(0), vmin=jsmin, vmax=jsmax) 145 | fig.colorbar(cf, ax=ax) 146 | ax.set_title(f"no obs") 147 | ax.set_ylabel('nobs') 148 | ax.set_xlabel('nints') 149 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 150 | ax.set_xticklabels(nints) 151 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 152 | ax.set_yticklabels(nobss) 153 | 154 | ax = axes[1, 1] 155 | cf = ax.pcolormesh(js_naiv.mean(0), vmin=jsmin, vmax=jsmax) 156 | fig.colorbar(cf, ax=ax) 157 | ax.set_title(f"naive obs+int") 158 | ax.set_ylabel('nobs') 159 | ax.set_xlabel('nints') 160 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 161 | ax.set_xticklabels(nints) 162 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 163 | ax.set_yticklabels(nobss) 164 | 165 | ax = axes[1, 2] 166 | cf = ax.pcolormesh(js_augm.mean(0), vmin=jsmin, vmax=jsmax) 167 | fig.colorbar(cf, ax=ax) 168 | ax.set_title(f"augmented obs+int") 169 | ax.set_ylabel('nobs') 170 | ax.set_xlabel('nints') 171 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 172 | ax.set_xticklabels(nints) 173 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 174 | ax.set_yticklabels(nobss) 175 | 176 | js_gain_int = (js_augm - js_int).mean(0) 177 | js_gain_naiv = (js_augm - js_naiv).mean(0) 178 | js_gain_max = np.max([np.abs(js_gain_int), np.abs(js_gain_naiv)]) 179 | js_gain_min = -js_gain_max 180 | 181 | ax = axes[1, 3] 182 | cf = ax.pcolormesh(js_gain_int, cmap=plt.get_cmap('PiYG'), vmin=js_gain_min, vmax=js_gain_max) 183 | fig.colorbar(cf, ax=ax) 184 | ax.set_title(f"augmented - no obs") 185 | ax.set_ylabel('nobs') 186 | ax.set_xlabel('nints') 187 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 188 | ax.set_xticklabels(nints) 189 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 190 | ax.set_yticklabels(nobss) 191 | 192 | ax = axes[1, 4] 193 | cf = ax.pcolormesh(js_gain_naiv, cmap=plt.get_cmap('PiYG'), vmin=js_gain_min, vmax=js_gain_max) 194 | fig.colorbar(cf, ax=ax) 195 | ax.set_title(f"augmented - naive") 196 | ax.set_ylabel('nobs') 197 | ax.set_xlabel('nints') 198 | ax.xaxis.set_ticks([i+0.5 for i in range(len(nints))]) 199 | ax.set_xticklabels(nints) 200 | ax.yaxis.set_ticks([i+0.5 for i in range(len(nobss))]) 201 | ax.set_yticklabels(nobss) 202 | 203 | fig.savefig(plotsdir / f"{privileged_policy}_reward_js_grids.pdf", bbox_inches='tight', pad_inches=0) 204 | plt.close(fig) 205 | 206 | def plot_mean_std(ax, x, y, label, color): 207 | ax.plot(x, y.mean(0), label=label, color=color) 208 | ax.fill_between(x, y.mean(0) - y.std(0), y.mean(0) + y.std(0), color=color, alpha=0.2) 209 | 210 | def plot_mean_lowhigh(ax, x, mean, low, high, label, color): 211 | ax.plot(x, mean, label=label, color=color) 212 | ax.fill_between(x, low, high, color=color, alpha=0.2) 213 | 214 | def compute_central_tendency_and_error(id_central, id_error, sample): 215 | if id_central == 'mean': 216 | central = np.nanmean(sample, axis=0) 217 | elif id_central == 'median': 218 | central = np.nanmedian(sample, axis=0) 219 | else: 220 | raise NotImplementedError 221 | 222 | if isinstance(id_error, int): 223 | low = np.nanpercentile(sample, q=int((100 - id_error) / 2), axis=0) 224 | high = np.nanpercentile(sample, q=int(100 - (100 - id_error) / 2), axis=0) 225 | elif id_error == 'std': 226 | low = central - np.nanstd(sample, axis=0) 227 | high = central + np.nanstd(sample, axis=0) 228 | elif id_error == 'sem': 229 | low = central - np.nanstd(sample, axis=0) / np.sqrt(sample.shape[0]) 230 | high = central + np.nanstd(sample, axis=0) / np.sqrt(sample.shape[0]) 231 | else: 232 | raise NotImplementedError 233 | 234 | return central, low, high 235 | 236 | for i, nobs in enumerate(nobss): 237 | 238 | test = 'Wilcoxon' 239 | deviation = 'std' # 'sem' 240 | confidence_level = 0.05 241 | 242 | ### Jensen-Shannon ### 243 | 244 | fig, axes = plt.subplots(1, 1, figsize=(3, 2.25), dpi=300) 245 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 246 | 247 | # statistical tests 248 | test_int_augm = [run_test(test, js_augm[:, i, j], js_int[:, i, j], alpha=confidence_level) for j in range(len(nints))] 249 | test_naiv_augm = [run_test(test, js_augm[:, i, j], js_naiv[:, i, j], alpha=confidence_level) for j in range(len(nints))] 250 | 251 | # mean and standard error 252 | mean0, low0, high0 = compute_central_tendency_and_error('mean', deviation, js_int[:, i]) 253 | mean1, low1, high1 = compute_central_tendency_and_error('mean', deviation, js_naiv[:, i]) 254 | mean2, low2, high2 = compute_central_tendency_and_error('mean', deviation, js_augm[:, i]) 255 | 256 | # plot JS curves 257 | ax = axes 258 | plot_mean_lowhigh(ax, nints, mean0, low0, high0, label="no obs", color="tab:blue") 259 | plot_mean_lowhigh(ax, nints, mean1, low1, high1, label="naive", color="tab:orange") 260 | plot_mean_lowhigh(ax, nints, mean2, low2, high2, label="augmented", color="tab:green") 261 | 262 | ymax = np.nanmax([high0, high1, high2]) 263 | ymin = np.nanmin([low0, low1, low2]) 264 | 265 | # plot significative difference as dots 266 | y = ymax + 0.05 * (ymax-ymin) 267 | x = np.asarray(nints)[np.argwhere(test_int_augm)] 268 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:blue', marker='v') 269 | 270 | y = ymax + 0.10 * (ymax-ymin) 271 | x = np.asarray(nints)[np.argwhere(test_naiv_augm)] 272 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:orange', marker='s') 273 | 274 | ax.set_title(f"JS divergence") 275 | ax.set_xlabel('nints (log scale)') 276 | ax.set_xscale('log', base=2) 277 | ax.set_ylim(bottom=0) 278 | ax.legend() 279 | 280 | fig.savefig(plotsdir / f"{privileged_policy}_js_nobs_{nobs}.pdf", bbox_inches='tight', pad_inches=0) 281 | plt.close(fig) 282 | 283 | 284 | ### Reward ### 285 | 286 | fig, axes = plt.subplots(1, 1, figsize=(3, 2.25), dpi=300) 287 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 288 | 289 | # statistical tests 290 | test_int_augm = [run_test(test, r_int[:, i, j], r_augm[:, i, j], alpha=confidence_level) for j in range(len(nints))] 291 | test_naiv_augm = [run_test(test, r_naiv[:, i, j], r_augm[:, i, j], alpha=confidence_level) for j in range(len(nints))] 292 | 293 | # mean and standard error 294 | mean0, low0, high0 = compute_central_tendency_and_error('mean', deviation, r_int[:, i]) 295 | mean1, low1, high1 = compute_central_tendency_and_error('mean', deviation, r_naiv[:, i]) 296 | mean2, low2, high2 = compute_central_tendency_and_error('mean', deviation, r_augm[:, i]) 297 | 298 | # plot reward curves 299 | ax = axes 300 | plot_mean_lowhigh(ax, nints, mean0, low0, high0, label="no obs", color="tab:blue") 301 | plot_mean_lowhigh(ax, nints, mean1, low1, high1, label="naive", color="tab:orange") 302 | plot_mean_lowhigh(ax, nints, mean2, low2, high2, label="augmented", color="tab:green") 303 | 304 | ymax = np.nanmax([high0, high1, high2]) 305 | ymin = np.nanmin([low0, low1, low2]) 306 | 307 | # plot significative difference as dots 308 | y = ymax + 0.05 * (ymax - ymin) 309 | x = np.asarray(nints)[np.argwhere(test_int_augm)] 310 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:blue', marker='v') 311 | 312 | y = ymax + 0.10 * (ymax - ymin) 313 | x = np.asarray(nints)[np.argwhere(test_naiv_augm)] 314 | ax.scatter(x, y * np.ones_like(x), s=10, c='tab:orange', marker='s') 315 | 316 | ax.set_title(f"reward") 317 | ax.set_xlabel('nints (log scale)') 318 | ax.set_xscale('log', base=2) 319 | # ax.legend() 320 | 321 | fig.savefig(plotsdir / f"{privileged_policy}_reward_nobs_{nobs}.pdf", bbox_inches='tight', pad_inches=0) 322 | plt.close(fig) 323 | -------------------------------------------------------------------------------- /learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils import Dataset, print_log 4 | 5 | 6 | def fit_model(m, train_data, valid_data, loss_type='nll', with_done=False, 7 | n_epochs=200, epoch_size=100, batch_size=16, 8 | lr=1e-2, patience=10, log=False, logfile=None, min_int_ratio=0.0, threshold=1e-4): 9 | 10 | # infer the device from the model 11 | device = next(m.parameters()).device 12 | 13 | if log: 14 | print_log(f"loss_type: {loss_type}", logfile) 15 | print_log(f"with_done: {with_done}", logfile) 16 | print_log(f"n_epochs: {n_epochs}", logfile) 17 | print_log(f"epoch_size: {epoch_size}", logfile) 18 | print_log(f"batch_size: {batch_size}", logfile) 19 | print_log(f"lr: {lr}", logfile) 20 | print_log(f"patience: {patience}", logfile) 21 | print_log(f"device: {device}", logfile) 22 | print_log(f"min_int_ratio: {min_int_ratio}", logfile) 23 | 24 | def compute_weights(data): 25 | nint = np.sum([regime == 1 for regime, _ in data]) 26 | nobs = len(data) - nint 27 | int_ratio = nint / (nint + nobs) 28 | 29 | if int_ratio >= min_int_ratio: 30 | weights = [1] * len(data) 31 | else: 32 | weights = [(1 - min_int_ratio) / nobs, min_int_ratio / nint] # obs, int 33 | weights = [weights[int(regime)] for regime, _ in data] 34 | 35 | return weights 36 | 37 | train_weights = compute_weights(train_data) 38 | valid_weights = compute_weights(valid_data) 39 | 40 | # Build training and validation data 41 | train_dataset = Dataset(train_data) 42 | valid_dataset = Dataset(list(zip(valid_data, valid_weights))) # to reweight the loss 43 | 44 | sampler = torch.utils.data.WeightedRandomSampler(train_weights, replacement=True, num_samples=epoch_size*batch_size) 45 | 46 | # Initiate DataLoader for training and validation 47 | train_loader = torch.utils.data.DataLoader(train_dataset, sampler=sampler, batch_size=batch_size) 48 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size) 49 | 50 | # Adam Optimizer with learning rate lr 51 | optimizer = torch.optim.Adam(m.parameters(), lr=lr) 52 | 53 | # Scheduler. Reduce learning rate on plateau. 54 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, verbose=log, threshold=threshold) 55 | 56 | # Early stopping 57 | best_valid_loss = float("Inf") 58 | best_parameters = m.state_dict().copy() 59 | best_epoch = -1 60 | 61 | # Start training loop 62 | for epoch in range(n_epochs + 1): 63 | 64 | # Set initial training loss as +inf 65 | if epoch == 0: 66 | train_loss = float("Inf") 67 | 68 | else: 69 | train_loss = 0 70 | train_nsamples = 0 71 | 72 | for batch in train_loader: 73 | regime, episode = batch 74 | regime = regime.to(device) 75 | episode = [tensor.to(device) for tensor in episode] 76 | 77 | batch_size = regime.shape[0] 78 | 79 | if loss_type == 'em': 80 | loss = m.loss_em(regime, episode, with_done=with_done).mean() 81 | elif loss_type == 'nll': 82 | loss = m.loss_nll(regime, episode, with_done=with_done).mean() 83 | elif loss_type == 'elbo': 84 | raise NotImplementedError() 85 | else: 86 | raise NotImplementedError() 87 | 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | train_loss += loss.item() * batch_size 93 | train_nsamples += batch_size 94 | 95 | train_loss /= train_nsamples 96 | 97 | # validation 98 | valid_loss = 0 99 | valid_nsamples = 0 100 | 101 | for batch in valid_loader: 102 | (regime, episode), weight = batch 103 | regime = regime.to(device) 104 | episode = [tensor.to(device) for tensor in episode] 105 | weight = weight.to(device) 106 | 107 | batch_size = regime.shape[0] 108 | 109 | with torch.no_grad(): 110 | 111 | loss = m.loss_nll(regime, episode, with_done=with_done) 112 | loss = (loss * weight).sum() # re-weighting the loss here 113 | 114 | valid_loss += loss.item() 115 | valid_nsamples += weight.sum().item() 116 | 117 | valid_loss /= valid_nsamples 118 | 119 | if log: 120 | print_log(f"epoch {epoch:04d} / {n_epochs:04d} train loss={train_loss:0.3f} valid loss={valid_loss:0.3f}", logfile) 121 | # q_s = torch.nn.functional.softmax(m.params_s.detach(), dim=-1) 122 | # print_log(f" q_s: {((q_s.cpu().numpy() * 100) // 1) / 100}", logfile) 123 | 124 | # check for best model 125 | if valid_loss < (best_valid_loss * (1 - threshold)): 126 | best_valid_loss = valid_loss 127 | best_parameters = m.state_dict().copy() 128 | best_epoch = epoch 129 | 130 | # check for early stopping 131 | if epoch > best_epoch + 2*patience: 132 | if log: 133 | print_log(f"{epoch-best_epoch} epochs without improvement, stopping.", logfile) 134 | break 135 | 136 | scheduler.step(valid_loss) 137 | 138 | # restore best model 139 | m.load_state_dict(best_parameters) 140 | 141 | 142 | 143 | def eval_model(m, data, batch_size=32, with_done=False): 144 | 145 | """ Evaluate model m on data using Negative Log-Likehood on episodes """ 146 | 147 | # infer the device from the model 148 | device = next(m.parameters()).device 149 | 150 | # Initialize NLL 151 | nll = 0 152 | 153 | # Load data as Dataset and create torch Dataloader 154 | dataset = Dataset(data) 155 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) 156 | 157 | # Iterate on dataloader 158 | for batch in dataloader: 159 | regime, episode = batch 160 | # Switch device if needed 161 | regime = regime.to(device) 162 | episode = [tensor.to(device) for tensor in episode] 163 | 164 | # Get no grad NLL on batch 165 | with torch.no_grad(): 166 | nll += m.loss_nll(regime, episode, with_done=with_done).sum().item() 167 | 168 | # Get mean NLL on data 169 | nll /= len(data) 170 | 171 | return nll 172 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class AugmentedModel(torch.nn.Module): 7 | 8 | """ Augmented Model Base Class """ 9 | 10 | def log_q_s(self, s=None): 11 | raise NotImplementedError 12 | 13 | def log_q_snext_sa(self, a, s=None, snext=None): 14 | raise NotImplementedError 15 | 16 | def log_q_o_s(self, o, s=None): 17 | raise NotImplementedError 18 | 19 | def log_q_r_s(self, r, s=None): 20 | raise NotImplementedError 21 | 22 | def log_q_d_s(self, d, s=None): 23 | raise NotImplementedError 24 | 25 | def log_q_a_s(self, a, s=None): 26 | raise NotImplementedError 27 | 28 | def log_q_s_h(self, regime, loq_q_sprev_hprev, a, o, r, d, with_done=False): 29 | 30 | assert (loq_q_sprev_hprev is None) == (a is None) 31 | 32 | # hprev = (o_0, r_0, d_0, a_0, ..., o_t-1, r_t-1, d_t-1) 33 | # sprev = s_t-1 34 | # a = a_t-1 35 | # o = o_t 36 | # r = r_t 37 | # d = d_t 38 | # h = (o_0, r_0, d_0, a_0, ..., o_t-1, r_t-1, d_t-1, a_t-1, o_t, r_t, d_t) 39 | # s = s_t 40 | 41 | no_hprev = (a is None) 42 | 43 | if no_hprev: 44 | # (batch_size, s_nvals) 45 | log_q_s_hpreva = self.log_q_s() 46 | 47 | else: 48 | 49 | # (batch_size, s_nvals) - (batch, sprev) 50 | log_q_a_sprev = self.log_q_a_s(a=a) 51 | log_q_a_sprev = log_q_a_sprev * (1 - d).unsqueeze(-1) # discard actions if done=True 52 | log_q_a_sprev = log_q_a_sprev * (1 - regime).unsqueeze(-1) # discard actions in interventional regime 53 | 54 | # (batch_size, s_nvals) - (batch, sprev) 55 | log_q_spreva_hprev = loq_q_sprev_hprev + log_q_a_sprev 56 | 57 | # (batch_size,) - (batch,) 58 | log_q_a_hprev = torch.logsumexp(log_q_spreva_hprev, dim=-1) 59 | 60 | # (batch_size, s_nvals) - (batch, sprev) 61 | log_q_sprev_hpreva = log_q_spreva_hprev - log_q_a_hprev.unsqueeze(-1) 62 | 63 | # (batch_size, s_nvals, s_nvals) - (batch, sprev, s) 64 | log_q_s_spreva = self.log_q_snext_sa(a=a) 65 | 66 | # (batch_size, s_nvals, s_nvals) - (batch, sprev, s) 67 | loq_q_sprevs_hpreva = log_q_sprev_hpreva.unsqueeze(-1) + log_q_s_spreva 68 | 69 | # (batch_size, s_nvals) - (batch, s) 70 | log_q_s_hpreva = torch.logsumexp(loq_q_sprevs_hpreva, dim=-2) 71 | 72 | log_q_o_s = self.log_q_o_s(o=o) 73 | log_q_r_s = self.log_q_r_s(r=r) 74 | log_q_d_s = self.log_q_d_s(d=d) if with_done else 0 75 | 76 | # (batch_size, s_nvals) 77 | log_q_ord_s = log_q_o_s + log_q_r_s + log_q_d_s 78 | 79 | # (batch_size, s_nvals) 80 | log_q_sord_hpreva = log_q_s_hpreva + log_q_ord_s 81 | 82 | # (batch_size,) 83 | log_q_ord_hpreva = torch.logsumexp(log_q_sord_hpreva, dim=-1) 84 | 85 | # (batch_size, s_nvals) 86 | log_q_s_h = log_q_sord_hpreva - log_q_ord_hpreva.unsqueeze(-1) 87 | 88 | return log_q_s_h 89 | 90 | @torch.jit.export 91 | def log_prob_joint(self, regime, episode, states, with_done=False): 92 | log_prob = 0 93 | 94 | n_transitions = len(episode) // 4 95 | for t in range(n_transitions + 1): 96 | 97 | # s_t, o_t, r_t, d_t 98 | state, obs, reward, done = states[t], episode[4*t], episode[4*t+1], episode[4*t+2] 99 | 100 | if t == 0: 101 | was_done = torch.zeros_like(done) 102 | 103 | # (batch_size, ) 104 | log_q_s_saprev = self.log_q_s(s=state) 105 | 106 | else: 107 | 108 | # safety fix, in case a done flag goes back down 109 | done = torch.max(was_done, done) 110 | 111 | # s_t-1, a_t-1 112 | state_prev, action_prev = states[t-1], episode[4*t-1] 113 | 114 | # (batch_size, ) 115 | log_q_s_saprev = self.log_q_snext_sa(a=action_prev, s=state_prev, snext=state) 116 | 117 | # (batch_size, ) 118 | log_q_o_s = self.log_q_o_s(o=obs, s=state) 119 | log_q_r_s = self.log_q_r_s(r=reward, s=state) 120 | 121 | if with_done: 122 | # (batch_size, ) 123 | log_q_d_s = self.log_q_d_s(d=done, s=state) 124 | else: 125 | log_q_d_s = 0 126 | 127 | # a_t (if any) 128 | if t < n_transitions: 129 | action = episode[4*(t+1)-1] 130 | 131 | # (batch_size, ) 132 | log_q_a_s = self.log_q_a_s(a=action, s=state) 133 | log_q_a_s = log_q_a_s * (1 - done) # discard actions if done=True 134 | log_q_a_s = log_q_a_s * (1 - regime) # discard actions in interventional regime 135 | else: 136 | log_q_a_s = 0 137 | 138 | # (batch_size, ) 139 | log_q_sorda_saprev = log_q_s_saprev + log_q_o_s + log_q_r_s + log_q_d_s + log_q_a_s 140 | 141 | # discard transitions after done=True (due to padding) 142 | log_q_sorda_saprev = log_q_sorda_saprev * (1 - was_done) 143 | 144 | # (batch_size, ) 145 | log_prob = log_prob + log_q_sorda_saprev 146 | 147 | was_done = done 148 | 149 | return log_prob 150 | 151 | @torch.jit.export 152 | def log_prob(self, regime: torch.Tensor, episode: typing.List[torch.Tensor], with_done: bool=False, return_loq_q_s_h: bool=False): 153 | 154 | # if requested, store all q(s_t | h_t) and q(s_t+1 | h_t) during forward 155 | if return_loq_q_s_h: 156 | seq_loq_q_s_h = [] 157 | 158 | log_prob = 0 159 | done = torch.tensor([0.]) 160 | 161 | n_transitions = len(episode) // 4 162 | for t in range(n_transitions + 1): 163 | 164 | # o_t, r_t, d_t 165 | obs, reward, done = episode[4*t], episode[4*t+1], episode[4*t+2] 166 | 167 | if t == 0: 168 | was_done = torch.zeros_like(done) 169 | 170 | # (batch_size, s_nvals) 171 | log_q_s_hprev = self.log_q_s().unsqueeze(0) 172 | 173 | else: 174 | # safety fix, in case a done flag goes back down 175 | done = torch.max(was_done, done) 176 | 177 | # (batch_size, s_nvals) 178 | log_q_s_hprev = log_q_snext_h 179 | 180 | # (batch_size, s_nvals) 181 | log_q_o_s = self.log_q_o_s(o=obs) 182 | log_q_r_s = self.log_q_r_s(r=reward) 183 | 184 | if with_done: 185 | # (batch_size, s_nvals) 186 | log_q_d_s = self.log_q_d_s(d=done) 187 | else: 188 | log_q_d_s = 0 189 | 190 | # a_t (if any) 191 | if t < n_transitions: 192 | action = episode[4*(t+1)-1] 193 | 194 | # (batch_size, s_nvals) 195 | log_q_a_s = self.log_q_a_s(a=action) 196 | log_q_a_s = torch.where((done==1).unsqueeze(-1), torch.zeros_like(log_q_a_s), log_q_a_s) # discard actions if done=True 197 | log_q_a_s = torch.where((regime==1).unsqueeze(-1), torch.zeros_like(log_q_a_s), log_q_a_s) # discard actions in interventional regime 198 | else: 199 | log_q_a_s = 0 200 | 201 | # hprev = (o_0, r_0, d_0, a_0, ..., o_t-1, r_t-1, d_t-1, a_t-1) 202 | 203 | # (batch_size, s_nvals) 204 | log_q_sorda_hprev = log_q_s_hprev + log_q_o_s + log_q_r_s + log_q_d_s + log_q_a_s 205 | 206 | # (batch_size, ) 207 | log_q_orda_hprev = torch.logsumexp(log_q_sorda_hprev, dim=-1) 208 | 209 | # discard transitions after done=True (due to padding) 210 | log_q_orda_hprev = log_q_orda_hprev * (1 - was_done) 211 | 212 | if t == 0: 213 | # (batch_size, ) 214 | log_prob = log_prob + log_q_orda_hprev 215 | else: 216 | # (batch_size, ) 217 | log_prob = torch.where(log_prob.isinf(), log_prob, log_prob + log_q_orda_hprev) # bugfix, otherwise NaNs will appear 218 | 219 | # h = (o_0, r_0, d_0, a_0, ..., o_t, r_t, d_t, a_t) 220 | 221 | # (batch_size, s_nvals) 222 | log_q_s_h = log_q_sorda_hprev - log_q_orda_hprev.unsqueeze(1) 223 | 224 | # snext = s_t+1 225 | 226 | if t < n_transitions: 227 | # (batch_size, s_nvals, s_nvals) 228 | log_q_ssnext_h = log_q_s_h.unsqueeze(2) + self.log_q_snext_sa(a=action) 229 | # (batch_size, s_nvals) 230 | log_q_snext_h = torch.logsumexp(log_q_ssnext_h, dim=1) 231 | 232 | else: 233 | log_q_snext_h = None 234 | 235 | # if requested, store all q(s_t | h_t) and q(s_t+1 | h_t) during forward 236 | if return_loq_q_s_h: 237 | seq_loq_q_s_h.append((log_q_s_h, log_q_snext_h)) 238 | 239 | was_done = done 240 | 241 | if return_loq_q_s_h: 242 | return log_prob, seq_loq_q_s_h 243 | else: 244 | return log_prob 245 | 246 | @torch.jit.export 247 | def sample_states(self, regime: torch.Tensor, episode: typing.List[torch.Tensor], with_done: bool=False): 248 | 249 | with torch.no_grad(): 250 | 251 | # collect all q(s_t | h_t) with a forward pass 252 | _, seq_log_q_s_h = self.log_prob(regime, episode, with_done=with_done, return_loq_q_s_h=True) 253 | 254 | # collect all s_t ~ q(s_t | h_t, s_t+1) with a backward pass 255 | states = [] 256 | n_transitions = len(episode) // 4 257 | for t in range(n_transitions, -1, -1): 258 | log_q_s_h, log_q_snext_h = seq_log_q_s_h[t] 259 | 260 | if t == n_transitions: 261 | # (batch_size, s_nvals) 262 | log_q_s_hsnext = log_q_s_h 263 | 264 | else: 265 | action = episode[4*(t+1)-1] 266 | 267 | # (batch_size, s_nvals) 268 | log_q_snext_sa = self.log_q_snext_sa(a=action, snext=state) 269 | 270 | # (batch_size, s_nvals) 271 | log_q_ssnext_h = log_q_s_h + log_q_snext_sa 272 | 273 | # (batch_size, s_nvals) 274 | log_q_s_hsnext = log_q_ssnext_h - log_q_snext_h 275 | 276 | state = torch.distributions.one_hot_categorical.OneHotCategorical( 277 | logits=log_q_s_hsnext, 278 | ).sample() 279 | 280 | states.insert(0, state) 281 | 282 | return states 283 | 284 | @torch.jit.export 285 | def loss_nll(self, regime: torch.Tensor, episode: typing.List[torch.Tensor], with_done: bool=False): 286 | return -self.log_prob(regime, episode, with_done=with_done) 287 | 288 | @torch.jit.export 289 | def loss_em(self, regime: torch.Tensor, episode: typing.List[torch.Tensor], with_done: bool=False): 290 | states = self.sample_states(regime, episode, with_done=with_done) 291 | return -self.log_prob_joint(regime, episode, states, with_done=with_done) 292 | 293 | 294 | class TabularAugmentedModel(AugmentedModel): 295 | 296 | """ Learnable Augmented Model using tabular probability distribution parameters. """ 297 | 298 | def __init__(self, s_nvals, o_nvals, a_nvals, r_nvals): 299 | super().__init__() 300 | self.s_nvals = s_nvals 301 | self.o_nvals = o_nvals 302 | self.a_nvals = a_nvals 303 | self.r_nvals = r_nvals 304 | 305 | # p(s_0) 306 | self.params_s = torch.nn.Parameter(torch.empty([s_nvals])) 307 | torch.nn.init.normal_(self.params_s) 308 | 309 | # p(s_t+1 | s_t, a_t) 310 | self.params_s_sa = torch.nn.Parameter(torch.empty([s_nvals, a_nvals, s_nvals])) 311 | torch.nn.init.normal_(self.params_s_sa) 312 | 313 | # p(o_t | s_t) 314 | self.params_o_s = torch.nn.Parameter(torch.empty([s_nvals, o_nvals])) 315 | torch.nn.init.normal_(self.params_o_s) 316 | 317 | # p(r_t | s_t) 318 | self.params_r_s = torch.nn.Parameter(torch.empty([s_nvals, r_nvals])) 319 | torch.nn.init.normal_(self.params_r_s) 320 | 321 | # p(d_t | s_t) 322 | self.params_d_s = torch.nn.Parameter(torch.empty([s_nvals])) 323 | torch.nn.init.normal_(self.params_d_s) 324 | 325 | # p(a_t | s_t, i=0) 326 | self.params_a_s = torch.nn.Parameter(torch.empty([s_nvals, a_nvals])) 327 | torch.nn.init.normal_(self.params_a_s) 328 | 329 | # @torch.jit.export 330 | @torch.jit.ignore 331 | def log_q_s(self, s: typing.Optional[torch.Tensor]=None): 332 | 333 | """ Log proba of state distribution p(s) """ 334 | 335 | log_q_s = torch.nn.functional.log_softmax(self.params_s, dim=-1) 336 | 337 | if s is not None: 338 | s_index = s.max(-1)[1] 339 | log_q_s = log_q_s[s_index] 340 | 341 | return log_q_s 342 | 343 | # @torch.jit.export 344 | @torch.jit.ignore 345 | def log_q_snext_sa(self, a: torch.Tensor, 346 | s: typing.Optional[torch.Tensor]=None, 347 | snext: typing.Optional[torch.Tensor]=None): 348 | 349 | """ Log proba of state transition distribution p(s|s, a). """ 350 | 351 | # (s_nvals, a_nvals, s_nvals) 352 | log_q_snext_sa = torch.nn.functional.log_softmax(self.params_s_sa, dim=-1) 353 | indices = [] 354 | 355 | if s is not None: 356 | s_index = s.max(-1)[1] 357 | indices.insert(0, s_index) 358 | 359 | if a is not None: 360 | a_index = a.max(-1)[1] 361 | indices.insert(0, a_index) 362 | log_q_snext_sa = log_q_snext_sa.permute(1, 0, 2) 363 | 364 | if snext is not None: 365 | snext_index = snext.max(-1)[1] 366 | indices.insert(0, snext_index) 367 | log_q_snext_sa = log_q_snext_sa.permute(2, 0, 1) 368 | 369 | if len(indices): 370 | log_q_snext_sa = log_q_snext_sa[indices] 371 | 372 | return log_q_snext_sa 373 | 374 | # @torch.jit.export 375 | @torch.jit.ignore 376 | def log_q_o_s(self, o: torch.Tensor, 377 | s: typing.Optional[torch.Tensor]=None): 378 | 379 | """ Log proba of conditional observation distribution from state p(o|s). """ 380 | 381 | log_q_o_s = torch.nn.functional.log_softmax(self.params_o_s, dim=-1) 382 | 383 | indices = [] 384 | 385 | if s is not None: 386 | s_index = s.max(-1)[1] 387 | indices.insert(0, s_index) 388 | 389 | if o is not None: 390 | o_index = o.max(-1)[1] 391 | indices.insert(0, o_index) 392 | log_q_o_s = log_q_o_s.permute(1, 0) 393 | 394 | if len(indices): 395 | log_q_o_s = log_q_o_s[indices] 396 | 397 | return log_q_o_s 398 | 399 | # @torch.jit.export 400 | @torch.jit.ignore 401 | def log_q_a_s(self, a: torch.Tensor, 402 | s: typing.Optional[torch.Tensor]=None): 403 | 404 | """ Log proba of conditional action distribution from state p(a|s). """ 405 | 406 | log_q_a_s = torch.nn.functional.log_softmax(self.params_a_s, dim=-1) 407 | 408 | indices = [] 409 | 410 | if s is not None: 411 | s_index = s.max(-1)[1] 412 | indices.insert(0, s_index) 413 | 414 | if a is not None: 415 | a_index = a.max(-1)[1] 416 | indices.insert(0, a_index) 417 | log_q_a_s = log_q_a_s.permute(1, 0) 418 | 419 | if len(indices): 420 | log_q_a_s = log_q_a_s[indices] 421 | 422 | return log_q_a_s 423 | 424 | # @torch.jit.export 425 | @torch.jit.ignore 426 | def log_q_r_s(self, r: torch.Tensor, 427 | s: typing.Optional[torch.Tensor]=None): 428 | 429 | """ Log proba of conditional reward distribution from state p(r|s). """ 430 | 431 | log_q_r_s = torch.nn.functional.log_softmax(self.params_r_s, dim=-1) 432 | 433 | indices = [] 434 | 435 | if s is not None: 436 | s_index = s.max(-1)[1] 437 | indices.insert(0, s_index) 438 | 439 | if r is not None: 440 | r_index = r.max(-1)[1] 441 | indices.insert(0, r_index) 442 | log_q_r_s = log_q_r_s.permute(1, 0) 443 | 444 | if len(indices): 445 | log_q_r_s = log_q_r_s[indices] 446 | 447 | return log_q_r_s 448 | 449 | # @torch.jit.export 450 | @torch.jit.ignore 451 | def log_q_d_s(self, d: torch.Tensor, 452 | s: typing.Optional[torch.Tensor]=None): 453 | 454 | """ Log proba of conditional flagDone distribution from state p(d|s). """ 455 | 456 | if s is not None: 457 | s_index = s.max(-1)[1] 458 | log_q_d_s = torch.distributions.bernoulli.Bernoulli( 459 | logits=self.params_d_s[s_index], 460 | ).log_prob(d.unsqueeze(1)) 461 | 462 | else: 463 | log_q_d_s = torch.distributions.bernoulli.Bernoulli( 464 | logits=self.params_d_s, 465 | ).log_prob(d.unsqueeze(1)) 466 | 467 | return log_q_d_s 468 | 469 | def get_settings(self): 470 | 471 | """ Return a dictionnary with all proba distributions. 472 | Straightforward, as we provided previously those same proba distributions. """ 473 | 474 | settings = {} 475 | settings["p_s"] = torch.nn.functional.softmax(self.params_s, dim=-1) 476 | settings["p_o_s"] = torch.nn.functional.softmax(self.params_o_s, dim=-1) 477 | settings["p_s_sa"] = torch.nn.functional.softmax(self.params_s_sa, dim=-1) 478 | settings["p_r_s"] = torch.nn.functional.softmax(self.params_r_s, dim=-1) 479 | 480 | q_d_s = torch.sigmoid(self.params_d_s) 481 | settings["p_d_s"] = torch.stack([1-q_d_s, q_d_s], dim=-1) 482 | 483 | settings["p_or_s"] = settings["p_r_s"].unsqueeze(-2) * settings["p_o_s"].unsqueeze(-1) 484 | 485 | return settings 486 | 487 | def set_probs(self, p_s=None, p_o_s=None, p_s_sa=None, p_r_s=None, p_d_s=None, p_a_s=None): 488 | with torch.no_grad(): 489 | if p_s is not None: 490 | self.params_s[:] = torch.as_tensor(p_s).log() 491 | if p_o_s is not None: 492 | self.params_o_s[:] = torch.as_tensor(p_o_s).log() 493 | if p_s_sa is not None: 494 | self.params_s_sa[:] = torch.as_tensor(p_s_sa).log() 495 | if p_r_s is not None: 496 | self.params_r_s[:] = torch.as_tensor(p_r_s).log() 497 | if p_d_s is not None: 498 | self.params_d_s[:] = torch.as_tensor(p_d_s).log() 499 | if p_a_s is not None: 500 | self.params_a_s[:] = torch.as_tensor(p_a_s).log() 501 | 502 | 503 | class NNProba(torch.nn.Module): 504 | 505 | """ NN Module to estimate (log) proba for any (input/output) pair """ 506 | 507 | def __init__(self, input_dim, output_dim, h_dim = 16): 508 | super(NNProba, self).__init__() 509 | 510 | self.fc1 = torch.nn.Linear(input_dim, h_dim) 511 | self.fc2 = torch.nn.Linear(h_dim, output_dim) 512 | 513 | def forward(self, x): 514 | 515 | x = F.relu(self.fc1(x)) 516 | x = self.fc2(x) 517 | 518 | return torch.nn.functional.softmax(x, dim=-1), torch.nn.functional.log_softmax(x, dim=-1) 519 | 520 | 521 | class NNAugmentedModel(AugmentedModel): 522 | 523 | """ NN Augmented Model """ 524 | 525 | def __init__(self, s_nvals, o_nvals, a_nvals, r_nvals, h_dim): 526 | super().__init__() 527 | self.s_nvals = s_nvals 528 | self.o_nvals = o_nvals 529 | self.a_nvals = a_nvals 530 | self.r_nvals = r_nvals 531 | 532 | # p(s_0) 533 | self.params_s = torch.nn.Parameter(torch.empty([s_nvals])) 534 | torch.nn.init.normal_(self.params_s) 535 | 536 | # p(s_t+1 | s_t, a_t) 537 | self.params_s_sa = NNProba(input_dim = s_nvals + a_nvals, output_dim = s_nvals, h_dim=h_dim) 538 | 539 | # p(o_t | s_t) 540 | self.params_o_s = NNProba(input_dim = s_nvals , output_dim = o_nvals, h_dim=h_dim) 541 | 542 | # p(r_t | s_t) 543 | self.params_r_s = NNProba(input_dim = s_nvals , output_dim = r_nvals, h_dim=h_dim) 544 | 545 | # p(d_t | s_t) 546 | self.params_d_s = torch.nn.Parameter(torch.empty([s_nvals])) 547 | torch.nn.init.normal_(self.params_d_s) 548 | 549 | # p(a_t | s_t, i=0) 550 | self.params_a_s = NNProba(input_dim = s_nvals, output_dim = a_nvals, h_dim=h_dim) 551 | 552 | def log_q_s(self, 553 | s: typing.Optional[torch.Tensor]=None): 554 | 555 | """ Log proba of state distribution p(s) """ 556 | 557 | log_q_s = torch.nn.functional.log_softmax(self.params_s, dim=-1) 558 | 559 | if s is not None: 560 | s_index = s.max(-1)[1] 561 | log_q_s = log_q_s[s_index] 562 | 563 | return log_q_s 564 | 565 | def log_q_snext_sa(self, 566 | a: torch.Tensor, 567 | s: typing.Optional[torch.Tensor]=None, 568 | snext: typing.Optional[torch.Tensor]=None): 569 | 570 | """ Log proba of state transition distribution p(s|s, a). 571 | 572 | Can be improved ? REPEAT, EXPAND ? """ 573 | 574 | batch_size, _ = a.shape 575 | 576 | if s is None and a is not None: 577 | 578 | s = torch.diag(torch.ones(self.s_nvals)).unsqueeze(0) # (1, 16, 16) 579 | s = s.expand(batch_size, self.s_nvals, self.s_nvals) # (batch_size, 16, 16) 580 | s = s.reshape(-1,self.s_nvals) # (batch_size*16, 16) 581 | 582 | a = torch.repeat_interleave(a, repeats=self.s_nvals, dim=0) #(batch_size*16, 2) 583 | 584 | _ , log_q_snext_sa = self.params_s_sa(torch.cat([s, a], dim=-1)) # (batsh_size*16, 16) from (batsh_size*16, 16+2) 585 | log_q_snext_sa = log_q_snext_sa.reshape(batch_size, self.s_nvals, 1, self.s_nvals) # batsh_size, 16, 16 586 | 587 | elif a is None and s is not None : 588 | 589 | a = torch.diag(torch.ones(self.a_nvals)).unsqueeze(0) # (1, 2, 2) 590 | a = a.expand(batch_size, self.a_nvals, self.a_nvals) # (batch_size, 2, 2) 591 | a = a.reshape(-1,self.a_nvals) # (batch_size*2, 2) 592 | 593 | s = torch.repeat_interleave(s, repeats=self.a_nvals, dim=0) # (batsh_size*2, 16) 594 | 595 | _ , log_q_snext_sa = self.params_s_sa(torch.cat([s, a], dim=-1) ) # (batsh_size*2, 16) 596 | log_q_snext_sa = log_q_snext_sa.reshape(batch_size, self.a_nvals, self.s_nvals) # batsh_size, 2, 16 597 | 598 | elif a is None and s is None: 599 | print("BOTH A and S are none") 600 | pass 601 | 602 | else : 603 | _ , log_q_snext_sa = self.params_s_sa(torch.cat([s, a], dim=-1)) # output (batch_size, 16), input (batsh_size, 16+2) 604 | log_q_snext_sa = log_q_snext_sa.unsqueeze(1).unsqueeze(1) #batsh_size, 1, 1, 32 605 | 606 | if snext is not None: 607 | 608 | snext_index = snext.max(-1)[1] 609 | log_q_snext_sa = log_q_snext_sa.permute(3, 0, 1, 2) # (16, batch_size, ?, ?) 610 | log_q_snext_sa = log_q_snext_sa[snext_index] # (batch_size, ?, ?) 611 | 612 | return torch.squeeze(log_q_snext_sa) 613 | 614 | def log_q_o_s(self, 615 | o: torch.Tensor, 616 | s: typing.Optional[torch.Tensor]=None): 617 | 618 | """ Log proba of conditional observation distribution from state p(o|s). """ 619 | 620 | batch_size, _ = o.shape 621 | 622 | if s is None: 623 | s = torch.diag(torch.ones(self.s_nvals)).unsqueeze(0) # (1, 16, 16) 624 | s = s.expand(batch_size, self.s_nvals, self.s_nvals) # (batch_size, 16, 16) 625 | s = s.reshape(-1,self.s_nvals) # (batch_size*16, 16) 626 | _ , log_q_o_s = self.params_o_s(s) #((batch_size*16, 2)) 627 | log_q_o_s = log_q_o_s.reshape(batch_size, self.s_nvals, self.o_nvals) # 32, 16, 2 628 | else : 629 | _ , log_q_o_s = self.params_o_s(s) 630 | log_q_o_s = log_q_o_s.unsqueeze(1) #32, 1, 2 631 | 632 | if o is not None: 633 | log_q_o_s = log_q_o_s * o.unsqueeze(1) # 32, ?, 2 634 | log_q_o_s = log_q_o_s.sum(-1) 635 | 636 | return torch.squeeze(log_q_o_s) 637 | 638 | @torch.jit.ignore 639 | def log_q_a_s(self, 640 | a: torch.Tensor, 641 | s: typing.Optional[torch.Tensor]=None): 642 | 643 | """ Log proba of conditional action distribution from state p(o|s). """ 644 | 645 | batch_size, _ = a.shape 646 | 647 | if s is None: 648 | s = torch.diag(torch.ones(self.s_nvals)).unsqueeze(0) # (1, 16, 16) 649 | s = s.expand(batch_size, self.s_nvals, self.s_nvals) # (batch_size, 16, 16) 650 | s = s.reshape(-1,self.s_nvals) # (batch_size*16, 16) 651 | _ , log_q_a_s = self.params_a_s(s) #((batch_size*16, 2)) 652 | log_q_a_s = log_q_a_s.reshape(batch_size, self.s_nvals, self.a_nvals) #32, 16, 2 653 | else : 654 | _ , log_q_a_s = self.params_a_s(s) 655 | log_q_a_s = log_q_a_s.unsqueeze(1) 656 | 657 | if a is not None: 658 | log_q_a_s = log_q_a_s * a.unsqueeze(1) # 32, ?, 2 659 | log_q_a_s = log_q_a_s.sum(-1) 660 | 661 | return torch.squeeze(log_q_a_s) 662 | 663 | def log_q_r_s(self, 664 | r: torch.Tensor, 665 | s: typing.Optional[torch.Tensor]=None): 666 | 667 | """ Log proba of conditional reward distribution from state p(o|s). """ 668 | 669 | batch_size, _ = r.shape 670 | 671 | if s is None: 672 | s = torch.diag(torch.ones(self.s_nvals)).unsqueeze(0) # (1, 16, 16) 673 | s = s.expand(batch_size, self.s_nvals, self.s_nvals) # (batch_size, 16, 16) 674 | s = s.reshape(-1,self.s_nvals) # (batch_size*16, 16) 675 | _ , log_q_r_s = self.params_r_s(s) #((batch_size*16, 2)) 676 | log_q_r_s = log_q_r_s.reshape(batch_size, self.s_nvals, self.r_nvals) # 32, 16, 2 677 | else : 678 | _ , log_q_r_s = self.params_r_s(s) 679 | log_q_r_s = log_q_r_s.unsqueeze(1) #32, 1, 2 680 | 681 | if r is not None: 682 | log_q_r_s = log_q_r_s * r.unsqueeze(1) # 32, ?, 2 683 | log_q_r_s = log_q_r_s.sum(-1) 684 | 685 | #print('Final Shape', torch.squeeze(log_q_o_s).shape) 686 | return torch.squeeze(log_q_r_s) 687 | 688 | def log_q_d_s(self, d: torch.Tensor, 689 | s: typing.Optional[torch.Tensor]=None): 690 | 691 | """ Log proba of conditional flagDone distribution from state p(o|s). """ 692 | 693 | if s is not None: 694 | s_index = s.max(-1)[1] 695 | log_q_d_s = torch.distributions.bernoulli.Bernoulli( 696 | logits=self.params_d_s[s_index], 697 | ).log_prob(d.unsqueeze(1)) 698 | 699 | else: 700 | log_q_d_s = torch.distributions.bernoulli.Bernoulli( 701 | logits=self.params_d_s, 702 | ).log_prob(d.unsqueeze(1)) 703 | 704 | return log_q_d_s 705 | 706 | def get_settings(self): 707 | 708 | """ Return a dictionnary with all proba distributions. 709 | Straightforward, as we provided previously those same proba distributions. """ 710 | 711 | settings = {} 712 | settings["p_s"] = torch.nn.functional.softmax(self.params_s, dim=-1) 713 | 714 | s = torch.diag(torch.ones(self.s_nvals)) # (16, 16) 715 | q_o_s, _ = self.params_o_s(s) # (16, 3) 716 | 717 | settings["p_o_s"] = q_o_s.reshape(self.s_nvals, self.o_nvals) 718 | 719 | a = torch.diag(torch.ones(self.a_nvals)) #(2, 2) 720 | a = a.expand(self.s_nvals, self.a_nvals, self.a_nvals).reshape(-1, self.a_nvals) # (16*2, 2) 721 | 722 | s = torch.repeat_interleave(s, repeats=self.a_nvals, dim=0) # (16*2, 16) 723 | concat_sa = torch.cat([s, a], dim=-1) # (16*2, 16+2) 724 | q_snext_sa, _ = self.params_s_sa(concat_sa) # (16*2, 16) 725 | 726 | settings["p_s_sa"] = q_snext_sa.reshape(self.s_nvals, self.a_nvals, self.s_nvals) # 16, 2, 16 727 | 728 | s = torch.diag(torch.ones(self.s_nvals)) # (16, 16) 729 | q_r_s, _ = self.params_r_s(s) # (16, 2) 730 | settings["p_r_s"] = q_r_s.reshape(self.s_nvals, self.r_nvals) 731 | 732 | q_d_s = torch.sigmoid(self.params_d_s) 733 | settings["p_d_s"] = torch.stack([1-q_d_s, q_d_s], dim=-1) 734 | 735 | settings["p_or_s"] = settings["p_r_s"].unsqueeze(-2) * settings["p_o_s"].unsqueeze(-1) 736 | 737 | return settings 738 | -------------------------------------------------------------------------------- /policies.py: -------------------------------------------------------------------------------- 1 | ########################################################## 2 | ######################## Policies ######################## 3 | ########################################################## 4 | 5 | """ 6 | One can use two kinds of policies to do rollout of the game : 7 | 8 | - UniformPolicy() which plays randomly from the tuple (o, r, d) 9 | - ExpertPolicy(p_a_s) which mimics an expert that has access to the true state `s`. 10 | One needs to provide the action distribution `p_a_s`. 11 | 12 | In both cases, one would only need to call of the `.action()` property of the Policy class, 13 | providing all the environment data. As a remark, it returns a one-hot encoded action, 14 | and the gym environment requires to have an *int*-type action (just use `action.argmax()` then). 15 | 16 | action = policy.action(o, r, done, **info) 17 | 18 | One can collect episodes by simply calling `rollout`, given an env and a policy: 19 | 20 | episode = rollout(env, default_policy) 21 | 22 | """ 23 | 24 | import torch 25 | 26 | class Policy(torch.nn.Module): 27 | 28 | """ Empty class, must include 29 | 1. a Reset method 30 | 3. a Action method to act in environment from 4-uplets (o, r, d, info) 31 | """ 32 | 33 | def reset(self): 34 | raise NotImplemented 35 | 36 | def action(self, o, r, d, **info): 37 | raise NotImplemented 38 | 39 | 40 | class UniformPolicy(Policy): 41 | 42 | """ Uniform policy to act within the environement with random distributed actions """ 43 | 44 | def __init__(self, a_nvals): 45 | super().__init__() 46 | 47 | self.a_nvals = a_nvals 48 | self.h = [] 49 | 50 | def reset(self): 51 | self.h.clear() 52 | 53 | def action(self, o, r, d, **info): 54 | 55 | self.h += [o, r, d] 56 | a = torch.distributions.one_hot_categorical.OneHotCategorical(\ 57 | probs=torch.ones(self.a_nvals)/self.a_nvals).sample(o.shape[:-1]) 58 | self.h += [a] 59 | 60 | return a 61 | 62 | class ExpertPolicy(Policy): 63 | 64 | """ Expert Policy Class that chooses its actions from the hidden state s """ 65 | 66 | def __init__(self, p_a_s): 67 | super().__init__() 68 | self.probs_a_s = p_a_s #p(a_t | s_t, i=0) 69 | 70 | def reset(self): 71 | pass 72 | 73 | def action(self, o, r, done, **info): 74 | 75 | s_index = info["s"].argmax() 76 | a = torch.distributions.one_hot_categorical.OneHotCategorical(probs=self.probs_a_s[s_index], ).sample() 77 | 78 | return a 79 | 80 | 81 | class AugmentedPolicy(Policy): 82 | 83 | """ Augmented Policy which uses am augmented model to estimate the belived state 84 | of the environment. It then chooses its actions according to the believed state estimate.""" 85 | 86 | def __init__(self, augmentedmodel, regime=torch.tensor(0), with_done=False): 87 | super().__init__() 88 | 89 | self.m = augmentedmodel 90 | self.q_s_h = None 91 | self.last_action = None 92 | self.regime = regime 93 | 94 | def update_hidden_state(self, o, r, d, with_done=False): 95 | log_q_s_h = self.m.log_q_s_h(regime=self.regime, 96 | loq_q_sprev_hprev=self.hidden_state, 97 | a=self.last_action, 98 | o=o.unsqueeze(0), 99 | r=r.unsqueeze(0), 100 | d=d.unsqueeze(0), 101 | with_done=False) 102 | self.q_s_h = torch.exp(log_q_s_h) 103 | 104 | def reset(self): 105 | self.q_s_h = None 106 | self.last_action = None 107 | 108 | def action(self, o, r, done, deterministic=False, **info): 109 | 110 | print(self.hidden_state, self.last_action) 111 | self.update_hidden_state(o, r, done) 112 | new_action_p = torch.exp(self.m.log_q_a_s(a=None, s=self.q_s_h)) 113 | 114 | if deterministic : 115 | new_action_p = new_action_p.round() 116 | 117 | a = torch.distributions.one_hot_categorical.OneHotCategorical( 118 | probs=new_action_p,).sample() 119 | 120 | self.last_action = a 121 | 122 | return a 123 | 124 | -------------------------------------------------------------------------------- /rl_agents/ac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn, optim 4 | import torch.nn.functional as F 5 | 6 | from utils import print_log 7 | 8 | eps = np.finfo(np.float32).eps.item() 9 | 10 | ############################################################ 11 | 12 | class ActorCritic(nn.Module): 13 | 14 | def __init__(self, s_nvals, a_nvals, hidden_size=32): 15 | 16 | super(ActorCritic, self).__init__() 17 | 18 | self.a_nvals = a_nvals 19 | self.s_nvals = s_nvals 20 | self.hidden_size = hidden_size 21 | 22 | self.actor = torch.nn.Sequential( 23 | torch.nn.Linear(self.s_nvals, hidden_size), 24 | torch.nn.ReLU(), 25 | # torch.nn.Linear(self.hidden_size, hidden_size), 26 | # torch.nn.ReLU(), 27 | torch.nn.Linear(hidden_size, self.a_nvals), 28 | torch.nn.LogSoftmax(dim=-1), 29 | ) 30 | 31 | self.critic = torch.nn.Sequential( 32 | torch.nn.Linear(self.s_nvals, hidden_size), 33 | torch.nn.ReLU(), 34 | # torch.nn.Linear(self.hidden_size, hidden_size), 35 | # torch.nn.ReLU(), 36 | torch.nn.Linear(hidden_size, 1), 37 | ) 38 | 39 | def forward(self, state): 40 | return self.actor(state), self.critic(state) 41 | 42 | def run_episode(env, model, max_steps_per_episode): 43 | 44 | action_log_probs, rewards, values = [], [], [] 45 | 46 | with torch.no_grad(): 47 | state = env.reset() 48 | 49 | for t in range(max_steps_per_episode): 50 | 51 | action_log_probs_t, value = model.forward(state) 52 | action = int(torch.multinomial(action_log_probs_t.exp(), 1)[0]) 53 | 54 | with torch.no_grad(): 55 | state, reward, done, info = env.step(action) 56 | 57 | action_log_prob = action_log_probs_t[:, action] 58 | 59 | action_log_probs.append(action_log_prob) 60 | values.append(value) 61 | rewards.append(reward) 62 | 63 | # if tmp_print_flag: 64 | # action_desc = ["top", "right", "bottom", "left", "noop"] 65 | # print(f"action={action_desc[action]} (p={torch.exp(action_log_prob).item()}), reward={reward}, value={value.detach().item()}") 66 | 67 | if done: 68 | break 69 | 70 | return action_log_probs, values, rewards 71 | 72 | def get_discounted_returns(rewards, gamma): 73 | 74 | n = len(rewards) 75 | returns = np.zeros(n) 76 | 77 | discounted_sum = 0. 78 | for t in range(n)[::-1]: 79 | discounted_sum = rewards[t] + gamma * discounted_sum 80 | returns[t] = discounted_sum 81 | 82 | return returns 83 | 84 | def loss_episode(env, model, gamma, max_steps_per_episode): 85 | 86 | action_log_probs, values, rewards = run_episode(env, model, max_steps_per_episode) 87 | returns = get_discounted_returns(rewards, gamma) 88 | 89 | action_log_probs = torch.cat(action_log_probs, dim=-1).unsqueeze(0) 90 | values = torch.cat(values, dim=-1) 91 | returns = torch.tensor(returns, dtype=torch.float).unsqueeze(0) 92 | 93 | # compute actor-critic loss values 94 | actor_loss = - torch.sum(action_log_probs * (returns - values.detach())) 95 | critic_loss = F.mse_loss(values, returns, reduction = 'sum') 96 | 97 | return actor_loss, critic_loss, np.sum(rewards) 98 | 99 | def run_actorcritic(env, agent, 100 | gamma=0.99, 101 | n_epochs=10000, 102 | batch_size=1, 103 | max_steps_per_episode=1000, 104 | log_every=1000, 105 | lr=1e-2, 106 | logfile=None): 107 | 108 | optimizer = optim.Adam(agent.parameters(), lr=lr) 109 | 110 | best_running_return = -float("inf") 111 | best_params = agent.state_dict().copy() 112 | 113 | for ep in range(n_epochs) : 114 | 115 | epoch_return = 0 116 | epoch_actor_loss = 0 117 | epoch_critic_loss = 0 118 | optimizer.zero_grad() 119 | 120 | for i in range(batch_size): 121 | actor_loss, critic_loss, episode_return = loss_episode(env, agent, gamma, max_steps_per_episode) 122 | 123 | epoch_return += episode_return / batch_size 124 | epoch_actor_loss += actor_loss.detach().item() / batch_size 125 | epoch_critic_loss += critic_loss.detach().item() / batch_size 126 | 127 | loss = (actor_loss + critic_loss) / batch_size 128 | loss.backward() 129 | 130 | if ep == 0: 131 | running_return = epoch_return 132 | else: 133 | running_return = epoch_return * 0.1 + running_return * 0.9 134 | 135 | if ep % log_every == 0: 136 | print_log(f'Epoch {ep}: running return= {np.round(running_return, 4)}, critic loss={np.round(epoch_critic_loss, 4)}', logfile=logfile) 137 | 138 | # store best agent 139 | if best_running_return < running_return: 140 | print_log(f" best agent so far ({np.round(running_return, 4)})", logfile=logfile) 141 | best_running_return = running_return 142 | best_params = agent.state_dict().copy() 143 | 144 | optimizer.step() 145 | 146 | # restore best agent 147 | agent.load_state_dict(best_params) 148 | 149 | def evaluate_agent(env, agent, n_episodes, max_steps_per_episode=1000): 150 | with torch.no_grad(): 151 | mean_return = 0 152 | for ep in range(n_episodes): 153 | _, _, rewards = run_episode(env, agent, max_steps_per_episode=max_steps_per_episode) 154 | ep_reward = np.sum(rewards) 155 | mean_return += np.sum(rewards) / n_episodes 156 | 157 | return mean_return 158 | -------------------------------------------------------------------------------- /rl_agents/reinforce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.autograd import Variable 4 | 5 | import numpy as np 6 | 7 | from utils import print_log 8 | 9 | class Actor(torch.nn.Module): 10 | 11 | def __init__(self, s_nvals, a_nvals, hidden_size=32): 12 | super().__init__() 13 | 14 | self.s_nvals = s_nvals 15 | self.a_nvals = a_nvals 16 | 17 | self.log_q_a_s = torch.nn.Sequential( 18 | torch.nn.Linear(self.s_nvals, hidden_size), 19 | torch.nn.ReLU(), 20 | torch.nn.Linear(hidden_size, self.a_nvals), 21 | torch.nn.LogSoftmax(dim=-1) 22 | ) 23 | 24 | def forward(self, state): 25 | return self.log_q_a_s(state) 26 | 27 | def get_action(self, state, with_log_prob=False): 28 | log_q_a_s = self.log_q_a_s(state) 29 | action = torch.distributions.categorical.Categorical(logits=log_q_a_s).sample() 30 | if with_log_prob: 31 | log_prob = log_q_a_s[:, action] 32 | return action, log_prob 33 | else: 34 | return action 35 | 36 | def reinforce_loss(action_log_probs, rewards, gamma): 37 | 38 | assert gamma > 0 and gamma <= 1 39 | 40 | # cumulated discounted rewards 41 | returns = rewards.float().clone() 42 | for t in range(1, rewards.shape[1]): 43 | returns[:, :-t] += gamma**t * rewards[:, t:] 44 | 45 | # return normalization 46 | returns = (returns - returns.mean()) / (returns.std() + 1e-9) 47 | 48 | # policy gradient loss 49 | loss = (-action_log_probs * returns).sum(dim=1).mean(dim=0) 50 | 51 | return loss 52 | 53 | def run_reinforce(env, agent, lr, batch_size=1, gamma=0.99, n_epochs=500, log_every=10, logfile=None): 54 | 55 | optimizer = torch.optim.Adam(agent.parameters(), lr=lr) 56 | 57 | best_running_return = -float("inf") 58 | best_params = agent.state_dict().copy() 59 | 60 | for epoch in range(n_epochs): 61 | 62 | epoch_return = 0 63 | optimizer.zero_grad() 64 | 65 | for i in range(batch_size): 66 | state = env.reset() 67 | done = False 68 | log_probs = [] 69 | rewards = [] 70 | 71 | while not done: 72 | state = state.float().detach() 73 | action, log_prob = agent.get_action(state, with_log_prob=True) 74 | state, reward, done, _ = env.step(action) 75 | 76 | log_probs.append(log_prob) 77 | rewards.append(reward) 78 | 79 | if epoch % log_every == 0 and i == 0: 80 | action_desc = ["top", "right", "bottom", "left", "noop"] 81 | print(f"action={action_desc[action]} (p={torch.exp(log_prob).item()}), reward={reward}") 82 | 83 | log_probs = torch.cat(log_probs, dim=1) 84 | rewards = torch.tensor(rewards, dtype=float).unsqueeze(0) 85 | 86 | episode_return = rewards.sum() 87 | epoch_return += episode_return / batch_size 88 | 89 | loss = reinforce_loss(log_probs, rewards, gamma=gamma) / batch_size 90 | 91 | loss.backward() 92 | 93 | if epoch == 0: 94 | running_return = epoch_return 95 | else: 96 | running_return = 0.9 * running_return + 0.1 * epoch_return 97 | 98 | if epoch % log_every == 0: 99 | print_log(f"epoch {epoch} return {np.round(running_return, 4)}", logfile=logfile) 100 | 101 | # store best agent 102 | if best_running_return < running_return: 103 | print_log(f" best agent so far ({np.round(running_return, 4)})", logfile=logfile) 104 | best_running_return = running_return 105 | best_params = agent.state_dict().copy() 106 | 107 | optimizer.step() 108 | 109 | # restore best agent 110 | agent.load_state_dict(best_params) 111 | -------------------------------------------------------------------------------- /stat_tests.py: -------------------------------------------------------------------------------- 1 | # Courtesy of Cédric Colas 2 | # https://github.com/ccolas/rl_stats 3 | 4 | import numpy as np 5 | from scipy.stats import ttest_ind, mannwhitneyu, rankdata, median_test, wilcoxon 6 | 7 | tests_list = ['t-test', "Welch t-test", 'Mann-Whitney', 'Ranked t-test', 'permutation'] 8 | 9 | 10 | def run_permutation_test(all_data, n1, n2): 11 | np.random.shuffle(all_data) 12 | data_a = all_data[:n1] 13 | data_b = all_data[-n2:] 14 | return data_a.mean() - data_b.mean() 15 | 16 | 17 | def run_test(test_id, data1, data2, alpha=0.05): 18 | """ 19 | Compute tests comparing data1 and data2 with confidence level alpha 20 | :param test_id: (str) refers to what test should be used 21 | :param data1: (np.ndarray) sample 1 22 | :param data2: (np.ndarray) sample 2 23 | :param alpha: (float) confidence level of the test 24 | :return: (bool) if True, the null hypothesis is rejected 25 | """ 26 | data1 = data1.squeeze() 27 | data2 = data2.squeeze() 28 | n1 = data1.size 29 | n2 = data2.size 30 | 31 | if all(data1 == data2): 32 | return False 33 | 34 | if test_id == 't-test': 35 | _, p = ttest_ind(data1, data2, equal_var=True) 36 | return p < alpha 37 | 38 | elif test_id == "Welch t-test": 39 | _, p = ttest_ind(data1, data2, equal_var=False) 40 | return p < alpha 41 | 42 | elif test_id == 'Mann-Whitney': 43 | _, p = mannwhitneyu(data1, data2, alternative='two-sided') 44 | return p < alpha 45 | 46 | elif test_id == 'Wilcoxon': 47 | _, p = wilcoxon(data1, data2, correction=True, alternative='two-sided', zero_method="pratt") 48 | return p < alpha 49 | 50 | elif test_id == 'Ranked t-test': 51 | all_data = np.concatenate([data1.copy(), data2.copy()], axis=0) 52 | ranks = rankdata(all_data) 53 | ranks1 = ranks[: n1] 54 | ranks2 = ranks[n1:n1 + n2] 55 | assert ranks2.size == n2 56 | _, p = ttest_ind(ranks1, ranks2, equal_var=True) 57 | return p < alpha 58 | 59 | elif test_id == 'permutation': 60 | all_data = np.concatenate([data1.copy(), data2.copy()], axis=0) 61 | delta = np.abs(data1.mean() - data2.mean()) 62 | num_samples = 1000 63 | estimates = [] 64 | for _ in range(num_samples): 65 | estimates.append(run_permutation_test(all_data.copy(), n1, n2)) 66 | estimates = np.abs(np.array(estimates)) 67 | diff_count = len(np.where(estimates <= delta)[0]) 68 | return (1.0 - (float(diff_count) / float(num_samples))) < alpha 69 | 70 | else: 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | 4 | 5 | def print_log(str, logfile=None): 6 | str = f'[{datetime.datetime.now()}] {str}' 7 | print(str) 8 | if logfile is not None: 9 | with open(logfile, mode='a') as f: 10 | print(str, file=f) 11 | 12 | 13 | #################################### Metrics #################################### 14 | 15 | @torch.jit.script 16 | def kl_div(p, q, ndims: int=1): 17 | # div = torch.nn.functional.kl_div(p, q, reduction='none') 18 | div = p * (torch.log(p) - torch.log(q)) 19 | div[p == 0] = 0 # NaNs quick fix 20 | dims = [i for i in range(-1, -(ndims+1), -1)] 21 | div = div.sum(dims) 22 | return div 23 | 24 | @torch.jit.script 25 | def js_div(p, q, ndims: int=1): 26 | m = (p + q) / 2 27 | div = (kl_div(p, m, ndims) + kl_div(q, m, ndims)) / 2 28 | return div 29 | 30 | ############################################################################# 31 | 32 | def rollout(env, policy): 33 | 34 | """ Perform rollout of the game and returning episodes in a list """ 35 | 36 | episode = [] 37 | o = env.reset() 38 | r, done, info = env.r, torch.tensor(0.), {"s" : env.s} 39 | episode += [o, r, done] 40 | 41 | while not done : 42 | action = policy.action(o, r, done, **info) 43 | o, r, done, info = env.step(action.argmax()) 44 | episode += [action, o, r, done] 45 | return episode 46 | 47 | def construct_dataset(env, policy, n_samples, regime): 48 | 49 | """ Construct a dataset (of n samples) by collecting rollouts using a given 50 | policy in a given environment """ 51 | 52 | data = [] 53 | for _ in range(n_samples): 54 | policy.reset() 55 | episode = rollout(env, policy) 56 | data.append((regime, episode)) 57 | return data 58 | 59 | ################################################################################# 60 | 61 | class Dataset(torch.utils.data.Dataset): 62 | def __init__(self, data): 63 | self.data = data 64 | 65 | def __len__(self): 66 | return len(self.data) 67 | 68 | def __getitem__(self, idx): 69 | return self.data[idx] 70 | 71 | ####################################### Empirical JS ####################################### 72 | 73 | from environment.env_pomdp import PomdpEnv 74 | 75 | def empiricalJS(model_q, model_p, policy, max_length=1, n_iter=500): 76 | 77 | settings_mp = model_p.get_settings() 78 | env_p = PomdpEnv(p_s=settings_mp["p_s"], p_or_s=settings_mp["p_or_s"], p_s_sa=settings_mp["p_s_sa"], 79 | categorical_obs = True, max_length=max_length) 80 | 81 | settings_mq = model_q.get_settings() 82 | env_q = PomdpEnv(p_s=settings_mq["p_s"], p_or_s=settings_mq["p_or_s"], p_s_sa=settings_mq["p_s_sa"], 83 | categorical_obs = True, max_length=max_length) 84 | 85 | 86 | n_iter = n_iter 87 | loss_q, loss_p = 0, 0 88 | 89 | # E x~q(x) [log(q(x)) - log(q(x) + p(x))] 90 | for _ in range(n_iter): 91 | ep = rollout(env_q, policy) 92 | ep = [t.unsqueeze(0) for t in ep] 93 | regime = torch.tensor(1.).unsqueeze(0) 94 | 95 | loss_q += model_q.log_prob(regime, ep)[0] 96 | loss_q -= torch.logsumexp(torch.cat([model_q.log_prob(regime, ep), \ 97 | model_p.log_prob(regime, ep)]).unsqueeze(0), 1)[0] 98 | 99 | 100 | # E x~p(x) [log(p(x)) - log(q(x) + p(x))] 101 | for _ in range(n_iter): 102 | ep = rollout(env_p, policy) 103 | ep = [t.unsqueeze(0) for t in ep] 104 | regime = torch.tensor(1.).unsqueeze(0) 105 | 106 | loss_p += model_p.log_prob(regime, ep)[0] 107 | loss_p -= torch.logsumexp(torch.cat([model_q.log_prob(regime, ep), \ 108 | model_p.log_prob(regime, ep)]).unsqueeze(0), 1)[0] 109 | 110 | return torch.log(torch.tensor(2.)) + (loss_p + loss_q)/(2*n_iter) 111 | 112 | 113 | def cross_entropy_empirical(model_q, data_p, batch_size, with_done=False): 114 | 115 | device = next(model_q.parameters()).device 116 | 117 | dataloader_p = torch.utils.data.DataLoader(Dataset(data_p), batch_size=batch_size) 118 | 119 | ce = 0 120 | 121 | for batch in dataloader_p: 122 | regime, episode = batch 123 | regime, episode = regime.to(device), [tensor.to(device) for tensor in episode] 124 | 125 | log_prob_q = model_q.log_prob(regime, episode, with_done=with_done) 126 | 127 | ce += -log_prob_q.sum(dim=0) 128 | 129 | ce /= len(data_p) 130 | 131 | return ce 132 | 133 | 134 | def kl_div_empirical(model_p, model_q, data_p, batch_size, with_done=False): 135 | 136 | assert next(model_q.parameters()).device == next(model_p.parameters()).device 137 | 138 | device = next(model_p.parameters()).device 139 | 140 | # Build DataLoaders 141 | dataloader_p = torch.utils.data.DataLoader(Dataset(data_p), batch_size=batch_size) 142 | 143 | # KL(p|q) = E x~p(x) [log(p(x)) - log(q(x))] 144 | kl_p_q = 0 145 | 146 | for batch in dataloader_p: 147 | regime, episode = batch 148 | regime, episode = regime.to(device), [tensor.to(device) for tensor in episode] 149 | 150 | log_prob_q = model_q.log_prob(regime, episode, with_done=with_done) 151 | log_prob_p = model_p.log_prob(regime, episode, with_done=with_done) 152 | 153 | kl_p_q += (log_prob_p - log_prob_q).sum(dim=0) 154 | 155 | kl_p_q /= len(data_p) 156 | 157 | return kl_p_q 158 | 159 | 160 | def js_div_empirical(model_q, model_p, data_q, data_p, batch_size, with_done=False): 161 | 162 | assert next(model_q.parameters()).device == next(model_p.parameters()).device 163 | 164 | device = next(model_p.parameters()).device 165 | 166 | # Build DataLoaders 167 | dataloader_q = torch.utils.data.DataLoader(Dataset(data_q), batch_size=batch_size) 168 | dataloader_p = torch.utils.data.DataLoader(Dataset(data_p), batch_size=batch_size) 169 | 170 | # m = (p + q) / 2 171 | 172 | # KL(p|m) = E x~p(x) [log(p(x)) - log(q(x) + p(x)) + log(2)] 173 | kl_p_m = 0 174 | 175 | for batch in dataloader_p: 176 | regime, episode = batch 177 | regime, episode = regime.to(device), [tensor.to(device) for tensor in episode] 178 | 179 | log_prob_q = model_q.log_prob(regime, episode, with_done=with_done) 180 | log_prob_p = model_p.log_prob(regime, episode, with_done=with_done) 181 | log_prob_m = torch.logsumexp(torch.stack([log_prob_q, log_prob_p], dim=0), dim=0) # - torch.log(torch.tensor(2, device=device)) 182 | 183 | kl_p_m += (log_prob_p - log_prob_m).sum(dim=0) 184 | 185 | kl_p_m /= len(data_p) 186 | kl_p_m += torch.log(torch.tensor(2, device=device)) 187 | 188 | # KL(q|m) = E x~q(x) [log(q(x)) - log(q(x) + p(x)) + log(2)] 189 | kl_q_m = 0 190 | 191 | for batch in dataloader_q: 192 | regime, episode = batch 193 | regime, episode = regime.to(device), [tensor.to(device) for tensor in episode] 194 | 195 | log_prob_q = model_q.log_prob(regime, episode, with_done=with_done) 196 | log_prob_p = model_p.log_prob(regime, episode, with_done=with_done) 197 | log_prob_m = torch.logsumexp(torch.stack([log_prob_q, log_prob_p], dim=0), dim=0) # - torch.log(torch.tensor(2, device=device)) 198 | 199 | kl_q_m += (log_prob_q - log_prob_m).sum(dim=0) 200 | 201 | kl_q_m /= len(data_q) 202 | kl_q_m += torch.log(torch.tensor(2, device=device)) 203 | 204 | # JS(p|q) = (KL(p|m) + KL(q|m)) / 2 205 | 206 | return (kl_q_m + kl_p_m) / 2 207 | 208 | 209 | from collections import Counter 210 | def get_sampler_weights(data): 211 | # Get ratio Interventional/Observation 212 | indices_count = Counter([int(source) for source, ep in data]) 213 | 214 | # If there is more observational data that interventional, re-weigth the train data sampling 215 | if indices_count[0] > indices_count[1] : 216 | #if indices_count[0] > indices_count[1] : 217 | # 1/(2*Nint) for interventional data, 1/(2*Nobs) for obsevational data 218 | weights = [1./(2*indices_count[int(source)]) for source, ep in data] 219 | #weigths = [3./(4*indices_count[int(source)]) if source == torch.tensor(0) else 1./(4*indices_count[int(source)]) for source, ep in train_data] 220 | return weights 221 | else : 222 | return [1 for source, ep in data] 223 | 224 | -------------------------------------------------------------------------------- /utils_kallus.py: -------------------------------------------------------------------------------- 1 | #from sklearn.cluster import SpectralClustering as sc 2 | import numpy.random as rand 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from sklearn.linear_model import Ridge as ridge 6 | from sklearn.linear_model import Lasso as lasso 7 | from sklearn.linear_model import LinearRegression as ols 8 | 9 | from sklearn.ensemble import RandomForestRegressor as rfr 10 | from sklearn.tree import DecisionTreeRegressor as reg_tree 11 | from sklearn.ensemble import AdaBoostRegressor as ada_reg 12 | from sklearn.ensemble import GradientBoostingRegressor as gbr 13 | from sklearn.metrics import mean_squared_error as mse 14 | import copy 15 | 16 | import rpy2.robjects as robjects 17 | import rpy2.robjects.numpy2ri as np2ri 18 | 19 | robjects.numpy2ri.activate() 20 | 21 | from rpy2.robjects.packages import importr 22 | grf = importr('grf') 23 | cf = grf.causal_forest 24 | 25 | regs = [rfr(n_estimators=i) for i in [10, 20, 40, 60, 100, 150, 200]] 26 | regs += [reg_tree(max_depth=i) for i in [5, 10, 20, 30, 40, 50]] 27 | regs += [ada_reg(n_estimators=i) for i in [10, 20, 50, 70, 100, 150, 200]] 28 | regs += [gbr(n_estimators=i) for i in [50, 70, 100, 150, 200]] 29 | 30 | def get_best_for_data(X, Y, regs): 31 | x_train, x_test, y_train, y_test = X, X, Y, Y 32 | val_errs = [] 33 | models = [] 34 | for reg in regs: 35 | model = copy.deepcopy(reg) 36 | model.fit(x_train, y_train) 37 | val_errs.append(mse(y_test, model.predict(x_test))) 38 | models.append(copy.deepcopy(model)) 39 | min_ind = val_errs.index(min(val_errs)) 40 | return copy.deepcopy(models[min_ind]) 41 | 42 | linear_regs = [ols()] 43 | linear_regs.extend([lasso(alpha=alph) for alph in [1e-5,1e-3,1e-1,1,1e+1,1e+3,1e+5]]) 44 | linear_regs.extend([ridge(alpha=alph) for alph in [1e-5,1e-3,1e-1,1,1e+1,1e+3,1e+5]]) 45 | 46 | def run_method(X_rct, Y_rct, T_rct, X_obs, Y_obs, T_obs): 47 | 48 | #1. Estimate ^w with causal forest as family Q from observational dataset.. 49 | _cf_model = cf(X_obs.reshape(-1,1), Y_obs.reshape(-1,1), T_obs.astype(int).reshape(-1,1), num_trees=200) 50 | # .. evaluated on interventional dataset ^w(Xint) 51 | omega_int_pred = np.array([a[0] for a in grf.predict_causal_forest(_cf_model, X_rct.reshape(-1,1))]).ravel() 52 | 53 | #2.# from line above lemma 1, e^int(x) = 0.5 with interventional data and using re-weighting formula from paper 54 | # given q(X)Y = 2(2T -1)Y. # Lemma 1 gives q(x)Y as unbiaised estimate of tau(Xint) 55 | cate_int_est = 2*np.multiply(Y_rct, 2*T_rct-1).ravel() 56 | assert(cate_int_est.shape == omega_int_pred.shape) 57 | 58 | #theta * x = ^tau - ^w as solution of (1) from article. Ie linear regr with eta_est as obj 59 | eta_est = cate_int_est - omega_int_pred 60 | assert(len(eta_est.shape) == 1) 61 | best_eta_est_linear = get_best_for_data(X_rct.reshape(-1,1), eta_est, regs=linear_regs) 62 | 63 | return copy.deepcopy(best_eta_est_linear), eta_est, omega_int_pred --------------------------------------------------------------------------------