├── .gitignore ├── README.md ├── ccil ├── __init__.py ├── environments │ ├── __init__.py │ └── mountain_car.py ├── gen_data.py ├── imitate.py ├── intervention_policy_execution.py └── utils │ ├── __init__.py │ ├── data.py │ ├── models.py │ ├── policy_runner.py │ └── utils.py ├── data └── experts │ └── mountaincar_deepq_custom.pickle └── environment.yml /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data/* 3 | !data/experts/ 4 | **/*.pyc 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Causal Confusion in Imitation Learning 2 | This is the code accompanying the paper: 3 | "[Causal Confusion in Imitation Learning](https://arxiv.org/abs/1905.11979)" 4 | by Pim de Haan, Dinesh Jayaraman and Sergey Levine, published at NeurIPS 2019. 5 | See the [website](https://sites.google.com/view/causal-confusion) for a video presentation of the work. 6 | 7 | This simplified code implements the graph conditioned policy learning and intervention by policy execution for the MountainCar environment. 8 | Code for the other environments and intervention modes may be published at a later stage. 9 | 10 | For questions or comments, feel free to submit an issue. 11 | 12 | ## Dependencies 13 | Assumes machines with CUDA 10. For machine without GPU or different CUDA versions, you may need to tweak the pytorch and tensorflow dependency. 14 | 15 | Full dependency setup: 16 | ``` 17 | conda env create 18 | ``` 19 | Or by hand: 20 | ``` 21 | conda env create -n causal-confusion python=3.6 22 | conda activate causal-confusion 23 | conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 ignite -c pytorch 24 | conda install tensorflow-gpu==1.14 mpi4py scikit-learn 25 | pip install git+https://github.com/pimdh/baselines@no-mujoco 26 | ``` 27 | Note I reference to a modified version of OpenAI baselines, as the provide pickle of the MountainCar expert does not work with the upstream version. 28 | Also, I modified Baselines' `setup.py` to remove the Mujoco dependency, to allow for easier setup. 29 | 30 | ## Usage 31 | First generate demonstrations: 32 | ``` 33 | python -m ccil.gen_data 34 | ``` 35 | 36 | To show causal confusion with simple behaviour cloning agent on original and confounded state: 37 | ``` 38 | python -m ccil.imitate original simple 39 | python -m ccil.imitate confounded simple 40 | ``` 41 | 42 | To train graph-parametrized policy on confounded state: 43 | ``` 44 | python -m ccil.imitate confounded uniform --save 45 | ``` 46 | 47 | To perform intervention by policy execution: 48 | ``` 49 | python -m ccil.intervention_policy_execution --num_its 10 50 | ``` 51 | Optionally, setting the `DATA_PATH` environment variable allows one to change the location of data files from the default `./data`. 52 | -------------------------------------------------------------------------------- /ccil/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pimdh/causal-confusion/3c0d31601cdee160f12eaae0de4747ff5703d857/ccil/__init__.py -------------------------------------------------------------------------------- /ccil/environments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pimdh/causal-confusion/3c0d31601cdee160f12eaae0de4747ff5703d857/ccil/environments/__init__.py -------------------------------------------------------------------------------- /ccil/environments/mountain_car.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from gym.envs.classic_control import MountainCarEnv 4 | from gym.envs.registration import register 5 | from gym import Wrapper 6 | 7 | from ccil.utils.utils import data_root_path 8 | 9 | 10 | class MCRichDenseEnv(Wrapper): 11 | """Richer initial conditions + dense rewards for easy training of expert.""" 12 | def __init__(self): 13 | super().__init__(MountainCarEnv()) 14 | 15 | def reset(self): 16 | self.env.state = np.array([ 17 | self.np_random.uniform(low=-1, high=0.5), 18 | self.unwrapped.np_random.randn() * 0.07, 19 | ]) 20 | return np.array(self.env.state) 21 | 22 | def step(self, action): 23 | state, reward, done, info = super().step(action) 24 | 25 | reward = self.env._height(self.env.state[0]) * 0.5 - 1 26 | 27 | return state, reward, done, info 28 | 29 | 30 | class MCRichEnv(Wrapper): 31 | """Richer initial conditions.""" 32 | 33 | def __init__(self): 34 | super().__init__(MountainCarEnv()) 35 | 36 | def reset(self): 37 | self.env.state = np.array([ 38 | self.np_random.uniform(low=-1, high=0.5), 39 | self.unwrapped.np_random.randn() * 0.07, 40 | ]) 41 | return np.array(self.env.state) 42 | 43 | 44 | register( 45 | id="MountainCarRichDense-v0", 46 | entry_point="ccil.environments.mountain_car:MCRichDenseEnv", 47 | max_episode_steps=200, 48 | reward_threshold=-110.0, 49 | ) 50 | 51 | register( 52 | id="MountainCarRich-v0", 53 | entry_point="ccil.environments.mountain_car:MCRichEnv", 54 | max_episode_steps=200, 55 | reward_threshold=-110.0, 56 | ) 57 | 58 | 59 | class MountainCarStateEncoder: 60 | """ 61 | Map batch from TransitionDataset or Trajectory into state vector. 62 | """ 63 | def __init__(self, random): 64 | """ 65 | :param random: Whether to use random action. 66 | """ 67 | self.random = random 68 | 69 | def batch(self, batch): 70 | assert batch.states.shape[1] >= 2 71 | x = batch.states[:, -1, :] 72 | 73 | if self.random: 74 | prev_action = torch.randint(0, 3, (x.shape[0], 1), device=x.device, dtype=torch.float) 75 | else: 76 | prev_action = batch.actions[:, -2] 77 | 78 | return torch.cat([x.float(), prev_action.float()], 1) 79 | 80 | def step(self, state, trajectory): 81 | if trajectory and not self.random: 82 | prev_action = trajectory.actions[-1] 83 | else: 84 | prev_action = np.atleast_1d(np.random.randint(0, 3)) 85 | x = np.concatenate([state, prev_action]) 86 | return x 87 | 88 | 89 | class MountainExpertCarStateEncoder: 90 | """ 91 | Map batch from TransitionDataset or Trajectory into state vector. 92 | """ 93 | def batch(self, batch): 94 | return batch.states[:, -1, :] 95 | 96 | def step(self, state, trajectory): 97 | return state 98 | 99 | 100 | class MountainCarExpert: 101 | def __init__(self): 102 | expert_path = data_root_path / "experts/mountaincar_deepq_custom.pickle" 103 | from baselines import deepq 104 | self.expert = deepq.load_act(expert_path) 105 | 106 | def __call__(self, state): 107 | return self.expert(state) 108 | 109 | -------------------------------------------------------------------------------- /ccil/gen_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import gym 5 | import torch 6 | 7 | from ccil.environments.mountain_car import MountainExpertCarStateEncoder, MountainCarExpert 8 | from ccil.utils.data import TransitionDataset, Trajectory 9 | from ccil.utils.policy_runner import PolicyRunner 10 | from ccil.utils.utils import data_root_path 11 | 12 | 13 | def gen_data(args): 14 | if args.env == 'mountain_car': 15 | expert = MountainCarExpert() 16 | expert_state_encode = MountainExpertCarStateEncoder() 17 | env = gym.make("MountainCarRich-v0") 18 | else: 19 | raise ValueError() 20 | 21 | runner = PolicyRunner(env, expert, expert_state_encode) 22 | trajectories = runner.run_num_steps(args.num_steps, True) 23 | dataset = TransitionDataset.from_trajectories(trajectories, stack_size=2, expert_trajectories=True) 24 | 25 | if args.save_path is None: 26 | save_path = data_root_path / 'demonstrations' / f'{args.env}.pkl' 27 | else: 28 | save_path = Path(args.save_path) 29 | 30 | save_path.parent.mkdir(exist_ok=True, parents=True) 31 | torch.save(dataset, save_path) 32 | print(f'Mean reward: {Trajectory.reward_sum_mean(trajectories)}') 33 | print('Done') 34 | 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--env', default='mountain_car') 39 | parser.add_argument('--num_steps', type=int, default=100000) 40 | parser.add_argument('--save_path') 41 | gen_data(parser.parse_args()) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /ccil/imitate.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import argparse 3 | from functools import partial 4 | 5 | import gym 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from ignite.engine import Engine, Events 10 | from ignite.metrics import Accuracy, Loss 11 | from torch.utils.data import DataLoader 12 | 13 | from ccil.environments.mountain_car import MountainCarStateEncoder 14 | from ccil.utils.data import random_split, batch_cat, DataLoaderRepeater, Trajectory 15 | from ccil.utils.models import SimplePolicy, MLP, UniformMaskPolicy 16 | from ccil.utils.policy_runner import PolicyRunner, RandomMaskPolicyAgent, FixedMaskPolicyAgent 17 | from ccil.utils.utils import random_mask_from_state, data_root_path, mask_idx_to_mask 18 | 19 | 20 | def train_step(engine, batch, state_encoder, policy_model, optimizer, criterion, device): 21 | x, y = state_encoder.batch(batch), batch.labels() 22 | x, y = x.to(device), y.to(device) 23 | 24 | mask = random_mask_from_state(x) 25 | output = policy_model.forward(x, mask) 26 | loss = criterion(output, y) 27 | 28 | optimizer.zero_grad() 29 | loss.backward() 30 | optimizer.step() 31 | return loss 32 | 33 | 34 | def inference_step(engine, batch, state_encoder, policy_model, device): 35 | x, y = state_encoder.batch(batch), batch.labels() 36 | x, y = x.to(device), y.to(device) 37 | mask = random_mask_from_state(x) 38 | output = policy_model.forward(x, mask) 39 | return output, y 40 | 41 | 42 | def print_metrics(engine, trainer, evaluator_name): 43 | print( 44 | f"Epoch: {trainer.state.epoch:> 3} {evaluator_name.title(): <5} " 45 | f"loss={engine.state.metrics['loss']:.4f} " 46 | f"acc={engine.state.metrics['acc']:.4f}") 47 | 48 | 49 | def run_simple(policy_model, state_encoder): 50 | """ 51 | Run the policy in environment. 52 | """ 53 | env = gym.make("MountainCar-v0") 54 | agent = RandomMaskPolicyAgent(policy_model) 55 | runner = PolicyRunner(env, agent, state_encoder) 56 | trajectories = runner.run_num_episodes(20) 57 | print(f'Mean reward: {Trajectory.reward_sum_mean(trajectories)}') 58 | 59 | 60 | def run_uniform(policy_model, state_encoder): 61 | """ 62 | Run all 8 policies in environment. 63 | """ 64 | env = gym.make("MountainCar-v0") 65 | for mask_idx in range(8): 66 | agent = FixedMaskPolicyAgent(policy_model, mask_idx_to_mask(3, mask_idx)) 67 | runner = PolicyRunner(env, agent, state_encoder) 68 | trajectories = runner.run_num_episodes(20) 69 | mask = mask_idx_to_mask(3, mask_idx).tolist() 70 | print(f'Mean reward mask {mask}: {Trajectory.reward_sum_mean(trajectories)}') 71 | 72 | 73 | def imitate(args): 74 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 75 | 76 | dataset = torch.load(data_root_path / 'demonstrations' / 'mountain_car.pkl') 77 | train_dataset, test_dataset = random_split(dataset, [args.num_samples, args.num_samples], args.data_seed) 78 | 79 | dataloaders = { 80 | 'train': DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=batch_cat), 81 | 'test': DataLoader(test_dataset, batch_size=64, shuffle=True, collate_fn=batch_cat), 82 | } 83 | # So that 1 train epoch has fixed number of samples (500 batches) regardless of dataset size 84 | dataloaders['train_repeated'] = DataLoaderRepeater(dataloaders['train'], 500) 85 | 86 | if args.network == 'simple': 87 | policy_model = SimplePolicy(MLP([3, 50, 50, 3])).to(device) 88 | max_epochs = 10 89 | elif args.network == 'uniform': 90 | policy_model = UniformMaskPolicy(MLP([6, 50, 50, 50, 3])).to(device) 91 | max_epochs = 20 92 | else: 93 | raise ValueError() 94 | 95 | optimizer = torch.optim.Adam(policy_model.parameters()) 96 | 97 | def criterion(x, y): 98 | return F.cross_entropy(x, y[:, 0]) 99 | 100 | metrics = { 101 | 'loss': Loss(F.cross_entropy, output_transform=lambda x: (x[0], x[1][:, 0])), 102 | 'acc': Accuracy(output_transform=lambda x: (x[0], x[1][:, 0])), 103 | } 104 | 105 | state_encoder = MountainCarStateEncoder(args.input_mode == 'original') 106 | 107 | trainer = Engine(partial( 108 | train_step, state_encoder=state_encoder, policy_model=policy_model, 109 | optimizer=optimizer, criterion=criterion, device=device 110 | )) 111 | evaluators = { 112 | name: Engine(partial( 113 | inference_step, state_encoder=state_encoder, policy_model=policy_model, device=device)) 114 | for name in ['train', 'test']} 115 | for evaluator_name, evaluator in evaluators.items(): 116 | for name, metric in metrics.items(): 117 | metric.attach(evaluator, name) 118 | evaluator.add_event_handler(Events.COMPLETED, print_metrics, evaluator_name=evaluator_name, trainer=trainer) 119 | 120 | @trainer.on(Events.EPOCH_COMPLETED) 121 | def run_eval(_trainer): 122 | for name, evaluator in evaluators.items(): 123 | evaluator.run(dataloaders[name]) 124 | 125 | trainer.run(dataloaders['train_repeated'], max_epochs=max_epochs) 126 | print("Trained") 127 | 128 | # Run policies in environment 129 | run_fn = dict(simple=run_simple, uniform=run_uniform)[args.network] 130 | run_fn(policy_model, state_encoder) 131 | 132 | if args.save: 133 | name = args.name or f"{args.input_mode}_{args.network}_{datetime.now():%Y%m%d-%H%M%S}" 134 | save_dir = data_root_path / 'policies' 135 | save_dir.mkdir(parents=True, exist_ok=True) 136 | path = save_dir / f"{name}.pkl" 137 | torch.save(policy_model, path) 138 | print(f"Policy saved to {path}") 139 | 140 | 141 | def main(): 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('input_mode', choices=['original', 'confounded']) 144 | parser.add_argument('network', choices=['simple', 'uniform']) 145 | parser.add_argument('--data_seed', type=int, help="Seed for splitting train/test data. Default=random") 146 | parser.add_argument('--num_samples', type=int, default=300) 147 | parser.add_argument('--save', action='store_true') 148 | parser.add_argument('--name', help="Policy save filename") 149 | imitate(parser.parse_args()) 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /ccil/intervention_policy_execution.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | import gym 4 | import numpy as np 5 | import argparse 6 | from time import perf_counter 7 | 8 | import torch 9 | from sklearn.linear_model import Ridge 10 | from torch.distributions import Bernoulli 11 | 12 | from ccil.environments.mountain_car import MountainCarStateEncoder 13 | from ccil.utils.data import Trajectory 14 | from ccil.utils.policy_runner import PolicyRunner, FixedMaskPolicyAgent, run_fixed_mask 15 | from ccil.utils.utils import data_root_path 16 | 17 | 18 | def sample(weights, temperature): 19 | return Bernoulli(logits=torch.from_numpy(weights) / temperature).sample().long().numpy() 20 | 21 | 22 | def linear_regression(masks, rewards, alpha=1.0): 23 | model = Ridge(alpha).fit(masks, rewards) 24 | return model.coef_, model.intercept_ 25 | 26 | 27 | class SoftQAlgo: 28 | def __init__( 29 | self, 30 | num_dims, 31 | reward_fn, 32 | its, 33 | temperature=1.0, 34 | device=None, 35 | evals_per_it=1, 36 | ): 37 | self.num_dims = num_dims 38 | self.reward_fn = reward_fn 39 | self.its = its 40 | self.device = device 41 | self.temperature = lambda t: temperature 42 | self.evals_per_it = evals_per_it 43 | 44 | def run(self): 45 | t = self.temperature(0) 46 | weights = np.zeros(self.num_dims) 47 | 48 | trace = [] 49 | masks = [] 50 | rewards = [] 51 | for it in range(self.its): 52 | start = perf_counter() 53 | mask = sample(weights, t) 54 | reward = np.mean([self.reward_fn(mask) for _ in range(self.evals_per_it)]) 55 | masks.append(mask) 56 | rewards.append(reward) 57 | 58 | weights, _ = linear_regression(masks, rewards, alpha=1.0) 59 | 60 | trace.append( 61 | { 62 | "it": it, 63 | "reward": reward, 64 | "mask": mask, 65 | "weights": weights, 66 | "mode": (np.sign(weights).astype(np.int64) + 1) // 2, 67 | "time": perf_counter() - start, 68 | "past_mean_reward": np.mean(rewards), 69 | } 70 | ) 71 | pprint(trace[-1]) 72 | 73 | return trace 74 | 75 | 76 | def intervention_policy_execution(args): 77 | policy_save_dir = data_root_path / 'policies' 78 | if args.policy_name: 79 | policy_path = policy_save_dir / f"{args.policy_name}.pkl" 80 | else: 81 | policy_paths = policy_save_dir.glob('confounded_uniform*.pkl') 82 | if not policy_paths: 83 | raise RuntimeError("No policy found") 84 | policy_path = next(iter(sorted(policy_paths, reverse=True))) 85 | policy_model = torch.load(policy_path) 86 | print(f"Loaded policy from {policy_path}") 87 | 88 | env = gym.make("MountainCar-v0") 89 | state_encoder = MountainCarStateEncoder(random=False) 90 | 91 | def run_step(mask): 92 | trajectories = run_fixed_mask(env, policy_model, state_encoder, mask, 1) 93 | return Trajectory.reward_sum_mean(trajectories) 94 | 95 | trace = SoftQAlgo(3, run_step, args.num_its, temperature=10).run() 96 | 97 | best_mask = trace[-1]['mode'] 98 | trajectories = run_fixed_mask(env, policy_model, state_encoder, best_mask, 20) 99 | print(f"Final mask {best_mask.tolist()}") 100 | print(f"Final reward {Trajectory.reward_sum_mean(trajectories)}") 101 | 102 | 103 | def main(): 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--policy_name', help="Policy save filename") 106 | parser.add_argument('--num_its', type=int, default=20) 107 | parser.add_argument('--temperature', type=float, default=10) 108 | intervention_policy_execution(parser.parse_args()) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /ccil/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pimdh/causal-confusion/3c0d31601cdee160f12eaae0de4747ff5703d857/ccil/utils/__init__.py -------------------------------------------------------------------------------- /ccil/utils/data.py: -------------------------------------------------------------------------------- 1 | from itertools import accumulate 2 | from typing import List 3 | import numpy as np 4 | import torch as torch 5 | 6 | 7 | class Subset: 8 | def __init__(self, dataset, indices): 9 | self.dataset = dataset 10 | self.indices = indices 11 | 12 | def __getitem__(self, idx): 13 | return self.dataset[self.indices[idx]] 14 | 15 | def __len__(self): 16 | return len(self.indices) 17 | 18 | def to_stack_size(self, stack_size): 19 | self.dataset = self.dataset.to_stack_size(stack_size) 20 | return self 21 | 22 | 23 | def random_split(dataset, lengths, seed=None): 24 | s = sum([l for l in lengths if l > 0]) 25 | lengths = [l if l > 0 else len(dataset) - s for l in lengths] 26 | assert sum(lengths) <= len(dataset), "Too many samples requested" 27 | if seed is not None: 28 | prev_seed = np.random.get_state() 29 | np.random.seed(seed) 30 | indices = np.random.permutation(len(dataset)) 31 | np.random.set_state(prev_seed) 32 | else: 33 | indices = np.random.permutation(len(dataset)) 34 | indices = torch.tensor(indices, dtype=torch.long) 35 | 36 | return [ 37 | Subset(dataset, indices[offset - length : offset]) 38 | for offset, length in zip(accumulate(lengths), lengths) 39 | ] 40 | 41 | 42 | class GrowingArray: 43 | """ 44 | Array-like object that allows for efficient appending to end. 45 | """ 46 | def __init__(self, data: np.ndarray, buffer_size=1000): 47 | self._size = len(data) 48 | self._buffer = data 49 | self._expand(buffer_size) 50 | self._buffer_grow_size = buffer_size 51 | self.dtype = self._buffer.dtype 52 | 53 | @property 54 | def shape(self): 55 | return (self._size,) + self._buffer.shape[1:] 56 | 57 | def __len__(self): 58 | return self._size 59 | 60 | def __array__(self): 61 | return self._buffer[: self._size] 62 | 63 | def __getitem__(self, item): 64 | if isinstance(item, slice): 65 | item = slice(*item.indices(self._size)) 66 | return self._buffer[item] 67 | return np.array(self)[item] 68 | 69 | def add(self, new_data): 70 | shortage = self._size + len(new_data) - len(self._buffer) 71 | if shortage > 0: 72 | self._expand(self._buffer_grow_size + shortage) 73 | 74 | self._buffer[self._size : self._size + len(new_data)] = new_data 75 | self._size += len(new_data) 76 | 77 | def _expand(self, n): 78 | zeros = np.zeros((n,) + self._buffer.shape[1:], dtype=self._buffer.dtype) 79 | self._buffer = np.concatenate((self._buffer, zeros), 0) 80 | 81 | 82 | class DataLoaderRepeater: 83 | """ 84 | Create repeat data loader to make larger dataset out of smaller. 85 | """ 86 | def __init__(self, loader, batches_per_epoch): 87 | self.i = 0 88 | self.batches_per_epoch = batches_per_epoch 89 | self.loader = loader 90 | 91 | @staticmethod 92 | def cycle(iterable): 93 | """Cycle iterable non-caching.""" 94 | while True: 95 | for x in iterable: 96 | yield x 97 | 98 | def __iter__(self): 99 | self.iterator = self.cycle(self.loader) 100 | self.i = 0 101 | return self 102 | 103 | def __len__(self): 104 | return self.batches_per_epoch 105 | 106 | def __next__(self): 107 | if self.i == self.batches_per_epoch: 108 | raise StopIteration() 109 | 110 | self.i += 1 111 | return next(self.iterator) 112 | 113 | 114 | class Trajectory: 115 | """ 116 | Data representing single episode, possibly not yet finished. 117 | Uses GrowingArray to efficiently append new steps. 118 | """ 119 | keys = ["states", "actions", "rewards", "pixels"] 120 | 121 | def __init__( 122 | self, 123 | states: np.ndarray, 124 | actions: np.ndarray, 125 | rewards: np.ndarray, 126 | pixels: np.ndarray, 127 | info=None, 128 | ): 129 | n = len(states) 130 | assert n == len(actions) == len(rewards) == len(pixels) 131 | self.states = GrowingArray(states) 132 | # Ensure 1D actions 133 | self.actions = GrowingArray(actions.reshape((n, -1))) 134 | self.rewards = GrowingArray(np.nan_to_num(rewards)) 135 | self.pixels = GrowingArray(pixels) 136 | self.info = info or {} 137 | 138 | @property 139 | def arrays(self): 140 | return [self.states, self.actions, self.rewards, self.pixels] 141 | 142 | @classmethod 143 | def from_list(cls, lst): 144 | return cls(*zip(*lst)) 145 | 146 | def reward_sum(self): 147 | return np.sum(self.rewards) 148 | 149 | def action_repeat(self): 150 | if self.actions.dtype == np.float32: 151 | return float("NaN") 152 | return (np.array(self.actions[1:]) == np.array(self.actions[:-1])).astype(np.float32).mean() 153 | 154 | def __len__(self): 155 | return len(self.states) 156 | 157 | @classmethod 158 | def add_step(cls, old, *arrays, info=None): 159 | arrays = [np.asarray(x) for x in arrays] 160 | arrays[1] = np.atleast_1d(arrays[1]) # Ensure 1D actions 161 | if old is None: 162 | return cls( 163 | *[x[None] for x in arrays], 164 | info={k: [v] for k, v in (info or {}).items()} 165 | ) 166 | 167 | for old_arr, new_arr in zip(old.arrays, arrays): 168 | old_arr.add(np.asarray(new_arr)[None]) 169 | 170 | newinfo = ( 171 | {k: np.concatenate([old.info[k], [info[k]]]) for k in old.info.keys()} 172 | if info and old.info 173 | else None 174 | ) 175 | old.info = newinfo 176 | return old 177 | 178 | def finished(self): 179 | self.states = np.asarray(self.states) 180 | self.actions = np.asarray(self.actions) 181 | self.rewards = np.asarray(self.rewards) 182 | self.pixels = np.asarray(self.pixels) 183 | 184 | @staticmethod 185 | def reward_sum_mean(trajectories): 186 | return np.mean([t.reward_sum() for t in trajectories]).item() 187 | 188 | @staticmethod 189 | def reward_sum_std(trajectories): 190 | return np.std([t.reward_sum() for t in trajectories]).item() 191 | 192 | @staticmethod 193 | def action_repeat_mean(trajectories): 194 | return np.mean([t.action_repeat() for t in trajectories]).item() 195 | 196 | @staticmethod 197 | def info_sum_mean(key, trajectories): 198 | return np.mean([np.nansum(t.info[key]) for t in trajectories]).item() 199 | 200 | @staticmethod 201 | def info_mean_mean(key, trajectories): 202 | return np.mean([np.nanmean(t.info[key]) for t in trajectories]).item() 203 | 204 | def stack(self, stack_size, pad=False): 205 | """Stack subsequent rows. Order: earlier->later. 206 | 207 | If pad, then prepend with copies of first, so length constant.""" 208 | outs = [] 209 | for arr in self.arrays: 210 | if pad: 211 | # Workaround https://github.com/numpy/numpy/issues/11395 212 | if arr.dtype == np.object_ and arr[0] is None: 213 | arr = np.full(len(arr) + stack_size - 1, None) 214 | else: 215 | p = [(stack_size - 1, 0)] + [(0, 0)] * (len(arr.shape) - 1) 216 | arr = np.pad(arr, p, "edge") 217 | 218 | out = np.stack( 219 | [np.roll(arr, i, 0) for i in range(stack_size - 1, -1, -1)], 1 220 | ) 221 | outs.append(out[(stack_size - 1) :]) 222 | return outs 223 | 224 | 225 | class Batch: 226 | """ 227 | Batch of tensor data from TransitionDataset. 228 | """ 229 | absent_sentinel = -111 # Can not use NaN for int actions 230 | 231 | def __init__(self, states, actions, rewards, pixels, expert_actions, indices): 232 | self.states = states 233 | self.actions = actions 234 | self.rewards = rewards 235 | self.pixels = pixels 236 | self.expert_actions = expert_actions 237 | self.indices = indices 238 | 239 | def has_labels(self): 240 | labels = self.expert_actions[:, -1] 241 | return not (labels == self.absent_sentinel).any() 242 | 243 | def labels(self): 244 | labels = self.expert_actions[:, -1] 245 | assert self.has_labels(), "No expert action provided" 246 | return labels 247 | 248 | 249 | def batch_cat(batches: List[Batch]): 250 | no_pixels = any([len(b.pixels.shape) == 2 for b in batches]) 251 | tensors = [] 252 | n = sum([b.states.shape[0] for b in batches]) 253 | for k in ["states", "actions", "rewards", "pixels", "expert_actions", "indices"]: 254 | if k == "indices": 255 | x = torch.cat([b.indices + i * 100000000 for i, b in enumerate(batches)]) 256 | elif k == "pixels" and no_pixels: 257 | x = torch.zeros( 258 | n, batches[0].states.shape[1], device=batches[0].states.device 259 | ) 260 | else: 261 | x = torch.cat([getattr(b, k) for b in batches]) 262 | tensors.append(x) 263 | batch = Batch(*tensors) 264 | return batch 265 | 266 | 267 | class TransitionDataset: 268 | """ 269 | Stacked transitions efficiently stored. 270 | 271 | All objects are tensors. 272 | """ 273 | def __init__( 274 | self, states, actions, rewards, pixels, expert_actions, done, stack_size 275 | ): 276 | self.states = states 277 | self.actions = actions 278 | self.rewards = rewards 279 | self.pixels = pixels 280 | self.expert_actions = ( 281 | expert_actions 282 | if expert_actions is not None 283 | else self.actions.new_full(self.actions.shape, Batch.absent_sentinel) 284 | ) 285 | self.done = done 286 | self.stack_size = stack_size 287 | self.starts = self.compute_starts(done, stack_size) 288 | self.stack_arange = torch.arange(self.stack_size, dtype=torch.long) 289 | 290 | @property 291 | def tensors(self): 292 | return [ 293 | self.states, 294 | self.actions, 295 | self.rewards, 296 | self.pixels, 297 | self.expert_actions, 298 | self.done, 299 | ] 300 | 301 | def __getitem__(self, indices: np.ndarray): 302 | """Input [idx], output Batch, tensors of shape [idx, stack].""" 303 | if isinstance(indices, range): 304 | indices = np.arange( 305 | indices.start or 0, indices.stop or len(self), indices.step or 1 306 | ) 307 | 308 | indices = np.atleast_1d(indices) 309 | assert indices.size > 0, "Indices can not be empty" 310 | 311 | indices = self.starts[indices] 312 | indices = indices[:, None] + self.stack_arange[None, :] 313 | tensors = [t[indices] for t in self.tensors[:-1]] 314 | return Batch(*tensors, indices=indices) 315 | 316 | def to(self, device): 317 | self.states = self.states.to(device) 318 | self.actions = self.actions.to(device) 319 | self.rewards = self.rewards.to(device) 320 | self.pixels = self.pixels.to(device) 321 | self.expert_actions = self.expert_actions.to(device) 322 | return self 323 | 324 | @staticmethod 325 | def compute_starts(done: torch.Tensor, stack_size) -> torch.Tensor: 326 | """Compute indices where stack starts.""" 327 | assert done[-1] 328 | done_indices = torch.nonzero(done).view(-1) 329 | # Indices we can not start 330 | if stack_size == 1: 331 | mask_indices = done_indices 332 | else: 333 | mask_indices = ( 334 | done_indices[:, None] 335 | - torch.arange(stack_size - 1, dtype=torch.long)[None, :] 336 | ) 337 | mask_indices = mask_indices.view(-1).clamp(0, len(done) - 1) 338 | 339 | # We invert by explicitly creating the mask 340 | mask = torch.zeros_like(done) 341 | mask[mask_indices] = 1 342 | start_indices = torch.nonzero(1 - mask).view(-1) 343 | return start_indices 344 | 345 | def __len__(self): 346 | return len(self.starts) 347 | 348 | @classmethod 349 | def from_trajectories( 350 | cls, 351 | trajectories: List[Trajectory], 352 | stack_size: int, 353 | expert_trajectories=False, 354 | expert_actions=None, 355 | ): 356 | if len(trajectories) == 0: 357 | return None 358 | arrays = zip(*[t.arrays for t in trajectories]) 359 | lenghts = np.array([len(t) for t in trajectories]) 360 | length = np.sum(lenghts).item() 361 | done = np.zeros(length, dtype=np.uint8) 362 | done[np.cumsum(lenghts) - 1] = 1 363 | tensors = [ 364 | torch.from_numpy(np.concatenate(arr)) 365 | if arr[0].dtype != np.object_ 366 | else torch.zeros(length) 367 | for arr in arrays 368 | ] 369 | if expert_trajectories: 370 | expert_actions = tensors[1] 371 | return cls( 372 | *tensors, 373 | expert_actions=expert_actions, 374 | done=torch.from_numpy(done), 375 | stack_size=stack_size 376 | ) 377 | 378 | @classmethod 379 | def cat(cls, a, b): 380 | if a is None: 381 | return b 382 | assert a.stack_size == b.stack_size, "Stack sizes must match" 383 | tensors = [ 384 | torch.cat([t_a, t_b.to(t_a.device)]) 385 | for t_a, t_b in zip(a.tensors, b.tensors) 386 | ] 387 | return cls(*tensors, stack_size=a.stack_size) 388 | 389 | def to_stack_size(self, stack_size): 390 | return TransitionDataset(*self.tensors, stack_size) 391 | 392 | def returns(self, discount): 393 | ret = 0 394 | returns = torch.zeros_like(self.rewards) 395 | for i in range(len(self.rewards) - 1, -1, -1): 396 | if self.done[i]: 397 | ret = 0 398 | ret = returns[i] = discount * ret + self.rewards[i] 399 | return returns 400 | 401 | -------------------------------------------------------------------------------- /ccil/utils/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MLP(nn.Module): 6 | """Helper module to create MLPs.""" 7 | def __init__(self, dims, activation=nn.ReLU): 8 | super().__init__() 9 | blocks = nn.ModuleList() 10 | 11 | for i, (dim_in, dim_out) in enumerate(zip(dims, dims[1:])): 12 | blocks.append(nn.Linear(dim_in, dim_out)) 13 | 14 | if i < len(dims)-2: 15 | blocks.append(activation()) 16 | 17 | self.blocks = nn.Sequential(*blocks) 18 | 19 | def forward(self, x): 20 | return self.blocks(x) 21 | 22 | 23 | class SimplePolicy(nn.Module): 24 | def __init__(self, net): 25 | super().__init__() 26 | self.net = net 27 | 28 | def forward(self, state, mask): 29 | return self.net(state) 30 | 31 | 32 | class UniformMaskPolicy(nn.Module): 33 | def __init__(self, net): 34 | super().__init__() 35 | self.net = net 36 | 37 | def forward(self, state, mask): 38 | mask = mask.to(state).expand_as(state) 39 | x = torch.cat([state * mask, mask], 1) 40 | return self.net(x) 41 | -------------------------------------------------------------------------------- /ccil/utils/policy_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Categorical 3 | from tqdm import tqdm 4 | 5 | from ccil.utils.data import Trajectory 6 | from ccil.utils.utils import random_mask_from_state 7 | 8 | 9 | class PolicyRunner: 10 | def __init__(self, env, agent, state_encoder): 11 | self.env = env 12 | self.agent = agent 13 | self.state_encoder = state_encoder 14 | 15 | def run_episode(self): 16 | state, done = self.env.reset(), False 17 | trajectory = None 18 | while not done: 19 | x = self.state_encoder.step(state, trajectory) 20 | action = self.agent(x).item() 21 | 22 | prev_action, prev_state = action, state 23 | state, rew, done, info = self.env.step(action) 24 | 25 | trajectory = Trajectory.add_step( 26 | trajectory, prev_state, prev_action, rew, None, info=info 27 | ) 28 | trajectory.finished() 29 | return trajectory 30 | 31 | def run_num_steps(self, num_steps, verbose=False): 32 | progress_bar = tqdm(total=num_steps, disable=not verbose) 33 | steps = 0 34 | trajectories = [] 35 | while True: 36 | trajectory = self.run_episode() 37 | steps += len(trajectory) 38 | progress_bar.update(len(trajectory)) 39 | trajectories.append(trajectory) 40 | if steps >= num_steps: 41 | break 42 | 43 | progress_bar.close() 44 | return trajectories 45 | 46 | def run_num_episodes(self, num_episodes, verbose=False): 47 | trajectories = [] 48 | for _ in tqdm(range(num_episodes), disable=not verbose): 49 | trajectory = self.run_episode() 50 | trajectories.append(trajectory) 51 | return trajectories 52 | 53 | 54 | def run_fixed_mask(env, policy_model, state_encoder, mask, num_episodes): 55 | agent = FixedMaskPolicyAgent(policy_model, mask) 56 | runner = PolicyRunner(env, agent, state_encoder) 57 | trajectories = runner.run_num_episodes(num_episodes) 58 | return trajectories 59 | 60 | 61 | def hard_discrete_action(output): 62 | return output.argmax(-1) 63 | 64 | 65 | def sample_discrete_action(output): 66 | return Categorical(logits=output).sample() 67 | 68 | 69 | class RandomMaskPolicyAgent: 70 | def __init__(self, policy, output_transformation=hard_discrete_action): 71 | self.policy = policy 72 | self.device = next(policy.parameters()).device 73 | self.output_transformation = output_transformation 74 | 75 | def __call__(self, state): 76 | x = torch.tensor(state, device=self.device, dtype=torch.float)[None] 77 | mask = random_mask_from_state(x) 78 | output = self.policy.forward(x, mask) 79 | action = self.output_transformation(output) 80 | return action 81 | 82 | 83 | class FixedMaskPolicyAgent: 84 | def __init__(self, policy, mask, output_transformation=hard_discrete_action): 85 | self.policy = policy 86 | self.device = next(policy.parameters()).device 87 | self.mask = torch.tensor(mask, device=self.device, dtype=torch.float) 88 | self.output_transformation = output_transformation 89 | 90 | def __call__(self, state): 91 | x = torch.tensor(state, device=self.device, dtype=torch.float)[None] 92 | output = self.policy.forward(x, self.mask) 93 | action = self.output_transformation(output) 94 | return action 95 | -------------------------------------------------------------------------------- /ccil/utils/utils.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | data_root_path = Path(environ.get('DATA_PATH', './data')) 9 | 10 | 11 | def random_mask_from_state(x): 12 | return torch.randint(0, 2, size=x.shape, device=x.device) 13 | 14 | 15 | def mask_idx_to_mask(n, i): 16 | i = np.asarray(i) 17 | assert np.all(i < 2**n) 18 | r = 2 ** np.arange(n - 1, -1, -1) 19 | x = (i[..., None] % (2 * r)) // r 20 | return x 21 | 22 | 23 | def mask_to_mask_idx(mask): 24 | mask = np.asarray(mask) 25 | n = mask.shape[-1] 26 | return (mask * 2**np.arange(n - 1, -1, -1)).sum(-1) 27 | 28 | 29 | def test_mask_idx_to_mask(): 30 | assert mask_idx_to_mask(3, 0).tolist() == [0, 0, 0] 31 | assert mask_idx_to_mask(3, 1).tolist() == [0, 0, 1] 32 | assert mask_idx_to_mask(3, 2).tolist() == [0, 1, 0] 33 | assert mask_idx_to_mask(3, 7).tolist() == [1, 1, 1] 34 | assert mask_idx_to_mask(3, [1, 2]).tolist() == [[0, 0, 1], [0, 1, 0]] 35 | 36 | 37 | def test_mask_to_mask_idx(): 38 | assert mask_to_mask_idx([0, 0, 0]).tolist() == 0 39 | assert mask_to_mask_idx([[1, 0, 0], [1, 1, 1]]).tolist() == [4, 7] 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /data/experts/mountaincar_deepq_custom.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pimdh/causal-confusion/3c0d31601cdee160f12eaae0de4747ff5703d857/data/experts/mountaincar_deepq_custom.pickle -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: causal-confusion 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _tflow_select=2.1.0=gpu 8 | - absl-py=0.8.0=py36_0 9 | - astor=0.8.0=py36_0 10 | - blas=1.0=mkl 11 | - c-ares=1.15.0=h7b6447c_1001 12 | - ca-certificates=2019.10.16=0 13 | - certifi=2019.9.11=py36_0 14 | - cffi=1.13.0=py36h2e261b9_0 15 | - cudatoolkit=10.0.130=0 16 | - cudnn=7.6.0=cuda10.0_0 17 | - cupti=10.0.130=0 18 | - freetype=2.9.1=h8a8886c_1 19 | - gast=0.3.2=py_0 20 | - google-pasta=0.1.7=py_0 21 | - grpcio=1.16.1=py36hf8bcb03_1 22 | - h5py=2.9.0=py36h7918eee_0 23 | - hdf5=1.10.4=hb1b8bf9_0 24 | - ignite=0.2.1=py36_0 25 | - intel-openmp=2019.4=243 26 | - jpeg=9b=h024ee3a_2 27 | - keras-applications=1.0.8=py_0 28 | - keras-preprocessing=1.1.0=py_1 29 | - libedit=3.1.20181209=hc058e9b_0 30 | - libffi=3.2.1=hd88cf55_4 31 | - libgcc-ng=9.1.0=hdf63c60_0 32 | - libgfortran-ng=7.3.0=hdf63c60_0 33 | - libpng=1.6.37=hbc83047_0 34 | - libprotobuf=3.9.2=hd408876_0 35 | - libstdcxx-ng=9.1.0=hdf63c60_0 36 | - libtiff=4.0.10=h2733197_2 37 | - markdown=3.1.1=py36_0 38 | - mkl=2019.4=243 39 | - mkl-service=2.3.0=py36he904b0f_0 40 | - mkl_fft=1.0.14=py36ha843d7b_0 41 | - mkl_random=1.1.0=py36hd6b4f25_0 42 | - mpi4py=2.0.0=py36_2 43 | - mpich2=1.4.1p1=0 44 | - ncurses=6.1=he6710b0_1 45 | - ninja=1.9.0=py36hfd86e86_0 46 | - numpy=1.17.2=py36haad9e8e_0 47 | - numpy-base=1.17.2=py36hde5b4d6_0 48 | - olefile=0.46=py36_0 49 | - openssl=1.1.1d=h7b6447c_3 50 | - pillow=6.2.0=py36h34e0f95_0 51 | - pip=19.3.1=py36_0 52 | - protobuf=3.9.2=py36he6710b0_0 53 | - pycparser=2.19=py36_0 54 | - python=3.6.9=h265db76_0 55 | - pytorch=1.0.1=py3.6_cuda10.0.130_cudnn7.4.2_2 56 | - readline=7.0=h7b6447c_5 57 | - scikit-learn=0.21.3=py36hd81dba3_0 58 | - scipy=1.3.1=py36h7c811a0_0 59 | - setuptools=41.4.0=py36_0 60 | - six=1.12.0=py36_0 61 | - sqlite=3.30.1=h7b6447c_0 62 | - tensorboard=1.14.0=py36hf484d3e_0 63 | - tensorflow=1.14.0=gpu_py36h57aa796_0 64 | - tensorflow-base=1.14.0=gpu_py36h8d69cac_0 65 | - tensorflow-estimator=1.14.0=py_0 66 | - tensorflow-gpu=1.14.0=h0d30ee6_0 67 | - termcolor=1.1.0=py36_1 68 | - tk=8.6.8=hbc83047_0 69 | - torchvision=0.2.2=py_3 70 | - werkzeug=0.16.0=py_0 71 | - wheel=0.33.6=py36_0 72 | - wrapt=1.11.2=py36h7b6447c_0 73 | - xz=5.2.4=h14c3975_4 74 | - zlib=1.2.11=h7b6447c_3 75 | - zstd=1.3.7=h0b5b093_0 76 | - pip: 77 | - "git+https://github.com/pimdh/baselines@no-mujoco" 78 | - click==7.0 79 | - cloudpickle==1.2.2 80 | - dill==0.3.1.1 81 | - future==0.18.1 82 | - gym==0.15.3 83 | - joblib==0.14.0 84 | - opencv-python==4.1.1.26 85 | - progressbar2==3.47.0 86 | - pyglet==1.3.2 87 | - python-utils==2.3.0 88 | - tqdm==4.36.1 89 | 90 | --------------------------------------------------------------------------------