├── scripts ├── __init__.py ├── __pycache__ │ ├── agent.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── ofenet.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── networks.cpython-37.pyc │ └── replay_buffer.cpython-37.pyc ├── utils.py ├── replay_buffer.py ├── ofenet.py ├── networks.py └── agent.py ├── requirements.txt ├── README.md ├── sac_ofenet.py └── OFENet-REDQ-notebook.ipynb /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.4 2 | argparse 3 | torch 4 | torchvision 5 | tensorboard==2.4.0 6 | gym 7 | pybullet -------------------------------------------------------------------------------- /scripts/__pycache__/agent.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/OFENet/HEAD/scripts/__pycache__/agent.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/OFENet/HEAD/scripts/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/ofenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/OFENet/HEAD/scripts/__pycache__/ofenet.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/OFENet/HEAD/scripts/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/OFENet/HEAD/scripts/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/replay_buffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/OFENet/HEAD/scripts/__pycache__/replay_buffer.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OFENet 2 | 3 | PyTorch implementation of the [OFENet Paper](https://arxiv.org/abs/2003.01629). 4 | 5 | ## Work in progress - working now but still not as good as the paper performance 6 | 7 | If you might be interested in the work check out the notebook. Currently it is not working as described in the paper, if you find errors or bugs feel free to let me know or correct them. 8 | 9 | # Environment Setup 10 | 11 | 1. run: `conda create -n OFENet python=3.7` 12 | 2. enter the environment with `conda activate OFENet` 13 | 3 run the installation the requirement.txt file with: `pip install -r requirement.txt` 14 | 15 | # To run 16 | 17 | To run one experiment simply type: `python sac_ofenet.py` 18 | 19 | All results are logged with tensorboard to check them type: `tensorboard --logdir=runs` 20 | 21 | ## TODO: 22 | - fix hyperparameter saving bug 23 | - OFENet training got worse, should have final loss for halfcheetah of 0.005 but currently has 0.12. Before the changes that made ofenet work loss was as paper loss or even lower. (problem might be batch norm?) 24 | - Create plots for halfcheetah and other env 25 | - add target-dim loading or values in a table in the readme 26 | 27 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def timer(start,end, train_type="Training"): 4 | """ Helper to print training time """ 5 | hours, rem = divmod(end-start, 3600) 6 | minutes, seconds = divmod(rem, 60) 7 | print("\n{} Time: {:0>2}:{:0>2}:{:05.2f}".format(train_type, int(hours),int(minutes),seconds)) 8 | 9 | def fill_buffer(agent, env, samples=1000): 10 | collected_samples = 0 11 | 12 | state_size = env.observation_space.shape[0] 13 | state = env.reset() 14 | 15 | state = state.reshape((1, state_size)) 16 | for i in range(samples): 17 | 18 | action = env.action_space.sample() 19 | next_state, reward, done, info = env.step(action) 20 | next_state = next_state.reshape((1, state_size)) 21 | agent.memory.add(state, action, reward, next_state, done) 22 | collected_samples += 1 23 | state = next_state 24 | if done: 25 | state = env.reset() 26 | state = state.reshape((1, state_size)) 27 | print("Adding random samples to buffer done! Buffer size: ", agent.memory.__len__()) 28 | 29 | def pretrain_ofenet(agent, epochs, writer, target_dim): 30 | for ep in range(epochs): 31 | # ---------------------------- update OFENet ---------------------------- # 32 | ofenet_loss = agent.ofenet.train_ofenet(agent.memory.sample()) 33 | writer.add_scalar("OFENet-pretrainig-loss", ofenet_loss, ep) 34 | 35 | def get_target_dim(env_name): 36 | TARGET_DIM_DICT = { 37 | "AntBulletEnv-v0": 27, # originally 28 38 | "HalfCheetahBulletEnv-v0": 17, # originally 26 39 | "Walker2dBulletEnv-v0": 17, 40 | "HopperBulletEnv-v0": 11, # originally 15 41 | "ReacherBulletEnv-v0": 11, # originally 9 42 | "HumanoidBulletEnv-v0": 292, # originally 44 43 | "Pendulum-v0": 3, 44 | "LunarLanderContinuous-v2": 8 45 | } 46 | return TARGET_DIM_DICT[env_name] -------------------------------------------------------------------------------- /scripts/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer: 6 | """Fixed-size buffer to store experience tuples.""" 7 | 8 | def __init__(self, action_size, state_size, buffer_size, batch_size, seed, device): 9 | """Initialize a ReplayBuffer object. 10 | Params 11 | ====== 12 | buffer_size (int): maximum size of buffer 13 | batch_size (int): size of each training batch 14 | """ 15 | self.action_size = action_size 16 | self.state_size = state_size 17 | self.batch_size = batch_size 18 | self.states_array = np.empty((buffer_size, state_size)) 19 | self.next_states_array = np.empty((buffer_size, state_size)) 20 | self.actions_array = np.empty((buffer_size, action_size)) 21 | self.dones_array = np.empty((buffer_size, 1)) 22 | self.rewards_array = np.empty((buffer_size, 1)) 23 | self.n_samples = 0 24 | self.device = device 25 | 26 | 27 | def add(self, state, action, reward, next_state, done): 28 | """Add a new experience to memory.""" 29 | isinstance(state, np.ndarray) 30 | isinstance(next_state, np.ndarray) 31 | isinstance(action, np.ndarray) 32 | isinstance(reward, np.ndarray) 33 | isinstance(done, np.ndarray) 34 | 35 | self.states_array[self.n_samples, ...] = state 36 | self.next_states_array[self.n_samples, ...] = next_state 37 | self.actions_array[self.n_samples, ...] = action 38 | self.rewards_array[self.n_samples, ...] = reward 39 | self.dones_array[self.n_samples, ...] = done 40 | self.n_samples += 1 41 | 42 | 43 | 44 | def sample(self): 45 | """Randomly sample a batch of experiences from memory.""" 46 | idxs = np.random.randint(low=0, high=self.n_samples, size=self.batch_size) 47 | 48 | states = torch.tensor(self.states_array[idxs], dtype=torch.float, device=self.device) 49 | next_states = torch.tensor(self.next_states_array[idxs], dtype=torch.float, device=self.device) 50 | actions = torch.tensor(self.actions_array[idxs], dtype=torch.float, device=self.device) 51 | rewards = torch.tensor(self.rewards_array[idxs], dtype=torch.float, device=self.device) 52 | dones = torch.tensor(self.dones_array[idxs], dtype=torch.float, device=self.device) 53 | 54 | 55 | return (states, actions, rewards, next_states, dones) 56 | 57 | def __len__(self): 58 | """Return the current size of internal memory.""" 59 | return self.n_samples 60 | -------------------------------------------------------------------------------- /scripts/ofenet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from .networks import DenseNetBlock 6 | 7 | class DummyRepresentationLearner(): 8 | def __init__(self, state_size, action_size, target_dim, num_layer=4, hidden_size=40, batch_norm=True, activation="SiLU", device="cpu"): 9 | self.state_size = state_size 10 | self.action_size = action_size 11 | self.hidden_size = hidden_size 12 | self.num_layer = num_layer 13 | self.target_dim = target_dim 14 | 15 | def eval(self, ): 16 | pass 17 | 18 | def train(self, ): 19 | pass 20 | 21 | def forward(self, state, action): 22 | return torch.randn((state[0],self.target_dim)) 23 | 24 | def get_state_features(self, state): 25 | return state 26 | 27 | def get_state_action_features(self, state, action): 28 | return torch.cat((state, action), dim=1) 29 | 30 | def train_ofenet(self, experiences): 31 | return 0.0 32 | 33 | def get_action_state_dim(self,): 34 | return (self.state_size+self.action_size) 35 | 36 | def get_state_dim(self,): 37 | return self.state_size 38 | 39 | class OFENet(nn.Module): 40 | def __init__(self, state_size, action_size, target_dim, num_layer=4, hidden_size=40, batch_norm=True, activation="SiLU", device="cpu"): 41 | super(OFENet, self).__init__() 42 | self.state_size = state_size 43 | self.action_size = action_size 44 | self.hidden_size = hidden_size 45 | self.num_layer = num_layer 46 | self.target_dim = target_dim 47 | 48 | denseblock = DenseNetBlock 49 | state_layer = [] 50 | action_layer = [] 51 | 52 | for i in range(num_layer): 53 | state_layer += [denseblock(input_nodes=state_size+i*hidden_size, 54 | output_nodes=hidden_size, 55 | activation=activation, 56 | batch_norm=batch_norm, 57 | device=device)] 58 | 59 | self.state_layer_block = nn.ModuleList(state_layer) 60 | self.encode_state_out = state_size + (num_layer) * hidden_size 61 | action_block_input = self.encode_state_out + action_size 62 | 63 | for i in range(num_layer): 64 | action_layer += [denseblock(input_nodes=action_block_input+i*hidden_size, 65 | output_nodes=hidden_size, 66 | activation=activation, 67 | batch_norm=batch_norm, 68 | device=device)] 69 | self.action_layer_block = nn.ModuleList(action_layer) 70 | 71 | self.pred_layer = nn.Linear((state_size+(2*num_layer)*hidden_size)+action_size, target_dim) 72 | self.optim = optim.Adam(params=self.parameters(), lr=3e-4) 73 | 74 | def forward(self, state, action): 75 | features = state 76 | for layer in self.state_layer_block: 77 | features = layer(features, trainable=True) 78 | features = torch.cat((features, action), dim=1) 79 | for layer in self.action_layer_block: 80 | features = layer(features, trainable=True) 81 | pred = self.pred_layer(features) 82 | return pred 83 | 84 | def get_state_features(self, state): 85 | with torch.no_grad(): 86 | for layer in self.state_layer_block: 87 | state = layer(state, trainable=False) 88 | return state 89 | 90 | def get_state_action_features(self, state, action): 91 | with torch.no_grad(): 92 | for layer in self.state_layer_block: 93 | state = layer(state, trainable=False) 94 | assert not state.requires_grad 95 | 96 | action_cat = torch.cat((state, action), dim=1) 97 | 98 | for layer in self.action_layer_block: 99 | action_cat = layer(action_cat, trainable=False) 100 | 101 | return action_cat 102 | 103 | def train_ofenet(self, experiences): 104 | states, actions, rewards, next_states, dones = experiences 105 | # ---------------------------- update OFENet ---------------------------- # 106 | pred = self.forward(states, actions) 107 | target_states = next_states[:, :self.target_dim] 108 | ofenet_loss = (pred - target_states).pow(2).mean() 109 | 110 | self.optim.zero_grad() 111 | ofenet_loss.backward() 112 | self.optim.step() 113 | return ofenet_loss.item() 114 | 115 | def get_action_state_dim(self,): 116 | return (self.state_size+(2*self.num_layer)*self.hidden_size)+self.action_size 117 | 118 | def get_state_dim(self,): 119 | return self.encode_state_out 120 | -------------------------------------------------------------------------------- /scripts/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Normal 4 | 5 | 6 | def hidden_init(layer): 7 | fan_in = layer.weight.data.size()[0] 8 | lim = 1. / np.sqrt(fan_in) 9 | return (-lim, lim) 10 | 11 | 12 | class Actor(nn.Module): 13 | """Actor (Policy) Model.""" 14 | 15 | def __init__(self, state_size, action_size, seed, hidden_size=256, init_w=3e-3, log_std_min=-20, log_std_max=2): 16 | """Initialize parameters and build model. 17 | Params 18 | ====== 19 | state_size (int): Dimension of each state 20 | action_size (int): Dimension of each action 21 | seed (int): Random seed 22 | fc1_units (int): Number of nodes in first hidden layer 23 | fc2_units (int): Number of nodes in second hidden layer 24 | """ 25 | super(Actor, self).__init__() 26 | torch.manual_seed(seed) 27 | self.log_std_min = log_std_min 28 | self.log_std_max = log_std_max 29 | 30 | self.fc1 = nn.Linear(state_size, hidden_size) 31 | self.fc2 = nn.Linear(hidden_size, hidden_size) 32 | 33 | self.mu = nn.Linear(hidden_size, action_size) 34 | self.log_std_linear = nn.Linear(hidden_size, action_size) 35 | 36 | 37 | def reset_parameters(self): 38 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 39 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 40 | self.mu.weight.data.uniform_(-init_w, init_w) 41 | self.log_std_linear.weight.data.uniform_(-init_w, init_w) 42 | 43 | def forward(self, state): 44 | 45 | x = torch.relu(self.fc1(state)) 46 | x = torch.relu(self.fc2(x)) 47 | 48 | mu = self.mu(x) 49 | 50 | log_std = self.log_std_linear(x) 51 | log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) 52 | return mu, log_std 53 | 54 | def sample(self, state, epsilon=1e-6): 55 | mu, log_std = self.forward(state) 56 | std = log_std.exp() 57 | dist = Normal(mu, std) 58 | e = dist.rsample().to(mu.device) 59 | action = torch.tanh(e) 60 | log_prob = (dist.log_prob(e) - torch.log(1 - action.pow(2) + epsilon)).sum(1, keepdim=True) 61 | 62 | return action, log_prob, torch.tanh(mu) 63 | 64 | class Critic(nn.Module): 65 | """Critic (Value) Model.""" 66 | 67 | def __init__(self, input_size, seed, hidden_size=256): 68 | """Initialize parameters and build model. 69 | Params 70 | ====== 71 | state_size (int): Dimension of each state 72 | action_size (int): Dimension of each action 73 | seed (int): Random seed 74 | hidden_size (int): Number of nodes in the network layers 75 | """ 76 | super(Critic, self).__init__() 77 | torch.manual_seed(seed) 78 | self.fc1 = nn.Linear(input_size, hidden_size) 79 | self.fc2 = nn.Linear(hidden_size, hidden_size) 80 | self.fc3 = nn.Linear(hidden_size, 1) 81 | #self.reset_parameters() 82 | 83 | def reset_parameters(self): 84 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 85 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 86 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 87 | 88 | def forward(self, state_action): 89 | """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" 90 | 91 | x = torch.relu(self.fc1(state_action)) 92 | x = torch.relu(self.fc2(x)) 93 | return self.fc3(x) 94 | 95 | 96 | class DenseNetBlock(nn.Module): 97 | def __init__(self, input_nodes, output_nodes, activation, batch_norm=False, device="cpu"): 98 | super(DenseNetBlock, self).__init__() 99 | self.device = device 100 | self.do_batch_norm = batch_norm 101 | if batch_norm: 102 | self.layer = nn.Linear(input_nodes, output_nodes, bias=True).to(device) 103 | nn.init.xavier_uniform_(self.layer.weight) 104 | nn.init.zeros_(self.layer.bias) 105 | self.batch_norm = nn.BatchNorm1d(output_nodes).to(device) #, momentum=0.99, eps=0.001 106 | else: 107 | self.layer = nn.Linear(input_nodes, output_nodes).to(device) 108 | 109 | if activation == "SiLU": 110 | self.act = nn.SiLU() 111 | elif activation == "ReLU": 112 | self.act = nn.ReLU() 113 | else: 114 | print("Activation Function can not be selected!") 115 | 116 | def forward(self, x, trainable): 117 | 118 | identity_map = x 119 | features = self.layer(x) 120 | # check if this is needed! 121 | if trainable == False: 122 | features = features.detach() 123 | assert not features.requires_grad 124 | 125 | if self.do_batch_norm and trainable: 126 | features = self.batch_norm(features) 127 | 128 | features = self.act(features) 129 | assert features.shape[0] == identity_map.shape[0], "features: {} | identity: {}".format(features.shape, identity_map.shape) 130 | 131 | features = torch.cat((features, identity_map), dim=1) 132 | return features 133 | -------------------------------------------------------------------------------- /sac_ofenet.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pybullet_envs 4 | from collections import deque 5 | import argparse 6 | import torch 7 | import json 8 | import gym 9 | import time 10 | 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from scripts.utils import timer, fill_buffer, pretrain_ofenet, get_target_dim 14 | from scripts.replay_buffer import ReplayBuffer 15 | from scripts.agent import REDQ_Agent 16 | from scripts.ofenet import OFENet, DummyRepresentationLearner 17 | 18 | 19 | 20 | def evaluate(frame, eval_runs=5, capture=False): 21 | """ 22 | Makes an evaluation run with the current epsilon 23 | """ 24 | 25 | reward_batch = [] 26 | for i in range(eval_runs): 27 | state = eval_env.reset() 28 | 29 | rewards = 0 30 | while True: 31 | action = agent.eval_(np.expand_dims(state, axis=0)) 32 | action_v = np.clip(action, action_low, action_high) 33 | state, reward, done, _ = eval_env.step(action_v) 34 | rewards += reward 35 | if done: 36 | break 37 | reward_batch.append(rewards) 38 | if capture == False: 39 | writer.add_scalar("Test_Reward", np.mean(reward_batch), frame) 40 | 41 | 42 | def train(steps, precollected, agent): 43 | scores_deque = deque(maxlen=100) 44 | average_100_scores = [] 45 | scores = [] 46 | losses = [] 47 | 48 | state = env.reset() 49 | state = state.reshape((1, state_size)) 50 | score = 0 51 | i_episode = 1 52 | for step in range(precollected+1, steps+1): 53 | 54 | # eval runs 55 | if step % args.eval_every == 0 or step == precollected+1: 56 | evaluate(step, args.eval_runs) 57 | 58 | action = agent.act(state) 59 | action_v = action.numpy() 60 | action_v = np.clip(action_v, action_low, action_high) 61 | next_state, reward, done, info = env.step(action_v) 62 | next_state = next_state.reshape((1, state_size)) 63 | ofenet_loss, a_loss, c_loss = agent.step(state, action, reward, next_state, done) 64 | state = next_state 65 | score += reward 66 | if done: 67 | scores_deque.append(score) 68 | scores.append(score) 69 | average_100_scores.append(np.mean(scores_deque)) 70 | current_step = step - precollected 71 | writer.add_scalar("Average100", np.mean(scores_deque), current_step) 72 | writer.add_scalar("Train_Reward", score, current_step) 73 | writer.add_scalar("OFENet loss", ofenet_loss, current_step) 74 | writer.add_scalar("Actor loss", a_loss, current_step) 75 | writer.add_scalar("Critic loss", c_loss, current_step) 76 | print('\rEpisode {} Env. Step: [{}/{}] Reward: {:.2f} Average100 Score: {:.2f} ofenet_loss: {:.3f}, a_loss: {:.3f}, c_loss: {:.3f}'.format(i_episode, step, steps, score, np.mean(scores_deque), ofenet_loss, a_loss, c_loss)) 77 | state = env.reset() 78 | state = state.reshape((1, state_size)) 79 | score = 0 80 | i_episode += 1 81 | 82 | return scores 83 | 84 | 85 | parser = argparse.ArgumentParser(description="") 86 | parser.add_argument("--env", type=str, default="HalfCheetahBulletEnv-v0", 87 | help="Environment name, default = HalfCheetahBulletEnv-v0") 88 | parser.add_argument("--info", type=str, default="SAC-OFENet", 89 | help="Information or name of the run") 90 | parser.add_argument("--steps", type=int, default=1_000_000, 91 | help="The amount of training interactions with the environment, default is 1mio") 92 | parser.add_argument("--N", type=int, default=2, 93 | help="Number of Q-network ensemble, default is 10") 94 | parser.add_argument("--M", type=int, default=2, 95 | help="Numbe of subsample set of the emsemble for updating the agent, default is 2 (currently only supports 2!)") 96 | parser.add_argument("--G", type=int, default=1, 97 | help="Update-to-Data (UTD) ratio, updates taken per step with the environment, default=20") 98 | parser.add_argument("--eval_every", type=int, default=10_000, 99 | help="Number of interactions after which the evaluation runs are performed, default = 10.000") 100 | parser.add_argument("--eval_runs", type=int, default=1, 101 | help="Number of evaluation runs performed, default = 1") 102 | parser.add_argument("--seed", type=int, default=0, 103 | help="Seed for the env and torch network weights, default is 0") 104 | parser.add_argument("--lr", type=float, default=3e-4, 105 | help="Actor learning rate of adapting the network weights, default is 3e-4") 106 | parser.add_argument("--layer_size", type=int, default=256, 107 | help="Number of nodes per neural network layer, default is 256") 108 | parser.add_argument("--replay_memory", type=int, default=int(1e6), 109 | help="Size of the Replay memory, default is 1e6") 110 | parser.add_argument("-bs", "--batch_size", type=int, default=256, 111 | help="Batch size, default is 256") 112 | parser.add_argument("-t", "--tau", type=float, default=0.005, 113 | help="Softupdate factor tau, default is 0.005") 114 | parser.add_argument("-g", "--gamma", type=float, default=0.99, 115 | help="discount factor gamma, default is 0.99") 116 | parser.add_argument("--ofenet_layer", type=int, default=8, 117 | help="Number of dense layer in each (state/action) block of the ofenet network, (default: 8)") 118 | parser.add_argument("--ofenet_size", type=int, default=30, help="Size of each Dense-Layer, (default: 30)") 119 | parser.add_argument("--collect_random", type=int, default=10_000, 120 | help="Number of randomly collected transitions to pretrain the OFENet, (default: 10.000)") 121 | parser.add_argument("--batch_norm", type=int, default=1, choices=[0,1], 122 | help="Add batch norm to the OFENet, default: 1") 123 | parser.add_argument("--activation", type=str, default="SiLU", choices=["SiLU", "ReLU"], 124 | help="Type of activation function for the ofenet network, choose between SiLU and ReLU, default: SiLU") 125 | parser.add_argument("--ofenet", type=int, default=1, choices=[0,1], help="Using OFENet feature extractor, default: True") 126 | 127 | args = parser.parse_args() 128 | 129 | 130 | 131 | if __name__ == "__main__": 132 | 133 | writer = SummaryWriter("runs/"+args.info) 134 | env = gym.make(args.env) 135 | eval_env = gym.make(args.env) 136 | action_high = env.action_space.high[0] 137 | seed = args.seed 138 | action_low = env.action_space.low[0] 139 | torch.manual_seed(seed) 140 | env.seed(seed) 141 | eval_env.seed(seed+1) 142 | np.random.seed(seed) 143 | state_size = env.observation_space.shape[0] 144 | action_size = env.action_space.shape[0] 145 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 146 | target_dim = get_target_dim(args.env) 147 | 148 | replay_buffer = ReplayBuffer(action_size, state_size, args.replay_memory, args.batch_size, seed, device) 149 | 150 | if args.ofenet: 151 | extractor = OFENet(state_size, 152 | action_size, 153 | target_dim=target_dim, 154 | num_layer=args.ofenet_layer, 155 | hidden_size=args.ofenet_size, 156 | batch_norm=args.batch_norm, 157 | activation=args.activation, 158 | device=device).to(device) 159 | print(extractor) 160 | else: 161 | extractor = DummyRepresentationLearner(state_size, 162 | action_size, 163 | target_dim=target_dim, 164 | num_layer=args.ofenet_layer, 165 | hidden_size=30, 166 | batch_norm=args.batch_norm, 167 | activation=args.activation, 168 | device=device) 169 | 170 | agent = REDQ_Agent(state_size=state_size, 171 | action_size=action_size, 172 | replay_buffer=replay_buffer, 173 | ofenet=extractor, 174 | random_seed=seed, 175 | lr=args.lr, 176 | hidden_size=args.layer_size, 177 | gamma=args.gamma, 178 | tau=args.tau, 179 | device=device, 180 | action_prior="uniform", 181 | N=args.N, 182 | M=args.M, 183 | G=args.G) 184 | 185 | fill_buffer(samples=args.collect_random, 186 | agent=agent, 187 | env=env) 188 | if args.ofenet: 189 | t0 = time.time() 190 | pretrain_ofenet(agent=agent, 191 | epochs=args.collect_random, 192 | writer=writer, 193 | target_dim=target_dim) 194 | t1 = time.time() 195 | timer(t0, t1, train_type="Pre-Training") 196 | # untrained eval run 197 | evaluate(0, args.eval_runs) 198 | 199 | t0 = time.time() 200 | final_average100 = train(steps=args.steps, 201 | precollected=args.collect_random, 202 | agent=agent) 203 | t1 = time.time() 204 | env.close() 205 | timer(t0, t1) 206 | 207 | # save parameter 208 | #with open('runs/'+args.info+".json", 'w') as f: 209 | # json.dump(args.__dict__, f, indent=2) 210 | #hparams = vars(args) 211 | #metric = {"final average 100 train reward": final_average100} 212 | #writer.add_hparams(hparams, metric) -------------------------------------------------------------------------------- /scripts/agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.optim as optim 5 | from .networks import Actor, Critic 6 | from .ofenet import OFENet 7 | import torch.nn.functional as F 8 | from torch.nn.utils import clip_grad_norm_ 9 | 10 | 11 | class REDQ_Agent(): 12 | """Interacts with and learns from the environment.""" 13 | 14 | def __init__(self, 15 | state_size, 16 | action_size, 17 | replay_buffer, 18 | ofenet, 19 | lr=3e-4, 20 | hidden_size=401, 21 | random_seed=0, 22 | device="cpu", 23 | action_prior="uniform", 24 | gamma=0.99, 25 | tau=0.005, 26 | N=2, 27 | M=2, 28 | G=1): 29 | """Initialize an Agent object. 30 | 31 | Params 32 | ====== 33 | state_size (int): dimension of each state 34 | action_size (int): dimension of each action 35 | random_seed (int): random seed 36 | """ 37 | self.device = device 38 | feature_size = state_size 39 | self.action_size = action_size 40 | feature_action_size = feature_size+action_size 41 | self.seed = random.seed(random_seed) 42 | self.hidden_size = hidden_size 43 | self.gamma = gamma 44 | self.tau = tau 45 | 46 | self.target_entropy = -action_size # -dim(A) 47 | self.log_alpha = torch.tensor([0.0], requires_grad=True) 48 | self.alpha = self.log_alpha.exp().detach() 49 | self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=lr) 50 | self._action_prior = action_prior 51 | self.alphas = [] 52 | print("Using: ", device) 53 | 54 | # REDQ parameter 55 | self.N = N # number of critics in the ensemble 56 | self.M = M # number of target critics that are randomly selected 57 | self.G = G # Updates per step ~ UTD-ratio 58 | 59 | # split state and action ~ weird step but to keep critic inputs consistent 60 | self.ofenet = ofenet 61 | feature_size = self.ofenet.get_state_dim() 62 | feature_action_size = self.ofenet.get_action_state_dim() 63 | 64 | # Actor Network 65 | self.actor_local = Actor(feature_size, action_size, random_seed, hidden_size=self.hidden_size).to(device) 66 | self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=lr) 67 | 68 | # Critic Network (w/ Target Network) 69 | self.critics = [] 70 | self.target_critics = [] 71 | self.optims = [] 72 | for i in range(self.N): 73 | critic = Critic(feature_action_size, i, hidden_size=self.hidden_size).to(device) 74 | 75 | optimizer = optim.Adam(critic.parameters(), lr=lr, weight_decay=0) 76 | self.optims.append(optimizer) 77 | self.critics.append(critic) 78 | target = Critic(feature_action_size, i, hidden_size=self.hidden_size).to(device) 79 | self.target_critics.append(target) 80 | 81 | # Replay memory 82 | self.memory = replay_buffer 83 | 84 | 85 | def step(self, state, action, reward, next_state, done): 86 | """Save experience in replay memory, and use random sample from buffer to learn.""" 87 | # Save experience / reward 88 | self.memory.add(state, action, reward, next_state, done) 89 | 90 | # Learn, if enough samples are available in memory 91 | actor_loss, critic1_loss, ofenet_loss = 0, 0, 0 92 | for update in range(self.G): 93 | if len(self.memory) > self.memory.batch_size: 94 | ofenet_loss = self.ofenet.train_ofenet(self.memory.sample()) 95 | experiences = self.memory.sample() 96 | actor_loss, critic1_loss = self.learn(update, experiences) 97 | return ofenet_loss, actor_loss, critic1_loss 98 | 99 | def act(self, state): 100 | """Returns actions for given state as per current policy.""" 101 | 102 | state = torch.from_numpy(state).float().to(self.device) 103 | self.actor_local.eval() 104 | 105 | with torch.no_grad(): 106 | self.ofenet.eval() 107 | state = self.ofenet.get_state_features(state) 108 | action, _, _ = self.actor_local.sample(state) 109 | self.actor_local.train() 110 | return action.detach().cpu()[0] 111 | 112 | def eval_(self, state): 113 | state = torch.from_numpy(state).float().to(self.device) 114 | self.actor_local.eval() 115 | 116 | with torch.no_grad(): 117 | self.ofenet.eval() 118 | state = self.ofenet.get_state_features(state) 119 | _, _ , action = self.actor_local.sample(state) 120 | self.actor_local.train() 121 | return action.detach().cpu()[0] 122 | 123 | def learn(self, step, experiences): 124 | """Updates actor, critics and entropy_alpha parameters using given batch of experience tuples. 125 | Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state)) 126 | Critic_loss = MSE(Q, Q_target) 127 | Actor_loss = α * log_pi(a|s) - Q(s,a) 128 | where: 129 | actor_target(state) -> action 130 | critic_target(state, action) -> Q-value 131 | Params 132 | ====== 133 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 134 | gamma (float): discount factor 135 | """ 136 | states, actions, rewards, next_states, dones = experiences 137 | 138 | # sample target critics 139 | idx = np.random.choice(len(self.critics), self.M, replace=False) # replace=False so that not picking the same idx twice 140 | 141 | 142 | # ---------------------------- update critic ---------------------------- # 143 | 144 | with torch.no_grad(): 145 | # Get predicted next-state actions and Q values from target models 146 | 147 | next_state_features = self.ofenet.get_state_features(next_states) 148 | next_action, next_log_prob, _ = self.actor_local.sample(next_state_features) 149 | next_state_action_features = self.ofenet.get_state_action_features(next_states, next_action) 150 | 151 | # TODO: make this variable for possible more than tnext_state_action_featureswo target critics 152 | Q_target1_next = self.target_critics[idx[0]](next_state_action_features) 153 | Q_target2_next = self.target_critics[idx[1]](next_state_action_features) 154 | assert not next_state_features.requires_grad, "next_state_features have gradient but shouldnt!!" 155 | # take the min of both critics for updating 156 | Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha.to(next_action.device) * next_log_prob 157 | 158 | Q_targets = 5.0 * rewards.cpu() + (self.gamma * (1 - dones.cpu()) * Q_target_next.cpu()) # 5.0* (reward_scale) 159 | assert not Q_targets.requires_grad, "Q_targets have gradients but shouldnt!" 160 | # Compute critic losses and update critics 161 | 162 | state_action_features = self.ofenet.get_state_action_features(states, actions) 163 | assert not state_action_features.requires_grad, "State_action_features have gradients but shouldnt!" 164 | for critic, optim, target in zip(self.critics, self.optims, self.target_critics): 165 | Q = critic(state_action_features).cpu() 166 | Q_loss = 0.5 * F.mse_loss(Q, Q_targets) 167 | 168 | # Update critic 169 | optim.zero_grad() 170 | Q_loss.backward() 171 | # add clip gradients? 172 | optim.step() 173 | # soft update of the targets 174 | self.soft_update(critic, target) 175 | 176 | # ---------------------------- update actor ---------------------------- # 177 | if step == self.G-1: 178 | 179 | state_features = self.ofenet.get_state_features(states) 180 | 181 | assert not state_features.requires_grad, "state features have gradients but shouldnt!" 182 | actions_pred, log_prob, _ = self.actor_local.sample(state_features) 183 | 184 | state_action_features = self.ofenet.get_state_action_features(states, actions_pred) 185 | 186 | assert state_action_features.requires_grad, "state_action_features should have gradients!" 187 | # TODO: make this variable for possible more than two critics 188 | 189 | Q1 = self.critics[idx[0]](state_action_features) 190 | Q2 = self.critics[idx[1]](state_action_features) 191 | Q = torch.min(Q1,Q2).cpu() 192 | 193 | actor_loss = (self.alpha * log_prob.cpu() - Q).mean() 194 | # Optimize the actor loss 195 | self.actor_optimizer.zero_grad() 196 | actor_loss.backward() 197 | # add clip gradients? 198 | self.actor_optimizer.step() 199 | 200 | # Compute alpha loss 201 | alpha_loss = - (self.log_alpha.exp() * (log_prob.cpu() + self.target_entropy).detach().cpu()).mean() 202 | 203 | self.alpha_optimizer.zero_grad() 204 | alpha_loss.backward() 205 | self.alpha_optimizer.step() 206 | self.alpha = self.log_alpha.exp().detach() 207 | self.alphas.append(self.alpha.detach()) 208 | 209 | return actor_loss.item(), Q_loss.item() 210 | 211 | def soft_update(self, local_model, target_model): 212 | """Soft update model parameters. 213 | θ_target = τ*θ_local + (1 - τ)*θ_target 214 | Params 215 | ====== 216 | local_model: PyTorch model (weights will be copied from) 217 | target_model: PyTorch model (weights will be copied to) 218 | tau (float): interpolation parameter 219 | """ 220 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 221 | target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data) -------------------------------------------------------------------------------- /OFENet-REDQ-notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 95, 6 | "id": "english-mention", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "\n", 11 | "import numpy as np\n", 12 | "import random\n", 13 | "\n", 14 | "import gym\n", 15 | "import pybullet_envs\n", 16 | "from collections import namedtuple, deque\n", 17 | "import torch\n", 18 | "import torch.nn as nn\n", 19 | "import torch.nn.functional as F\n", 20 | "\n", 21 | "from torch.distributions import Normal, MultivariateNormal\n", 22 | "\n", 23 | "import torch.optim as optim\n", 24 | "import time\n", 25 | "from torch.utils.tensorboard import SummaryWriter\n", 26 | "import argparse\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import tqdm\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "impressed-junction", 34 | "metadata": {}, 35 | "source": [ 36 | "# Networks" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 96, 42 | "id": "renewable-rating", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "\n", 47 | "\n", 48 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 49 | "\n", 50 | "def hidden_init(layer):\n", 51 | " fan_in = layer.weight.data.size()[0]\n", 52 | " lim = 1. / np.sqrt(fan_in)\n", 53 | " return (-lim, lim)\n", 54 | "\n", 55 | "class Actor(nn.Module):\n", 56 | " \"\"\"Actor (Policy) Model.\"\"\"\n", 57 | "\n", 58 | " def __init__(self, state_size, action_size, seed, hidden_size=256, init_w=3e-3, log_std_min=-20, log_std_max=2):\n", 59 | " \"\"\"Initialize parameters and build model.\n", 60 | " Params\n", 61 | " ======\n", 62 | " state_size (int): Dimension of each state\n", 63 | " action_size (int): Dimension of each action\n", 64 | " seed (int): Random seed\n", 65 | " fc1_units (int): Number of nodes in first hidden layer\n", 66 | " fc2_units (int): Number of nodes in second hidden layer\n", 67 | " \"\"\"\n", 68 | " super(Actor, self).__init__()\n", 69 | " torch.manual_seed(seed)\n", 70 | " self.log_std_min = log_std_min\n", 71 | " self.log_std_max = log_std_max\n", 72 | " \n", 73 | " self.fc1 = nn.Linear(state_size, hidden_size)\n", 74 | " self.fc2 = nn.Linear(hidden_size, hidden_size)\n", 75 | " \n", 76 | " self.mu = nn.Linear(hidden_size, action_size)\n", 77 | " self.log_std_linear = nn.Linear(hidden_size, action_size)\n", 78 | "\n", 79 | "\n", 80 | " def reset_parameters(self):\n", 81 | " self.fc1.weight.data.uniform_(*hidden_init(self.fc1))\n", 82 | " self.fc2.weight.data.uniform_(*hidden_init(self.fc2))\n", 83 | " self.mu.weight.data.uniform_(-init_w, init_w)\n", 84 | " self.log_std_linear.weight.data.uniform_(-init_w, init_w)\n", 85 | "\n", 86 | " def forward(self, state):\n", 87 | "\n", 88 | " x = F.relu(self.fc1(state))\n", 89 | " x = F.relu(self.fc2(x))\n", 90 | "\n", 91 | " mu = self.mu(x)\n", 92 | "\n", 93 | " log_std = self.log_std_linear(x)\n", 94 | " log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)\n", 95 | " return mu, log_std\n", 96 | " \n", 97 | " def sample(self, state, epsilon=1e-6):\n", 98 | " mu, log_std = self.forward(state)\n", 99 | " std = log_std.exp()\n", 100 | " dist = Normal(mu, std)\n", 101 | " e = dist.rsample().to(device)\n", 102 | " action = torch.tanh(e)\n", 103 | " log_prob = (dist.log_prob(e) - torch.log(1 - action.pow(2) + epsilon)).sum(1, keepdim=True)\n", 104 | "\n", 105 | " return action, log_prob, torch.tanh(mu)\n", 106 | " \n", 107 | "\n", 108 | "class Critic(nn.Module):\n", 109 | " \"\"\"Critic (Value) Model.\"\"\"\n", 110 | "\n", 111 | " def __init__(self, input_size, seed, hidden_size=256):\n", 112 | " \"\"\"Initialize parameters and build model.\n", 113 | " Params\n", 114 | " ======\n", 115 | " state_size (int): Dimension of each state\n", 116 | " action_size (int): Dimension of each action\n", 117 | " seed (int): Random seed\n", 118 | " hidden_size (int): Number of nodes in the network layers\n", 119 | " \"\"\"\n", 120 | " super(Critic, self).__init__()\n", 121 | " torch.manual_seed(seed)\n", 122 | " self.fc1 = nn.Linear(input_size, hidden_size)\n", 123 | " self.fc2 = nn.Linear(hidden_size, hidden_size)\n", 124 | " self.fc3 = nn.Linear(hidden_size, 1)\n", 125 | " #self.reset_parameters()\n", 126 | "\n", 127 | " def reset_parameters(self):\n", 128 | " self.fc1.weight.data.uniform_(*hidden_init(self.fc1))\n", 129 | " self.fc2.weight.data.uniform_(*hidden_init(self.fc2))\n", 130 | " self.fc3.weight.data.uniform_(-3e-3, 3e-3)\n", 131 | "\n", 132 | " def forward(self, state_action):\n", 133 | " \"\"\"Build a critic (value) network that maps (state, action) pairs -> Q-values.\"\"\"\n", 134 | "\n", 135 | " x = F.relu(self.fc1(state_action))\n", 136 | " x = F.relu(self.fc2(x))\n", 137 | " return self.fc3(x)\n", 138 | " " 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 97, 144 | "id": "powerful-exposure", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "class OFENet(nn.Module):\n", 149 | " def __init__(self, state_size, action_size, target_dim, num_layer=4, hidden_size=40):\n", 150 | " super(OFENet, self).__init__()\n", 151 | " self.state_size = state_size\n", 152 | " self.action_size = action_size\n", 153 | " self.hidden_size = hidden_size\n", 154 | " self.num_layer = num_layer\n", 155 | " self.target_dim = target_dim\n", 156 | " \n", 157 | " denseblock = DenseNetBlock\n", 158 | " state_layer = []\n", 159 | " action_layer = []\n", 160 | " \n", 161 | " for i in range(num_layer):\n", 162 | " state_layer += [denseblock(input_nodes=state_size+i*hidden_size,\n", 163 | " output_nodes=hidden_size,\n", 164 | " activation=\"SiLU\",\n", 165 | " batch_norm=True)]\n", 166 | " \n", 167 | " self.state_layer_block = nn.Sequential(*state_layer)\n", 168 | " self.encode_state_out = state_size + (num_layer) * hidden_size\n", 169 | " action_block_input = self.encode_state_out + action_size\n", 170 | " \n", 171 | " for i in range(num_layer):\n", 172 | " action_layer += [denseblock(input_nodes=action_block_input+i*hidden_size,\n", 173 | " output_nodes=hidden_size,\n", 174 | " activation=\"SiLU\",\n", 175 | " batch_norm=True)]\n", 176 | " self.action_layer_block = nn.Sequential(*action_layer)\n", 177 | "\n", 178 | " self.pred_layer = nn.Linear((state_size+(2*num_layer)*hidden_size)+action_size, target_dim)\n", 179 | " \n", 180 | " def forward(self, state, action):\n", 181 | " features = state\n", 182 | " features = self.state_layer_block(features)\n", 183 | " features = torch.cat((features, action), dim=1)\n", 184 | " features = self.action_layer_block(features)\n", 185 | " pred = self.pred_layer(features)\n", 186 | "\n", 187 | " return pred\n", 188 | " \n", 189 | " def get_state_features(self, state):\n", 190 | " self.state_layer_block.eval()\n", 191 | " with torch.no_grad():\n", 192 | " z0 = self.state_layer_block(state)\n", 193 | " self.state_layer_block.train()\n", 194 | " return z0\n", 195 | " \n", 196 | " def get_state_action_features(self, state, action):\n", 197 | " self.state_layer_block.eval()\n", 198 | " self.action_layer_block.eval()\n", 199 | " with torch.no_grad():\n", 200 | " z0 = self.state_layer_block(state)\n", 201 | " action_cat = torch.cat((z0, action), dim=1)\n", 202 | " z0_a = self.action_layer_block(action_cat)\n", 203 | " self.state_layer_block.train()\n", 204 | " self.action_layer_block.train()\n", 205 | " return z0_a\n", 206 | " \n", 207 | " def train_ofenet(self, experiences, optim):\n", 208 | " states, actions, rewards, next_states, dones = experiences\n", 209 | " # ---------------------------- update OFENet ---------------------------- #\n", 210 | " pred = self.forward(states, actions)\n", 211 | " target_states = next_states[:, :self.target_dim]\n", 212 | " ofenet_loss = (target_states - pred).pow(2).mean()\n", 213 | " \n", 214 | "\n", 215 | " optim.zero_grad()\n", 216 | " ofenet_loss.backward()\n", 217 | " optim.step()\n", 218 | " return ofenet_loss.item()\n", 219 | " \n", 220 | " def get_action_state_dim(self,):\n", 221 | " return (self.state_size+(2*self.num_layer)*self.hidden_size)+self.action_size\n", 222 | " \n", 223 | " def get_state_dim(self,):\n", 224 | " return self.encode_state_out" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 98, 230 | "id": "foreign-blogger", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "class DenseNetBlock(nn.Module):\n", 235 | " def __init__(self, input_nodes, output_nodes, activation, batch_norm=False):\n", 236 | " super(DenseNetBlock, self).__init__()\n", 237 | " \n", 238 | " \n", 239 | " self.do_batch_norm = batch_norm\n", 240 | " if batch_norm:\n", 241 | " self.layer = nn.Linear(input_nodes, output_nodes, bias=True)\n", 242 | " nn.init.xavier_uniform_(self.layer.weight)\n", 243 | " nn.init.zeros_(self.layer.bias)\n", 244 | " self.batch_norm = nn.BatchNorm1d(output_nodes) #, momentum=0.99, eps=0.001\n", 245 | " else:\n", 246 | " self.layer = nn.Linear(input_nodes, output_nodes)\n", 247 | " if activation == \"SiLU\":\n", 248 | " self.act = nn.SiLU()\n", 249 | " elif activation == \"ReLU\":\n", 250 | " self.act = nn.ReLU()\n", 251 | " else:\n", 252 | " print(\"Activation Function can not be selected!\")\n", 253 | " \n", 254 | " def forward(self, x):\n", 255 | " identity_map = x\n", 256 | " features = self.layer(x)\n", 257 | "\n", 258 | " if self.do_batch_norm:\n", 259 | " features = self.batch_norm(features)\n", 260 | " features = self.act(features)\n", 261 | " assert features.shape[0] == identity_map.shape[0], \"features: {} | identity: {}\".format(features.shape, identity_map.shape)\n", 262 | " features = torch.cat((features, identity_map), dim=1)\n", 263 | " return features" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 99, 269 | "id": "critical-framework", 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "def train(steps, precollected, print_every=10):\n", 274 | " scores_deque = deque(maxlen=100)\n", 275 | " average_100_scores = []\n", 276 | " scores = []\n", 277 | " losses = []\n", 278 | "\n", 279 | " state = env.reset()\n", 280 | " state = state.reshape((1, state_size))\n", 281 | " score = 0\n", 282 | " i_episode = 1\n", 283 | " for step in range(precollected+1, steps+1):\n", 284 | "\n", 285 | " action = agent.act(state)\n", 286 | " action_v = action.numpy()\n", 287 | " action_v = np.clip(action_v, action_low, action_high)\n", 288 | " next_state, reward, done, info = env.step(action_v)\n", 289 | " next_state = next_state.reshape((1, state_size))\n", 290 | " ofenet_loss, a_loss, c_loss = agent.step(state, action, reward, next_state, done)\n", 291 | " state = next_state\n", 292 | " score += reward\n", 293 | " if done:\n", 294 | " scores_deque.append(score)\n", 295 | " scores.append(score)\n", 296 | " average_100_scores.append(np.mean(scores_deque))\n", 297 | " losses.append((ofenet_loss, a_loss, c_loss))\n", 298 | " print('\\rEpisode {} Frame: [{}/{}] Reward: {:.2f} Average100 Score: {:.2f} ofenet_loss: {:.3f}, a_loss: {:.3f}, c_loss: {:.3f}'.format(i_episode, step, steps, score, np.mean(scores_deque), ofenet_loss, a_loss, c_loss))\n", 299 | " state = env.reset()\n", 300 | " state = state.reshape((1, state_size))\n", 301 | " score = 0\n", 302 | " i_episode += 1\n", 303 | " \n", 304 | "\n", 305 | " return scores" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 100, 311 | "id": "important-bruce", 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "data": { 316 | "text/plain": [ 317 | "Linear(in_features=26, out_features=30, bias=True)" 318 | ] 319 | }, 320 | "execution_count": 100, 321 | "metadata": {}, 322 | "output_type": "execute_result" 323 | } 324 | ], 325 | "source": [ 326 | "agent.ofenet.state_layer_block[0].layer" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 101, 332 | "id": "confused-prison", 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "def fill_buffer(samples=1000):\n", 337 | " collected_samples = 0\n", 338 | " \n", 339 | " state = env.reset()\n", 340 | " state = state.reshape((1, state_size))\n", 341 | " for i in range(samples):\n", 342 | " \n", 343 | " action = env.action_space.sample()\n", 344 | " next_state, reward, done, info = env.step(action)\n", 345 | " next_state = next_state.reshape((1, state_size))\n", 346 | " agent.memory.add(state, action, reward, next_state, done)\n", 347 | " collected_samples += 1\n", 348 | " state = next_state\n", 349 | " if done:\n", 350 | " state = env.reset()\n", 351 | " state = state.reshape((1, state_size))\n", 352 | " print(\"Adding random samples to buffer done! Buffer size: \", agent.memory.__len__())\n", 353 | " \n", 354 | "def pretrain_ofenet(agent, epochs):\n", 355 | " losses = []\n", 356 | "\n", 357 | " for ep in range(epochs):\n", 358 | " states, actions, rewards, next_states, dones = agent.memory.sample()\n", 359 | " # ---------------------------- update OFENet ---------------------------- #\n", 360 | " pred = agent.ofenet.forward(states, actions)\n", 361 | " targets = next_states[:,:17]\n", 362 | " ofenet_loss = (targets-pred).pow(2).mean()\n", 363 | " agent.ofenet_optim.zero_grad()\n", 364 | " ofenet_loss.backward()\n", 365 | " agent.ofenet_optim.step()\n", 366 | " losses.append(ofenet_loss.item())\n", 367 | " plt.plot(losses)\n", 368 | " plt.show()\n", 369 | " print(losses[-1])\n", 370 | " return agent" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "id": "regulation-pepper", 376 | "metadata": {}, 377 | "source": [ 378 | "# Train SAC" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "id": "convinced-collectible", 384 | "metadata": {}, 385 | "source": [ 386 | "# Train OFENet REDQ" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 102, 392 | "id": "nonprofit-cabin", 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "class REDQ_Agent():\n", 397 | " \"\"\"Interacts with and learns from the environment.\"\"\"\n", 398 | " \n", 399 | " def __init__(self, state_size, action_size, random_seed, device, action_prior=\"uniform\", N=2, M=2, G=1):\n", 400 | " \"\"\"Initialize an Agent object.\n", 401 | " \n", 402 | " Params\n", 403 | " ======\n", 404 | " state_size (int): dimension of each state\n", 405 | " action_size (int): dimension of each action\n", 406 | " random_seed (int): random seed\n", 407 | " \"\"\"\n", 408 | " self.state_size = state_size\n", 409 | " self.action_size = action_size\n", 410 | " self.seed = random.seed(random_seed)\n", 411 | " self.hidden_size = 256\n", 412 | " \n", 413 | " self.target_entropy = -action_size # -dim(A)\n", 414 | " self.log_alpha = torch.tensor([0.0], requires_grad=True)\n", 415 | " self.alpha = self.log_alpha.exp().detach()\n", 416 | " self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=lr) \n", 417 | " self._action_prior = action_prior\n", 418 | " self.alphas = []\n", 419 | " print(\"Using: \", device)\n", 420 | " \n", 421 | " # REDQ parameter\n", 422 | " self.N = N # number of critics in the ensemble\n", 423 | " self.M = M # number of target critics that are randomly selected\n", 424 | " self.G = G # Updates per step ~ UTD-ratio\n", 425 | " \n", 426 | " ofenet_size = 30\n", 427 | " self.ofenet = OFENet(state_size, action_size, target_dim=17, num_layer=8, hidden_size=ofenet_size).to(device)\n", 428 | " # TODO: CHECK ADAM PARAMS WITH TF AND PAPER\n", 429 | " self.ofenet_optim = optim.Adam(self.ofenet.parameters(), lr=3e-4, eps=1e-07) \n", 430 | " print(self.ofenet)\n", 431 | "\n", 432 | " # split state and action ~ weird step but to keep critic inputs consistent\n", 433 | " feature_size = self.ofenet.get_state_dim()\n", 434 | " feature_action_size = self.ofenet.get_action_state_dim()\n", 435 | " \n", 436 | " # Actor Network \n", 437 | " self.actor_local = Actor(feature_size, action_size, random_seed, hidden_size=self.hidden_size).to(device)\n", 438 | " self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=lr) \n", 439 | " \n", 440 | " # Critic Network (w/ Target Network)\n", 441 | " self.critics = []\n", 442 | " self.target_critics = []\n", 443 | " self.optims = []\n", 444 | " for i in range(self.N):\n", 445 | " critic = Critic(feature_action_size, i, hidden_size=self.hidden_size).to(device)\n", 446 | "\n", 447 | " optimizer = optim.Adam(critic.parameters(), lr=lr, weight_decay=0)\n", 448 | " self.optims.append(optimizer)\n", 449 | " self.critics.append(critic)\n", 450 | " target = Critic(feature_action_size, i, hidden_size=self.hidden_size).to(device)\n", 451 | " self.target_critics.append(target)\n", 452 | "\n", 453 | "\n", 454 | " # Replay memory\n", 455 | " self.memory = ReplayBuffer(action_size, buffer_size, batch_size, random_seed)\n", 456 | " \n", 457 | "\n", 458 | " def step(self, state, action, reward, next_state, done):\n", 459 | " \"\"\"Save experience in replay memory, and use random sample from buffer to learn.\"\"\"\n", 460 | " # Save experience / reward\n", 461 | " self.memory.add(state, action, reward, next_state, done)\n", 462 | "\n", 463 | " # Learn, if enough samples are available in memory\n", 464 | " actor_loss, critic1_loss, ofenet_loss = 0, 0, 0\n", 465 | " for update in range(self.G):\n", 466 | " if len(self.memory) > batch_size:\n", 467 | " ofenet_loss = self.ofenet.train_ofenet(self.memory.sample(), self.ofenet_optim)\n", 468 | " experiences = self.memory.sample()\n", 469 | " actor_loss, critic1_loss = self.learn(update, experiences, gamma)\n", 470 | " return ofenet_loss, actor_loss, critic1_loss # future ofenet_loss\n", 471 | " \n", 472 | " def act(self, state):\n", 473 | " \"\"\"Returns actions for given state as per current policy.\"\"\"\n", 474 | " state = torch.from_numpy(state).float().to(device)\n", 475 | " self.actor_local.eval()\n", 476 | " self.ofenet.eval()\n", 477 | " with torch.no_grad():\n", 478 | " state = self.ofenet.get_state_features(state)\n", 479 | " action, _, _ = self.actor_local.sample(state)\n", 480 | " self.actor_local.train()\n", 481 | " self.ofenet.train()\n", 482 | " return action.detach().cpu()[0]\n", 483 | " \n", 484 | " def eval_(self, state):\n", 485 | " state = torch.from_numpy(state).float().to(device)\n", 486 | " self.actor_local.eval()\n", 487 | " self.ofenet.eval()\n", 488 | " with torch.no_grad():\n", 489 | " state = self.ofenet.get_state_features(state)\n", 490 | " _, _ , action = self.actor_local.sample(state)\n", 491 | " self.actor_local.train()\n", 492 | " self.ofenet.train()\n", 493 | " return action.detach().cpu()[0]\n", 494 | " \n", 495 | " def learn(self, step, experiences, gamma):\n", 496 | " \"\"\"Updates actor, critics and entropy_alpha parameters using given batch of experience tuples.\n", 497 | " Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state))\n", 498 | " Critic_loss = MSE(Q, Q_target)\n", 499 | " Actor_loss = α * log_pi(a|s) - Q(s,a)\n", 500 | " where:\n", 501 | " actor_target(state) -> action\n", 502 | " critic_target(state, action) -> Q-value\n", 503 | " Params\n", 504 | " ======\n", 505 | " experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples \n", 506 | " gamma (float): discount factor\n", 507 | " \"\"\"\n", 508 | " states, actions, rewards, next_states, dones = experiences\n", 509 | "\n", 510 | " # sample target critics\n", 511 | " idx = np.random.choice(len(self.critics), self.M, replace=False) # replace=False so that not picking the same idx twice\n", 512 | " \n", 513 | "\n", 514 | " # ---------------------------- update critic ---------------------------- #\n", 515 | "\n", 516 | " with torch.no_grad():\n", 517 | " # Get predicted next-state actions and Q values from target models\n", 518 | " next_state_features = self.ofenet.get_state_features(next_states)\n", 519 | " next_action, next_log_prob, _ = self.actor_local.sample(next_state_features)\n", 520 | " next_state_action_features = self.ofenet.get_state_action_features(next_states, next_action) #get_state_action_features\n", 521 | " # TODO: make this variable for possible more than tnext_state_action_featureswo target critics\n", 522 | " Q_target1_next = self.target_critics[idx[0]](next_state_action_features)\n", 523 | " Q_target2_next = self.target_critics[idx[1]](next_state_action_features)\n", 524 | " \n", 525 | " # take the min of both critics for updating\n", 526 | " Q_target_next = torch.min(Q_target1_next, Q_target2_next) - self.alpha.to(device) * next_log_prob\n", 527 | "\n", 528 | " Q_targets = 5.0*rewards.cpu() + (gamma * (1 - dones.cpu()) * Q_target_next.cpu())\n", 529 | "\n", 530 | " # Compute critic losses and update critics \n", 531 | " state_action_features = self.ofenet.get_state_action_features(states, actions)\n", 532 | " for critic, optim, target in zip(self.critics, self.optims, self.target_critics):\n", 533 | " Q = critic(state_action_features).cpu()\n", 534 | " Q_loss = 0.5*F.mse_loss(Q, Q_targets)\n", 535 | " \n", 536 | " # Update critic\n", 537 | " optim.zero_grad()\n", 538 | " Q_loss.backward()\n", 539 | " optim.step()\n", 540 | " # soft update of the targets\n", 541 | " self.soft_update(critic, target)\n", 542 | " \n", 543 | " # ---------------------------- update actor ---------------------------- #\n", 544 | " if step == self.G-1:\n", 545 | " state_features = self.ofenet.get_state_features(states)\n", 546 | " actions_pred, log_prob, _ = self.actor_local.sample(state_features) \n", 547 | " \n", 548 | " state_action_features = self.ofenet.get_state_action_features(states, actions_pred)\n", 549 | " # TODO: make this variable for possible more than two critics\n", 550 | " Q1 = self.critics[idx[0]](state_action_features).cpu()\n", 551 | " Q2 = self.critics[idx[0]](state_action_features).cpu()\n", 552 | " Q = torch.min(Q1,Q2)\n", 553 | " actor_loss = (self.alpha * log_prob.cpu() - Q).mean()\n", 554 | " # Optimize the actor loss\n", 555 | " self.actor_optimizer.zero_grad()\n", 556 | " actor_loss.backward()\n", 557 | " self.actor_optimizer.step()\n", 558 | "\n", 559 | " # Compute alpha loss \n", 560 | " alpha_loss = - (self.log_alpha.exp() * (log_prob.cpu() + self.target_entropy).detach().cpu()).mean()\n", 561 | "\n", 562 | " self.alpha_optimizer.zero_grad()\n", 563 | " alpha_loss.backward()\n", 564 | " self.alpha_optimizer.step()\n", 565 | " self.alpha = self.log_alpha.exp().detach()\n", 566 | " self.alphas.append(self.alpha.detach())\n", 567 | " \n", 568 | " return actor_loss.item(), Q_loss.item()\n", 569 | "\n", 570 | " \n", 571 | " def soft_update(self, local_model, target_model):\n", 572 | " \"\"\"Soft update model parameters.\n", 573 | " θ_target = τ*θ_local + (1 - τ)*θ_target\n", 574 | " Params\n", 575 | " ======\n", 576 | " local_model: PyTorch model (weights will be copied from)\n", 577 | " target_model: PyTorch model (weights will be copied to)\n", 578 | " tau (float): interpolation parameter \n", 579 | " \"\"\"\n", 580 | " for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):\n", 581 | " target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)\n", 582 | "\n", 583 | "class ReplayBuffer:\n", 584 | " \"\"\"Fixed-size buffer to store experience tuples.\"\"\"\n", 585 | "\n", 586 | " def __init__(self, action_size, buffer_size, batch_size, seed):\n", 587 | " \"\"\"Initialize a ReplayBuffer object.\n", 588 | " Params\n", 589 | " ======\n", 590 | " buffer_size (int): maximum size of buffer\n", 591 | " batch_size (int): size of each training batch\n", 592 | " \"\"\"\n", 593 | " self.action_size = action_size\n", 594 | " self.memory = deque(maxlen=buffer_size) # internal memory (deque)\n", 595 | " self.batch_size = batch_size\n", 596 | " self.experience = namedtuple(\"Experience\", field_names=[\"state\", \"action\", \"reward\", \"next_state\", \"done\"])\n", 597 | " self.seed = random.seed(seed)\n", 598 | " \n", 599 | " def add(self, state, action, reward, next_state, done):\n", 600 | " \"\"\"Add a new experience to memory.\"\"\"\n", 601 | " e = self.experience(state, action, reward, next_state, done)\n", 602 | " self.memory.append(e)\n", 603 | " \n", 604 | " def sample(self):\n", 605 | " \"\"\"Randomly sample a batch of experiences from memory.\"\"\"\n", 606 | " experiences = random.sample(self.memory, k=self.batch_size)\n", 607 | " \n", 608 | " states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)\n", 609 | " actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(device)\n", 610 | " rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)\n", 611 | " next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)\n", 612 | " dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)\n", 613 | "\n", 614 | " return (states, actions, rewards, next_states, dones)\n", 615 | "\n", 616 | " def __len__(self):\n", 617 | " \"\"\"Return the current size of internal memory.\"\"\"\n", 618 | " return len(self.memory)" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": 103, 624 | "id": "derived-tsunami", 625 | "metadata": {}, 626 | "outputs": [ 627 | { 628 | "name": "stdout", 629 | "output_type": "stream", 630 | "text": [ 631 | "Using: cpu\n", 632 | "OFENet(\n", 633 | " (state_layer_block): Sequential(\n", 634 | " (0): DenseNetBlock(\n", 635 | " (layer): Linear(in_features=26, out_features=30, bias=True)\n", 636 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 637 | " (act): SiLU()\n", 638 | " )\n", 639 | " (1): DenseNetBlock(\n", 640 | " (layer): Linear(in_features=56, out_features=30, bias=True)\n", 641 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 642 | " (act): SiLU()\n", 643 | " )\n", 644 | " (2): DenseNetBlock(\n", 645 | " (layer): Linear(in_features=86, out_features=30, bias=True)\n", 646 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 647 | " (act): SiLU()\n", 648 | " )\n", 649 | " (3): DenseNetBlock(\n", 650 | " (layer): Linear(in_features=116, out_features=30, bias=True)\n", 651 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 652 | " (act): SiLU()\n", 653 | " )\n", 654 | " (4): DenseNetBlock(\n", 655 | " (layer): Linear(in_features=146, out_features=30, bias=True)\n", 656 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 657 | " (act): SiLU()\n", 658 | " )\n", 659 | " (5): DenseNetBlock(\n", 660 | " (layer): Linear(in_features=176, out_features=30, bias=True)\n", 661 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 662 | " (act): SiLU()\n", 663 | " )\n", 664 | " (6): DenseNetBlock(\n", 665 | " (layer): Linear(in_features=206, out_features=30, bias=True)\n", 666 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 667 | " (act): SiLU()\n", 668 | " )\n", 669 | " (7): DenseNetBlock(\n", 670 | " (layer): Linear(in_features=236, out_features=30, bias=True)\n", 671 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 672 | " (act): SiLU()\n", 673 | " )\n", 674 | " )\n", 675 | " (action_layer_block): Sequential(\n", 676 | " (0): DenseNetBlock(\n", 677 | " (layer): Linear(in_features=272, out_features=30, bias=True)\n", 678 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 679 | " (act): SiLU()\n", 680 | " )\n", 681 | " (1): DenseNetBlock(\n", 682 | " (layer): Linear(in_features=302, out_features=30, bias=True)\n", 683 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 684 | " (act): SiLU()\n", 685 | " )\n", 686 | " (2): DenseNetBlock(\n", 687 | " (layer): Linear(in_features=332, out_features=30, bias=True)\n", 688 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 689 | " (act): SiLU()\n", 690 | " )\n", 691 | " (3): DenseNetBlock(\n", 692 | " (layer): Linear(in_features=362, out_features=30, bias=True)\n", 693 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 694 | " (act): SiLU()\n", 695 | " )\n", 696 | " (4): DenseNetBlock(\n", 697 | " (layer): Linear(in_features=392, out_features=30, bias=True)\n", 698 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 699 | " (act): SiLU()\n", 700 | " )\n", 701 | " (5): DenseNetBlock(\n", 702 | " (layer): Linear(in_features=422, out_features=30, bias=True)\n", 703 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 704 | " (act): SiLU()\n", 705 | " )\n", 706 | " (6): DenseNetBlock(\n", 707 | " (layer): Linear(in_features=452, out_features=30, bias=True)\n", 708 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 709 | " (act): SiLU()\n", 710 | " )\n", 711 | " (7): DenseNetBlock(\n", 712 | " (layer): Linear(in_features=482, out_features=30, bias=True)\n", 713 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 714 | " (act): SiLU()\n", 715 | " )\n", 716 | " )\n", 717 | " (pred_layer): Linear(in_features=512, out_features=17, bias=True)\n", 718 | ")\n", 719 | "Trainable OFENet Parameter: 132081\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "\n", 725 | "env_name = \"HalfCheetahBulletEnv-v0\" #\"HalfCheetahPyBulletEnv-v0\"#\"Pendulum-v0\"\n", 726 | "max_steps = 1_000_000\n", 727 | "seed = 1\n", 728 | "#Hyperparameter\n", 729 | "lr = 3e-4\n", 730 | "buffer_size = int(1e6)\n", 731 | "batch_size = 256\n", 732 | "tau = 0.005\n", 733 | "gamma = 0.99\n", 734 | "\n", 735 | "random_collect = 10000\n", 736 | "\n", 737 | "# RED-Q Parameter\n", 738 | "N = 2\n", 739 | "M = 2\n", 740 | "G = 1\n", 741 | "\n", 742 | "#writer = SummaryWriter(\"runs/\"+args.info)\n", 743 | "env = gym.make(env_name)\n", 744 | "action_high = env.action_space.high[0]\n", 745 | "action_low = env.action_space.low[0]\n", 746 | "torch.manual_seed(seed)\n", 747 | "env.seed(seed)\n", 748 | "np.random.seed(seed)\n", 749 | "state_size = env.observation_space.shape[0]\n", 750 | "action_size = env.action_space.shape[0]\n", 751 | "agent = REDQ_Agent(state_size=state_size,\n", 752 | " action_size=action_size,\n", 753 | " random_seed=seed,\n", 754 | " device=device,\n", 755 | " action_prior=\"uniform\", N=N, M=M, G=G)\n", 756 | "\n", 757 | "def count_parameters(model):\n", 758 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 759 | "print(\"Trainable OFENet Parameter: \", count_parameters(agent.ofenet))" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 104, 765 | "id": "patent-questionnaire", 766 | "metadata": {}, 767 | "outputs": [ 768 | { 769 | "name": "stdout", 770 | "output_type": "stream", 771 | "text": [ 772 | "Adding random samples to buffer done! Buffer size: 10000\n" 773 | ] 774 | }, 775 | { 776 | "data": { 777 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdlklEQVR4nO3de5RdZZ3m8e9Tt9wIuZBSIBeSQGiNogTLgI1iq1yi9CTajW1QV+MMszIqGRnp6ekwKPTEoQfpXsw0y7TCaGYcRwwo3VpLw6SRi8o4xFQkXBIMqQRIqrikIKECudT1N3+cXXGfS6VOUqc4lZ3ns9ZZ7P3u/e7z7trhqV3vfs95FRGYmVl21VS7AWZmNrIc9GZmGeegNzPLOAe9mVnGOejNzDKurtoNKDRt2rSYPXt2tZthZnZc2bhx4ysR0Vhq26gL+tmzZ9PS0lLtZpiZHVckPT/YNnfdmJllnIPezCzjHPRmZhnnoDczyzgHvZlZxjnozcwyzkFvZpZxmQn6A9293PbPW3ls595qN8XMbFQpK+glLZK0VVKrpBVH2O9PJYWkplTZ9Um9rZIuq0SjSznY3cftD7byZHvnSL2FmdlxachPxkqqBVYBlwBtwAZJzRGxpWC/icC1wPpU2XxgKfAO4HTg55LOjoi+yp1CPs+jYmaWr5w7+oVAa0TsiIhuYA2wpMR+XwO+DhxKlS0B1kREV0Q8C7Qmx6s4SQB4xiwzs3zlBP10YFdqvS0pO0zSecDMiPjZ0dZN6i+T1CKppaOjo6yGFx3jmGqZmWXfsB/GSqoBbgP+4liPERF3RkRTRDQ1Npb88rXyjzWs2mZm2VPOt1e2AzNT6zOSsgETgXcCDyfdJ6cCzZIWl1G3YuRbejOzksq5o98AzJM0R1IDuYerzQMbI6IzIqZFxOyImA08CiyOiJZkv6WSxkiaA8wDflPxs0hxF72ZWb4h7+gjolfScmAdUAusjojNklYCLRHRfIS6myXdA2wBeoFrRmrEjdxLb2ZWUlkTj0TEWmBtQdmNg+z7RwXrNwM3H2P7jppv6M3M8mXmk7EDN/QeXmlmli8zQe+HsWZmpWUm6M3MrLTMBL1v6M3MSstM0A9wF72ZWb7MBP3h77rxuBszszzZCfpqN8DMbJTKTNAPcNeNmVm+zAS9h1eamZWWmaAf4Bt6M7N8mQl6f9eNmVlpmQn6Ae6jNzPLl5mgH+ij9/BKM7N8mQl6MzMrLXNB764bM7N8ZQW9pEWStkpqlbSixPbPS3pS0iZJj0ian5TPlnQwKd8k6VuVPoHft2GkjmxmdnwbcuIRSbXAKuASoA3YIKk5IrakdrsrIr6V7L+Y3GThi5Jt2yPi3Iq22szMylbOHf1CoDUidkREN7AGWJLeISL2pVYnUIXh7B5eaWZWWjlBPx3YlVpvS8rySLpG0nbgVuBLqU1zJD0m6ReSPjCs1pbBM0yZmeWr2MPYiFgVEWcCfwV8JSl+EZgVEQuA64C7JJ1cWFfSMkktklo6OjqO6f0PD690zpuZ5Skn6NuBman1GUnZYNYAHweIiK6IeDVZ3ghsB84urBARd0ZEU0Q0NTY2ltn0fO64MTMrrZyg3wDMkzRHUgOwFGhO7yBpXmr1cmBbUt6YPMxF0lxgHrCjEg0fjG/ozczyDTnqJiJ6JS0H1gG1wOqI2CxpJdASEc3AckkXAz3AXuCqpPpFwEpJPUA/8PmI2DMSJyKPrzQzK2nIoAeIiLXA2oKyG1PL1w5S717g3uE08Gi5j97MLF9mPhk7cD/v77oxM8uXnaB3z42ZWUmZCfoB7roxM8uXmaD3w1gzs9IyE/QDfENvZpYvc0FvZmb5shf07qQ3M8uTqaCX3HVjZlYoW0Ff7QaYmY1CmQp6cM+NmVmhTAW9h1iamRXLVNCDvwLBzKxQpoJeuOvGzKxQtoLePTdmZkUyFfTg4ZVmZoUyFfTyAEszsyJlBb2kRZK2SmqVtKLE9s9LelLSJkmPSJqf2nZ9Um+rpMsq2fhS3EdvZpZvyKBP5nxdBXwUmA9cmQ7yxF0RcU5EnAvcCtyW1J1Pbo7ZdwCLgH8YmEN2RPiG3sysSDl39AuB1ojYERHdwBpgSXqHiNiXWp3A77vKlwBrIqIrIp4FWpPjjRgPrzQzy1fOnLHTgV2p9Tbg/MKdJF0DXAc0AB9O1X20oO70EnWXAcsAZs2aVU67SxL4aayZWYGKPYyNiFURcSbwV8BXjrLunRHRFBFNjY2Nx9wGD680MytWTtC3AzNT6zOSssGsAT5+jHWHzTf0Zmb5ygn6DcA8SXMkNZB7uNqc3kHSvNTq5cC2ZLkZWCppjKQ5wDzgN8NvdmkeXmlmVmzIPvqI6JW0HFgH1AKrI2KzpJVAS0Q0A8slXQz0AHuBq5K6myXdA2wBeoFrIqJvhM5loL0jeXgzs+NOOQ9jiYi1wNqCshtTy9ceoe7NwM3H2sCj4T56M7NimfpkLPgDU2ZmhTIV9MIPY83MCmUr6N13Y2ZWJFNBD+66MTMrlKmg9/28mVmxTAU9+LtuzMwKZSvo5a4bM7NCmQp6d92YmRXLVNCbmVmxTAW9h1eamRXLVNCDv+vGzKxQpoLeN/RmZsUyFfTgr0AwMyuUqaAXHl5pZlYoW0HvvhszsyJlBb2kRZK2SmqVtKLE9uskbZH0hKQHJJ2R2tYnaVPyai6sW2n+ZKyZWb4hJx6RVAusAi4B2oANkpojYktqt8eApog4IOkLwK3Ap5JtByPi3Mo2e5C2vhlvYmZ2nCnnjn4h0BoROyKim9zk30vSO0TEQxFxIFl9lNwk4FXhPnozs3zlBP10YFdqvS0pG8zVwH2p9bGSWiQ9KunjR9/E8kkedWNmVqisOWPLJemzQBPwwVTxGRHRLmku8KCkJyNie0G9ZcAygFmzZg2nBcOoa2aWTeXc0bcDM1PrM5KyPJIuBm4AFkdE10B5RLQn/90BPAwsKKwbEXdGRFNENDU2Nh7VCRQfa1jVzcwyp5yg3wDMkzRHUgOwFMgbPSNpAXAHuZDfnSqfImlMsjwNuBBIP8StKI+uNDMrNmTXTUT0SloOrANqgdURsVnSSqAlIpqBvwVOAn6YjGXfGRGLgbcDd0jqJ/dL5ZaC0TojwLf0ZmZpZfXRR8RaYG1B2Y2p5YsHqfdr4JzhNPBo+IbezKxYpj4ZC+6jNzMrlKmgl6cSNDMrkq2gd+eNmVmRTAU9+LtuzMwKZSroPbzSzKxYpoIe3EdvZlYoU0EvPIrezKxQtoLefTdmZkUyFfTgrhszs0KZC3ozM8uXuaD38Eozs3yZCnp30ZuZFctU0AMedmNmViBTQe+pBM3MimUr6P1dN2ZmRTIV9ADh8ZVmZnnKCnpJiyRtldQqaUWJ7ddJ2iLpCUkPSDojte0qSduS11WVbHxxO0by6GZmx6chg15SLbAK+CgwH7hS0vyC3R4DmiLiXcCPgFuTulOBm4DzgYXATZKmVK75xXw/b2aWr5w7+oVAa0TsiIhuYA2wJL1DRDwUEQeS1UeBGcnyZcD9EbEnIvYC9wOLKtP0Yr6hNzMrVk7QTwd2pdbbkrLBXA3cdzR1JS2T1CKppaOjo4wmDc5d9GZm+Sr6MFbSZ4Em4G+Ppl5E3BkRTRHR1NjYOJz3d9eNmVmBcoK+HZiZWp+RlOWRdDFwA7A4IrqOpm6luOvGzKxYOUG/AZgnaY6kBmAp0JzeQdIC4A5yIb87tWkdcKmkKclD2EuTshHj4ZVmZvnqhtohInolLScX0LXA6ojYLGkl0BIRzeS6ak4Cfph8J/zOiFgcEXskfY3cLwuAlRGxZ0TOBHxLb2ZWwpBBDxARa4G1BWU3ppYvPkLd1cDqY23g0RB+GGtmVihTn4zNPYx10puZpWUq6GvkO3ozs0IZC3rR76Q3M8uTqaAH6HfOm5nlyVTQ10juujEzK5CtoK/xOHozs0KZCnrhPnozs0KZCvoaTyVoZlYkU0EvyQ9jzcwKZCzo3UdvZlYoU0HvUTdmZsUyFfQCP4w1MyuQqaD3Hb2ZWbFMBb3kO3ozs0KZC3rnvJlZvrKCXtIiSVsltUpaUWL7RZJ+K6lX0hUF2/okbUpezYV1K6nGX1NsZlZkyIlHJNUCq4BLgDZgg6TmiNiS2m0n8Dng35c4xMGIOHf4TR1ajcfRm5kVKWeGqYVAa0TsAJC0BlgCHA76iHgu2dY/Am0sm/vozcyKldN1Mx3YlVpvS8rKNVZSi6RHJX38aBp3tORRN2ZmRcqaM3aYzoiIdklzgQclPRkR29M7SFoGLAOYNWvWMb9RjT8Za2ZWpJw7+nZgZmp9RlJWlohoT/67A3gYWFBinzsjoikimhobG8s9dJHcB6aOubqZWSaVE/QbgHmS5khqAJYCZY2ekTRF0phkeRpwIam+/UrzqBszs2JDBn1E9ALLgXXA08A9EbFZ0kpJiwEkvVdSG/BJ4A5Jm5PqbwdaJD0OPATcUjBap6Ik0V/Vx8FmZqNPWX30EbEWWFtQdmNqeQO5Lp3Cer8GzhlmG8vmUTdmZsUy9cnYGlW7BWZmo0+mgt5TCZqZFctU0OcmB692K8zMRpdMBX1uKkEnvZlZWraCHt/Rm5kVylTQ58bRm5lZWsaC3sMrzcwKZSro3UdvZlYsY0HvPnozs0KZCnpPDm5mVixTQZ/79konvZlZWqaC3nf0ZmbFMhX0/lIzM7NiGQt6Tw5uZlYoU0Gf+/ZKJ72ZWVqmgj7XdVPtVpiZjS5lBb2kRZK2SmqVtKLE9osk/VZSr6QrCrZdJWlb8rqqUg0vJfcw1klvZpY2ZNBLqgVWAR8F5gNXSppfsNtO4HPAXQV1pwI3AecDC4GbJE0ZfrNLq3EfvZlZkXLu6BcCrRGxIyK6gTXAkvQOEfFcRDwBFM7Yehlwf0TsiYi9wP3Aogq0uyQJ+p30ZmZ5ygn66cCu1HpbUlaOsupKWiapRVJLR0dHmYcuVlcjeh30ZmZ5RsXD2Ii4MyKaIqKpsbHxmI9TV1tDb3/hHxVmZie2coK+HZiZWp+RlJVjOHWPWn2N6OkLP5A1M0spJ+g3APMkzZHUACwFmss8/jrgUklTkoewlyZlI6KuNnc6fe6+MTM7bMigj4heYDm5gH4auCciNktaKWkxgKT3SmoDPgncIWlzUncP8DVyvyw2ACuTshFRVysA99ObmaXUlbNTRKwF1haU3Zha3kCuW6ZU3dXA6mG0sWz1NbnfWz19/Yytr30z3tLMbNQbFQ9jK2Xgjr6nz3f0ZmYDMhb0udPp7fPIGzOzAZkK+oaBO3r30ZuZHZapoK+r8R29mVmhTAV9fV3udLp7HfRmZgMyFfTjkpE2h3oc9GZmAzIV9A0Dd/R9fVVuiZnZ6JGtoK8d6Lrxw1gzswHZCvq63Kibbj+MNTM7LFNBX5/c0ff4YayZ2WGZCvqBPvqX9h2qckvMzEaPTAX9/q5eAL7y46eq3BIzs9EjU0EPqnYDzMxGnUwF/Zi6TJ2OmVlFZCoZT500ttpNMDMbdcoKekmLJG2V1CppRYntYyTdnWxfL2l2Uj5b0kFJm5LXtyrc/jynTGgYycObmR2Xhpx4RFItsAq4BGgDNkhqjogtqd2uBvZGxFmSlgJfBz6VbNseEedWttmDtvXNeBszs+NKOXf0C4HWiNgREd3AGmBJwT5LgO8myz8CPiKnrpnZqFBO0E8HdqXW25Kykvskc8x2Aqck2+ZIekzSLyR9YJjtLduB7t43663MzEa1kX4Y+yIwKyIWANcBd0k6uXAnScsktUhq6ejoqMgbb9r5WkWOY2Z2vCsn6NuBman1GUlZyX0k1QGTgFcjoisiXgWIiI3AduDswjeIiDsjoikimhobG4/+LEq4/cFtFTmOmdnxrpyg3wDMkzRHUgOwFGgu2KcZuCpZvgJ4MCJCUmPyMBdJc4F5wI7KNP3IHt2x5814GzOzUW/IoE/63JcD64CngXsiYrOklZIWJ7t9BzhFUiu5LpqBIZgXAU9I2kTuIe3nI2JEE/jffviskTy8mdlxZ8jhlQARsRZYW1B2Y2r5EPDJEvXuBe4dZhuPyoJZk9/MtzMzG/Uy9clYgA+/7a2Hl3+yqfBRgpnZiSdzQZ927ZpN1W6CmVnVZTrozcwso0H/2QtmHV6+78kXq9gSM7Pqy2TQf+Xy+YeXv/D93zJ7xc+q2Bozs+rKZNCPra8tKvt16ytVaImZWfVlMugBtqy8LG/9099ez+wVP+OfHmurUovMzKojs0E/vqH0RwS+fPfjXHLbL9jR8QadB3roPNjzJrfMzOzNpYiodhvyNDU1RUtLS0WOdbC7j7ff+H+G3G/733yMGsHeAz1M9eQlZnYckrQxIppKbsty0AP09Qdn/se1Q++YctO/mM+fLJjByePqPJmJmR0XTuigH7D5hU4uv/2RY6o7oaGW/d19rP3SB5jbOIGG2hp27T3AqZPG0lBbw4HuPmprVPIhMEDr7teZMWX8oNvNzIbLQZ947pX9fPUnT/GrbSM3AufMxgns3tfF9Cnj+DcfnMvJY+u5+rstnDdrMvd+4Q8P/4VwqKePjte7mDl1/Ii1xcxOHA76Ap0HejjY08dd65/n9gdbR/S9yjW2voZDPf2H11d9+jy6evvY/MI+nmh7jUM9/TzZ3sntVy5g4pg6JPjdS69zxXtmMO2kMVVsuZmNBg76o3Cgu5ffPv8a335kBw9vrcxsV6Pdb274CAtvfoDz50xl/bN7uOVPzmH6lHG8/6xpvLyvizUbdrJnfzefOf8M/uDUiezv6uXT317PtR85i43P7+XCM6fxh2dNA6C/P+ju62dsfS1793czrqGWjc/vZeqEBt5+WtHkYmZWIQ76Cnhk2yu878xT+N1L+/jLHz7BDZe/nfXP7uH2BzyTVaXd/Il3sm7zy/zymdwv2r/5xDmMa6jhy3c/DsDbTp3IFz90FlPG1/PNh7fzd598NwG8deIY/uHh7cyaOp671u9k4869fPMz5/He2VOZPL6e//yzp/ng2Y1cdHZuFrPevn5e7DzEdx55lr9e/I68NqR/Yb3UeYix9TVMHl/eiKzHd73Gu2ZMGvJBfvtrB5kyvn7QocBDeeDplzl/7imcNObY6lu2OOjfBPu7cpORv/pGNzOnjmP3613s3tfFOTMm0dPXT+fBHp5oe429+3v4ix8+zumTxtI0eypdvX080dbJi52HqnwGVi0NdTV09/YPveMgbv3Td/Fkeyffe/R5PjBvGss/dBb/9efPsPmFfaz+3HvZ8NweJo6tp23PAZpmT+Xhrbt54bWDPLS1g4lj6/jq5fO5eP5bmTSuns0vdHLWW07i6Rf30dMXjKuvpaGuhh9vaufksfX885aXue3P3s2vnungnx5r5xufPo+Tx9YzaXw9ABFBV28/r7zRxaknj6WuNvdRndcOdLP5hX00ThzDWyaO4TfP7uGisxuJAKn0p9n7+oOX9h1i+uRxQO651itvdDFlfAPjG2pL/iLd/fohJjTUMaHMX34RcVQj63r7+unu6z/mX84jadhBL2kR8PdALfDtiLilYPsY4H8B7wFeBT4VEc8l264Hrgb6gC9FxLojvdfxGvRvtje6ernvyRe54j0z6OsP9uzvRhJdvX0AtO89yNQJDfT0BW85eQzff3Qn4xpq+J//9zn++N2n0/F6Fxuf38tFZ0/jwjOn8f31O3nEXxNhVlXXfOhM/vKytx1T3WEFfTLn6zPAJUAbuTlkr4yILal9vgi8KyI+L2kp8ImI+JSk+cAPgIXA6cDPgbMjom+w93PQZ0N3bz8He/qYNC53p/fKG11MHlfPvkO9jG+oZd/BHrp6+5k6oYEJY+qICLbtfoOpExqYMr6BzoM9TBlfz/7uPv73o88zvqGW9581jRqJN7p6eebl1+ntD3r6+rl/y8s81d7J3MaT+Orl85nTOIGbfrKZnXv2s/dADxeeeQp1tTW0PLeHx9s6q/yTMTuy5265/JjqDTfo3wf8dURclqxfDxAR/yW1z7pkn/8nqQ54CWgkmTt2YN/0foO9n4Pe7Nj09vVTW6PDXRF9/UFtjfK210gc7Ok73LUREXS80cX4hjr2d/UyZXwDL+87xOmTxyFy3Srtrx1k+uRxdPf1s7+rj/pacainn/ENtXQe7GHi2Dp+te0Vzp8zlV9u6+CMUyZw8th6Og928/qhXnbtPci8t5zE5PH17OjYz7Ov7Oed0ydRVyPWbX6Jg919jKmvoa8/uGheI//9VztYMGsKv3ymg+deze0r4PTJ4/jpEy/y7pmT2fby6xzozt0v1taIvv7f59j4hlrGN9Ty6v5uRlnP9JBuveJd/FnTzGOqe6SgL6ejaTqwK7XeBpw/2D4R0SupEzglKX+0oO70Eg1cBiwDmDVrVuFmMyvDQH/4gHTIp7en+68l8ZaJYwEOP9Qt/GzHjCm59TF1tYypy/WlJ1UOH+tj55wGwCcWzDhiG992av7IqwuT0VppH02OVco3Pn3Ew9sgRsWXmkXEnRHRFBFNjY2N1W6OmVmmlBP07UD6b4kZSVnJfZKum0nkHsqWU9fMzEZQOUG/AZgnaY6kBmAp0FywTzNwVbJ8BfBg5Dr/m4GlksZImgPMA35TmaabmVk5huyjT/rclwPryA2vXB0RmyWtBFoiohn4DvA9Sa3AHnK/DEj2uwfYAvQC1xxpxI2ZmVWePzBlZpYBRxp1MyoexpqZ2chx0JuZZZyD3sws40ZdH72kDuD5YRxiGnCifWnLiXbOJ9r5gs/5RDGccz4jIkp+EGnUBf1wSWoZ7IFEVp1o53yinS/4nE8UI3XO7roxM8s4B72ZWcZlMejvrHYDquBEO+cT7XzB53yiGJFzzlwfvZmZ5cviHb2ZmaU46M3MMi4zQS9pkaStklolrah2e4ZD0kxJD0naImmzpGuT8qmS7pe0LfnvlKRckm5Pzv0JSeeljnVVsv82SVcN9p6jgaRaSY9J+mmyPkfS+uS87k6+PZXk21DvTsrXS5qdOsb1SflWSZdV6VTKImmypB9J+p2kpyW97wS4xl9O/k0/JekHksZm7TpLWi1pt6SnUmUVu66S3iPpyaTO7VIZs5tHxHH/IvetmtuBuUAD8Dgwv9rtGsb5nAaclyxPJDdn73zgVmBFUr4C+Hqy/DHgPkDABcD6pHwqsCP575RkeUq1z+8I530dcBfw02T9HmBpsvwt4AvJ8heBbyXLS4G7k+X5ybUfA8xJ/k3UVvu8jnC+3wX+dbLcAEzO8jUmN7vcs8C41PX9XNauM3ARcB7wVKqsYteV3Fe9X5DUuQ/46JBtqvYPpUI/2PcB61Lr1wPXV7tdFTy/n5CbnH0rcFpSdhqwNVm+g9yE7QP7b022XwnckSrP2280vchNSvMA8GHgp8k/4leAusJrTO4rs9+XLNcl+6nwuqf3G20vcpPzPEsyIKLw2mX0Gg9MOTo1uW4/BS7L4nUGZhcEfUWua7Ltd6nyvP0Ge2Wl66bUvLZFc9Mej5I/VxcA64G3RsSLyaaXgLcmy4Od//H0c/lvwH8A+pP1U4DXIqI3WU+3PW+OYiA9R/Hxcr5zgA7gfyTdVd+WNIEMX+OIaAf+DtgJvEjuum0k29d5QKWu6/RkubD8iLIS9Jkk6STgXuDfRcS+9LbI/TrPxNhYSX8M7I6IjdVuy5uojtyf99+MiAXAfnJ/0h+WpWsMkPRLLyH3S+50YAKwqKqNqoJqXNesBH3m5qaVVE8u5L8fEf+YFL8s6bRk+2nA7qR8sPM/Xn4uFwKLJT0HrCHXffP3wGTl5iCG/LZnYY7iNqAtItYn6z8iF/xZvcYAFwPPRkRHRPQA/0ju2mf5Og+o1HVtT5YLy48oK0Ffzry2x43kKfp3gKcj4rbUpvTcvFeR67sfKP/z5An+BUBn8mfiOuBSSVOSu6lLk7JRJSKuj4gZETGb3LV7MCI+AzxEbg5iKD7f43qO4oh4Cdgl6Q+Soo+Qm3Izk9c4sRO4QNL45N/4wDln9jqnVOS6Jtv2Sbog+Rn+eepYg6v2Q4sKPvz4GLnRKduBG6rdnmGey/vJ/Wn3BLApeX2MXP/kA8A24OfA1GR/AauSc38SaEod618BrcnrX1b73Mo49z/i96Nu5pL7H7gV+CEwJikfm6y3JtvnpurfkPwctlLGaIQqn+u5QEtynX9MbnRFpq8x8J+A3wFPAd8jN3ImU9cZ+AG5ZxA95P5yu7qS1xVoSn5+24FvUPBAv9TLX4FgZpZxWem6MTOzQTjozcwyzkFvZpZxDnozs4xz0JuZZZyD3sws4xz0ZmYZ9/8BVx8cmSjGiaoAAAAASUVORK5CYII=\n", 778 | "text/plain": [ 779 | "
" 780 | ] 781 | }, 782 | "metadata": { 783 | "needs_background": "light" 784 | }, 785 | "output_type": "display_data" 786 | }, 787 | { 788 | "name": "stdout", 789 | "output_type": "stream", 790 | "text": [ 791 | "0.00331158097833395\n" 792 | ] 793 | } 794 | ], 795 | "source": [ 796 | "import time\n", 797 | "fill_buffer(samples=random_collect)\n", 798 | "start_time = time.time()\n", 799 | "agent = pretrain_ofenet(agent, epochs=random_collect)\n", 800 | "end_time = time.time()\n" 801 | ] 802 | }, 803 | { 804 | "cell_type": "code", 805 | "execution_count": 77, 806 | "id": "racial-metabolism", 807 | "metadata": {}, 808 | "outputs": [ 809 | { 810 | "name": "stdout", 811 | "output_type": "stream", 812 | "text": [ 813 | "pre-training took: 3.44853032430013\n" 814 | ] 815 | } 816 | ], 817 | "source": [ 818 | "print(\"pre-training took: {}\".format((end_time-start_time)/60))" 819 | ] 820 | }, 821 | { 822 | "cell_type": "markdown", 823 | "id": "owned-wilson", 824 | "metadata": {}, 825 | "source": [ 826 | "# paper achieves loss of ~ 0.005" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 105, 832 | "id": "finite-hungary", 833 | "metadata": {}, 834 | "outputs": [], 835 | "source": [ 836 | "torch.save(agent.ofenet.state_dict(), \"ofenet_params_cheetah.pth\")" 837 | ] 838 | }, 839 | { 840 | "cell_type": "code", 841 | "execution_count": 106, 842 | "id": "electoral-bullet", 843 | "metadata": {}, 844 | "outputs": [ 845 | { 846 | "data": { 847 | "text/plain": [ 848 | "OFENet(\n", 849 | " (state_layer_block): Sequential(\n", 850 | " (0): DenseNetBlock(\n", 851 | " (layer): Linear(in_features=26, out_features=30, bias=True)\n", 852 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 853 | " (act): SiLU()\n", 854 | " )\n", 855 | " (1): DenseNetBlock(\n", 856 | " (layer): Linear(in_features=56, out_features=30, bias=True)\n", 857 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 858 | " (act): SiLU()\n", 859 | " )\n", 860 | " (2): DenseNetBlock(\n", 861 | " (layer): Linear(in_features=86, out_features=30, bias=True)\n", 862 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 863 | " (act): SiLU()\n", 864 | " )\n", 865 | " (3): DenseNetBlock(\n", 866 | " (layer): Linear(in_features=116, out_features=30, bias=True)\n", 867 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 868 | " (act): SiLU()\n", 869 | " )\n", 870 | " (4): DenseNetBlock(\n", 871 | " (layer): Linear(in_features=146, out_features=30, bias=True)\n", 872 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 873 | " (act): SiLU()\n", 874 | " )\n", 875 | " (5): DenseNetBlock(\n", 876 | " (layer): Linear(in_features=176, out_features=30, bias=True)\n", 877 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 878 | " (act): SiLU()\n", 879 | " )\n", 880 | " (6): DenseNetBlock(\n", 881 | " (layer): Linear(in_features=206, out_features=30, bias=True)\n", 882 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 883 | " (act): SiLU()\n", 884 | " )\n", 885 | " (7): DenseNetBlock(\n", 886 | " (layer): Linear(in_features=236, out_features=30, bias=True)\n", 887 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 888 | " (act): SiLU()\n", 889 | " )\n", 890 | " )\n", 891 | " (action_layer_block): Sequential(\n", 892 | " (0): DenseNetBlock(\n", 893 | " (layer): Linear(in_features=272, out_features=30, bias=True)\n", 894 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 895 | " (act): SiLU()\n", 896 | " )\n", 897 | " (1): DenseNetBlock(\n", 898 | " (layer): Linear(in_features=302, out_features=30, bias=True)\n", 899 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 900 | " (act): SiLU()\n", 901 | " )\n", 902 | " (2): DenseNetBlock(\n", 903 | " (layer): Linear(in_features=332, out_features=30, bias=True)\n", 904 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 905 | " (act): SiLU()\n", 906 | " )\n", 907 | " (3): DenseNetBlock(\n", 908 | " (layer): Linear(in_features=362, out_features=30, bias=True)\n", 909 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 910 | " (act): SiLU()\n", 911 | " )\n", 912 | " (4): DenseNetBlock(\n", 913 | " (layer): Linear(in_features=392, out_features=30, bias=True)\n", 914 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 915 | " (act): SiLU()\n", 916 | " )\n", 917 | " (5): DenseNetBlock(\n", 918 | " (layer): Linear(in_features=422, out_features=30, bias=True)\n", 919 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 920 | " (act): SiLU()\n", 921 | " )\n", 922 | " (6): DenseNetBlock(\n", 923 | " (layer): Linear(in_features=452, out_features=30, bias=True)\n", 924 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 925 | " (act): SiLU()\n", 926 | " )\n", 927 | " (7): DenseNetBlock(\n", 928 | " (layer): Linear(in_features=482, out_features=30, bias=True)\n", 929 | " (batch_norm): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 930 | " (act): SiLU()\n", 931 | " )\n", 932 | " )\n", 933 | " (pred_layer): Linear(in_features=512, out_features=17, bias=True)\n", 934 | ")" 935 | ] 936 | }, 937 | "execution_count": 106, 938 | "metadata": {}, 939 | "output_type": "execute_result" 940 | } 941 | ], 942 | "source": [ 943 | "agent.ofenet.load_state_dict(torch.load(\"ofenet_params_cheetah.pth\"))\n", 944 | "agent.ofenet.to(device)" 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": 107, 950 | "id": "manufactured-adobe", 951 | "metadata": {}, 952 | "outputs": [ 953 | { 954 | "name": "stdout", 955 | "output_type": "stream", 956 | "text": [ 957 | "Episode 1 Frame: [11000/1000000] Reward: -1095.94 Average100 Score: -1095.94 ofenet_loss: 0.003, a_loss: 12.023, c_loss: 9.186\n", 958 | "Episode 2 Frame: [12000/1000000] Reward: -1154.81 Average100 Score: -1125.38 ofenet_loss: 0.003, a_loss: 30.677, c_loss: 14.340\n", 959 | "Episode 3 Frame: [13000/1000000] Reward: -1317.20 Average100 Score: -1189.32 ofenet_loss: 0.003, a_loss: 52.550, c_loss: 18.166\n", 960 | "Episode 4 Frame: [14000/1000000] Reward: -1231.73 Average100 Score: -1199.92 ofenet_loss: 0.004, a_loss: 79.028, c_loss: 19.346\n", 961 | "Episode 5 Frame: [15000/1000000] Reward: -1371.70 Average100 Score: -1234.28 ofenet_loss: 0.003, a_loss: 107.285, c_loss: 26.071\n", 962 | "Episode 6 Frame: [16000/1000000] Reward: -1027.28 Average100 Score: -1199.78 ofenet_loss: 0.004, a_loss: 130.426, c_loss: 19.074\n", 963 | "Episode 7 Frame: [17000/1000000] Reward: -1238.18 Average100 Score: -1205.26 ofenet_loss: 0.004, a_loss: 156.055, c_loss: 23.853\n" 964 | ] 965 | }, 966 | { 967 | "ename": "KeyboardInterrupt", 968 | "evalue": "", 969 | "output_type": "error", 970 | "traceback": [ 971 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 972 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 973 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mt0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mscores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_collect\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mt1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training took {} min!\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mt0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m60\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 974 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(steps, precollected, print_every)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction_v\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mnext_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mofenet_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mscore\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 975 | "\u001b[0;32m\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, state, action, reward, next_state, done)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mofenet_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mofenet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_ofenet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mofenet_optim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0mexperiences\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mactor_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcritic1_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexperiences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgamma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mofenet_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactor_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcritic1_loss\u001b[0m \u001b[0;31m# future ofenet_loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 976 | "\u001b[0;32m\u001b[0m in \u001b[0;36mlearn\u001b[0;34m(self, step, experiences, gamma)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;31m# Update critic\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mQ_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;31m# soft update of the targets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 977 | "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 978 | "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 130\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 131\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 979 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 980 | ] 981 | } 982 | ], 983 | "source": [ 984 | "t0 = time.time()\n", 985 | "scores = train(max_steps, random_collect)\n", 986 | "t1 = time.time()\n", 987 | "env.close()\n", 988 | "print(\"training took {} min!\".format((t1-t0)/60))" 989 | ] 990 | }, 991 | { 992 | "cell_type": "markdown", 993 | "id": "fantastic-assets", 994 | "metadata": {}, 995 | "source": [ 996 | "# Actor loss is the problem!!" 997 | ] 998 | } 999 | ], 1000 | "metadata": { 1001 | "kernelspec": { 1002 | "display_name": "Python 3", 1003 | "language": "python", 1004 | "name": "python3" 1005 | }, 1006 | "language_info": { 1007 | "codemirror_mode": { 1008 | "name": "ipython", 1009 | "version": 3 1010 | }, 1011 | "file_extension": ".py", 1012 | "mimetype": "text/x-python", 1013 | "name": "python", 1014 | "nbconvert_exporter": "python", 1015 | "pygments_lexer": "ipython3", 1016 | "version": "3.7.4" 1017 | } 1018 | }, 1019 | "nbformat": 4, 1020 | "nbformat_minor": 5 1021 | } 1022 | --------------------------------------------------------------------------------