├── .gitignore ├── configs └── esper │ ├── connect_four.yaml │ ├── gambling.yaml │ └── tfe.yaml ├── data ├── connect_four.ret ├── gambling.ret └── tfe.ret ├── decision_transformer ├── conda_env.yml ├── data │ ├── download_d4rl_datasets.py │ └── download_esper_datasets.py ├── decision_transformer │ ├── envs │ │ ├── assets │ │ │ └── reacher_2d.xml │ │ └── reacher_2d.py │ ├── evaluation │ │ └── evaluate_episodes.py │ ├── models │ │ ├── decision_transformer.py │ │ ├── mlp_bc.py │ │ ├── model.py │ │ └── trajectory_gpt2.py │ ├── training │ │ ├── act_trainer.py │ │ ├── seq_trainer.py │ │ └── trainer.py │ └── utils │ │ ├── __init__.py │ │ ├── convert_dataset.py │ │ └── preemption.py ├── experiment.py └── readme-gym.md ├── readme.md ├── requirements.txt ├── return_transforms ├── algos │ ├── __init__.py │ └── esper │ │ ├── __init__.py │ │ └── esper.py ├── datasets │ └── esper_dataset.py ├── generate.py ├── models │ ├── __init__.py │ ├── basic │ │ ├── __init__.py │ │ └── mlp.py │ └── esper │ │ ├── __init__.py │ │ ├── cluster_model.py │ │ └── dynamics_model.py └── utils │ ├── __init__.py │ └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | *.pkl 163 | 164 | # Don't include any wandb folders (regarddless of location) 165 | **/wandb/ 166 | 167 | # VScode 168 | .vscode/ -------------------------------------------------------------------------------- /configs/esper/connect_four.yaml: -------------------------------------------------------------------------------- 1 | method: esper 2 | normalize: false 3 | dynamics_model_args: 4 | hidden_size: 512 5 | num_layers: 2 6 | activation: "relu" 7 | batchnorm: True 8 | layernorm: False 9 | dropout: 0.0 10 | cluster_model_args: 11 | rep_size: 128 12 | groups: 4 13 | obs_action_model: 14 | hidden_size: 512 15 | num_layers: 2 16 | activation: "relu" 17 | batchnorm: True 18 | layernorm: False 19 | dropout: 0.0 20 | ret_obs_action_model: 21 | hidden_size: 512 22 | num_layers: 2 23 | activation: "relu" 24 | batchnorm: True 25 | layernorm: False 26 | dropout: 0.0 27 | logit_model: 28 | hidden_size: 512 29 | num_layers: 1 30 | activation: "relu" 31 | batchnorm: True 32 | layernorm: False 33 | dropout: 0.0 34 | return_model: 35 | hidden_size: 512 36 | num_layers: 2 37 | activation: "relu" 38 | batchnorm: True 39 | layernorm: False 40 | dropout: 0.0 41 | action_model: 42 | hidden_size: 512 43 | num_layers: 2 44 | activation: "relu" 45 | batchnorm: True 46 | layernorm: False 47 | dropout: 0.0 48 | train_args: 49 | gamma: 1.0 50 | scale: 1.0 51 | dynamics_model_lr: 5e-4 52 | cluster_model_lr: 1e-4 53 | batch_size: 100 54 | cluster_epochs: 5 55 | return_epochs: 5 56 | adv_loss_weight: 1.0 57 | act_loss_weight: 0.05 58 | -------------------------------------------------------------------------------- /configs/esper/gambling.yaml: -------------------------------------------------------------------------------- 1 | method: esper 2 | normalize: false 3 | dynamics_model_args: 4 | hidden_size: 512 5 | num_layers: 2 6 | activation: "relu" 7 | batchnorm: True 8 | layernorm: False 9 | dropout: 0.0 10 | cluster_model_args: 11 | rep_size: 8 12 | groups: 1 13 | obs_action_model: 14 | hidden_size: 512 15 | num_layers: 2 16 | activation: "relu" 17 | batchnorm: True 18 | layernorm: False 19 | dropout: 0.0 20 | ret_obs_action_model: 21 | hidden_size: 512 22 | num_layers: 2 23 | activation: "relu" 24 | batchnorm: True 25 | layernorm: False 26 | dropout: 0.0 27 | logit_model: 28 | hidden_size: 512 29 | num_layers: 1 30 | activation: "relu" 31 | batchnorm: True 32 | layernorm: False 33 | dropout: 0.0 34 | return_model: 35 | hidden_size: 512 36 | num_layers: 2 37 | activation: "relu" 38 | batchnorm: True 39 | layernorm: False 40 | dropout: 0.0 41 | action_model: 42 | hidden_size: 512 43 | num_layers: 2 44 | activation: "relu" 45 | batchnorm: True 46 | layernorm: False 47 | dropout: 0.0 48 | train_args: 49 | gamma: 1.0 50 | scale: 5.0 51 | dynamics_model_lr: 5e-4 52 | cluster_model_lr: 1e-4 53 | batch_size: 100 54 | cluster_epochs: 5 55 | return_epochs: 1 56 | adv_loss_weight: 1.0 57 | act_loss_weight: 0.01 58 | -------------------------------------------------------------------------------- /configs/esper/tfe.yaml: -------------------------------------------------------------------------------- 1 | method: esper 2 | normalize: false 3 | dynamics_model_args: 4 | hidden_size: 512 5 | num_layers: 2 6 | activation: "relu" 7 | batchnorm: True 8 | layernorm: False 9 | dropout: 0.0 10 | cluster_model_args: 11 | rep_size: 128 12 | groups: 4 13 | obs_action_model: 14 | hidden_size: 512 15 | num_layers: 2 16 | activation: "relu" 17 | batchnorm: True 18 | layernorm: False 19 | dropout: 0.0 20 | ret_obs_action_model: 21 | hidden_size: 512 22 | num_layers: 2 23 | activation: "relu" 24 | batchnorm: True 25 | layernorm: False 26 | dropout: 0.0 27 | logit_model: 28 | hidden_size: 512 29 | num_layers: 1 30 | activation: "relu" 31 | batchnorm: True 32 | layernorm: False 33 | dropout: 0.0 34 | return_model: 35 | hidden_size: 512 36 | num_layers: 2 37 | activation: "relu" 38 | batchnorm: True 39 | layernorm: False 40 | dropout: 0.0 41 | action_model: 42 | hidden_size: 512 43 | num_layers: 2 44 | activation: "relu" 45 | batchnorm: True 46 | layernorm: False 47 | dropout: 0.0 48 | train_args: 49 | gamma: 1.0 50 | scale: 1.0 51 | dynamics_model_lr: 5e-4 52 | cluster_model_lr: 1e-4 53 | batch_size: 100 54 | cluster_epochs: 4 55 | return_epochs: 1 56 | adv_loss_weight: 1.0 57 | act_loss_weight: 0.02 58 | -------------------------------------------------------------------------------- /data/connect_four.ret: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/data/connect_four.ret -------------------------------------------------------------------------------- /data/gambling.ret: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/data/gambling.ret -------------------------------------------------------------------------------- /data/tfe.ret: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/data/tfe.ret -------------------------------------------------------------------------------- /decision_transformer/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: decision-transformer-gym 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3.8.5 6 | - anaconda 7 | - cudatoolkit=10. 8 | - numpy 9 | - pip 10 | - pip: 11 | - gym==0.18.3 12 | - mujoco-py==2.0.2.13 13 | - numpy==1.20.3 14 | - torch==1.8.1 15 | - transformers==4.5.1 16 | - wandb==0.9.1 17 | -------------------------------------------------------------------------------- /decision_transformer/data/download_d4rl_datasets.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | import collections 5 | import pickle 6 | 7 | import d4rl 8 | 9 | 10 | datasets = [] 11 | 12 | for env_name in ['halfcheetah', 'hopper', 'walker2d']: 13 | for dataset_type in ['medium', 'medium-replay', 'expert', 'medium-expert']: 14 | name = f'{env_name}-{dataset_type}-v2' 15 | env = gym.make(name) 16 | dataset = env.get_dataset() 17 | 18 | N = dataset['rewards'].shape[0] 19 | data_ = collections.defaultdict(list) 20 | 21 | use_timeouts = False 22 | if 'timeouts' in dataset: 23 | use_timeouts = True 24 | 25 | episode_step = 0 26 | paths = [] 27 | for i in range(N): 28 | done_bool = bool(dataset['terminals'][i]) 29 | if use_timeouts: 30 | final_timestep = dataset['timeouts'][i] 31 | else: 32 | final_timestep = (episode_step == 1000-1) 33 | for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']: 34 | data_[k].append(dataset[k][i]) 35 | if done_bool or final_timestep: 36 | episode_step = 0 37 | episode_data = {} 38 | for k in data_: 39 | episode_data[k] = np.array(data_[k]) 40 | paths.append(episode_data) 41 | data_ = collections.defaultdict(list) 42 | episode_step += 1 43 | 44 | returns = np.array([np.sum(p['rewards']) for p in paths]) 45 | num_samples = np.sum([p['rewards'].shape[0] for p in paths]) 46 | print(f'Number of samples collected: {num_samples}') 47 | print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') 48 | 49 | with open(f'{name}.pkl', 'wb') as f: 50 | pickle.dump(paths, f) 51 | -------------------------------------------------------------------------------- /decision_transformer/data/download_esper_datasets.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | import collections 5 | import pickle 6 | 7 | from stochastic_offline_envs.envs.offline_envs.connect_four_offline_env import ConnectFourOfflineEnv 8 | from stochastic_offline_envs.envs.offline_envs.tfe_offline_env import TFEOfflineEnv 9 | from stochastic_offline_envs.envs.offline_envs.gambling_offline_env import GamblingOfflineEnv 10 | 11 | 12 | def save_esper_dataset(name, offline_env): 13 | """Gets transforms the traj list format into one that the 14 | Decision Transformer codebase can understand and saves 15 | as a pickle.""" 16 | env = offline_env.env_cls() 17 | n_actions = env.action_space.n 18 | 19 | trajs = offline_env.trajs 20 | episode_step = 0 21 | paths = [] 22 | for traj in trajs: 23 | episode_data = collections.defaultdict(list) 24 | if 'connect_four' in name: 25 | episode_data['observations'] = np.array( 26 | [obs['grid'].reshape(-1) for obs in traj.obs]) 27 | else: 28 | episode_data['observations'] = np.array( 29 | [obs.reshape(-1) for obs in traj.obs]) 30 | 31 | a = np.array(traj.actions) 32 | actions = np.zeros((a.size, n_actions)) 33 | actions[np.arange(a.size), a] = 1 34 | episode_data['actions'] = actions 35 | episode_data['rewards'] = np.array(traj.rewards) 36 | terminals = np.array([False] * (len(traj.obs) - 1) + [True]) 37 | episode_data['terminals'] = terminals 38 | paths.append(episode_data) 39 | 40 | returns = np.array([np.sum(p['rewards']) for p in paths]) 41 | num_samples = np.sum([p['rewards'].shape[0] for p in paths]) 42 | print(f'Number of samples collected: {num_samples}') 43 | print( 44 | f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') 45 | 46 | with open(f'{name}.pkl', 'wb') as f: 47 | pickle.dump(paths, f) 48 | 49 | 50 | # Gambling task 51 | env_type = 'alias' 52 | name = 'gambling-default-v2' 53 | offline_env = GamblingOfflineEnv() 54 | 55 | save_esper_dataset(name, offline_env) 56 | 57 | # Connect Four task 58 | name = 'connect_four-default-v2' 59 | offline_env = ConnectFourOfflineEnv() 60 | save_esper_dataset(name, offline_env) 61 | 62 | # 2048 task 63 | name = 'tfe-default-v2' 64 | offline_env = TFEOfflineEnv() 65 | save_esper_dataset(name, offline_env) 66 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/envs/assets/reacher_2d.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 34 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/envs/reacher_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | 5 | import os 6 | 7 | 8 | class Reacher2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | 10 | def __init__(self): 11 | self.fingertip_sid = 0 12 | self.target_bid = 0 13 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 14 | mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/reacher_2d.xml', 15) 15 | self.fingertip_sid = self.sim.model.site_name2id('fingertip') 16 | self.target_bid = self.sim.model.body_name2id('target') 17 | utils.EzPickle.__init__(self) 18 | 19 | def step(self, action): 20 | action = np.clip(action, -1.0, 1.0) 21 | self.do_simulation(action, self.frame_skip) 22 | tip = self.data.site_xpos[self.fingertip_sid][:2] 23 | tar = self.data.body_xpos[self.target_bid][:2] 24 | dist = np.sum(np.abs(tip - tar)) 25 | reward_dist = 0. # - 0.1 * dist 26 | reward_ctrl = 0.0 27 | reward_bonus = 1.0 if dist < 0.1 else 0.0 28 | reward = reward_bonus + reward_ctrl + reward_dist 29 | done = False 30 | ob = self._get_obs() 31 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl, reward_bonus=reward_bonus) 32 | 33 | def _get_obs(self): 34 | theta = self.data.qpos.ravel() 35 | tip = self.data.site_xpos[self.fingertip_sid][:2] 36 | tar = self.data.body_xpos[self.target_bid][:2] 37 | return np.concatenate([ 38 | # self.data.qpos.flat, 39 | np.sin(theta), 40 | np.cos(theta), 41 | self.dt * self.data.qvel.ravel(), 42 | tip, 43 | tar, 44 | tip-tar, 45 | ]) 46 | 47 | def reset_model(self): 48 | # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos 49 | # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 50 | qpos = self.np_random.uniform(low=-2.0, high=2.0, size=self.model.nq) 51 | qvel = self.init_qvel * 0.0 52 | while True: 53 | self.goal = self.np_random.uniform(low=-1.5, high=1.5, size=2) 54 | if np.linalg.norm(self.goal) <= 1.0 and np.linalg.norm(self.goal) >= 0.5: 55 | break 56 | self.set_state(qpos, qvel) 57 | self.model.body_pos[self.target_bid][:2] = self.goal 58 | self.sim.forward() 59 | return self._get_obs() 60 | 61 | def viewer_setup(self): 62 | self.viewer.cam.distance = self.model.stat.extent * 5.0 63 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/evaluation/evaluate_episodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | from torch.distributions import Categorical 6 | 7 | def evaluate_episode( 8 | env, 9 | state_dim, 10 | act_dim, 11 | model, 12 | max_ep_len=1000, 13 | device='cuda', 14 | target_return=None, 15 | mode='normal', 16 | state_mean=0., 17 | state_std=1., 18 | ): 19 | 20 | model.eval() 21 | model.to(device=device) 22 | 23 | state_mean = torch.from_numpy(state_mean).to(device=device) 24 | state_std = torch.from_numpy(state_std).to(device=device) 25 | 26 | state = env.reset() 27 | 28 | # we keep all the histories on the device 29 | # note that the latest action and reward will be "padding" 30 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) 31 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) 32 | rewards = torch.zeros(0, device=device, dtype=torch.float32) 33 | target_return = torch.tensor(target_return, device=device, dtype=torch.float32) 34 | sim_states = [] 35 | 36 | episode_return, episode_length = 0, 0 37 | for t in range(max_ep_len): 38 | 39 | # add padding 40 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) 41 | rewards = torch.cat([rewards, torch.zeros(1, device=device)]) 42 | 43 | action = model.get_action( 44 | (states.to(dtype=torch.float32) - state_mean) / state_std, 45 | actions.to(dtype=torch.float32), 46 | rewards.to(dtype=torch.float32), 47 | target_return=target_return, 48 | ) 49 | actions[-1] = action 50 | action = action.detach().cpu().numpy() 51 | 52 | state, reward, done, _ = env.step(action) 53 | 54 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) 55 | states = torch.cat([states, cur_state], dim=0) 56 | rewards[-1] = reward 57 | 58 | episode_return += reward 59 | episode_length += 1 60 | 61 | if done: 62 | break 63 | 64 | return episode_return, episode_length 65 | 66 | 67 | def evaluate_episode_rtg( 68 | env, 69 | state_dim, 70 | act_dim, 71 | model, 72 | max_ep_len=1000, 73 | scale=1000., 74 | state_mean=0., 75 | state_std=1., 76 | device='cuda', 77 | target_return=None, 78 | mode='normal', 79 | action_type='continuous' 80 | ): 81 | 82 | model.eval() 83 | model.to(device=device) 84 | 85 | state_mean = torch.from_numpy(state_mean).to(device=device) 86 | state_std = torch.from_numpy(state_std).to(device=device) 87 | 88 | state = env.reset() 89 | if mode == 'noise': 90 | state = state + np.random.normal(0, 0.1, size=state.shape) 91 | 92 | # we keep all the histories on the device 93 | # note that the latest action and reward will be "padding" 94 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) 95 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) 96 | rewards = torch.zeros(0, device=device, dtype=torch.float32) 97 | 98 | ep_return = target_return 99 | target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1) 100 | timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1) 101 | 102 | sim_states = [] 103 | 104 | episode_return, episode_length = 0, 0 105 | for t in range(max_ep_len): 106 | 107 | # add padding 108 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) 109 | rewards = torch.cat([rewards, torch.zeros(1, device=device)]) 110 | 111 | action = model.get_action( 112 | (states.to(dtype=torch.float32) - state_mean) / state_std, 113 | actions.to(dtype=torch.float32), 114 | rewards.to(dtype=torch.float32), 115 | target_return.to(dtype=torch.float32), 116 | timesteps.to(dtype=torch.long), 117 | ) 118 | 119 | if action_type == 'discrete': 120 | # sample action 121 | act_probs = F.softmax(action, dim=-1) 122 | action = Categorical(probs=act_probs).sample() 123 | # make the action one hot 124 | one_hot_action = torch.zeros(1, act_dim).float() 125 | one_hot_action[0, action] = 1 126 | actions[-1] = one_hot_action 127 | else: 128 | actions[-1] = action 129 | action = action.detach().cpu().numpy() 130 | 131 | state, reward, done, _ = env.step(action) 132 | 133 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) 134 | states = torch.cat([states, cur_state], dim=0) 135 | rewards[-1] = reward 136 | 137 | if mode != 'delayed': 138 | pred_return = target_return[0,-1] - (reward/scale) 139 | else: 140 | pred_return = target_return[0,-1] 141 | target_return = torch.cat( 142 | [target_return, pred_return.reshape(1, 1)], dim=1) 143 | timesteps = torch.cat( 144 | [timesteps, 145 | torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1) 146 | 147 | episode_return += reward 148 | episode_length += 1 149 | 150 | if done: 151 | break 152 | 153 | return episode_return, episode_length 154 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/models/decision_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import transformers 6 | 7 | from decision_transformer.models.model import TrajectoryModel 8 | from decision_transformer.models.trajectory_gpt2 import GPT2Model 9 | 10 | 11 | class DecisionTransformer(TrajectoryModel): 12 | 13 | """ 14 | This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...) 15 | """ 16 | 17 | def __init__( 18 | self, 19 | state_dim, 20 | act_dim, 21 | hidden_size, 22 | max_length=None, 23 | max_ep_len=4096, 24 | action_tanh=True, 25 | rtg_seq=True, 26 | **kwargs 27 | ): 28 | super().__init__(state_dim, act_dim, max_length=max_length) 29 | 30 | self.hidden_size = hidden_size 31 | self.rtg_seq = rtg_seq 32 | config = transformers.GPT2Config( 33 | vocab_size=1, # doesn't matter -- we don't use the vocab 34 | n_embd=hidden_size, 35 | **kwargs 36 | ) 37 | 38 | # note: the only difference between this GPT2Model and the default Huggingface version 39 | # is that the positional embeddings are removed (since we'll add those ourselves) 40 | self.transformer = GPT2Model(config) 41 | 42 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) 43 | self.embed_return = torch.nn.Linear(1, hidden_size) 44 | self.embed_state = torch.nn.Linear(self.state_dim, hidden_size) 45 | self.embed_action = torch.nn.Linear(self.act_dim, hidden_size) 46 | 47 | self.embed_ln = nn.LayerNorm(hidden_size) 48 | 49 | # note: we don't predict states or returns for the paper 50 | self.predict_state = torch.nn.Linear(hidden_size, self.state_dim) 51 | if self.rtg_seq: 52 | self.predict_action = nn.Sequential( 53 | *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else [])) 54 | ) 55 | else: 56 | self.predict_action = nn.Sequential( 57 | *([nn.Linear(hidden_size * 2, hidden_size), 58 | nn.ReLU(), 59 | nn.Linear(hidden_size, hidden_size), 60 | nn.ReLU(), 61 | nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else [])) 62 | ) 63 | self.predict_return = torch.nn.Linear(hidden_size, 1) 64 | 65 | def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): 66 | 67 | if self.rtg_seq: 68 | embed_per_timestep = 3 69 | else: 70 | embed_per_timestep = 2 71 | 72 | batch_size, seq_length = states.shape[0], states.shape[1] 73 | 74 | if attention_mask is None: 75 | # attention mask for GPT: 1 if can be attended to, 0 if not 76 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 77 | 78 | # embed each modality with a different head 79 | state_embeddings = self.embed_state(states) 80 | action_embeddings = self.embed_action(actions) 81 | returns_embeddings = self.embed_return(returns_to_go) 82 | time_embeddings = self.embed_timestep(timesteps) 83 | 84 | # time embeddings are treated similar to positional embeddings 85 | state_embeddings = state_embeddings + time_embeddings 86 | action_embeddings = action_embeddings + time_embeddings 87 | if self.rtg_seq: 88 | returns_embeddings = returns_embeddings + time_embeddings 89 | 90 | if self.rtg_seq: 91 | # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) 92 | # which works nice in an autoregressive sense since states predict actions 93 | stacked_inputs = torch.stack( 94 | (returns_embeddings, state_embeddings, action_embeddings), dim=1 95 | ).permute(0, 2, 1, 3).reshape(batch_size, embed_per_timestep*seq_length, self.hidden_size) 96 | stacked_inputs = self.embed_ln(stacked_inputs) 97 | 98 | # to make the attention mask fit the stacked inputs, have to stack it as well 99 | stacked_attention_mask = torch.stack( 100 | (attention_mask, attention_mask, attention_mask), dim=1 101 | ).permute(0, 2, 1).reshape(batch_size, embed_per_timestep*seq_length) 102 | else: 103 | # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) 104 | # which works nice in an autoregressive sense since states predict actions 105 | stacked_inputs = torch.stack( 106 | (state_embeddings, action_embeddings), dim=1 107 | ).permute(0, 2, 1, 3).reshape(batch_size, embed_per_timestep*seq_length, self.hidden_size) 108 | stacked_inputs = self.embed_ln(stacked_inputs) 109 | 110 | # to make the attention mask fit the stacked inputs, have to stack it as well 111 | stacked_attention_mask = torch.stack( 112 | (attention_mask, attention_mask), dim=1 113 | ).permute(0, 2, 1).reshape(batch_size, embed_per_timestep*seq_length) 114 | 115 | # we feed in the input embeddings (not word indices as in NLP) to the model 116 | transformer_outputs = self.transformer( 117 | inputs_embeds=stacked_inputs, 118 | attention_mask=stacked_attention_mask, 119 | ) 120 | x = transformer_outputs['last_hidden_state'] 121 | 122 | # reshape x so that the second dimension corresponds to the original 123 | # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t 124 | x = x.reshape(batch_size, seq_length, embed_per_timestep, self.hidden_size).permute(0, 2, 1, 3) 125 | 126 | # get predictions 127 | if self.rtg_seq: 128 | return_preds = self.predict_return(x[:,2]) # predict next return given state and action 129 | state_preds = self.predict_state(x[:,2]) # predict next state given state and action 130 | action_preds = self.predict_action(x[:,1]) # predict next action given state 131 | else: 132 | state_preds = self.predict_state(x[:,1]) # predict next state given state and action 133 | state_return = torch.cat((x[:,0], returns_embeddings), dim=-1) 134 | action_preds = self.predict_action(state_return) # predict next action given state 135 | 136 | return state_preds, action_preds, return_preds 137 | 138 | def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs): 139 | # we don't care about the past rewards in this model 140 | 141 | states = states.reshape(1, -1, self.state_dim) 142 | actions = actions.reshape(1, -1, self.act_dim) 143 | returns_to_go = returns_to_go.reshape(1, -1, 1) 144 | timesteps = timesteps.reshape(1, -1) 145 | 146 | if self.max_length is not None: 147 | states = states[:,-self.max_length:] 148 | actions = actions[:,-self.max_length:] 149 | returns_to_go = returns_to_go[:,-self.max_length:] 150 | timesteps = timesteps[:,-self.max_length:] 151 | 152 | # pad all tokens to sequence length 153 | attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])]) 154 | attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) 155 | states = torch.cat( 156 | [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states], 157 | dim=1).to(dtype=torch.float32) 158 | actions = torch.cat( 159 | [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim), 160 | device=actions.device), actions], 161 | dim=1).to(dtype=torch.float32) 162 | returns_to_go = torch.cat( 163 | [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], 164 | dim=1).to(dtype=torch.float32) 165 | timesteps = torch.cat( 166 | [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps], 167 | dim=1 168 | ).to(dtype=torch.long) 169 | else: 170 | attention_mask = None 171 | 172 | _, action_preds, return_preds = self.forward( 173 | states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) 174 | 175 | return action_preds[0,-1] 176 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/models/mlp_bc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from decision_transformer.models.model import TrajectoryModel 6 | 7 | 8 | class MLPBCModel(TrajectoryModel): 9 | 10 | """ 11 | Simple MLP that predicts next action a from past states s. 12 | """ 13 | 14 | def __init__(self, state_dim, act_dim, hidden_size, n_layer, dropout=0.1, max_length=1, **kwargs): 15 | super().__init__(state_dim, act_dim) 16 | 17 | self.hidden_size = hidden_size 18 | self.max_length = max_length 19 | 20 | layers = [nn.Linear(max_length*self.state_dim, hidden_size)] 21 | for _ in range(n_layer-1): 22 | layers.extend([ 23 | nn.ReLU(), 24 | nn.Dropout(dropout), 25 | nn.Linear(hidden_size, hidden_size) 26 | ]) 27 | layers.extend([ 28 | nn.ReLU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_size, self.act_dim), 31 | nn.Tanh(), 32 | ]) 33 | 34 | self.model = nn.Sequential(*layers) 35 | 36 | def forward(self, states, actions, rewards, attention_mask=None, target_return=None): 37 | 38 | states = states[:,-self.max_length:].reshape(states.shape[0], -1) # concat states 39 | actions = self.model(states).reshape(states.shape[0], 1, self.act_dim) 40 | 41 | return None, actions, None 42 | 43 | def get_action(self, states, actions, rewards, **kwargs): 44 | states = states.reshape(1, -1, self.state_dim) 45 | if states.shape[1] < self.max_length: 46 | states = torch.cat( 47 | [torch.zeros((1, self.max_length-states.shape[1], self.state_dim), 48 | dtype=torch.float32, device=states.device), states], dim=1) 49 | states = states.to(dtype=torch.float32) 50 | _, actions, _ = self.forward(states, None, None, **kwargs) 51 | return actions[0,-1] 52 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/models/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TrajectoryModel(nn.Module): 7 | 8 | def __init__(self, state_dim, act_dim, max_length=None): 9 | super().__init__() 10 | 11 | self.state_dim = state_dim 12 | self.act_dim = act_dim 13 | self.max_length = max_length 14 | 15 | def forward(self, states, actions, rewards, masks=None, attention_mask=None): 16 | # "masked" tokens or unspecified inputs can be passed in as None 17 | return None, None, None 18 | 19 | def get_action(self, states, actions, rewards, **kwargs): 20 | # these will come as tensors on the correct device 21 | return torch.zeros_like(actions[-1]) 22 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/models/trajectory_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | import math 19 | import os 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple, Union 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | from packaging import version 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | 29 | 30 | if version.parse(torch.__version__) >= version.parse("1.6"): 31 | is_amp_available = True 32 | from torch.cuda.amp import autocast 33 | else: 34 | is_amp_available = False 35 | 36 | from transformers.activations import ACT2FN 37 | from transformers.modeling_outputs import ( 38 | BaseModelOutputWithPastAndCrossAttentions, 39 | CausalLMOutputWithCrossAttentions, 40 | SequenceClassifierOutputWithPast, 41 | TokenClassifierOutput, 42 | ) 43 | from transformers.modeling_utils import PreTrainedModel, SequenceSummary 44 | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer 45 | from transformers.utils import ( 46 | ModelOutput, 47 | add_code_sample_docstrings, 48 | add_start_docstrings, 49 | add_start_docstrings_to_model_forward, 50 | logging, 51 | replace_return_docstrings, 52 | ) 53 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 54 | from transformers.models.gpt2.configuration_gpt2 import GPT2Config 55 | 56 | 57 | logger = logging.get_logger(__name__) 58 | 59 | _CHECKPOINT_FOR_DOC = "gpt2" 60 | _CONFIG_FOR_DOC = "GPT2Config" 61 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 62 | 63 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 64 | "gpt2", 65 | "gpt2-medium", 66 | "gpt2-large", 67 | "gpt2-xl", 68 | "distilgpt2", 69 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 70 | ] 71 | 72 | 73 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): 74 | """Load tf checkpoints in a pytorch model""" 75 | try: 76 | import re 77 | 78 | import tensorflow as tf 79 | except ImportError: 80 | logger.error( 81 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 82 | "https://www.tensorflow.org/install/ for installation instructions." 83 | ) 84 | raise 85 | tf_path = os.path.abspath(gpt2_checkpoint_path) 86 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 87 | # Load weights from TF model 88 | init_vars = tf.train.list_variables(tf_path) 89 | names = [] 90 | arrays = [] 91 | for name, shape in init_vars: 92 | logger.info(f"Loading TF weight {name} with shape {shape}") 93 | array = tf.train.load_variable(tf_path, name) 94 | names.append(name) 95 | arrays.append(array.squeeze()) 96 | 97 | for name, array in zip(names, arrays): 98 | name = name[6:] # skip "model/" 99 | name = name.split("/") 100 | pointer = model 101 | for m_name in name: 102 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 103 | scope_names = re.split(r"(\d+)", m_name) 104 | else: 105 | scope_names = [m_name] 106 | if scope_names[0] == "w" or scope_names[0] == "g": 107 | pointer = getattr(pointer, "weight") 108 | elif scope_names[0] == "b": 109 | pointer = getattr(pointer, "bias") 110 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 111 | pointer = getattr(pointer, scope_names[0]) 112 | pointer = getattr(pointer, "weight") 113 | else: 114 | pointer = getattr(pointer, scope_names[0]) 115 | if len(scope_names) >= 2: 116 | num = int(scope_names[1]) 117 | pointer = pointer[num] 118 | try: 119 | assert ( 120 | pointer.shape == array.shape 121 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 122 | except AssertionError as e: 123 | e.args += (pointer.shape, array.shape) 124 | raise 125 | logger.info(f"Initialize PyTorch weight {name}") 126 | pointer.data = torch.from_numpy(array) 127 | return model 128 | 129 | 130 | class GPT2Attention(nn.Module): 131 | def __init__(self, config, is_cross_attention=False, layer_idx=None): 132 | super().__init__() 133 | 134 | max_positions = config.max_position_embeddings 135 | self.register_buffer( 136 | "bias", 137 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 138 | 1, 1, max_positions, max_positions 139 | ), 140 | ) 141 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 142 | 143 | self.embed_dim = config.hidden_size 144 | self.num_heads = config.num_attention_heads 145 | self.head_dim = self.embed_dim // self.num_heads 146 | self.split_size = self.embed_dim 147 | if self.head_dim * self.num_heads != self.embed_dim: 148 | raise ValueError( 149 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 150 | f" {self.num_heads})." 151 | ) 152 | 153 | self.scale_attn_weights = config.scale_attn_weights 154 | self.is_cross_attention = is_cross_attention 155 | 156 | # Layer-wise attention scaling, reordering, and upcasting 157 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 158 | self.layer_idx = layer_idx 159 | self.reorder_and_upcast_attn = config.reorder_and_upcast_attn 160 | 161 | if self.is_cross_attention: 162 | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) 163 | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) 164 | else: 165 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) 166 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) 167 | 168 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 169 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 170 | 171 | self.pruned_heads = set() 172 | 173 | def prune_heads(self, heads): 174 | if len(heads) == 0: 175 | return 176 | heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) 177 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 178 | 179 | # Prune conv1d layers 180 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 181 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 182 | 183 | # Update hyper params 184 | self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) 185 | self.num_heads = self.num_heads - len(heads) 186 | self.pruned_heads = self.pruned_heads.union(heads) 187 | 188 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 189 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 190 | 191 | if self.scale_attn_weights: 192 | attn_weights = attn_weights / torch.tensor( 193 | value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device 194 | ) 195 | 196 | # Layer-wise attention scaling 197 | if self.scale_attn_by_inverse_layer_idx: 198 | attn_weights = attn_weights / float(self.layer_idx + 1) 199 | 200 | if not self.is_cross_attention: 201 | # if only "normal" attention layer implements causal mask 202 | query_length, key_length = query.size(-2), key.size(-2) 203 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) 204 | mask_value = torch.finfo(attn_weights.dtype).min 205 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 206 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 207 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 208 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 209 | 210 | if attention_mask is not None: 211 | # Apply the attention mask 212 | attn_weights = attn_weights + attention_mask 213 | 214 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 215 | 216 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise 217 | attn_weights = attn_weights.type(value.dtype) 218 | attn_weights = self.attn_dropout(attn_weights) 219 | 220 | # Mask heads if we want to 221 | if head_mask is not None: 222 | attn_weights = attn_weights * head_mask 223 | 224 | attn_output = torch.matmul(attn_weights, value) 225 | 226 | return attn_output, attn_weights 227 | 228 | def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): 229 | # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) 230 | bsz, num_heads, q_seq_len, dk = query.size() 231 | _, _, k_seq_len, _ = key.size() 232 | 233 | # Preallocate attn_weights for `baddbmm` 234 | attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) 235 | 236 | # Compute Scale Factor 237 | scale_factor = 1.0 238 | if self.scale_attn_weights: 239 | scale_factor /= float(value.size(-1)) ** 0.5 240 | 241 | if self.scale_attn_by_inverse_layer_idx: 242 | scale_factor /= float(self.layer_idx + 1) 243 | 244 | # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) 245 | if is_amp_available: 246 | with autocast(enabled=False): 247 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 248 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 249 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 250 | else: 251 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 252 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 253 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 254 | 255 | if not self.is_cross_attention: 256 | # if only "normal" attention layer implements causal mask 257 | query_length, key_length = query.size(-2), key.size(-2) 258 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() 259 | mask_value = torch.finfo(attn_weights.dtype).min 260 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 261 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 262 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 263 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 264 | 265 | if attention_mask is not None: 266 | # Apply the attention mask 267 | attn_weights = attn_weights + attention_mask 268 | 269 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 270 | 271 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise 272 | if attn_weights.dtype != torch.float32: 273 | raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") 274 | attn_weights = attn_weights.type(value.dtype) 275 | attn_weights = self.attn_dropout(attn_weights) 276 | 277 | # Mask heads if we want to 278 | if head_mask is not None: 279 | attn_weights = attn_weights * head_mask 280 | 281 | attn_output = torch.matmul(attn_weights, value) 282 | 283 | return attn_output, attn_weights 284 | 285 | def _split_heads(self, tensor, num_heads, attn_head_size): 286 | """ 287 | Splits hidden_size dim into attn_head_size and num_heads 288 | """ 289 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) 290 | tensor = tensor.view(new_shape) 291 | return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 292 | 293 | def _merge_heads(self, tensor, num_heads, attn_head_size): 294 | """ 295 | Merges attn_head_size dim and num_attn_heads dim into hidden_size 296 | """ 297 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 298 | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) 299 | return tensor.view(new_shape) 300 | 301 | def forward( 302 | self, 303 | hidden_states: Optional[Tuple[torch.FloatTensor]], 304 | layer_past: Optional[Tuple[torch.Tensor]] = None, 305 | attention_mask: Optional[torch.FloatTensor] = None, 306 | head_mask: Optional[torch.FloatTensor] = None, 307 | encoder_hidden_states: Optional[torch.Tensor] = None, 308 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 309 | use_cache: Optional[bool] = False, 310 | output_attentions: Optional[bool] = False, 311 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 312 | if encoder_hidden_states is not None: 313 | if not hasattr(self, "q_attn"): 314 | raise ValueError( 315 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 316 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 317 | ) 318 | 319 | query = self.q_attn(hidden_states) 320 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 321 | attention_mask = encoder_attention_mask 322 | else: 323 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 324 | 325 | query = self._split_heads(query, self.num_heads, self.head_dim) 326 | key = self._split_heads(key, self.num_heads, self.head_dim) 327 | value = self._split_heads(value, self.num_heads, self.head_dim) 328 | 329 | if layer_past is not None: 330 | past_key, past_value = layer_past 331 | key = torch.cat((past_key, key), dim=-2) 332 | value = torch.cat((past_value, value), dim=-2) 333 | 334 | if use_cache is True: 335 | present = (key, value) 336 | else: 337 | present = None 338 | 339 | if self.reorder_and_upcast_attn: 340 | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) 341 | else: 342 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 343 | 344 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 345 | attn_output = self.c_proj(attn_output) 346 | attn_output = self.resid_dropout(attn_output) 347 | 348 | outputs = (attn_output, present) 349 | if output_attentions: 350 | outputs += (attn_weights,) 351 | 352 | return outputs # a, present, (attentions) 353 | 354 | 355 | class GPT2MLP(nn.Module): 356 | def __init__(self, intermediate_size, config): 357 | super().__init__() 358 | embed_dim = config.hidden_size 359 | self.c_fc = Conv1D(intermediate_size, embed_dim) 360 | self.c_proj = Conv1D(embed_dim, intermediate_size) 361 | self.act = ACT2FN[config.activation_function] 362 | self.dropout = nn.Dropout(config.resid_pdrop) 363 | 364 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 365 | hidden_states = self.c_fc(hidden_states) 366 | hidden_states = self.act(hidden_states) 367 | hidden_states = self.c_proj(hidden_states) 368 | hidden_states = self.dropout(hidden_states) 369 | return hidden_states 370 | 371 | 372 | class GPT2Block(nn.Module): 373 | def __init__(self, config, layer_idx=None): 374 | super().__init__() 375 | hidden_size = config.hidden_size 376 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 377 | 378 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 379 | self.attn = GPT2Attention(config, layer_idx=layer_idx) 380 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 381 | 382 | if config.add_cross_attention: 383 | self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) 384 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 385 | 386 | self.mlp = GPT2MLP(inner_dim, config) 387 | 388 | def forward( 389 | self, 390 | hidden_states: Optional[Tuple[torch.FloatTensor]], 391 | layer_past: Optional[Tuple[torch.Tensor]] = None, 392 | attention_mask: Optional[torch.FloatTensor] = None, 393 | head_mask: Optional[torch.FloatTensor] = None, 394 | encoder_hidden_states: Optional[torch.Tensor] = None, 395 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 396 | use_cache: Optional[bool] = False, 397 | output_attentions: Optional[bool] = False, 398 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 399 | residual = hidden_states 400 | hidden_states = self.ln_1(hidden_states) 401 | attn_outputs = self.attn( 402 | hidden_states, 403 | layer_past=layer_past, 404 | attention_mask=attention_mask, 405 | head_mask=head_mask, 406 | use_cache=use_cache, 407 | output_attentions=output_attentions, 408 | ) 409 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 410 | outputs = attn_outputs[1:] 411 | # residual connection 412 | hidden_states = attn_output + residual 413 | 414 | if encoder_hidden_states is not None: 415 | # add one self-attention block for cross-attention 416 | if not hasattr(self, "crossattention"): 417 | raise ValueError( 418 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 419 | "cross-attention layers by setting `config.add_cross_attention=True`" 420 | ) 421 | residual = hidden_states 422 | hidden_states = self.ln_cross_attn(hidden_states) 423 | cross_attn_outputs = self.crossattention( 424 | hidden_states, 425 | attention_mask=attention_mask, 426 | head_mask=head_mask, 427 | encoder_hidden_states=encoder_hidden_states, 428 | encoder_attention_mask=encoder_attention_mask, 429 | output_attentions=output_attentions, 430 | ) 431 | attn_output = cross_attn_outputs[0] 432 | # residual connection 433 | hidden_states = residual + attn_output 434 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 435 | 436 | residual = hidden_states 437 | hidden_states = self.ln_2(hidden_states) 438 | feed_forward_hidden_states = self.mlp(hidden_states) 439 | # residual connection 440 | hidden_states = residual + feed_forward_hidden_states 441 | 442 | if use_cache: 443 | outputs = (hidden_states,) + outputs 444 | else: 445 | outputs = (hidden_states,) + outputs[1:] 446 | 447 | return outputs # hidden_states, present, (attentions, cross_attentions) 448 | 449 | 450 | class GPT2PreTrainedModel(PreTrainedModel): 451 | """ 452 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 453 | models. 454 | """ 455 | 456 | config_class = GPT2Config 457 | load_tf_weights = load_tf_weights_in_gpt2 458 | base_model_prefix = "transformer" 459 | is_parallelizable = True 460 | supports_gradient_checkpointing = True 461 | _no_split_modules = ["GPT2Block"] 462 | 463 | def __init__(self, *inputs, **kwargs): 464 | super().__init__(*inputs, **kwargs) 465 | 466 | def _init_weights(self, module): 467 | """Initialize the weights.""" 468 | if isinstance(module, (nn.Linear, Conv1D)): 469 | # Slightly different from the TF version which uses truncated_normal for initialization 470 | # cf https://github.com/pytorch/pytorch/pull/5617 471 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 472 | if module.bias is not None: 473 | module.bias.data.zero_() 474 | elif isinstance(module, nn.Embedding): 475 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 476 | if module.padding_idx is not None: 477 | module.weight.data[module.padding_idx].zero_() 478 | elif isinstance(module, nn.LayerNorm): 479 | module.bias.data.zero_() 480 | module.weight.data.fill_(1.0) 481 | 482 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 483 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 484 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 485 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 486 | # 487 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 488 | for name, p in module.named_parameters(): 489 | if name == "c_proj.weight": 490 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 491 | p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) 492 | 493 | def _set_gradient_checkpointing(self, module, value=False): 494 | if isinstance(module, GPT2Model): 495 | module.gradient_checkpointing = value 496 | 497 | 498 | @dataclass 499 | class GPT2DoubleHeadsModelOutput(ModelOutput): 500 | """ 501 | Base class for outputs of models predicting if two sentences are consecutive or not. 502 | 503 | Args: 504 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 505 | Language modeling loss. 506 | mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): 507 | Multiple choice classification loss. 508 | logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): 509 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 510 | mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): 511 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). 512 | past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 513 | Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, 514 | sequence_length, embed_size_per_head)`). 515 | 516 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see 517 | `past_key_values` input) to speed up sequential decoding. 518 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 519 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of 520 | shape `(batch_size, sequence_length, hidden_size)`. 521 | 522 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 523 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 524 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 525 | sequence_length)`. 526 | 527 | GPT2Attentions weights after the attention softmax, used to compute the weighted average in the 528 | self-attention heads. 529 | """ 530 | 531 | loss: Optional[torch.FloatTensor] = None 532 | mc_loss: Optional[torch.FloatTensor] = None 533 | logits: torch.FloatTensor = None 534 | mc_logits: torch.FloatTensor = None 535 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 536 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 537 | attentions: Optional[Tuple[torch.FloatTensor]] = None 538 | 539 | 540 | GPT2_START_DOCSTRING = r""" 541 | 542 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 543 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 544 | etc.) 545 | 546 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 547 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 548 | and behavior. 549 | 550 | Parameters: 551 | config ([`GPT2Config`]): Model configuration class with all the parameters of the model. 552 | Initializing with a config file does not load the weights associated with the model, only the 553 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 554 | """ 555 | 556 | GPT2_INPUTS_DOCSTRING = r""" 557 | Args: 558 | input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): 559 | `input_ids_length` = `sequence_length` if `past_key_values` is `None` else 560 | `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input 561 | sequence tokens in the vocabulary. 562 | 563 | If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as 564 | `input_ids`. 565 | 566 | Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and 567 | [`PreTrainedTokenizer.__call__`] for details. 568 | 569 | [What are input IDs?](../glossary#input-ids) 570 | past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): 571 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 572 | `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have 573 | their past given to this model should not be passed as `input_ids` as they have already been computed. 574 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 575 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 576 | 577 | - 1 for tokens that are **not masked**, 578 | - 0 for tokens that are **masked**. 579 | 580 | If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for 581 | `past_key_values`. In other words, the `attention_mask` always has to have the length: 582 | `len(past_key_values) + len(input_ids)` 583 | 584 | [What are attention masks?](../glossary#attention-mask) 585 | token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): 586 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 587 | 1]`: 588 | 589 | - 0 corresponds to a *sentence A* token, 590 | - 1 corresponds to a *sentence B* token. 591 | 592 | [What are token type IDs?](../glossary#token-type-ids) 593 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 594 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 595 | config.max_position_embeddings - 1]`. 596 | 597 | [What are position IDs?](../glossary#position-ids) 598 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 599 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 600 | 601 | - 1 indicates the head is **not masked**, 602 | - 0 indicates the head is **masked**. 603 | 604 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 605 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 606 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 607 | model's internal embedding lookup matrix. 608 | 609 | If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see 610 | `past_key_values`). 611 | use_cache (`bool`, *optional*): 612 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 613 | `past_key_values`). 614 | output_attentions (`bool`, *optional*): 615 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 616 | tensors for more detail. 617 | output_hidden_states (`bool`, *optional*): 618 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 619 | more detail. 620 | return_dict (`bool`, *optional*): 621 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 622 | """ 623 | PARALLELIZE_DOCSTRING = r""" 624 | This is an experimental feature and is a subject to change at a moment's notice. 625 | 626 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 627 | it will evenly distribute blocks across all devices. 628 | 629 | Args: 630 | device_map (`Dict[int, list]`, optional, defaults to None): 631 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 632 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 633 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the 634 | following number of attention modules: 635 | 636 | - gpt2: 12 637 | - gpt2-medium: 24 638 | - gpt2-large: 36 639 | - gpt2-xl: 48 640 | 641 | Example: 642 | 643 | ```python 644 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: 645 | model = GPT2LMHeadModel.from_pretrained("gpt2-xl") 646 | device_map = { 647 | 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 648 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], 649 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 650 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], 651 | } 652 | model.parallelize(device_map) 653 | ``` 654 | """ 655 | DEPARALLELIZE_DOCSTRING = r""" 656 | Moves the model to cpu from a model parallel state. 657 | 658 | Example: 659 | 660 | ```python 661 | # On a 4 GPU machine with gpt2-large: 662 | model = GPT2LMHeadModel.from_pretrained("gpt2-large") 663 | device_map = { 664 | 0: [0, 1, 2, 3, 4, 5, 6, 7], 665 | 1: [8, 9, 10, 11, 12, 13, 14, 15], 666 | 2: [16, 17, 18, 19, 20, 21, 22, 23], 667 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], 668 | } 669 | model.parallelize(device_map) # Splits the model across several devices 670 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 671 | ``` 672 | """ 673 | 674 | 675 | @add_start_docstrings( 676 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", 677 | GPT2_START_DOCSTRING, 678 | ) 679 | class GPT2Model(GPT2PreTrainedModel): 680 | _keys_to_ignore_on_load_missing = ["attn.masked_bias"] 681 | 682 | def __init__(self, config): 683 | super().__init__(config) 684 | 685 | self.embed_dim = config.hidden_size 686 | 687 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 688 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 689 | 690 | self.drop = nn.Dropout(config.embd_pdrop) 691 | self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 692 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 693 | 694 | # Model parallel 695 | self.model_parallel = False 696 | self.device_map = None 697 | self.gradient_checkpointing = False 698 | 699 | # Initialize weights and apply final processing 700 | self.post_init() 701 | 702 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 703 | def parallelize(self, device_map=None): 704 | # Check validity of device_map 705 | self.device_map = ( 706 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 707 | ) 708 | assert_device_map(self.device_map, len(self.h)) 709 | self.model_parallel = True 710 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 711 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 712 | self.wte = self.wte.to(self.first_device) 713 | self.wpe = self.wpe.to(self.first_device) 714 | # Load onto devices 715 | for k, v in self.device_map.items(): 716 | for block in v: 717 | cuda_device = "cuda:" + str(k) 718 | self.h[block] = self.h[block].to(cuda_device) 719 | # ln_f to last 720 | self.ln_f = self.ln_f.to(self.last_device) 721 | 722 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 723 | def deparallelize(self): 724 | self.model_parallel = False 725 | self.device_map = None 726 | self.first_device = "cpu" 727 | self.last_device = "cpu" 728 | self.wte = self.wte.to("cpu") 729 | self.wpe = self.wpe.to("cpu") 730 | for index in range(len(self.h)): 731 | self.h[index] = self.h[index].to("cpu") 732 | self.ln_f = self.ln_f.to("cpu") 733 | torch.cuda.empty_cache() 734 | 735 | def get_input_embeddings(self): 736 | return self.wte 737 | 738 | def set_input_embeddings(self, new_embeddings): 739 | self.wte = new_embeddings 740 | 741 | def _prune_heads(self, heads_to_prune): 742 | """ 743 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 744 | """ 745 | for layer, heads in heads_to_prune.items(): 746 | self.h[layer].attn.prune_heads(heads) 747 | 748 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 749 | @add_code_sample_docstrings( 750 | processor_class=_TOKENIZER_FOR_DOC, 751 | checkpoint=_CHECKPOINT_FOR_DOC, 752 | output_type=BaseModelOutputWithPastAndCrossAttentions, 753 | config_class=_CONFIG_FOR_DOC, 754 | ) 755 | def forward( 756 | self, 757 | input_ids: Optional[torch.LongTensor] = None, 758 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 759 | attention_mask: Optional[torch.FloatTensor] = None, 760 | token_type_ids: Optional[torch.LongTensor] = None, 761 | position_ids: Optional[torch.LongTensor] = None, 762 | head_mask: Optional[torch.FloatTensor] = None, 763 | inputs_embeds: Optional[torch.FloatTensor] = None, 764 | encoder_hidden_states: Optional[torch.Tensor] = None, 765 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 766 | use_cache: Optional[bool] = None, 767 | output_attentions: Optional[bool] = None, 768 | output_hidden_states: Optional[bool] = None, 769 | return_dict: Optional[bool] = None, 770 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 771 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 772 | output_hidden_states = ( 773 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 774 | ) 775 | use_cache = use_cache if use_cache is not None else self.config.use_cache 776 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 777 | 778 | if input_ids is not None and inputs_embeds is not None: 779 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 780 | elif input_ids is not None: 781 | input_shape = input_ids.size() 782 | input_ids = input_ids.view(-1, input_shape[-1]) 783 | batch_size = input_ids.shape[0] 784 | elif inputs_embeds is not None: 785 | input_shape = inputs_embeds.size()[:-1] 786 | batch_size = inputs_embeds.shape[0] 787 | else: 788 | raise ValueError("You have to specify either input_ids or inputs_embeds") 789 | 790 | device = input_ids.device if input_ids is not None else inputs_embeds.device 791 | 792 | if token_type_ids is not None: 793 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 794 | if position_ids is not None: 795 | position_ids = position_ids.view(-1, input_shape[-1]) 796 | 797 | if past_key_values is None: 798 | past_length = 0 799 | past_key_values = tuple([None] * len(self.h)) 800 | else: 801 | past_length = past_key_values[0][0].size(-2) 802 | if position_ids is None: 803 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 804 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 805 | 806 | # GPT2Attention mask. 807 | if attention_mask is not None: 808 | if batch_size <= 0: 809 | raise ValueError("batch_size has to be defined and > 0") 810 | attention_mask = attention_mask.view(batch_size, -1) 811 | # We create a 3D attention mask from a 2D tensor mask. 812 | # Sizes are [batch_size, 1, 1, to_seq_length] 813 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 814 | # this attention mask is more simple than the triangular masking of causal attention 815 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 816 | attention_mask = attention_mask[:, None, None, :] 817 | 818 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 819 | # masked positions, this operation will create a tensor which is 0.0 for 820 | # positions we want to attend and -10000.0 for masked positions. 821 | # Since we are adding it to the raw scores before the softmax, this is 822 | # effectively the same as removing these entirely. 823 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 824 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min 825 | 826 | # If a 2D or 3D attention mask is provided for the cross-attention 827 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 828 | if self.config.add_cross_attention and encoder_hidden_states is not None: 829 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 830 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 831 | if encoder_attention_mask is None: 832 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 833 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 834 | else: 835 | encoder_attention_mask = None 836 | 837 | # Prepare head mask if needed 838 | # 1.0 in head_mask indicate we keep the head 839 | # attention_probs has shape bsz x n_heads x N x N 840 | # head_mask has shape n_layer x batch x n_heads x N x N 841 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 842 | 843 | if inputs_embeds is None: 844 | inputs_embeds = self.wte(input_ids) 845 | # position_embeds = self.wpe(position_ids) 846 | hidden_states = inputs_embeds # + position_embeds 847 | 848 | if token_type_ids is not None: 849 | token_type_embeds = self.wte(token_type_ids) 850 | hidden_states = hidden_states + token_type_embeds 851 | 852 | hidden_states = self.drop(hidden_states) 853 | 854 | output_shape = input_shape + (hidden_states.size(-1),) 855 | 856 | presents = () if use_cache else None 857 | all_self_attentions = () if output_attentions else None 858 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 859 | all_hidden_states = () if output_hidden_states else None 860 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 861 | 862 | # Model parallel 863 | if self.model_parallel: 864 | torch.cuda.set_device(hidden_states.device) 865 | # Ensure layer_past is on same device as hidden_states (might not be correct) 866 | if layer_past is not None: 867 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 868 | # Ensure that attention_mask is always on the same device as hidden_states 869 | if attention_mask is not None: 870 | attention_mask = attention_mask.to(hidden_states.device) 871 | if isinstance(head_mask, torch.Tensor): 872 | head_mask = head_mask.to(hidden_states.device) 873 | if output_hidden_states: 874 | all_hidden_states = all_hidden_states + (hidden_states,) 875 | 876 | if self.gradient_checkpointing and self.training: 877 | 878 | if use_cache: 879 | logger.warning( 880 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 881 | ) 882 | use_cache = False 883 | 884 | def create_custom_forward(module): 885 | def custom_forward(*inputs): 886 | # None for past_key_value 887 | return module(*inputs, use_cache, output_attentions) 888 | 889 | return custom_forward 890 | 891 | outputs = torch.utils.checkpoint.checkpoint( 892 | create_custom_forward(block), 893 | hidden_states, 894 | None, 895 | attention_mask, 896 | head_mask[i], 897 | encoder_hidden_states, 898 | encoder_attention_mask, 899 | ) 900 | else: 901 | outputs = block( 902 | hidden_states, 903 | layer_past=layer_past, 904 | attention_mask=attention_mask, 905 | head_mask=head_mask[i], 906 | encoder_hidden_states=encoder_hidden_states, 907 | encoder_attention_mask=encoder_attention_mask, 908 | use_cache=use_cache, 909 | output_attentions=output_attentions, 910 | ) 911 | 912 | hidden_states = outputs[0] 913 | if use_cache is True: 914 | presents = presents + (outputs[1],) 915 | 916 | if output_attentions: 917 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 918 | if self.config.add_cross_attention: 919 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 920 | 921 | # Model Parallel: If it's the last layer for that device, put things on the next device 922 | if self.model_parallel: 923 | for k, v in self.device_map.items(): 924 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 925 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 926 | 927 | hidden_states = self.ln_f(hidden_states) 928 | 929 | hidden_states = hidden_states.view(output_shape) 930 | # Add last hidden state 931 | if output_hidden_states: 932 | all_hidden_states = all_hidden_states + (hidden_states,) 933 | 934 | if not return_dict: 935 | return tuple( 936 | v 937 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 938 | if v is not None 939 | ) 940 | 941 | return BaseModelOutputWithPastAndCrossAttentions( 942 | last_hidden_state=hidden_states, 943 | past_key_values=presents, 944 | hidden_states=all_hidden_states, 945 | attentions=all_self_attentions, 946 | cross_attentions=all_cross_attentions, 947 | ) 948 | 949 | 950 | @add_start_docstrings( 951 | """ 952 | The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input 953 | embeddings). 954 | """, 955 | GPT2_START_DOCSTRING, 956 | ) 957 | class GPT2LMHeadModel(GPT2PreTrainedModel): 958 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] 959 | 960 | def __init__(self, config): 961 | super().__init__(config) 962 | self.transformer = GPT2Model(config) 963 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 964 | 965 | # Model parallel 966 | self.model_parallel = False 967 | self.device_map = None 968 | 969 | # Initialize weights and apply final processing 970 | self.post_init() 971 | 972 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 973 | def parallelize(self, device_map=None): 974 | self.device_map = ( 975 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 976 | if device_map is None 977 | else device_map 978 | ) 979 | assert_device_map(self.device_map, len(self.transformer.h)) 980 | self.transformer.parallelize(self.device_map) 981 | self.lm_head = self.lm_head.to(self.transformer.first_device) 982 | self.model_parallel = True 983 | 984 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 985 | def deparallelize(self): 986 | self.transformer.deparallelize() 987 | self.transformer = self.transformer.to("cpu") 988 | self.lm_head = self.lm_head.to("cpu") 989 | self.model_parallel = False 990 | torch.cuda.empty_cache() 991 | 992 | def get_output_embeddings(self): 993 | return self.lm_head 994 | 995 | def set_output_embeddings(self, new_embeddings): 996 | self.lm_head = new_embeddings 997 | 998 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 999 | token_type_ids = kwargs.get("token_type_ids", None) 1000 | # only last token for inputs_ids if past is defined in kwargs 1001 | if past: 1002 | input_ids = input_ids[:, -1].unsqueeze(-1) 1003 | if token_type_ids is not None: 1004 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 1005 | 1006 | attention_mask = kwargs.get("attention_mask", None) 1007 | position_ids = kwargs.get("position_ids", None) 1008 | 1009 | if attention_mask is not None and position_ids is None: 1010 | # create position_ids on the fly for batch generation 1011 | position_ids = attention_mask.long().cumsum(-1) - 1 1012 | position_ids.masked_fill_(attention_mask == 0, 1) 1013 | if past: 1014 | position_ids = position_ids[:, -1].unsqueeze(-1) 1015 | else: 1016 | position_ids = None 1017 | return { 1018 | "input_ids": input_ids, 1019 | "past_key_values": past, 1020 | "use_cache": kwargs.get("use_cache"), 1021 | "position_ids": position_ids, 1022 | "attention_mask": attention_mask, 1023 | "token_type_ids": token_type_ids, 1024 | } 1025 | 1026 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1027 | @add_code_sample_docstrings( 1028 | processor_class=_TOKENIZER_FOR_DOC, 1029 | checkpoint=_CHECKPOINT_FOR_DOC, 1030 | output_type=CausalLMOutputWithCrossAttentions, 1031 | config_class=_CONFIG_FOR_DOC, 1032 | ) 1033 | def forward( 1034 | self, 1035 | input_ids: Optional[torch.LongTensor] = None, 1036 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1037 | attention_mask: Optional[torch.FloatTensor] = None, 1038 | token_type_ids: Optional[torch.LongTensor] = None, 1039 | position_ids: Optional[torch.LongTensor] = None, 1040 | head_mask: Optional[torch.FloatTensor] = None, 1041 | inputs_embeds: Optional[torch.FloatTensor] = None, 1042 | encoder_hidden_states: Optional[torch.Tensor] = None, 1043 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 1044 | labels: Optional[torch.LongTensor] = None, 1045 | use_cache: Optional[bool] = None, 1046 | output_attentions: Optional[bool] = None, 1047 | output_hidden_states: Optional[bool] = None, 1048 | return_dict: Optional[bool] = None, 1049 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 1050 | r""" 1051 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1052 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1053 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 1054 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 1055 | """ 1056 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1057 | 1058 | transformer_outputs = self.transformer( 1059 | input_ids, 1060 | past_key_values=past_key_values, 1061 | attention_mask=attention_mask, 1062 | token_type_ids=token_type_ids, 1063 | position_ids=position_ids, 1064 | head_mask=head_mask, 1065 | inputs_embeds=inputs_embeds, 1066 | encoder_hidden_states=encoder_hidden_states, 1067 | encoder_attention_mask=encoder_attention_mask, 1068 | use_cache=use_cache, 1069 | output_attentions=output_attentions, 1070 | output_hidden_states=output_hidden_states, 1071 | return_dict=return_dict, 1072 | ) 1073 | hidden_states = transformer_outputs[0] 1074 | 1075 | # Set device for model parallelism 1076 | if self.model_parallel: 1077 | torch.cuda.set_device(self.transformer.first_device) 1078 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1079 | 1080 | lm_logits = self.lm_head(hidden_states) 1081 | 1082 | loss = None 1083 | if labels is not None: 1084 | # Shift so that tokens < n predict n 1085 | shift_logits = lm_logits[..., :-1, :].contiguous() 1086 | shift_labels = labels[..., 1:].contiguous() 1087 | # Flatten the tokens 1088 | loss_fct = CrossEntropyLoss() 1089 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1090 | 1091 | if not return_dict: 1092 | output = (lm_logits,) + transformer_outputs[1:] 1093 | return ((loss,) + output) if loss is not None else output 1094 | 1095 | return CausalLMOutputWithCrossAttentions( 1096 | loss=loss, 1097 | logits=lm_logits, 1098 | past_key_values=transformer_outputs.past_key_values, 1099 | hidden_states=transformer_outputs.hidden_states, 1100 | attentions=transformer_outputs.attentions, 1101 | cross_attentions=transformer_outputs.cross_attentions, 1102 | ) 1103 | 1104 | @staticmethod 1105 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 1106 | """ 1107 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1108 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1109 | beam_idx at every generation step. 1110 | """ 1111 | return tuple( 1112 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1113 | for layer_past in past 1114 | ) 1115 | 1116 | 1117 | @add_start_docstrings( 1118 | """ 1119 | The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for 1120 | RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the 1121 | input embeddings, the classification head takes as input the input of a specified classification token index in the 1122 | input sequence). 1123 | """, 1124 | GPT2_START_DOCSTRING, 1125 | ) 1126 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel): 1127 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] 1128 | 1129 | def __init__(self, config): 1130 | super().__init__(config) 1131 | config.num_labels = 1 1132 | self.transformer = GPT2Model(config) 1133 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 1134 | self.multiple_choice_head = SequenceSummary(config) 1135 | 1136 | # Model parallel 1137 | self.model_parallel = False 1138 | self.device_map = None 1139 | 1140 | # Initialize weights and apply final processing 1141 | self.post_init() 1142 | 1143 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 1144 | def parallelize(self, device_map=None): 1145 | self.device_map = ( 1146 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 1147 | if device_map is None 1148 | else device_map 1149 | ) 1150 | assert_device_map(self.device_map, len(self.transformer.h)) 1151 | self.transformer.parallelize(self.device_map) 1152 | self.lm_head = self.lm_head.to(self.transformer.first_device) 1153 | self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) 1154 | self.model_parallel = True 1155 | 1156 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 1157 | def deparallelize(self): 1158 | self.transformer.deparallelize() 1159 | self.transformer = self.transformer.to("cpu") 1160 | self.lm_head = self.lm_head.to("cpu") 1161 | self.multiple_choice_head = self.multiple_choice_head.to("cpu") 1162 | self.model_parallel = False 1163 | torch.cuda.empty_cache() 1164 | 1165 | def get_output_embeddings(self): 1166 | return self.lm_head 1167 | 1168 | def set_output_embeddings(self, new_embeddings): 1169 | self.lm_head = new_embeddings 1170 | 1171 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 1172 | token_type_ids = kwargs.get("token_type_ids", None) 1173 | # only last token for inputs_ids if past is defined in kwargs 1174 | if past: 1175 | input_ids = input_ids[:, -1].unsqueeze(-1) 1176 | if token_type_ids is not None: 1177 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 1178 | 1179 | attention_mask = kwargs.get("attention_mask", None) 1180 | position_ids = kwargs.get("position_ids", None) 1181 | 1182 | if attention_mask is not None and position_ids is None: 1183 | # create position_ids on the fly for batch generation 1184 | position_ids = attention_mask.long().cumsum(-1) - 1 1185 | position_ids.masked_fill_(attention_mask == 0, 1) 1186 | if past: 1187 | position_ids = position_ids[:, -1].unsqueeze(-1) 1188 | else: 1189 | position_ids = None 1190 | 1191 | return { 1192 | "input_ids": input_ids, 1193 | "past_key_values": past, 1194 | "use_cache": kwargs.get("use_cache"), 1195 | "position_ids": position_ids, 1196 | "attention_mask": attention_mask, 1197 | "token_type_ids": token_type_ids, 1198 | } 1199 | 1200 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1201 | @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) 1202 | def forward( 1203 | self, 1204 | input_ids: Optional[torch.LongTensor] = None, 1205 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1206 | attention_mask: Optional[torch.FloatTensor] = None, 1207 | token_type_ids: Optional[torch.LongTensor] = None, 1208 | position_ids: Optional[torch.LongTensor] = None, 1209 | head_mask: Optional[torch.FloatTensor] = None, 1210 | inputs_embeds: Optional[torch.FloatTensor] = None, 1211 | mc_token_ids: Optional[torch.LongTensor] = None, 1212 | labels: Optional[torch.LongTensor] = None, 1213 | mc_labels: Optional[torch.LongTensor] = None, 1214 | use_cache: Optional[bool] = None, 1215 | output_attentions: Optional[bool] = None, 1216 | output_hidden_states: Optional[bool] = None, 1217 | return_dict: Optional[bool] = None, 1218 | **kwargs, 1219 | ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: 1220 | r""" 1221 | mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): 1222 | Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - 1223 | 1[`. 1224 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1225 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1226 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size - 1]` All labels set to 1227 | `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` 1228 | mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): 1229 | Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` 1230 | where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) 1231 | 1232 | Return: 1233 | 1234 | Example: 1235 | 1236 | ```python 1237 | >>> import torch 1238 | >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel 1239 | 1240 | >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 1241 | >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2") 1242 | 1243 | >>> # Add a [CLS] to the vocabulary (we should train it also!) 1244 | >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) 1245 | >>> # Update the model embeddings with the new vocabulary size 1246 | >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) 1247 | 1248 | >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] 1249 | >>> encoded_choices = [tokenizer.encode(s) for s in choices] 1250 | >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] 1251 | 1252 | >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 1253 | >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 1254 | 1255 | >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) 1256 | >>> lm_logits = outputs.logits 1257 | >>> mc_logits = outputs.mc_logits 1258 | ```""" 1259 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1260 | 1261 | transformer_outputs = self.transformer( 1262 | input_ids, 1263 | past_key_values=past_key_values, 1264 | attention_mask=attention_mask, 1265 | token_type_ids=token_type_ids, 1266 | position_ids=position_ids, 1267 | head_mask=head_mask, 1268 | inputs_embeds=inputs_embeds, 1269 | use_cache=use_cache, 1270 | output_attentions=output_attentions, 1271 | output_hidden_states=output_hidden_states, 1272 | return_dict=return_dict, 1273 | ) 1274 | 1275 | hidden_states = transformer_outputs[0] 1276 | 1277 | # Set device for model parallelism 1278 | if self.model_parallel: 1279 | torch.cuda.set_device(self.transformer.first_device) 1280 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1281 | 1282 | lm_logits = self.lm_head(hidden_states) 1283 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 1284 | 1285 | mc_loss = None 1286 | if mc_labels is not None: 1287 | loss_fct = CrossEntropyLoss() 1288 | mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) 1289 | lm_loss = None 1290 | if labels is not None: 1291 | shift_logits = lm_logits[..., :-1, :].contiguous() 1292 | shift_labels = labels[..., 1:].contiguous() 1293 | loss_fct = CrossEntropyLoss() 1294 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1295 | 1296 | if not return_dict: 1297 | output = (lm_logits, mc_logits) + transformer_outputs[1:] 1298 | if mc_loss is not None: 1299 | output = (mc_loss,) + output 1300 | return ((lm_loss,) + output) if lm_loss is not None else output 1301 | 1302 | return GPT2DoubleHeadsModelOutput( 1303 | loss=lm_loss, 1304 | mc_loss=mc_loss, 1305 | logits=lm_logits, 1306 | mc_logits=mc_logits, 1307 | past_key_values=transformer_outputs.past_key_values, 1308 | hidden_states=transformer_outputs.hidden_states, 1309 | attentions=transformer_outputs.attentions, 1310 | ) 1311 | 1312 | @staticmethod 1313 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 1314 | """ 1315 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1316 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1317 | beam_idx at every generation step. 1318 | """ 1319 | return tuple( 1320 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1321 | for layer_past in past 1322 | ) 1323 | 1324 | 1325 | @add_start_docstrings( 1326 | """ 1327 | The GPT2 Model transformer with a sequence classification head on top (linear layer). 1328 | 1329 | [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1330 | (e.g. GPT-1) do. 1331 | 1332 | Since it does classification on the last token, it requires to know the position of the last token. If a 1333 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1334 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1335 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1336 | each row of the batch). 1337 | """, 1338 | GPT2_START_DOCSTRING, 1339 | ) 1340 | class GPT2ForSequenceClassification(GPT2PreTrainedModel): 1341 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] 1342 | 1343 | def __init__(self, config): 1344 | super().__init__(config) 1345 | self.num_labels = config.num_labels 1346 | self.transformer = GPT2Model(config) 1347 | self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) 1348 | 1349 | # Model parallel 1350 | self.model_parallel = False 1351 | self.device_map = None 1352 | 1353 | # Initialize weights and apply final processing 1354 | self.post_init() 1355 | 1356 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1357 | @add_code_sample_docstrings( 1358 | processor_class=_TOKENIZER_FOR_DOC, 1359 | checkpoint="microsoft/DialogRPT-updown", 1360 | output_type=SequenceClassifierOutputWithPast, 1361 | config_class=_CONFIG_FOR_DOC, 1362 | expected_output="'LABEL_0'", 1363 | expected_loss=5.28, 1364 | ) 1365 | def forward( 1366 | self, 1367 | input_ids: Optional[torch.LongTensor] = None, 1368 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1369 | attention_mask: Optional[torch.FloatTensor] = None, 1370 | token_type_ids: Optional[torch.LongTensor] = None, 1371 | position_ids: Optional[torch.LongTensor] = None, 1372 | head_mask: Optional[torch.FloatTensor] = None, 1373 | inputs_embeds: Optional[torch.FloatTensor] = None, 1374 | labels: Optional[torch.LongTensor] = None, 1375 | use_cache: Optional[bool] = None, 1376 | output_attentions: Optional[bool] = None, 1377 | output_hidden_states: Optional[bool] = None, 1378 | return_dict: Optional[bool] = None, 1379 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1380 | r""" 1381 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1382 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1383 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1384 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1385 | """ 1386 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1387 | 1388 | transformer_outputs = self.transformer( 1389 | input_ids, 1390 | past_key_values=past_key_values, 1391 | attention_mask=attention_mask, 1392 | token_type_ids=token_type_ids, 1393 | position_ids=position_ids, 1394 | head_mask=head_mask, 1395 | inputs_embeds=inputs_embeds, 1396 | use_cache=use_cache, 1397 | output_attentions=output_attentions, 1398 | output_hidden_states=output_hidden_states, 1399 | return_dict=return_dict, 1400 | ) 1401 | hidden_states = transformer_outputs[0] 1402 | logits = self.score(hidden_states) 1403 | 1404 | if input_ids is not None: 1405 | batch_size, sequence_length = input_ids.shape[:2] 1406 | else: 1407 | batch_size, sequence_length = inputs_embeds.shape[:2] 1408 | 1409 | assert ( 1410 | self.config.pad_token_id is not None or batch_size == 1 1411 | ), "Cannot handle batch sizes > 1 if no padding token is defined." 1412 | if self.config.pad_token_id is None: 1413 | sequence_lengths = -1 1414 | else: 1415 | if input_ids is not None: 1416 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 1417 | else: 1418 | sequence_lengths = -1 1419 | logger.warning( 1420 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 1421 | "unexpected if using padding tokens in conjunction with `inputs_embeds.`" 1422 | ) 1423 | 1424 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1425 | 1426 | loss = None 1427 | if labels is not None: 1428 | if self.config.problem_type is None: 1429 | if self.num_labels == 1: 1430 | self.config.problem_type = "regression" 1431 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1432 | self.config.problem_type = "single_label_classification" 1433 | else: 1434 | self.config.problem_type = "multi_label_classification" 1435 | 1436 | if self.config.problem_type == "regression": 1437 | loss_fct = MSELoss() 1438 | if self.num_labels == 1: 1439 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1440 | else: 1441 | loss = loss_fct(pooled_logits, labels) 1442 | elif self.config.problem_type == "single_label_classification": 1443 | loss_fct = CrossEntropyLoss() 1444 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1445 | elif self.config.problem_type == "multi_label_classification": 1446 | loss_fct = BCEWithLogitsLoss() 1447 | loss = loss_fct(pooled_logits, labels) 1448 | if not return_dict: 1449 | output = (pooled_logits,) + transformer_outputs[1:] 1450 | return ((loss,) + output) if loss is not None else output 1451 | 1452 | return SequenceClassifierOutputWithPast( 1453 | loss=loss, 1454 | logits=pooled_logits, 1455 | past_key_values=transformer_outputs.past_key_values, 1456 | hidden_states=transformer_outputs.hidden_states, 1457 | attentions=transformer_outputs.attentions, 1458 | ) 1459 | 1460 | 1461 | @add_start_docstrings( 1462 | """ 1463 | GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1464 | Named-Entity-Recognition (NER) tasks. 1465 | """, 1466 | GPT2_START_DOCSTRING, 1467 | ) 1468 | class GPT2ForTokenClassification(GPT2PreTrainedModel): 1469 | def __init__(self, config): 1470 | super().__init__(config) 1471 | self.num_labels = config.num_labels 1472 | 1473 | self.transformer = GPT2Model(config) 1474 | if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: 1475 | classifier_dropout = config.classifier_dropout 1476 | elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: 1477 | classifier_dropout = config.hidden_dropout 1478 | else: 1479 | classifier_dropout = 0.1 1480 | self.dropout = nn.Dropout(classifier_dropout) 1481 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1482 | 1483 | # Model parallel 1484 | self.model_parallel = False 1485 | self.device_map = None 1486 | 1487 | # Initialize weights and apply final processing 1488 | self.post_init() 1489 | 1490 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1491 | # fmt: off 1492 | @add_code_sample_docstrings( 1493 | processor_class=_TOKENIZER_FOR_DOC, 1494 | checkpoint="brad1141/gpt2-finetuned-comp2", 1495 | output_type=TokenClassifierOutput, 1496 | config_class=_CONFIG_FOR_DOC, 1497 | expected_loss=0.25, 1498 | expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"], 1499 | ) 1500 | # fmt: on 1501 | def forward( 1502 | self, 1503 | input_ids: Optional[torch.LongTensor] = None, 1504 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1505 | attention_mask: Optional[torch.FloatTensor] = None, 1506 | token_type_ids: Optional[torch.LongTensor] = None, 1507 | position_ids: Optional[torch.LongTensor] = None, 1508 | head_mask: Optional[torch.FloatTensor] = None, 1509 | inputs_embeds: Optional[torch.FloatTensor] = None, 1510 | labels: Optional[torch.LongTensor] = None, 1511 | use_cache: Optional[bool] = None, 1512 | output_attentions: Optional[bool] = None, 1513 | output_hidden_states: Optional[bool] = None, 1514 | return_dict: Optional[bool] = None, 1515 | ) -> Union[Tuple, TokenClassifierOutput]: 1516 | r""" 1517 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1518 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1519 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1520 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1521 | """ 1522 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1523 | 1524 | transformer_outputs = self.transformer( 1525 | input_ids, 1526 | past_key_values=past_key_values, 1527 | attention_mask=attention_mask, 1528 | token_type_ids=token_type_ids, 1529 | position_ids=position_ids, 1530 | head_mask=head_mask, 1531 | inputs_embeds=inputs_embeds, 1532 | use_cache=use_cache, 1533 | output_attentions=output_attentions, 1534 | output_hidden_states=output_hidden_states, 1535 | return_dict=return_dict, 1536 | ) 1537 | 1538 | hidden_states = transformer_outputs[0] 1539 | hidden_states = self.dropout(hidden_states) 1540 | logits = self.classifier(hidden_states) 1541 | 1542 | loss = None 1543 | if labels is not None: 1544 | loss_fct = CrossEntropyLoss() 1545 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1546 | 1547 | if not return_dict: 1548 | output = (logits,) + transformer_outputs[2:] 1549 | return ((loss,) + output) if loss is not None else output 1550 | 1551 | return TokenClassifierOutput( 1552 | loss=loss, 1553 | logits=logits, 1554 | hidden_states=transformer_outputs.hidden_states, 1555 | attentions=transformer_outputs.attentions, 1556 | ) -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/training/act_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from decision_transformer.training.trainer import Trainer 5 | 6 | 7 | class ActTrainer(Trainer): 8 | 9 | def train_step(self): 10 | states, actions, rewards, dones, rtg, _, attention_mask = self.get_batch(self.batch_size) 11 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) 12 | 13 | state_preds, action_preds, reward_preds = self.model.forward( 14 | states, actions, rewards, attention_mask=attention_mask, target_return=rtg[:,0], 15 | ) 16 | 17 | act_dim = action_preds.shape[2] 18 | action_preds = action_preds.reshape(-1, act_dim) 19 | action_target = action_target[:,-1].reshape(-1, act_dim) 20 | 21 | loss = self.loss_fn( 22 | state_preds, action_preds, reward_preds, 23 | state_target, action_target, reward_target, 24 | ) 25 | self.optimizer.zero_grad() 26 | loss.backward() 27 | self.optimizer.step() 28 | 29 | return loss.detach().cpu().item() 30 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/training/seq_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from decision_transformer.training.trainer import Trainer 5 | 6 | 7 | class SequenceTrainer(Trainer): 8 | 9 | def train_step(self): 10 | states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size) 11 | action_target = torch.clone(actions) 12 | 13 | state_preds, action_preds, reward_preds = self.model.forward( 14 | states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, 15 | ) 16 | 17 | act_dim = action_preds.shape[2] 18 | action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 19 | action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] 20 | 21 | loss = self.loss_fn( 22 | None, action_preds, None, 23 | None, action_target, None, 24 | ) 25 | 26 | self.optimizer.zero_grad() 27 | loss.backward() 28 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25) 29 | self.optimizer.step() 30 | 31 | with torch.no_grad(): 32 | self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item() 33 | 34 | return loss.detach().cpu().item() 35 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/training/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import time 5 | from tqdm import tqdm 6 | 7 | 8 | class Trainer: 9 | 10 | def __init__(self, model, optimizer, batch_size, get_batch, loss_fn, scheduler=None, eval_fns=None): 11 | self.model = model 12 | self.optimizer = optimizer 13 | self.batch_size = batch_size 14 | self.get_batch = get_batch 15 | self.loss_fn = loss_fn 16 | self.scheduler = scheduler 17 | self.eval_fns = [] if eval_fns is None else eval_fns 18 | self.diagnostics = dict() 19 | 20 | self.start_time = time.time() 21 | 22 | def train_iteration(self, num_steps, iter_num=0, print_logs=False): 23 | 24 | train_losses = [] 25 | logs = dict() 26 | 27 | train_start = time.time() 28 | 29 | self.model.train() 30 | for _ in tqdm(range(num_steps)): 31 | train_loss = self.train_step() 32 | train_losses.append(train_loss) 33 | if self.scheduler is not None: 34 | self.scheduler.step() 35 | 36 | logs['time/training'] = time.time() - train_start 37 | 38 | eval_start = time.time() 39 | 40 | self.model.eval() 41 | for eval_fn in self.eval_fns: 42 | outputs = eval_fn(self.model) 43 | for k, v in outputs.items(): 44 | logs[f'evaluation/{k}'] = v 45 | 46 | logs['time/total'] = time.time() - self.start_time 47 | logs['time/evaluation'] = time.time() - eval_start 48 | logs['training/train_loss_mean'] = np.mean(train_losses) 49 | logs['training/train_loss_std'] = np.std(train_losses) 50 | 51 | for k in self.diagnostics: 52 | logs[k] = self.diagnostics[k] 53 | 54 | if print_logs: 55 | print('=' * 80) 56 | print(f'Iteration {iter_num}') 57 | for k, v in logs.items(): 58 | print(f'{k}: {v}') 59 | 60 | return logs 61 | 62 | def train_step(self): 63 | states, actions, rewards, dones, attention_mask, returns = self.get_batch(self.batch_size) 64 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) 65 | 66 | state_preds, action_preds, reward_preds = self.model.forward( 67 | states, actions, rewards, masks=None, attention_mask=attention_mask, target_return=returns, 68 | ) 69 | 70 | # note: currently indexing & masking is not fully correct 71 | loss = self.loss_fn( 72 | state_preds, action_preds, reward_preds, 73 | state_target[:,1:], action_target, reward_target[:,1:], 74 | ) 75 | self.optimizer.zero_grad() 76 | loss.backward() 77 | self.optimizer.step() 78 | 79 | return loss.detach().cpu().item() 80 | -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/decision_transformer/decision_transformer/utils/__init__.py -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/utils/convert_dataset.py: -------------------------------------------------------------------------------- 1 | from stoch_rvs.samplers.trajectory_sampler import Trajectory 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | def convert_dataset(trajectories, action_type): 6 | trajs = [] 7 | for path in trajectories: 8 | obs_ = [] 9 | actions_ = [] 10 | rewards_ = [] 11 | infos_ = [] 12 | policy_infos_ = [] 13 | for t in range(len(path['observations'])): 14 | obs_.append(deepcopy(path['observations'][t])) 15 | if action_type == 'discrete': 16 | actions_.append(np.argmax(path['actions'][t])) 17 | else: 18 | actions_.append(path['actions'][t]) 19 | rewards_.append(path['rewards'][t]) 20 | trajs.append(Trajectory(obs=obs_, 21 | actions=actions_, 22 | rewards=rewards_, 23 | infos=infos_, 24 | policy_infos=policy_infos_)) 25 | return trajs -------------------------------------------------------------------------------- /decision_transformer/decision_transformer/utils/preemption.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pickle 4 | import wandb 5 | import time 6 | 7 | class CheckpointTimer: 8 | 9 | def __init__(self, checkpoint_every): 10 | self.checkpoint_every = checkpoint_every 11 | self.last_chk = 0 12 | 13 | def should(self): 14 | now = time.time() 15 | return now - self.last_chk >= self.checkpoint_every 16 | 17 | def done(self): 18 | now = time.time() 19 | self.last_chk = now 20 | 21 | class PreemptionManager: 22 | 23 | def __init__(self, checkpoint_dir, checkpoint_every=0, checkpoint_timer=None, prefix=''): 24 | self.checkpoint_dir = checkpoint_dir 25 | self._wandb_id = None 26 | self.prefix = prefix 27 | if checkpoint_timer is None: 28 | self.checkpoint_timer = CheckpointTimer(checkpoint_every) 29 | else: 30 | self.checkpoint_timer = checkpoint_timer 31 | self.last_chk = 0 32 | self.stored = dict() 33 | 34 | def _load_data(self, name): 35 | if self.checkpoint_dir is not None: 36 | path = os.path.join(self.checkpoint_dir, f'{self.prefix}_{name}.pkl') 37 | if os.path.exists(path): 38 | with open(path, 'rb') as file: 39 | print(f'Loaded {name}...') 40 | data = pickle.load(file) 41 | return data 42 | return None 43 | 44 | def save(self, name, data, now=False): 45 | if now: 46 | if self.checkpoint_dir is not None: 47 | path = os.path.join(self.checkpoint_dir, f'{self.prefix}_{name}.pkl') 48 | with open(path, 'wb') as path: 49 | pickle.dump(data, path) 50 | else: 51 | self.stored[name] = data 52 | 53 | def wandb_id(self): 54 | if self._wandb_id is not None: 55 | return self._wandb_id 56 | 57 | self._wandb_id = self._load_data('wandb_id') 58 | 59 | if self._wandb_id is None: 60 | self._wandb_id = wandb.util.generate_id() 61 | 62 | self.save('wandb_id', self._wandb_id, now=True) 63 | 64 | return self._wandb_id 65 | 66 | def load_torch(self, name, cl, *args, **kwargs): 67 | state_dict = self._load_data(name) 68 | model = cl(*args, **kwargs) 69 | if state_dict is not None: 70 | model.load_state_dict(state_dict) 71 | 72 | return model 73 | 74 | def exists(self, name): 75 | if self.checkpoint_dir is not None: 76 | path = os.path.join(self.checkpoint_dir, f'{self.prefix}_{name}.pkl') 77 | return os.path.exists(path) 78 | return False 79 | 80 | def save_torch(self, name, model): 81 | self.save(name, model.state_dict()) 82 | 83 | def load_if_exists(self, name, value): 84 | data = self._load_data(name) 85 | if data is None: 86 | return value 87 | return data 88 | 89 | def for_obj(self, prefix): 90 | return PreemptionManager(self.checkpoint_dir, prefix=prefix, checkpoint_timer=self.checkpoint_timer) 91 | 92 | def checkpoint(self): 93 | if self.checkpoint_timer.should(): 94 | for key in self.stored: 95 | self.save(key, self.stored[key], now=True) 96 | 97 | self.checkpoint_timer.done() 98 | self.stored = dict() 99 | -------------------------------------------------------------------------------- /decision_transformer/experiment.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import wandb 5 | from torch import nn 6 | 7 | import argparse 8 | import pickle 9 | import random 10 | import sys 11 | 12 | from decision_transformer.evaluation.evaluate_episodes import evaluate_episode, evaluate_episode_rtg 13 | from decision_transformer.models.decision_transformer import DecisionTransformer 14 | from decision_transformer.models.mlp_bc import MLPBCModel 15 | from decision_transformer.training.act_trainer import ActTrainer 16 | from decision_transformer.training.seq_trainer import SequenceTrainer 17 | 18 | import d4rl 19 | from decision_transformer.utils.preemption import PreemptionManager 20 | from stoch_rvs.utils.utils import return_labels, learned_labels, set_seed 21 | from stoch_rvs.algos.learn_labels import learn_labels 22 | 23 | from stoch_rvs.datasets.seq_dataset import SeqDataset 24 | from decision_transformer.utils.convert_dataset import convert_dataset 25 | 26 | 27 | def discount_cumsum(x, gamma): 28 | discount_cumsum = np.zeros_like(x) 29 | discount_cumsum[-1] = x[-1] 30 | for t in reversed(range(x.shape[0] - 1)): 31 | discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1] 32 | return discount_cumsum 33 | 34 | 35 | def experiment( 36 | exp_prefix, 37 | variant, 38 | ): 39 | pm = PreemptionManager(variant['checkpoint_dir'], checkpoint_every=600) 40 | 41 | device = variant.get('device', 'cuda') 42 | log_to_wandb = variant.get('log_to_wandb', False) 43 | 44 | env_name, dataset = variant['env'], variant['dataset'] 45 | model_type = variant['model_type'] 46 | group_name = f'{exp_prefix}-{env_name}-{dataset}' 47 | exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}' 48 | 49 | ret_postprocess_fn = lambda returns: returns 50 | action_type = 'continuous' 51 | if env_name == 'hopper': 52 | env = gym.make('Hopper-v3') 53 | max_ep_len = 1000 54 | env_targets = [3600, 1800] # evaluation conditioning targets 55 | scale = 1000. # normalization for rewards/returns 56 | name = f'{env_name}-{dataset}-v2' 57 | d4rl_env = gym.make(name) 58 | ret_postprocess_fn = d4rl_env.get_normalized_score 59 | elif env_name == 'halfcheetah': 60 | env = gym.make('HalfCheetah-v3') 61 | max_ep_len = 1000 62 | env_targets = [12000, 6000] 63 | scale = 1000. 64 | name = f'{env_name}-{dataset}-v2' 65 | d4rl_env = gym.make(name) 66 | ret_postprocess_fn = d4rl_env.get_normalized_score 67 | elif env_name == 'walker2d': 68 | env = gym.make('Walker2d-v3') 69 | max_ep_len = 1000 70 | env_targets = [5000, 2500] 71 | scale = 1000. 72 | name = f'{env_name}-{dataset}-v2' 73 | d4rl_env = gym.make(name) 74 | ret_postprocess_fn = d4rl_env.get_normalized_score 75 | elif env_name == 'reacher2d': 76 | from decision_transformer.envs.reacher_2d import Reacher2dEnv 77 | env = Reacher2dEnv() 78 | max_ep_len = 100 79 | env_targets = [76, 40] 80 | scale = 10. 81 | elif env_name == 'gambling': 82 | from stochastic_offline_envs.envs.offline_envs.gambling_offline_env import GamblingOfflineEnv 83 | task = GamblingOfflineEnv() 84 | env = task.env_cls() 85 | max_ep_len = 5 86 | env_targets = list(np.arange(-15, 5, 0.5)) + [5] 87 | # env_targets = [-15, -6, 1, 5] 88 | scale = 5. 89 | action_type = 'discrete' 90 | elif env_name == 'connect_four': 91 | from stochastic_offline_envs.envs.offline_envs.connect_four_offline_env import ConnectFourOfflineEnv 92 | from stochastic_offline_envs.envs.connect_four.connect_four_env import GridWrapper 93 | # TODO: env should just deal with this automatically 94 | task = ConnectFourOfflineEnv() 95 | env_cls = lambda: GridWrapper(task.env_cls()) 96 | env = env_cls() 97 | max_ep_len = 50 98 | env_targets = list(np.arange(-1, 1, 0.25)) + [1] 99 | # env_targets = [-1, 0, 1] 100 | scale = 1. 101 | action_type = 'discrete' 102 | elif env_name == 'tfe': 103 | from stochastic_offline_envs.envs.offline_envs.tfe_offline_env import TFEOfflineEnv 104 | task = TFEOfflineEnv() 105 | env = task.env_cls() 106 | max_ep_len = 500 107 | env_targets = list(np.arange(0, 1, 0.1)) + [1] 108 | scale = 1. 109 | action_type = 'discrete' 110 | else: 111 | raise NotImplementedError 112 | 113 | if model_type == 'bc': 114 | # since BC ignores target, no need for different evaluations 115 | env_targets = env_targets[:1] 116 | 117 | example_state = env.reset() 118 | state_dim = np.prod(env.observation_space.shape) 119 | 120 | if action_type == 'discrete': 121 | act_dim = env.action_space.n 122 | else: 123 | act_dim = env.action_space.shape[0] 124 | 125 | # load dataset 126 | dataset_path = f'data/{env_name}-{dataset}-v2.pkl' 127 | with open(dataset_path, 'rb') as f: 128 | trajectories = pickle.load(f) 129 | 130 | n_data = len(trajectories) 131 | used_data = int(n_data * variant['prop_data']) 132 | trajectories = trajectories[:used_data] 133 | 134 | esper_trajs = convert_dataset(trajectories, action_type) 135 | 136 | # save all path information into separate lists 137 | mode = variant.get('mode', 'normal') 138 | states, traj_lens, returns = [], [], [] 139 | for path in trajectories: 140 | if mode == 'delayed': # delayed: all rewards moved to end of trajectory 141 | path['rewards'][-1] = path['rewards'].sum() 142 | path['rewards'][:-1] = 0. 143 | # Pre-compute the return-to-gos 144 | path['rtg'] = discount_cumsum(path['rewards'], gamma=1.) 145 | states.append(path['observations']) 146 | traj_lens.append(len(path['observations'])) 147 | returns.append(path['rewards'].sum()) 148 | traj_lens, returns = np.array(traj_lens), np.array(returns) 149 | 150 | # used for input normalization 151 | states = np.concatenate(states, axis=0) 152 | state_mean, state_std = np.mean( 153 | states, axis=0), np.std(states, axis=0) + 1e-6 154 | 155 | num_timesteps = sum(traj_lens) 156 | 157 | print('=' * 50) 158 | print(f'Starting new experiment: {env_name} {dataset}') 159 | print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found') 160 | print( 161 | f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}') 162 | print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}') 163 | print('=' * 50) 164 | 165 | K = variant['K'] 166 | batch_size = variant['batch_size'] 167 | num_eval_episodes = variant['num_eval_episodes'] 168 | pct_traj = variant.get('pct_traj', 1.) 169 | 170 | # only train on top pct_traj trajectories (for %BC experiment) 171 | num_timesteps = max(int(pct_traj * num_timesteps), 1) 172 | sorted_inds = np.argsort(returns) # lowest to highest 173 | num_trajectories = 1 174 | timesteps = traj_lens[sorted_inds[-1]] 175 | ind = len(trajectories) - 2 176 | while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps: 177 | timesteps += traj_lens[sorted_inds[ind]] 178 | num_trajectories += 1 179 | ind -= 1 180 | sorted_inds = sorted_inds[-num_trajectories:] 181 | 182 | # used to reweight sampling so we sample according to timesteps instead of trajectories 183 | p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds]) 184 | 185 | def get_batch(batch_size=256, max_len=K): 186 | batch_inds = np.random.choice( 187 | np.arange(num_trajectories), 188 | size=batch_size, 189 | replace=True, 190 | p=p_sample, # reweights so we sample according to timesteps 191 | ) 192 | 193 | s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], [] 194 | for i in range(batch_size): 195 | traj = trajectories[int(sorted_inds[batch_inds[i]])] 196 | si = random.randint(0, traj['rewards'].shape[0] - 1) 197 | 198 | # get sequences from dataset 199 | s.append(traj['observations'] 200 | [si:si + max_len].reshape(1, -1, state_dim)) 201 | a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim)) 202 | r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1)) 203 | if 'terminals' in traj: 204 | d.append(traj['terminals'][si:si + max_len].reshape(1, -1)) 205 | else: 206 | d.append(traj['dones'][si:si + max_len].reshape(1, -1)) 207 | timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1)) 208 | timesteps[-1][timesteps[-1] >= 209 | max_ep_len] = max_ep_len - 1 # padding cutoff 210 | rtg.append(traj['rtg'][si:][:s[-1].shape[1] + 1].reshape(1, -1, 1)) 211 | if rtg[-1].shape[1] <= s[-1].shape[1]: 212 | rtg[-1] = np.concatenate([rtg[-1], 213 | np.zeros((1, 1, 1))], axis=1) 214 | 215 | # padding and state + reward normalization 216 | tlen = s[-1].shape[1] 217 | s[-1] = np.concatenate([np.zeros((1, max_len - 218 | tlen, state_dim)), s[-1]], axis=1) 219 | if variant['normalize_states']: 220 | s[-1] = (s[-1] - state_mean) / state_std 221 | a[-1] = np.concatenate([np.ones((1, max_len - 222 | tlen, act_dim)) * -10., a[-1]], axis=1) 223 | r[-1] = np.concatenate([np.zeros((1, max_len - 224 | tlen, 1)), r[-1]], axis=1) 225 | d[-1] = np.concatenate([np.ones((1, max_len - tlen)) 226 | * 2, d[-1]], axis=1) 227 | rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), 228 | rtg[-1]], axis=1) / scale 229 | timesteps[-1] = np.concatenate( 230 | [np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1) 231 | mask.append(np.concatenate( 232 | [np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1)) 233 | 234 | s = torch.from_numpy(np.concatenate(s, axis=0)).to( 235 | dtype=torch.float32, device=device) 236 | a = torch.from_numpy(np.concatenate(a, axis=0)).to( 237 | dtype=torch.float32, device=device) 238 | r = torch.from_numpy(np.concatenate(r, axis=0)).to( 239 | dtype=torch.float32, device=device) 240 | d = torch.from_numpy(np.concatenate(d, axis=0)).to( 241 | dtype=torch.long, device=device) 242 | rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to( 243 | dtype=torch.float32, device=device) 244 | timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to( 245 | dtype=torch.long, device=device) 246 | mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device) 247 | 248 | return s, a, r, d, rtg, timesteps, mask 249 | 250 | def eval_episodes(target_rew): 251 | def fn(model): 252 | returns, lengths = [], [] 253 | for _ in range(num_eval_episodes): 254 | with torch.no_grad(): 255 | if model_type == 'dt': 256 | ret, length = evaluate_episode_rtg( 257 | env, 258 | state_dim, 259 | act_dim, 260 | model, 261 | max_ep_len=max_ep_len, 262 | scale=scale, 263 | target_return=target_rew / scale, 264 | mode=mode, 265 | state_mean=state_mean, 266 | state_std=state_std, 267 | device=device, 268 | action_type=action_type 269 | ) 270 | else: 271 | ret, length = evaluate_episode( 272 | env, 273 | state_dim, 274 | act_dim, 275 | model, 276 | max_ep_len=max_ep_len, 277 | target_return=target_rew / scale, 278 | mode=mode, 279 | state_mean=state_mean, 280 | state_std=state_std, 281 | device=device, 282 | ) 283 | returns.append(ret_postprocess_fn(ret)) 284 | lengths.append(length) 285 | return { 286 | f'target_{target_rew}_return_mean': np.mean(returns), 287 | # f'target_{target_rew}_return_std': np.std(returns), 288 | # f'target_{target_rew}_length_mean': np.mean(lengths), 289 | # f'target_{target_rew}_length_std': np.std(lengths), 290 | } 291 | return fn 292 | 293 | if model_type == 'dt': 294 | model = pm.load_torch('model', DecisionTransformer, 295 | state_dim=state_dim, 296 | act_dim=act_dim, 297 | max_length=K, 298 | max_ep_len=max_ep_len, 299 | hidden_size=variant['embed_dim'], 300 | n_layer=variant['n_layer'], 301 | n_head=variant['n_head'], 302 | n_inner=4 * variant['embed_dim'], 303 | activation_function=variant['activation_function'], 304 | n_positions=1024, 305 | resid_pdrop=variant['dropout'], 306 | attn_pdrop=variant['dropout'], 307 | action_tanh=action_type == 'continuous', 308 | rtg_seq=variant['rtg_seq']) 309 | elif model_type == 'bc': 310 | model = pm.load_torch('model', MLPBCModel, 311 | state_dim=state_dim, 312 | act_dim=act_dim, 313 | max_length=K, 314 | hidden_size=variant['embed_dim'], 315 | n_layer=variant['n_layer'],) 316 | else: 317 | raise NotImplementedError 318 | 319 | model = model.to(device=device) 320 | 321 | warmup_steps = variant['warmup_steps'] 322 | optimizer = pm.load_torch('optimizer', torch.optim.AdamW, 323 | model.parameters(), 324 | lr=variant['learning_rate'], 325 | weight_decay=variant['weight_decay']) 326 | scheduler = pm.load_torch('scheduler', torch.optim.lr_scheduler.LambdaLR, 327 | optimizer, 328 | lambda steps: min((steps + 1) / warmup_steps, 1)) 329 | 330 | if action_type == 'continuous': 331 | action_loss = lambda s_hat, a_hat, r_hat, s, a, r: torch.mean( 332 | (a_hat - a)**2) 333 | else: 334 | ce_loss = nn.CrossEntropyLoss() 335 | 336 | def action_loss(s_hat, a_hat, r_hat, s, a, r): 337 | a = torch.argmax(a, dim=-1) 338 | return ce_loss(a_hat, a) 339 | 340 | if model_type == 'dt': 341 | trainer = SequenceTrainer( 342 | model=model, 343 | optimizer=optimizer, 344 | batch_size=batch_size, 345 | get_batch=get_batch, 346 | scheduler=scheduler, 347 | loss_fn=action_loss, 348 | eval_fns=[eval_episodes(tar) for tar in env_targets], 349 | ) 350 | elif model_type == 'bc': 351 | trainer = ActTrainer( 352 | model=model, 353 | optimizer=optimizer, 354 | batch_size=batch_size, 355 | get_batch=get_batch, 356 | scheduler=scheduler, 357 | loss_fn=action_loss, 358 | eval_fns=[eval_episodes(tar) for tar in env_targets], 359 | ) 360 | 361 | if log_to_wandb: 362 | wandb_id = pm.wandb_id() 363 | wandb.init( 364 | name=exp_prefix, 365 | # group=group_name, 366 | project='decision-transformer', 367 | config=variant, 368 | id=wandb_id, 369 | resume='allow', 370 | ) 371 | # wandb.watch(model) # wandb has some bug 372 | 373 | if variant['rtg']: 374 | # Load custom return-to-go 375 | rtg_path = variant['rtg'] 376 | # Load the pickle 377 | with open(rtg_path, 'rb') as f: 378 | rtg_dict = pickle.load(f) 379 | for i, path in enumerate(trajectories): 380 | path['rtg'] = rtg_dict[i] 381 | 382 | # if mode == 'esper': 383 | # if variant['normalize_states']: 384 | # print('Normalizing states') 385 | # for traj in esper_trajs: 386 | # for i in range(len(traj.obs)): 387 | # traj.obs[i] = (traj.obs[i] - state_mean) / state_std 388 | # print(esper_trajs[0].obs[0]) 389 | # seq_dataset = SeqDataset(esper_trajs, act_dim, max_ep_len, gamma=1, reward_norm=scale, 390 | # act_type=action_type) 391 | # label_model = learn_labels(seq_dataset, 392 | # act_dim, 393 | # batch_size=100, 394 | # learning_rate=5e-4, 395 | # hidden_size=512, 396 | # rep_size=variant['rep_size'], 397 | # rep_groups=variant['rep_groups'], 398 | # device='cuda', 399 | # pm=pm.for_obj('label_model'), 400 | # act_loss_coef=variant['act_loss_coef'], 401 | # adv_loss_coef=variant['adv_loss_coef'], 402 | # pretrain_epochs=0, 403 | # cluster_epochs=variant['cluster_epochs'], 404 | # label_epochs=variant['label_epochs']) 405 | # label_fn = lambda traj: learned_labels( 406 | # traj, label_model, act_dim, max_ep_len, 'cuda', act_type=action_type) 407 | # labs = [label_fn(traj) * scale for traj in esper_trajs] 408 | # for i, path in enumerate(trajectories): 409 | # path['rtg'] = labs[i] 410 | 411 | completed_iters = pm.load_if_exists('completed_iters', 0) 412 | for iter in range(completed_iters, variant['max_iters']): 413 | outputs = trainer.train_iteration( 414 | num_steps=variant['num_steps_per_iter'], iter_num=iter + 1, print_logs=True) 415 | pm.save_torch('optimizer', optimizer) 416 | pm.save_torch('scheduler', scheduler) 417 | pm.save_torch('model', model) 418 | pm.checkpoint() 419 | if log_to_wandb: 420 | wandb.log(outputs) 421 | completed_iters += 1 422 | pm.save('completed_iters', completed_iters) 423 | 424 | 425 | if __name__ == '__main__': 426 | parser = argparse.ArgumentParser() 427 | parser.add_argument('--env', type=str, default='hopper') 428 | # medium, medium-replay, medium-expert, expert 429 | parser.add_argument('--dataset', type=str, default='medium') 430 | # normal for standard setting, delayed for sparse, esper for learned average returns 431 | parser.add_argument('--mode', type=str, default='normal') 432 | parser.add_argument('--K', type=int, default=20) 433 | parser.add_argument('--pct_traj', type=float, default=1.) 434 | parser.add_argument('--batch_size', type=int, default=64) 435 | # dt for decision transformer, bc for behavior cloning 436 | parser.add_argument('--model_type', type=str, default='dt') 437 | parser.add_argument('--embed_dim', type=int, default=128) 438 | parser.add_argument('--n_layer', type=int, default=3) 439 | parser.add_argument('--n_head', type=int, default=1) 440 | parser.add_argument('--activation_function', type=str, default='relu') 441 | parser.add_argument('--dropout', type=float, default=0.1) 442 | parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4) 443 | parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4) 444 | parser.add_argument('--warmup_steps', type=int, default=10000) 445 | parser.add_argument('--num_eval_episodes', type=int, default=100) 446 | parser.add_argument('--max_iters', type=int, default=10) 447 | parser.add_argument('--num_steps_per_iter', type=int, default=10000) 448 | parser.add_argument('--device', type=str, default='cuda') 449 | parser.add_argument('--log_to_wandb', '-w', type=bool, default=False) 450 | parser.add_argument('--checkpoint_dir', type=str, default=None) 451 | parser.add_argument('--seed', type=int, default=0) 452 | 453 | parser.add_argument('--act_loss_coef', type=float, default=0.01) 454 | parser.add_argument('--adv_loss_coef', type=float, default=1) 455 | parser.add_argument('--cluster_epochs', type=int, default=5) 456 | parser.add_argument('--label_epochs', type=int, default=5) 457 | parser.add_argument('--rep_size', type=int, default=8) 458 | parser.add_argument('--rep_groups', type=int, default=1) 459 | parser.add_argument('--rtg_seq', type=bool, default=True) 460 | parser.add_argument('--normalize_states', action='store_true') 461 | 462 | parser.add_argument('--prop_data', type=float, default=1.) 463 | 464 | parser.add_argument('--rtg', type=str, default=None) 465 | 466 | args = parser.parse_args() 467 | 468 | print(vars(args)) 469 | 470 | experiment('gym-experiment', variant=vars(args)) 471 | -------------------------------------------------------------------------------- /decision_transformer/readme-gym.md: -------------------------------------------------------------------------------- 1 | 2 | # OpenAI Gym 3 | 4 | ## Installation 5 | 6 | Experiments require MuJoCo. 7 | Follow the instructions in the [mujoco-py repo](https://github.com/openai/mujoco-py) to install. 8 | Then, dependencies can be installed with the following command: 9 | 10 | ``` 11 | conda env create -f conda_env.yml 12 | ``` 13 | 14 | ## Downloading datasets 15 | 16 | Datasets are stored in the `data` directory. 17 | Install the [D4RL repo](https://github.com/rail-berkeley/d4rl), following the instructions there. 18 | Then, run the following script in order to download the datasets and save them in our format: 19 | 20 | ``` 21 | python download_d4rl_datasets.py 22 | ``` 23 | 24 | ## Example usage 25 | 26 | Experiments can be reproduced with the following: 27 | 28 | ``` 29 | python experiment.py --env hopper --dataset medium --model_type dt 30 | ``` 31 | 32 | Adding `-w True` will log results to Weights and Biases. 33 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # return_transforms 2 | 3 | This repository contains code for the following return transformation methods for [Upside Down RL](https://arxiv.org/abs/1912.02875) and [Decision Transformer](https://arxiv.org/abs/2106.01345): 4 | - ESPER - [You Can't Count on Luck: Why Decision Transformers Fail in Stochastic Environments](https://arxiv.org/abs/2205.15967) 5 | 6 | ## Installation Instructions 7 | 8 | - Install `stochastic_offline_envs` from [here](https://github.com/keirp/stochastic_offline_envs). 9 | - Install dependencies with `pip install -r requirements.txt`. 10 | - Make sure pytorch >= 1.10 is installed. 11 | - Install the package with `pip install -e .`. 12 | - Install the included `decision_transformer` package. This is only necessary if you want to use the transformed returns with the included modified decision transformer implementation. 13 | 14 | ## Instructions for Decision Transformer 15 | - Run `download_esper_datasets.py` to save the `stochastic_offline_envs` datasets in a format that Decision Transformer understands. 16 | - Use the `--rtg path/to/returns` flag to use the generated returns or leave it out to use the original returns. 17 | 18 | ## Usage 19 | 20 | `return_transforms` operates on offline RL datasets. It saves a file with the transformed returns in the specified directory. 21 | 22 | To use `return_transforms` on a dataset, run the following command: 23 | 24 | ```python return_transforms/generate.py --env_name tfe --config configs/esper/tfe.yaml --device cuda --n_cpu 10 --ret_file data/tfe.ret``` 25 | 26 | Then, you can use the included fork of Decision Transformer (in the `decision_transformer` directory) to train on the transformed returns. 27 | 28 | ```python experiment.py --env tfe --dataset default -w True --max_iters 2 --num_steps_per_iter 25000 --rtg ../data/tfe.ret``` 29 | 30 | Configurations are included for all included `stochastic_offline_envs` in the `configs/esper` directory. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire==0.4.0 2 | gym==0.19.0 3 | numpy==1.21.4 4 | PyYAML==6.0 5 | setuptools==63.4.1 6 | tqdm==4.62.3 -------------------------------------------------------------------------------- /return_transforms/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/return_transforms/algos/__init__.py -------------------------------------------------------------------------------- /return_transforms/algos/esper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/return_transforms/algos/esper/__init__.py -------------------------------------------------------------------------------- /return_transforms/algos/esper/esper.py: -------------------------------------------------------------------------------- 1 | from return_transforms.models.esper.cluster_model import ClusterModel 2 | from return_transforms.models.esper.dynamics_model import DynamicsModel 3 | from return_transforms.datasets.esper_dataset import ESPERDataset 4 | from return_transforms.utils.utils import learned_labels 5 | from tqdm.autonotebook import tqdm 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import gym 11 | 12 | 13 | def esper(trajs, 14 | action_space, 15 | dynamics_model_args, 16 | cluster_model_args, 17 | train_args, 18 | device, 19 | n_cpu=2): 20 | 21 | # Check if discrete action space 22 | if isinstance(action_space, gym.spaces.Discrete): 23 | action_size = action_space.n 24 | act_loss_fn = lambda pred, truth: F.cross_entropy(pred.view(-1, pred.shape[-1]), torch.argmax(truth, dim=-1).view(-1), 25 | reduction='none') 26 | act_type = 'discrete' 27 | else: 28 | action_size = action_space.shape[0] 29 | act_loss_fn = lambda pred, truth: ((pred - truth) ** 2).mean(dim=-1) 30 | act_type = 'continuous' 31 | 32 | # Get the length of the longest trajectory 33 | max_len = max([len(traj.obs) for traj in trajs]) + 1 34 | 35 | dataset = ESPERDataset(trajs, action_size, max_len, 36 | gamma=train_args['gamma'], act_type=act_type) 37 | 38 | scale = train_args['scale'] 39 | 40 | # Get the obs size from the first datapoint 41 | obs, _, _, _ = next(iter(dataset)) 42 | obs_shape = obs[0].shape 43 | obs_size = np.prod(obs_shape) 44 | 45 | # Set up the models 46 | print('Creating models...') 47 | dynamics_model = DynamicsModel(obs_size, 48 | action_size, 49 | cluster_model_args['rep_size'], 50 | dynamics_model_args).to(device) 51 | 52 | cluster_model = ClusterModel(obs_size, 53 | action_size, 54 | cluster_model_args['rep_size'], 55 | cluster_model_args, 56 | cluster_model_args['groups']).to(device) 57 | 58 | dynamics_optimizer = optim.AdamW( 59 | dynamics_model.parameters(), lr=float(train_args['dynamics_model_lr'])) 60 | cluster_optimizer = optim.AdamW( 61 | cluster_model.parameters(), lr=float(train_args['cluster_model_lr'])) 62 | 63 | dataloader = torch.utils.data.DataLoader(dataset, 64 | batch_size=train_args['batch_size'], 65 | num_workers=n_cpu) 66 | 67 | # Calculate epoch markers 68 | total_epochs = train_args['cluster_epochs'] + train_args['return_epochs'] 69 | ret_stage = train_args['cluster_epochs'] 70 | 71 | print('Training...') 72 | 73 | dynamics_model.train() 74 | cluster_model.train() 75 | for epoch in range(total_epochs): 76 | pbar = tqdm(dataloader, total=len(dataloader)) 77 | total_loss = 0 78 | total_act_loss = 0 79 | total_ret_loss = 0 80 | total_dyn_loss = 0 81 | total_baseline_dyn_loss = 0 82 | total_batches = 0 83 | for obs, acts, ret, seq_len in pbar: 84 | total_batches += 1 85 | # Take an optimization step for the cluster model 86 | cluster_optimizer.zero_grad() 87 | obs = obs.to(device) 88 | acts = acts.to(device) 89 | ret = ret.to(device) / scale 90 | seq_len = seq_len.to(device) 91 | 92 | bsz, t = obs.shape[:2] 93 | 94 | act_mask = (acts.sum(dim=-1) == 0) 95 | obs_mask = (obs.view(bsz, t, -1)[:, :-1].sum(dim=-1) == 0) 96 | 97 | # Get the cluster predictions 98 | clusters, ret_pred, act_pred, _ = cluster_model( 99 | obs, acts, seq_len, hard=epoch >= ret_stage) 100 | 101 | pred_next_obs, next_obs = dynamics_model( 102 | obs, acts, clusters, seq_len) 103 | 104 | # Calculate the losses 105 | 106 | ret_loss = ((ret_pred.view(bsz, t) - ret.view(bsz, t)) ** 2).mean() 107 | act_loss = act_loss_fn(act_pred, acts).view(bsz, t)[ 108 | ~act_mask].mean() 109 | dynamics_loss = ((pred_next_obs - next_obs) ** 2)[~obs_mask].mean() 110 | 111 | # Calculate the total loss 112 | if epoch < ret_stage: 113 | loss = -train_args['adv_loss_weight'] * dynamics_loss + \ 114 | train_args['act_loss_weight'] * act_loss 115 | else: 116 | loss = ret_loss 117 | 118 | loss.backward() 119 | cluster_optimizer.step() 120 | 121 | # Take an optimization step for the dynamics model 122 | dynamics_optimizer.zero_grad() 123 | pred_next_obs, next_obs = dynamics_model( 124 | obs, acts, clusters.detach(), seq_len) 125 | baseline_pred_next_obs, _ = dynamics_model( 126 | obs, acts, torch.zeros_like(clusters), seq_len) 127 | dynamics_loss = ((pred_next_obs - next_obs) ** 2)[~obs_mask].mean() 128 | baseline_dynamics_loss = ( 129 | (baseline_pred_next_obs - next_obs) ** 2)[~obs_mask].mean() 130 | total_dynamics_loss = dynamics_loss + baseline_dynamics_loss 131 | total_dynamics_loss.backward() 132 | dynamics_optimizer.step() 133 | 134 | # Update the progress bar 135 | total_loss += loss.item() 136 | total_act_loss += act_loss.item() 137 | total_ret_loss += ret_loss.item() 138 | total_dyn_loss += dynamics_loss.item() 139 | total_baseline_dyn_loss += baseline_dynamics_loss.item() 140 | 141 | advantage = total_baseline_dyn_loss - total_dyn_loss 142 | 143 | pbar.set_description( 144 | f"Epoch {epoch} | Loss: {total_loss / total_batches:.4f} | Act Loss: {total_act_loss / total_batches:.4f} | Ret Loss: {total_ret_loss / total_batches:.4f} | Dyn Loss: {total_dyn_loss / total_batches:.4f} | Adv: {advantage / total_batches:.4f}") 145 | 146 | # Get the learned return labels 147 | avg_returns = [] 148 | for traj in tqdm(trajs): 149 | labels = learned_labels(traj, cluster_model, 150 | action_size, max_len, device, act_type) 151 | avg_returns.append(labels * scale) 152 | 153 | return avg_returns 154 | -------------------------------------------------------------------------------- /return_transforms/datasets/esper_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import IterableDataset 2 | import torch 3 | import numpy as np 4 | from return_transforms.utils.utils import return_labels 5 | 6 | 7 | class ESPERDataset(IterableDataset): 8 | 9 | rand: np.random.Generator 10 | 11 | def __init__(self, trajs, n_actions, horizon, gamma=1, act_type='discrete', 12 | epoch_len=1e5): 13 | self.trajs = trajs 14 | self.rets = [return_labels(traj, gamma) 15 | for traj in self.trajs] 16 | self.n_actions = n_actions 17 | self.horizon = horizon 18 | self.epoch_len = epoch_len 19 | self.act_type = act_type 20 | 21 | def segment_generator(self, epoch_len): 22 | for _ in range(epoch_len): 23 | traj_idx = self.rand.integers(len(self.trajs)) 24 | traj = self.trajs[traj_idx] 25 | rets = self.rets[traj_idx] 26 | if self.act_type == 'discrete': 27 | a = np.array(traj.actions) 28 | actions = np.zeros((a.size, self.n_actions)) 29 | actions[np.arange(a.size), a] = 1 30 | else: 31 | actions = np.array(traj.actions) 32 | obs = np.array(traj.obs) 33 | 34 | padded_obs = np.zeros((self.horizon, *obs.shape[1:])) 35 | padded_acts = np.zeros((self.horizon, self.n_actions)) 36 | padded_rets = np.zeros(self.horizon) 37 | 38 | padded_obs[-obs.shape[0]:] = obs 39 | padded_acts[-obs.shape[0]:] = actions 40 | padded_rets[-obs.shape[0]:] = np.array(rets) 41 | seq_length = obs.shape[0] 42 | 43 | yield torch.tensor(padded_obs).float(), \ 44 | torch.tensor(padded_acts).float(), \ 45 | torch.tensor(padded_rets).float(), \ 46 | torch.tensor(seq_length).long() 47 | 48 | def __len__(self): 49 | return int(self.epoch_len) 50 | 51 | def __iter__(self): 52 | worker_info = torch.utils.data.get_worker_info() 53 | self.rand = np.random.default_rng(None) 54 | if worker_info is None: # single-process data loading, return the full iterator 55 | gen = self.segment_generator(int(self.epoch_len)) 56 | else: # in a worker process 57 | # split workload 58 | per_worker_time_steps = int( 59 | self.epoch_len / float(worker_info.num_workers)) 60 | gen = self.segment_generator(per_worker_time_steps) 61 | return gen 62 | -------------------------------------------------------------------------------- /return_transforms/generate.py: -------------------------------------------------------------------------------- 1 | from email import generator 2 | import pickle 3 | from return_transforms.algos.esper.esper import esper 4 | from fire import Fire 5 | import yaml 6 | from pathlib import Path 7 | import numpy as np 8 | 9 | 10 | def load_config(config_path): 11 | return yaml.safe_load(Path(config_path).read_text()) 12 | 13 | 14 | def load_env(env_name): 15 | if env_name == 'connect_four': 16 | from stochastic_offline_envs.envs.offline_envs.connect_four_offline_env import ConnectFourOfflineEnv 17 | from stochastic_offline_envs.envs.connect_four.connect_four_env import GridWrapper 18 | # TODO: env should just deal with this automatically 19 | task = ConnectFourOfflineEnv() 20 | env = task.env_cls() 21 | env = GridWrapper(env) 22 | trajs = task.trajs 23 | for traj in trajs: 24 | for i in range(len(traj.obs)): 25 | traj.obs[i] = traj.obs[i]['grid'] 26 | return env, trajs 27 | elif env_name == 'tfe': 28 | from stochastic_offline_envs.envs.offline_envs.tfe_offline_env import TFEOfflineEnv 29 | task = TFEOfflineEnv() 30 | env = task.env_cls() 31 | trajs = task.trajs 32 | return env, trajs 33 | elif env_name == 'gambling': 34 | from stochastic_offline_envs.envs.offline_envs.gambling_offline_env import GamblingOfflineEnv 35 | task = GamblingOfflineEnv() 36 | env = task.env_cls() 37 | trajs = task.trajs 38 | return env, trajs 39 | # TODO: implement the rest 40 | 41 | 42 | def normalize_obs(trajs): 43 | obs_list = [] 44 | for traj in trajs: 45 | obs_list.extend(traj.obs) 46 | obs = np.array(obs_list) 47 | obs_mean = np.mean(obs, axis=0) 48 | obs_std = np.std(obs, axis=0) + 1e-8 49 | for traj in trajs: 50 | for i in range(len(traj.obs)): 51 | traj.obs[i] = (traj.obs[i] - obs_mean) / obs_std 52 | return trajs 53 | 54 | 55 | def generate(env_name, config, ret_file, device, n_cpu=2): 56 | print('Loading config...') 57 | config = load_config(config) 58 | 59 | if config['method'] == 'esper': 60 | print('Loading offline RL task...') 61 | env, trajs = load_env(env_name) 62 | 63 | if config['normalize']: 64 | print('Normalizing observations...') 65 | trajs = normalize_obs(trajs) 66 | 67 | print('Creating ESPER returns...') 68 | rets = esper(trajs, 69 | env.action_space, 70 | config['dynamics_model_args'], 71 | config['cluster_model_args'], 72 | config['train_args'], 73 | device, 74 | n_cpu) 75 | 76 | # Save the returns as a pickle 77 | print('Saving returns...') 78 | Path(ret_file).parent.mkdir(parents=True, exist_ok=True) 79 | with open(ret_file, 'wb') as f: 80 | pickle.dump(rets, f) 81 | 82 | else: 83 | raise NotImplementedError 84 | 85 | 86 | if __name__ == '__main__': 87 | Fire(generate) 88 | -------------------------------------------------------------------------------- /return_transforms/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/return_transforms/models/__init__.py -------------------------------------------------------------------------------- /return_transforms/models/basic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/return_transforms/models/basic/__init__.py -------------------------------------------------------------------------------- /return_transforms/models/basic/mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | A set of methods for building MLPs with different features. 3 | - batchnorm/layernorm 4 | - dropout 5 | - different activation functions 6 | - input size, output size, hidden size 7 | - different number of layers 8 | """ 9 | 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class MLP(nn.Module): 15 | def __init__(self, input_size, output_size, hidden_size, num_layers, activation, batchnorm=False, layernorm=False, dropout=0.0): 16 | super(MLP, self).__init__() 17 | self.input_size = input_size 18 | self.output_size = output_size 19 | self.hidden_size = hidden_size 20 | self.num_layers = num_layers 21 | 22 | if activation == 'relu': 23 | self.activation = nn.ReLU 24 | else: 25 | raise NotImplementedError 26 | 27 | self.batchnorm = batchnorm 28 | self.layernorm = layernorm 29 | self.dropout = dropout 30 | 31 | self.layers = nn.ModuleList() 32 | self.layers.append(nn.Linear(self.input_size, self.hidden_size)) 33 | if self.batchnorm: 34 | self.layers.append(nn.BatchNorm1d(self.hidden_size)) 35 | if self.layernorm: 36 | self.layers.append(nn.LayerNorm(self.hidden_size)) 37 | self.layers.append(self.activation()) 38 | if self.dropout > 0.0: 39 | self.layers.append(nn.Dropout(self.dropout)) 40 | for i in range(self.num_layers - 1): 41 | self.layers.append(nn.Linear(self.hidden_size, self.hidden_size)) 42 | if self.batchnorm: 43 | self.layers.append(nn.BatchNorm1d(self.hidden_size)) 44 | if self.layernorm: 45 | self.layers.append(nn.LayerNorm(self.hidden_size)) 46 | self.layers.append(self.activation()) 47 | if self.dropout > 0.0: 48 | self.layers.append(nn.Dropout(self.dropout)) 49 | self.layers.append(nn.Linear(self.hidden_size, self.output_size)) 50 | 51 | def forward(self, x): 52 | for layer in self.layers: 53 | x = layer(x) 54 | return x 55 | -------------------------------------------------------------------------------- /return_transforms/models/esper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/return_transforms/models/esper/__init__.py -------------------------------------------------------------------------------- /return_transforms/models/esper/cluster_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from return_transforms.models.basic.mlp import MLP 5 | from return_transforms.utils.utils import get_past_indices 6 | 7 | 8 | class ClusterModel(nn.Module): 9 | """ 10 | Model that takes an input trajectories in the form of obs and actions 11 | and returns a cluster assignment (in the form of discrete representations, of 12 | which there are `groups` many) for each timestep in the trajectory. 13 | 14 | In summary, the model performs the following steps (on trajectories that are front-padded): 15 | 1. Concatenate the observation and action at each timestep 16 | 2. Pass the concatenated vector through an MLP to get a representation 17 | 3. Reverse the sequence and pass the sequence through an LSTM to get a representation 18 | 4. Pass the LSTM hidden state through an MLP to get logits for the cluster assignment 19 | 5. Use gumbel softmax to sample a cluster assignment 20 | 21 | The model also includes the return-predictor and action-predictor, which predict returns and actions 22 | given the cluster assignments. 23 | """ 24 | 25 | def __init__(self, obs_size, action_size, rep_size, model_args, groups=4): 26 | super().__init__() 27 | self.obs_size = obs_size 28 | self.action_size = action_size 29 | self.rep_size = rep_size 30 | self.groups = groups 31 | 32 | self.hidden_size = model_args['obs_action_model']['hidden_size'] 33 | 34 | self.obs_action_model = MLP( 35 | obs_size + action_size, self.hidden_size, **model_args['obs_action_model']) 36 | self.lstm_model = nn.LSTM(self.hidden_size, 37 | self.hidden_size, 38 | batch_first=True) 39 | self.logit_model = MLP( 40 | self.hidden_size, rep_size, **model_args['logit_model']) 41 | 42 | self.ret_obs_action_model = MLP( 43 | obs_size + action_size, self.hidden_size, **model_args['ret_obs_action_model']) 44 | self.return_model = MLP(rep_size + self.hidden_size, 45 | 1, **model_args['return_model']) 46 | self.action_model = MLP(rep_size + obs_size, 47 | action_size, **model_args['action_model']) 48 | 49 | def forward(self, obs, action, seq_len, hidden=None, hard=False): 50 | bsz, t = obs.shape[:2] 51 | obs = obs.view(bsz, t, -1) 52 | 53 | # Concatenate observations and actions 54 | x = torch.cat([obs, action], dim=-1) 55 | 56 | # Reverse the sequence in time 57 | x = torch.flip(x, [1]).view(bsz * t, -1) 58 | 59 | # Pass through MLP to get representation 60 | obs_act_reps = self.obs_action_model(x).view(bsz, t, -1) 61 | 62 | # Use LSTM to get the representations for each suffix of the sequence 63 | if hidden is None: 64 | hidden = (torch.zeros(1, bsz, self.hidden_size).to(x.device), 65 | torch.zeros(1, bsz, self.hidden_size).to(x.device)) 66 | 67 | x, hidden = self.lstm_model(obs_act_reps, hidden) 68 | 69 | # Reverse the sequence in time again 70 | x = torch.flip(x, [1]).reshape(bsz * t, -1) 71 | 72 | # Pass through MLP to get logits for cluster assignment 73 | logits = self.logit_model(x) 74 | 75 | # Some inputs are padding (0), so we mask them out 76 | logits[obs.view(bsz * t, -1).sum(-1) == 0] = 0 77 | 78 | # Sample cluster assignment 79 | logits = logits.view(bsz * t, self.groups, -1) 80 | clusters = F.gumbel_softmax(logits, tau=1, hard=hard) 81 | clusters = clusters.view(bsz, t, -1) 82 | 83 | # ================ Compute return prediction ================ 84 | x = torch.cat([obs, action], dim=-1).view(bsz * t, -1) 85 | ret_obs_act_reps = self.ret_obs_action_model(x).view(bsz, t, -1) 86 | 87 | ret_input = torch.cat( 88 | [clusters.detach(), ret_obs_act_reps], dim=-1).view(bsz * t, -1) 89 | 90 | ret_pred = self.return_model(ret_input).view(bsz, t, -1) 91 | 92 | # ================ Compute action prediction ================ 93 | 94 | # First, we need to get the past indices 95 | idxs = get_past_indices(obs_act_reps, seq_len) 96 | idxs = idxs.view(bsz, t, 1).expand(bsz, t, self.rep_size) 97 | 98 | # Get cluster representations for the past 99 | past_cluster = torch.gather(clusters, 1, idxs) 100 | 101 | obs_context = torch.cat([obs, past_cluster], dim=-1).view(bsz * t, -1) 102 | act_pred = self.action_model(obs_context).view(bsz, t, -1) 103 | 104 | return clusters, ret_pred, act_pred, hidden 105 | 106 | def return_preds(self, obs, action, hard=False): 107 | """ 108 | Returns the return predictions for the given trajectories. 109 | """ 110 | bsz, t = obs.shape[:2] 111 | obs = obs.view(bsz, t, -1) 112 | 113 | # Concatenate observations and actions 114 | x = torch.cat([obs, action], dim=-1) 115 | 116 | # Reverse the sequence in time 117 | x = torch.flip(x, [1]).view(bsz * t, -1) 118 | 119 | # Pass through MLP to get representation 120 | obs_act_reps = self.obs_action_model(x).view(bsz, t, -1) 121 | 122 | # Use LSTM to get the representations for each suffix of the sequence 123 | hidden = (torch.zeros(1, bsz, self.hidden_size).to(x.device), 124 | torch.zeros(1, bsz, self.hidden_size).to(x.device)) 125 | 126 | x, hidden = self.lstm_model(obs_act_reps, hidden) 127 | 128 | # Reverse the sequence in time again 129 | x = torch.flip(x, [1]).reshape(bsz * t, -1) 130 | 131 | # Pass through MLP to get logits for cluster assignment 132 | logits = self.logit_model(x) 133 | 134 | # Some inputs are padding (0), so we mask them out 135 | logits[obs.view(bsz * t, -1).sum(-1) == 0] = 0 136 | 137 | # Sample cluster assignment 138 | logits = logits.view(bsz * t, self.groups, -1) 139 | clusters = F.gumbel_softmax(logits, tau=1, hard=hard) 140 | clusters = clusters.view(bsz, t, -1) 141 | 142 | # ================ Compute return prediction ================ 143 | x = torch.cat([obs, action], dim=-1).view(bsz * t, -1) 144 | ret_obs_act_reps = self.ret_obs_action_model(x).view(bsz, t, -1) 145 | 146 | ret_input = torch.cat( 147 | [clusters.detach(), ret_obs_act_reps], dim=-1).view(bsz * t, -1) 148 | 149 | ret_pred = self.return_model(ret_input).view(bsz, t, -1) 150 | 151 | # ================ Compute action prediction ================ 152 | 153 | return ret_pred, clusters 154 | -------------------------------------------------------------------------------- /return_transforms/models/esper/dynamics_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from return_transforms.models.basic.mlp import MLP 4 | from return_transforms.utils.utils import get_past_indices 5 | 6 | 7 | class DynamicsModel(nn.Module): 8 | """ 9 | Dynamics predictor that conditions on a trajectory representation. 10 | Uses MLPs, but can easily be extended to different architectures. 11 | 12 | During training, this model conditions itself on a trajectory representation 13 | from a random timestep in the past. 14 | """ 15 | 16 | def __init__(self, obs_size, action_size, rep_size, model_args): 17 | super().__init__() 18 | self.obs_size = obs_size 19 | self.action_size = action_size 20 | self.rep_size = rep_size 21 | self.dynamics_model = MLP( 22 | obs_size + action_size + rep_size, obs_size, **model_args) 23 | 24 | def forward(self, obs, action, cluster, seq_len): 25 | bsz, t = obs.shape[:2] 26 | obs = obs.view(bsz, t, -1) 27 | 28 | x = torch.cat([obs, action], dim=-1) 29 | 30 | idxs = get_past_indices(x, seq_len) 31 | idxs = idxs.view(bsz, t, 1).expand(bsz, t, self.rep_size) 32 | 33 | past_cluster = torch.gather(cluster, 1, idxs) 34 | 35 | # We don't condition on the last timestep since we don't have a next observation 36 | context = x[:, :-1] 37 | context = torch.cat([context, past_cluster[:, :-1]], dim=-1) 38 | next_obs = obs[:, 1:] 39 | 40 | pred_next_obs = self.dynamics_model( 41 | context.view(bsz * (t - 1), -1)).view(*next_obs.shape) 42 | 43 | return pred_next_obs, next_obs 44 | -------------------------------------------------------------------------------- /return_transforms/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keirp/return_transforms/faa35a25aaef9fd8fc47c38a3bb57f0fec8b10e4/return_transforms/utils/__init__.py -------------------------------------------------------------------------------- /return_transforms/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | """ESPER utils.""" 5 | 6 | 7 | def get_past_indices(x, seq_len): 8 | """ 9 | Note: this assumes that padding is actually before the sequence. 10 | 11 | Often we want to get a tensor of indices for another tensor of shape 12 | (bsz, T, ...). These indices (bsz, T) should be between the start of the 13 | non-padded inputs and T. This function returns such an index tensor. 14 | """ 15 | bsz, t = x.shape[:2] 16 | 17 | idxs = torch.randint(0, t, (bsz, t)).to(x) 18 | ts = torch.arange(0, t).view(1, t).expand(bsz, t).to(x) 19 | # Denotes how much padding is before each sequence 20 | pad_lens = t - seq_len.view(bsz, 1).expand(bsz, t) 21 | ts = ts - pad_lens + 1 # Shifts the indices so that the first non-padded index is 0 22 | 23 | # If ts == 0, then set idxs to 0. Otherwise, use the remainder of the division. 24 | idxs = torch.where(ts == 0, torch.zeros_like(idxs), idxs % ts) 25 | 26 | # Now add back the padding lengths 27 | idxs = idxs + pad_lens 28 | 29 | return idxs.long() 30 | 31 | 32 | def return_labels(traj, gamma=1): 33 | rewards = traj.rewards 34 | returns = [] 35 | ret = 0 36 | for reward in reversed(rewards): 37 | ret *= gamma 38 | ret += float(reward) 39 | returns.append(ret) 40 | returns = list(reversed(returns)) 41 | return returns 42 | 43 | 44 | def learned_labels(traj, label_model, n_actions, horizon, device, 45 | act_type='discrete'): 46 | with torch.no_grad(): 47 | label_model.eval() 48 | obs = np.array(traj.obs) 49 | if act_type == 'discrete': 50 | a = np.array(traj.actions) 51 | actions = np.zeros((a.size, n_actions)) 52 | actions[np.arange(a.size), a] = 1 53 | else: 54 | actions = np.array(traj.actions) 55 | 56 | labels = [] 57 | 58 | padded_obs = np.zeros((horizon, *obs.shape[1:])) 59 | padded_acts = np.zeros((horizon, n_actions)) 60 | 61 | padded_obs[-obs.shape[0]:] = obs 62 | padded_acts[-obs.shape[0]:] = actions 63 | 64 | padded_obs = torch.tensor(padded_obs).float().unsqueeze(0).to(device) 65 | padded_acts = torch.tensor(padded_acts).float().unsqueeze(0).to(device) 66 | 67 | labels, _ = label_model.return_preds( 68 | padded_obs, padded_acts, hard=True) 69 | labels = labels[0, -obs.shape[0]:].view(-1).cpu().detach().numpy() 70 | 71 | return np.around(labels, decimals=1) 72 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='return_transforms', 5 | version='1.0', 6 | description='', 7 | author='Keiran Paster', 8 | packages=['return_transforms'], 9 | ) 10 | --------------------------------------------------------------------------------