├── .gitignore ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── agent.py ├── env.py ├── environment.yml ├── main.py ├── memory.py ├── model.py ├── requirements.txt ├── results └── README.md └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | 92 | # Saved models 93 | *.pth 94 | # Plots 95 | *.html 96 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributing Guidelines 2 | ======================= 3 | 4 | This project is an open source version of code that I use in my work, released (under this [license](LICENSE.md)) for the benefit of others. It is developed and maintained in my own time, and hence I cannot guarantee responses. While contributions are welcome, please keep the following points in mind: 5 | 6 | - Please be civil to myself and other contributors. 7 | - Do raise issues for bugs and other implementation problems/improvements. 8 | - If you have studied the repo and have a question, raise an issue (it could possibly be a bug). 9 | - Bug fixes and small improvements are very welcome. 10 | - Raise an issue before developing a large contribution to a) discuss b) see if I would be willing to merge it (you're obviously free to keep your own fork) c) prevent overlap with others. 11 | - All code contributions should adhere to the existing style (e.g., 2-space indent, no max line length, etc.). 12 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kai Arulkumaran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Rainbow 2 | ======= 3 | [![MIT License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md) 4 | 5 | Rainbow: Combining Improvements in Deep Reinforcement Learning [[1]](#references). 6 | 7 | Results and pretrained models can be found in the [releases](https://github.com/Kaixhin/Rainbow/releases). 8 | 9 | - [x] DQN [[2]](#references) 10 | - [x] Double DQN [[3]](#references) 11 | - [x] Prioritised Experience Replay [[4]](#references) 12 | - [x] Dueling Network Architecture [[5]](#references) 13 | - [x] Multi-step Returns [[6]](#references) 14 | - [x] Distributional RL [[7]](#references) 15 | - [x] Noisy Nets [[8]](#references) 16 | 17 | Run the original Rainbow with the default arguments: 18 | 19 | ``` 20 | python main.py 21 | ``` 22 | 23 | Data-efficient Rainbow [[9]](#references) can be run using the following options (note that the "unbounded" memory is implemented here in practice by manually setting the memory capacity to be the same as the maximum number of timesteps): 24 | 25 | ``` 26 | python main.py --target-update 2000 \ 27 | --T-max 100000 \ 28 | --learn-start 1600 \ 29 | --memory-capacity 100000 \ 30 | --replay-frequency 1 \ 31 | --multi-step 20 \ 32 | --architecture data-efficient \ 33 | --hidden-size 256 \ 34 | --learning-rate 0.0001 \ 35 | --evaluation-interval 10000 36 | ``` 37 | 38 | Note that pretrained models from the [`1.3`](https://github.com/Kaixhin/Rainbow/releases/tag/1.3) release used a (slightly) incorrect network architecture. To use these, change the padding in the first convolutional layer from 0 to 1 (DeepMind uses "valid" (no) padding). 39 | 40 | Requirements 41 | ------------ 42 | 43 | - [atari-py](https://github.com/openai/atari-py) 44 | - [OpenCV Python](https://pypi.python.org/pypi/opencv-python) 45 | - [Plotly](https://plot.ly/) 46 | - [PyTorch](http://pytorch.org/) 47 | 48 | To install all dependencies with Anaconda run `conda env create -f environment.yml` and use `source activate rainbow` to activate the environment. 49 | 50 | Available Atari games can be found in the [`atari-py` ROMs folder](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms). 51 | 52 | Acknowledgements 53 | ---------------- 54 | 55 | - [@floringogianu](https://github.com/floringogianu) for [categorical-dqn](https://github.com/floringogianu/categorical-dqn) 56 | - [@jvmancuso](https://github.com/jvmancuso) for [Noisy layer](https://github.com/pytorch/pytorch/pull/2103) 57 | - [@jaara](https://github.com/jaara) for [AI-blog](https://github.com/jaara/AI-blog) 58 | - [@openai](https://github.com/openai) for [Baselines](https://github.com/openai/baselines) 59 | - [@mtthss](https://github.com/mtthss) for [implementation details](https://github.com/Kaixhin/Rainbow/wiki/Matteo's-Notes) 60 | 61 | References 62 | ---------- 63 | 64 | [1] [Rainbow: Combining Improvements in Deep Reinforcement Learning](https://arxiv.org/abs/1710.02298) 65 | [2] [Playing Atari with Deep Reinforcement Learning](http://arxiv.org/abs/1312.5602) 66 | [3] [Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461) 67 | [4] [Prioritized Experience Replay](http://arxiv.org/abs/1511.05952) 68 | [5] [Dueling Network Architectures for Deep Reinforcement Learning](http://arxiv.org/abs/1511.06581) 69 | [6] [Reinforcement Learning: An Introduction](http://www.incompleteideas.net/sutton/book/ebook/the-book.html) 70 | [7] [A Distributional Perspective on Reinforcement Learning](https://arxiv.org/abs/1707.06887) 71 | [8] [Noisy Networks for Exploration](https://arxiv.org/abs/1706.10295) 72 | [9] [When to Use Parametric Models in Reinforcement Learning?](https://arxiv.org/abs/1906.05243) 73 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import numpy as np 5 | import torch 6 | from torch import optim 7 | from torch.nn.utils import clip_grad_norm_ 8 | 9 | from model import DQN 10 | 11 | 12 | class Agent(): 13 | def __init__(self, args, env): 14 | self.action_space = env.action_space() 15 | self.atoms = args.atoms 16 | self.Vmin = args.V_min 17 | self.Vmax = args.V_max 18 | self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device) # Support (range) of z 19 | self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) 20 | self.batch_size = args.batch_size 21 | self.n = args.multi_step 22 | self.discount = args.discount 23 | self.norm_clip = args.norm_clip 24 | 25 | self.online_net = DQN(args, self.action_space).to(device=args.device) 26 | if args.model: # Load pretrained model if provided 27 | if os.path.isfile(args.model): 28 | state_dict = torch.load(args.model, map_location='cpu') # Always load tensors onto CPU by default, will shift to GPU if necessary 29 | if 'conv1.weight' in state_dict.keys(): 30 | for old_key, new_key in (('conv1.weight', 'convs.0.weight'), ('conv1.bias', 'convs.0.bias'), ('conv2.weight', 'convs.2.weight'), ('conv2.bias', 'convs.2.bias'), ('conv3.weight', 'convs.4.weight'), ('conv3.bias', 'convs.4.bias')): 31 | state_dict[new_key] = state_dict[old_key] # Re-map state dict for old pretrained models 32 | del state_dict[old_key] # Delete old keys for strict load_state_dict 33 | self.online_net.load_state_dict(state_dict) 34 | print("Loading pretrained model: " + args.model) 35 | else: # Raise error if incorrect model path provided 36 | raise FileNotFoundError(args.model) 37 | 38 | self.online_net.train() 39 | 40 | self.target_net = DQN(args, self.action_space).to(device=args.device) 41 | self.update_target_net() 42 | self.target_net.train() 43 | for param in self.target_net.parameters(): 44 | param.requires_grad = False 45 | 46 | self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps) 47 | 48 | # Resets noisy weights in all linear layers (of online net only) 49 | def reset_noise(self): 50 | self.online_net.reset_noise() 51 | 52 | # Acts based on single state (no batch) 53 | def act(self, state): 54 | with torch.no_grad(): 55 | return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item() 56 | 57 | # Acts with an ε-greedy policy (used for evaluation only) 58 | def act_e_greedy(self, state, epsilon=0.001): # High ε can reduce evaluation scores drastically 59 | return np.random.randint(0, self.action_space) if np.random.random() < epsilon else self.act(state) 60 | 61 | def learn(self, mem): 62 | # Sample transitions 63 | idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) 64 | 65 | # Calculate current state probabilities (online network noise already sampled) 66 | log_ps = self.online_net(states, log=True) # Log probabilities log p(s_t, ·; θonline) 67 | log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) 68 | 69 | with torch.no_grad(): 70 | # Calculate nth next state probabilities 71 | pns = self.online_net(next_states) # Probabilities p(s_t+n, ·; θonline) 72 | dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) 73 | argmax_indices_ns = dns.sum(2).argmax(1) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] 74 | self.target_net.reset_noise() # Sample new target net noise 75 | pns = self.target_net(next_states) # Probabilities p(s_t+n, ·; θtarget) 76 | pns_a = pns[range(self.batch_size), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) 77 | 78 | # Compute Tz (Bellman operator T applied to z) 79 | Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0) # Tz = R^n + (γ^n)z (accounting for terminal states) 80 | Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values 81 | # Compute L2 projection of Tz onto fixed support z 82 | b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz 83 | l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) 84 | # Fix disappearing probability mass when l = b = u (b is int) 85 | l[(u > 0) * (l == u)] -= 1 86 | u[(l < (self.atoms - 1)) * (l == u)] += 1 87 | 88 | # Distribute probability of Tz 89 | m = states.new_zeros(self.batch_size, self.atoms) 90 | offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions) 91 | m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) 92 | m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) 93 | 94 | loss = -torch.sum(m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) 95 | self.online_net.zero_grad() 96 | (weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss 97 | clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm 98 | self.optimiser.step() 99 | 100 | mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions 101 | 102 | def update_target_net(self): 103 | self.target_net.load_state_dict(self.online_net.state_dict()) 104 | 105 | # Save model parameters on current device (don't move model between devices) 106 | def save(self, path, name='model.pth'): 107 | torch.save(self.online_net.state_dict(), os.path.join(path, name)) 108 | 109 | # Evaluates Q-value based on single state (no batch) 110 | def evaluate_q(self, state): 111 | with torch.no_grad(): 112 | return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item() 113 | 114 | def train(self): 115 | self.online_net.train() 116 | 117 | def eval(self): 118 | self.online_net.eval() 119 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import deque 3 | import random 4 | import atari_py 5 | import cv2 6 | import torch 7 | 8 | 9 | class Env(): 10 | def __init__(self, args): 11 | self.device = args.device 12 | self.ale = atari_py.ALEInterface() 13 | self.ale.setInt('random_seed', args.seed) 14 | self.ale.setInt('max_num_frames_per_episode', args.max_episode_length) 15 | self.ale.setFloat('repeat_action_probability', 0) # Disable sticky actions 16 | self.ale.setInt('frame_skip', 0) 17 | self.ale.setBool('color_averaging', False) 18 | self.ale.loadROM(atari_py.get_game_path(args.game)) # ROM loading must be done after setting options 19 | actions = self.ale.getMinimalActionSet() 20 | self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions)) 21 | self.lives = 0 # Life counter (used in DeepMind training) 22 | self.life_termination = False # Used to check if resetting only from loss of life 23 | self.window = args.history_length # Number of frames to concatenate 24 | self.state_buffer = deque([], maxlen=args.history_length) 25 | self.training = True # Consistent with model training mode 26 | 27 | def _get_state(self): 28 | state = cv2.resize(self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR) 29 | return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255) 30 | 31 | def _reset_buffer(self): 32 | for _ in range(self.window): 33 | self.state_buffer.append(torch.zeros(84, 84, device=self.device)) 34 | 35 | def reset(self): 36 | if self.life_termination: 37 | self.life_termination = False # Reset flag 38 | self.ale.act(0) # Use a no-op after loss of life 39 | else: 40 | # Reset internals 41 | self._reset_buffer() 42 | self.ale.reset_game() 43 | # Perform up to 30 random no-ops before starting 44 | for _ in range(random.randrange(30)): 45 | self.ale.act(0) # Assumes raw action 0 is always no-op 46 | if self.ale.game_over(): 47 | self.ale.reset_game() 48 | # Process and return "initial" state 49 | observation = self._get_state() 50 | self.state_buffer.append(observation) 51 | self.lives = self.ale.lives() 52 | return torch.stack(list(self.state_buffer), 0) 53 | 54 | def step(self, action): 55 | # Repeat action 4 times, max pool over last 2 frames 56 | frame_buffer = torch.zeros(2, 84, 84, device=self.device) 57 | reward, done = 0, False 58 | for t in range(4): 59 | reward += self.ale.act(self.actions.get(action)) 60 | if t == 2: 61 | frame_buffer[0] = self._get_state() 62 | elif t == 3: 63 | frame_buffer[1] = self._get_state() 64 | done = self.ale.game_over() 65 | if done: 66 | break 67 | observation = frame_buffer.max(0)[0] 68 | self.state_buffer.append(observation) 69 | # Detect loss of life as terminal in training mode 70 | if self.training: 71 | lives = self.ale.lives() 72 | if lives < self.lives and lives > 0: # Lives > 0 for Q*bert 73 | self.life_termination = not done # Only set flag when not truly done 74 | done = True 75 | self.lives = lives 76 | # Return state, reward, done 77 | return torch.stack(list(self.state_buffer), 0), reward, done 78 | 79 | # Uses loss of life as terminal signal 80 | def train(self): 81 | self.training = True 82 | 83 | # Uses standard terminal signal 84 | def eval(self): 85 | self.training = False 86 | 87 | def action_space(self): 88 | return len(self.actions) 89 | 90 | def render(self): 91 | cv2.imshow('screen', self.ale.getScreenRGB()[:, :, ::-1]) 92 | cv2.waitKey(1) 93 | 94 | def close(self): 95 | cv2.destroyAllWindows() 96 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: rainbow 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - plotly 6 | - pytorch 7 | - tqdm 8 | - pip: 9 | - atari-py 10 | - opencv-python 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import argparse 4 | import bz2 5 | from datetime import datetime 6 | import os 7 | import pickle 8 | 9 | import atari_py 10 | import numpy as np 11 | import torch 12 | from tqdm import trange 13 | 14 | from agent import Agent 15 | from env import Env 16 | from memory import ReplayMemory 17 | from test import test 18 | 19 | 20 | # Note that hyperparameters may originally be reported in ATARI game frames instead of agent steps 21 | parser = argparse.ArgumentParser(description='Rainbow') 22 | parser.add_argument('--id', type=str, default='default', help='Experiment ID') 23 | parser.add_argument('--seed', type=int, default=123, help='Random seed') 24 | parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') 25 | parser.add_argument('--game', type=str, default='space_invaders', choices=atari_py.list_games(), help='ATARI game') 26 | parser.add_argument('--T-max', type=int, default=int(50e6), metavar='STEPS', help='Number of training steps (4x number of frames)') 27 | parser.add_argument('--max-episode-length', type=int, default=int(108e3), metavar='LENGTH', help='Max episode length in game frames (0 to disable)') 28 | parser.add_argument('--history-length', type=int, default=4, metavar='T', help='Number of consecutive states processed') 29 | parser.add_argument('--architecture', type=str, default='canonical', choices=['canonical', 'data-efficient'], metavar='ARCH', help='Network architecture') 30 | parser.add_argument('--hidden-size', type=int, default=512, metavar='SIZE', help='Network hidden size') 31 | parser.add_argument('--noisy-std', type=float, default=0.1, metavar='σ', help='Initial standard deviation of noisy linear layers') 32 | parser.add_argument('--atoms', type=int, default=51, metavar='C', help='Discretised size of value distribution') 33 | parser.add_argument('--V-min', type=float, default=-10, metavar='V', help='Minimum of value distribution support') 34 | parser.add_argument('--V-max', type=float, default=10, metavar='V', help='Maximum of value distribution support') 35 | parser.add_argument('--model', type=str, metavar='PARAMS', help='Pretrained model (state dict)') 36 | parser.add_argument('--memory-capacity', type=int, default=int(1e6), metavar='CAPACITY', help='Experience replay memory capacity') 37 | parser.add_argument('--replay-frequency', type=int, default=4, metavar='k', help='Frequency of sampling from memory') 38 | parser.add_argument('--priority-exponent', type=float, default=0.5, metavar='ω', help='Prioritised experience replay exponent (originally denoted α)') 39 | parser.add_argument('--priority-weight', type=float, default=0.4, metavar='β', help='Initial prioritised experience replay importance sampling weight') 40 | parser.add_argument('--multi-step', type=int, default=3, metavar='n', help='Number of steps for multi-step return') 41 | parser.add_argument('--discount', type=float, default=0.99, metavar='γ', help='Discount factor') 42 | parser.add_argument('--target-update', type=int, default=int(8e3), metavar='τ', help='Number of steps after which to update target network') 43 | parser.add_argument('--reward-clip', type=int, default=1, metavar='VALUE', help='Reward clipping (0 to disable)') 44 | parser.add_argument('--learning-rate', type=float, default=0.0000625, metavar='η', help='Learning rate') 45 | parser.add_argument('--adam-eps', type=float, default=1.5e-4, metavar='ε', help='Adam epsilon') 46 | parser.add_argument('--batch-size', type=int, default=32, metavar='SIZE', help='Batch size') 47 | parser.add_argument('--norm-clip', type=float, default=10, metavar='NORM', help='Max L2 norm for gradient clipping') 48 | parser.add_argument('--learn-start', type=int, default=int(20e3), metavar='STEPS', help='Number of steps before starting training') 49 | parser.add_argument('--evaluate', action='store_true', help='Evaluate only') 50 | parser.add_argument('--evaluation-interval', type=int, default=100000, metavar='STEPS', help='Number of training steps between evaluations') 51 | parser.add_argument('--evaluation-episodes', type=int, default=10, metavar='N', help='Number of evaluation episodes to average over') 52 | # TODO: Note that DeepMind's evaluation method is running the latest agent for 500K frames ever every 1M steps 53 | parser.add_argument('--evaluation-size', type=int, default=500, metavar='N', help='Number of transitions to use for validating Q') 54 | parser.add_argument('--render', action='store_true', help='Display screen (testing only)') 55 | parser.add_argument('--enable-cudnn', action='store_true', help='Enable cuDNN (faster but nondeterministic)') 56 | parser.add_argument('--checkpoint-interval', default=0, help='How often to checkpoint the model, defaults to 0 (never checkpoint)') 57 | parser.add_argument('--memory', help='Path to save/load the memory from') 58 | parser.add_argument('--disable-bzip-memory', action='store_true', help='Don\'t zip the memory file. Not recommended (zipping is a bit slower and much, much smaller)') 59 | 60 | # Setup 61 | args = parser.parse_args() 62 | 63 | print(' ' * 26 + 'Options') 64 | for k, v in vars(args).items(): 65 | print(' ' * 26 + k + ': ' + str(v)) 66 | results_dir = os.path.join('results', args.id) 67 | if not os.path.exists(results_dir): 68 | os.makedirs(results_dir) 69 | metrics = {'steps': [], 'rewards': [], 'Qs': [], 'best_avg_reward': -float('inf')} 70 | np.random.seed(args.seed) 71 | torch.manual_seed(np.random.randint(1, 10000)) 72 | if torch.cuda.is_available() and not args.disable_cuda: 73 | args.device = torch.device('cuda') 74 | torch.cuda.manual_seed(np.random.randint(1, 10000)) 75 | torch.backends.cudnn.enabled = args.enable_cudnn 76 | else: 77 | args.device = torch.device('cpu') 78 | 79 | 80 | # Simple ISO 8601 timestamped logger 81 | def log(s): 82 | print('[' + str(datetime.now().strftime('%Y-%m-%dT%H:%M:%S')) + '] ' + s) 83 | 84 | 85 | def load_memory(memory_path, disable_bzip): 86 | if disable_bzip: 87 | with open(memory_path, 'rb') as pickle_file: 88 | return pickle.load(pickle_file) 89 | else: 90 | with bz2.open(memory_path, 'rb') as zipped_pickle_file: 91 | return pickle.load(zipped_pickle_file) 92 | 93 | 94 | def save_memory(memory, memory_path, disable_bzip): 95 | if disable_bzip: 96 | with open(memory_path, 'wb') as pickle_file: 97 | pickle.dump(memory, pickle_file) 98 | else: 99 | with bz2.open(memory_path, 'wb') as zipped_pickle_file: 100 | pickle.dump(memory, zipped_pickle_file) 101 | 102 | 103 | # Environment 104 | env = Env(args) 105 | env.train() 106 | action_space = env.action_space() 107 | 108 | # Agent 109 | dqn = Agent(args, env) 110 | 111 | # If a model is provided, and evaluate is false, presumably we want to resume, so try to load memory 112 | if args.model is not None and not args.evaluate: 113 | if not args.memory: 114 | raise ValueError('Cannot resume training without memory save path. Aborting...') 115 | elif not os.path.exists(args.memory): 116 | raise ValueError('Could not find memory file at {path}. Aborting...'.format(path=args.memory)) 117 | 118 | mem = load_memory(args.memory, args.disable_bzip_memory) 119 | 120 | else: 121 | mem = ReplayMemory(args, args.memory_capacity) 122 | 123 | priority_weight_increase = (1 - args.priority_weight) / (args.T_max - args.learn_start) 124 | 125 | 126 | # Construct validation memory 127 | val_mem = ReplayMemory(args, args.evaluation_size) 128 | T, done = 0, True 129 | while T < args.evaluation_size: 130 | if done: 131 | state = env.reset() 132 | 133 | next_state, _, done = env.step(np.random.randint(0, action_space)) 134 | val_mem.append(state, -1, 0.0, done) 135 | state = next_state 136 | T += 1 137 | 138 | if args.evaluate: 139 | dqn.eval() # Set DQN (online network) to evaluation mode 140 | avg_reward, avg_Q = test(args, 0, dqn, val_mem, metrics, results_dir, evaluate=True) # Test 141 | print('Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(avg_Q)) 142 | else: 143 | # Training loop 144 | dqn.train() 145 | done = True 146 | for T in trange(1, args.T_max + 1): 147 | if done: 148 | state = env.reset() 149 | 150 | if T % args.replay_frequency == 0: 151 | dqn.reset_noise() # Draw a new set of noisy weights 152 | 153 | action = dqn.act(state) # Choose an action greedily (with noisy weights) 154 | next_state, reward, done = env.step(action) # Step 155 | if args.reward_clip > 0: 156 | reward = max(min(reward, args.reward_clip), -args.reward_clip) # Clip rewards 157 | mem.append(state, action, reward, done) # Append transition to memory 158 | 159 | # Train and test 160 | if T >= args.learn_start: 161 | mem.priority_weight = min(mem.priority_weight + priority_weight_increase, 1) # Anneal importance sampling weight β to 1 162 | 163 | if T % args.replay_frequency == 0: 164 | dqn.learn(mem) # Train with n-step distributional double-Q learning 165 | 166 | if T % args.evaluation_interval == 0: 167 | dqn.eval() # Set DQN (online network) to evaluation mode 168 | avg_reward, avg_Q = test(args, T, dqn, val_mem, metrics, results_dir) # Test 169 | log('T = ' + str(T) + ' / ' + str(args.T_max) + ' | Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(avg_Q)) 170 | dqn.train() # Set DQN (online network) back to training mode 171 | 172 | # If memory path provided, save it 173 | if args.memory is not None: 174 | save_memory(mem, args.memory, args.disable_bzip_memory) 175 | 176 | # Update target network 177 | if T % args.target_update == 0: 178 | dqn.update_target_net() 179 | 180 | # Checkpoint the network 181 | if (args.checkpoint_interval != 0) and (T % args.checkpoint_interval == 0): 182 | dqn.save(results_dir, 'checkpoint.pth') 183 | 184 | state = next_state 185 | 186 | env.close() 187 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import numpy as np 4 | import torch 5 | 6 | 7 | Transition_dtype = np.dtype([('timestep', np.int32), ('state', np.uint8, (84, 84)), ('action', np.int32), ('reward', np.float32), ('nonterminal', np.bool_)]) 8 | blank_trans = (0, np.zeros((84, 84), dtype=np.uint8), 0, 0.0, False) 9 | 10 | 11 | # Segment tree data structure where parent node values are sum/max of children node values 12 | class SegmentTree(): 13 | def __init__(self, size): 14 | self.index = 0 15 | self.size = size 16 | self.full = False # Used to track actual capacity 17 | self.tree_start = 2**(size-1).bit_length()-1 # Put all used node leaves on last tree level 18 | self.sum_tree = np.zeros((self.tree_start + self.size,), dtype=np.float32) 19 | self.data = np.array([blank_trans] * size, dtype=Transition_dtype) # Build structured array 20 | self.max = 1 # Initial max value to return (1 = 1^ω) 21 | 22 | # Updates nodes values from current tree 23 | def _update_nodes(self, indices): 24 | children_indices = indices * 2 + np.expand_dims([1, 2], axis=1) 25 | self.sum_tree[indices] = np.sum(self.sum_tree[children_indices], axis=0) 26 | 27 | # Propagates changes up tree given tree indices 28 | def _propagate(self, indices): 29 | parents = (indices - 1) // 2 30 | unique_parents = np.unique(parents) 31 | self._update_nodes(unique_parents) 32 | if parents[0] != 0: 33 | self._propagate(parents) 34 | 35 | # Propagates single value up tree given a tree index for efficiency 36 | def _propagate_index(self, index): 37 | parent = (index - 1) // 2 38 | left, right = 2 * parent + 1, 2 * parent + 2 39 | self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right] 40 | if parent != 0: 41 | self._propagate_index(parent) 42 | 43 | # Updates values given tree indices 44 | def update(self, indices, values): 45 | self.sum_tree[indices] = values # Set new values 46 | self._propagate(indices) # Propagate values 47 | current_max_value = np.max(values) 48 | self.max = max(current_max_value, self.max) 49 | 50 | # Updates single value given a tree index for efficiency 51 | def _update_index(self, index, value): 52 | self.sum_tree[index] = value # Set new value 53 | self._propagate_index(index) # Propagate value 54 | self.max = max(value, self.max) 55 | 56 | def append(self, data, value): 57 | self.data[self.index] = data # Store data in underlying data structure 58 | self._update_index(self.index + self.tree_start, value) # Update tree 59 | self.index = (self.index + 1) % self.size # Update index 60 | self.full = self.full or self.index == 0 # Save when capacity reached 61 | self.max = max(value, self.max) 62 | 63 | # Searches for the location of values in sum tree 64 | def _retrieve(self, indices, values): 65 | children_indices = (indices * 2 + np.expand_dims([1, 2], axis=1)) # Make matrix of children indices 66 | # If indices correspond to leaf nodes, return them 67 | if children_indices[0, 0] >= self.sum_tree.shape[0]: 68 | return indices 69 | # If children indices correspond to leaf nodes, bound rare outliers in case total slightly overshoots 70 | elif children_indices[0, 0] >= self.tree_start: 71 | children_indices = np.minimum(children_indices, self.sum_tree.shape[0] - 1) 72 | left_children_values = self.sum_tree[children_indices[0]] 73 | successor_choices = np.greater(values, left_children_values).astype(np.int32) # Classify which values are in left or right branches 74 | successor_indices = children_indices[successor_choices, np.arange(indices.size)] # Use classification to index into the indices matrix 75 | successor_values = values - successor_choices * left_children_values # Subtract the left branch values when searching in the right branch 76 | return self._retrieve(successor_indices, successor_values) 77 | 78 | # Searches for values in sum tree and returns values, data indices and tree indices 79 | def find(self, values): 80 | indices = self._retrieve(np.zeros(values.shape, dtype=np.int32), values) 81 | data_index = indices - self.tree_start 82 | return (self.sum_tree[indices], data_index, indices) # Return values, data indices, tree indices 83 | 84 | # Returns data given a data index 85 | def get(self, data_index): 86 | return self.data[data_index % self.size] 87 | 88 | def total(self): 89 | return self.sum_tree[0] 90 | 91 | class ReplayMemory(): 92 | def __init__(self, args, capacity): 93 | self.device = args.device 94 | self.capacity = capacity 95 | self.history = args.history_length 96 | self.discount = args.discount 97 | self.n = args.multi_step 98 | self.priority_weight = args.priority_weight # Initial importance sampling weight β, annealed to 1 over course of training 99 | self.priority_exponent = args.priority_exponent 100 | self.t = 0 # Internal episode timestep counter 101 | self.n_step_scaling = torch.tensor([self.discount ** i for i in range(self.n)], dtype=torch.float32, device=self.device) # Discount-scaling vector for n-step returns 102 | self.transitions = SegmentTree(capacity) # Store transitions in a wrap-around cyclic buffer within a sum tree for querying priorities 103 | 104 | # Adds state and action at time t, reward and terminal at time t + 1 105 | def append(self, state, action, reward, terminal): 106 | state = state[-1].mul(255).to(dtype=torch.uint8, device=torch.device('cpu')) # Only store last frame and discretise to save memory 107 | self.transitions.append((self.t, state, action, reward, not terminal), self.transitions.max) # Store new transition with maximum priority 108 | self.t = 0 if terminal else self.t + 1 # Start new episodes with t = 0 109 | 110 | # Returns the transitions with blank states where appropriate 111 | def _get_transitions(self, idxs): 112 | transition_idxs = np.arange(-self.history + 1, self.n + 1) + np.expand_dims(idxs, axis=1) 113 | transitions = self.transitions.get(transition_idxs) 114 | transitions_firsts = transitions['timestep'] == 0 115 | blank_mask = np.zeros_like(transitions_firsts, dtype=np.bool_) 116 | for t in range(self.history - 2, -1, -1): # e.g. 2 1 0 117 | blank_mask[:, t] = np.logical_or(blank_mask[:, t + 1], transitions_firsts[:, t + 1]) # True if future frame has timestep 0 118 | for t in range(self.history, self.history + self.n): # e.g. 4 5 6 119 | blank_mask[:, t] = np.logical_or(blank_mask[:, t - 1], transitions_firsts[:, t]) # True if current or past frame has timestep 0 120 | transitions[blank_mask] = blank_trans 121 | return transitions 122 | 123 | # Returns a valid sample from each segment 124 | def _get_samples_from_segments(self, batch_size, p_total): 125 | segment_length = p_total / batch_size # Batch size number of segments, based on sum over all probabilities 126 | segment_starts = np.arange(batch_size) * segment_length 127 | valid = False 128 | while not valid: 129 | samples = np.random.uniform(0.0, segment_length, [batch_size]) + segment_starts # Uniformly sample from within all segments 130 | probs, idxs, tree_idxs = self.transitions.find(samples) # Retrieve samples from tree with un-normalised probability 131 | if np.all((self.transitions.index - idxs) % self.capacity > self.n) and np.all((idxs - self.transitions.index) % self.capacity >= self.history) and np.all(probs != 0): 132 | valid = True # Note that conditions are valid but extra conservative around buffer index 0 133 | # Retrieve all required transition data (from t - h to t + n) 134 | transitions = self._get_transitions(idxs) 135 | # Create un-discretised states and nth next states 136 | all_states = transitions['state'] 137 | states = torch.tensor(all_states[:, :self.history], device=self.device, dtype=torch.float32).div_(255) 138 | next_states = torch.tensor(all_states[:, self.n:self.n + self.history], device=self.device, dtype=torch.float32).div_(255) 139 | # Discrete actions to be used as index 140 | actions = torch.tensor(np.copy(transitions['action'][:, self.history - 1]), dtype=torch.int64, device=self.device) 141 | # Calculate truncated n-step discounted returns R^n = Σ_k=0->n-1 (γ^k)R_t+k+1 (note that invalid nth next states have reward 0) 142 | rewards = torch.tensor(np.copy(transitions['reward'][:, self.history - 1:-1]), dtype=torch.float32, device=self.device) 143 | R = torch.matmul(rewards, self.n_step_scaling) 144 | # Mask for non-terminal nth next states 145 | nonterminals = torch.tensor(np.expand_dims(transitions['nonterminal'][:, self.history + self.n - 1], axis=1), dtype=torch.float32, device=self.device) 146 | return probs, idxs, tree_idxs, states, actions, R, next_states, nonterminals 147 | 148 | def sample(self, batch_size): 149 | p_total = self.transitions.total() # Retrieve sum of all priorities (used to create a normalised probability distribution) 150 | probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = self._get_samples_from_segments(batch_size, p_total) # Get batch of valid samples 151 | probs = probs / p_total # Calculate normalised probabilities 152 | capacity = self.capacity if self.transitions.full else self.transitions.index 153 | weights = (capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w 154 | weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=self.device) # Normalise by max importance-sampling weight from batch 155 | return tree_idxs, states, actions, returns, next_states, nonterminals, weights 156 | 157 | def update_priorities(self, idxs, priorities): 158 | priorities = np.power(priorities, self.priority_exponent) 159 | self.transitions.update(idxs, priorities) 160 | 161 | # Set up internal state for iterator 162 | def __iter__(self): 163 | self.current_idx = 0 164 | return self 165 | 166 | # Return valid states for validation 167 | def __next__(self): 168 | if self.current_idx == self.capacity: 169 | raise StopIteration 170 | transitions = self.transitions.data[np.arange(self.current_idx - self.history + 1, self.current_idx + 1)] 171 | transitions_firsts = transitions['timestep'] == 0 172 | blank_mask = np.zeros_like(transitions_firsts, dtype=np.bool_) 173 | for t in reversed(range(self.history - 1)): 174 | blank_mask[t] = np.logical_or(blank_mask[t + 1], transitions_firsts[t + 1]) # If future frame has timestep 0 175 | transitions[blank_mask] = blank_trans 176 | state = torch.tensor(transitions['state'], dtype=torch.float32, device=self.device).div_(255) # Agent will turn into batch 177 | self.current_idx += 1 178 | return state 179 | 180 | next = __next__ # Alias __next__ for Python 2 compatibility 181 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | # Factorised NoisyLinear layer with bias 10 | class NoisyLinear(nn.Module): 11 | def __init__(self, in_features, out_features, std_init=0.5): 12 | super(NoisyLinear, self).__init__() 13 | self.in_features = in_features 14 | self.out_features = out_features 15 | self.std_init = std_init 16 | self.weight_mu = nn.Parameter(torch.empty(out_features, in_features)) 17 | self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features)) 18 | self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) 19 | self.bias_mu = nn.Parameter(torch.empty(out_features)) 20 | self.bias_sigma = nn.Parameter(torch.empty(out_features)) 21 | self.register_buffer('bias_epsilon', torch.empty(out_features)) 22 | self.reset_parameters() 23 | self.reset_noise() 24 | 25 | def reset_parameters(self): 26 | mu_range = 1 / math.sqrt(self.in_features) 27 | self.weight_mu.data.uniform_(-mu_range, mu_range) 28 | self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features)) 29 | self.bias_mu.data.uniform_(-mu_range, mu_range) 30 | self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features)) 31 | 32 | def _scale_noise(self, size): 33 | x = torch.randn(size, device=self.weight_mu.device) 34 | return x.sign().mul_(x.abs().sqrt_()) 35 | 36 | def reset_noise(self): 37 | epsilon_in = self._scale_noise(self.in_features) 38 | epsilon_out = self._scale_noise(self.out_features) 39 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 40 | self.bias_epsilon.copy_(epsilon_out) 41 | 42 | def forward(self, input): 43 | if self.training: 44 | return F.linear(input, self.weight_mu + self.weight_sigma * self.weight_epsilon, self.bias_mu + self.bias_sigma * self.bias_epsilon) 45 | else: 46 | return F.linear(input, self.weight_mu, self.bias_mu) 47 | 48 | 49 | class DQN(nn.Module): 50 | def __init__(self, args, action_space): 51 | super(DQN, self).__init__() 52 | self.atoms = args.atoms 53 | self.action_space = action_space 54 | 55 | if args.architecture == 'canonical': 56 | self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 8, stride=4, padding=0), nn.ReLU(), 57 | nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), 58 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU()) 59 | self.conv_output_size = 3136 60 | elif args.architecture == 'data-efficient': 61 | self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 5, stride=5, padding=0), nn.ReLU(), 62 | nn.Conv2d(32, 64, 5, stride=5, padding=0), nn.ReLU()) 63 | self.conv_output_size = 576 64 | self.fc_h_v = NoisyLinear(self.conv_output_size, args.hidden_size, std_init=args.noisy_std) 65 | self.fc_h_a = NoisyLinear(self.conv_output_size, args.hidden_size, std_init=args.noisy_std) 66 | self.fc_z_v = NoisyLinear(args.hidden_size, self.atoms, std_init=args.noisy_std) 67 | self.fc_z_a = NoisyLinear(args.hidden_size, action_space * self.atoms, std_init=args.noisy_std) 68 | 69 | def forward(self, x, log=False): 70 | x = self.convs(x) 71 | x = x.view(-1, self.conv_output_size) 72 | v = self.fc_z_v(F.relu(self.fc_h_v(x))) # Value stream 73 | a = self.fc_z_a(F.relu(self.fc_h_a(x))) # Advantage stream 74 | v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_space, self.atoms) 75 | q = v + a - a.mean(1, keepdim=True) # Combine streams 76 | if log: # Use log softmax for numerical stability 77 | q = F.log_softmax(q, dim=2) # Log probabilities with action over second dimension 78 | else: 79 | q = F.softmax(q, dim=2) # Probabilities with action over second dimension 80 | return q 81 | 82 | def reset_noise(self): 83 | for name, module in self.named_children(): 84 | if 'fc' in name: 85 | module.reset_noise() 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atari-py 2 | opencv-python 3 | plotly 4 | torch 5 | tqdm 6 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | Contains plots of evaluation rewards and Q-values, plus saved model weights 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import plotly 5 | from plotly.graph_objs import Scatter 6 | from plotly.graph_objs.scatter import Line 7 | import torch 8 | 9 | from env import Env 10 | 11 | 12 | # Test DQN 13 | def test(args, T, dqn, val_mem, metrics, results_dir, evaluate=False): 14 | env = Env(args) 15 | env.eval() 16 | metrics['steps'].append(T) 17 | T_rewards, T_Qs = [], [] 18 | 19 | # Test performance over several episodes 20 | done = True 21 | for _ in range(args.evaluation_episodes): 22 | while True: 23 | if done: 24 | state, reward_sum, done = env.reset(), 0, False 25 | 26 | action = dqn.act_e_greedy(state) # Choose an action ε-greedily 27 | state, reward, done = env.step(action) # Step 28 | reward_sum += reward 29 | if args.render: 30 | env.render() 31 | 32 | if done: 33 | T_rewards.append(reward_sum) 34 | break 35 | env.close() 36 | 37 | # Test Q-values over validation memory 38 | for state in val_mem: # Iterate over valid states 39 | T_Qs.append(dqn.evaluate_q(state)) 40 | 41 | avg_reward, avg_Q = sum(T_rewards) / len(T_rewards), sum(T_Qs) / len(T_Qs) 42 | if not evaluate: 43 | # Save model parameters if improved 44 | if avg_reward > metrics['best_avg_reward']: 45 | metrics['best_avg_reward'] = avg_reward 46 | dqn.save(results_dir) 47 | 48 | # Append to results and save metrics 49 | metrics['rewards'].append(T_rewards) 50 | metrics['Qs'].append(T_Qs) 51 | torch.save(metrics, os.path.join(results_dir, 'metrics.pth')) 52 | 53 | # Plot 54 | _plot_line(metrics['steps'], metrics['rewards'], 'Reward', path=results_dir) 55 | _plot_line(metrics['steps'], metrics['Qs'], 'Q', path=results_dir) 56 | 57 | # Return average reward and Q-value 58 | return avg_reward, avg_Q 59 | 60 | 61 | # Plots min, max and mean + standard deviation bars of a population over time 62 | def _plot_line(xs, ys_population, title, path=''): 63 | max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)' 64 | 65 | ys = torch.tensor(ys_population, dtype=torch.float32) 66 | ys_min, ys_max, ys_mean, ys_std = ys.min(1)[0].squeeze(), ys.max(1)[0].squeeze(), ys.mean(1).squeeze(), ys.std(1).squeeze() 67 | ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std 68 | 69 | trace_max = Scatter(x=xs, y=ys_max.numpy(), line=Line(color=max_colour, dash='dash'), name='Max') 70 | trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False) 71 | trace_mean = Scatter(x=xs, y=ys_mean.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean') 72 | trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), name='-1 Std. Dev.', showlegend=False) 73 | trace_min = Scatter(x=xs, y=ys_min.numpy(), line=Line(color=max_colour, dash='dash'), name='Min') 74 | 75 | plotly.offline.plot({ 76 | 'data': [trace_upper, trace_mean, trace_lower, trace_min, trace_max], 77 | 'layout': dict(title=title, xaxis={'title': 'Step'}, yaxis={'title': title}) 78 | }, filename=os.path.join(path, title + '.html'), auto_open=False) 79 | --------------------------------------------------------------------------------