├── .github ├── dependabot.yml └── workflows │ └── model-sanity-test.yml ├── .gitignore ├── LICENSE ├── README.md ├── core ├── __init__.py ├── argparser.py ├── constants.py ├── helpers.py ├── model.py ├── replay_buffer.py ├── train_information.py └── wrappers.py ├── media └── pongnoframeskip-v4.gif ├── models └── PongNoFrameskip-v4.dat ├── requirements.txt ├── test.py └── train.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | -------------------------------------------------------------------------------- /.github/workflows/model-sanity-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow ensures no destructive changes made to the application which would 2 | # prevent the model from learning Pong. The latest trained model is used and should 3 | # get a score of at least 19 in Pong for a single episode. 4 | name: Pong Retraining Sanity Test 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: [3.7, 3.8, 3.9] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -r requirements.txt 27 | - name: Lint with pycodestyle 28 | run: | 29 | python -m pip install pycodestyle 30 | pycodestyle core *.py 31 | - name: Test transfer learning 32 | run: | 33 | python train.py --environment PongNoFrameskip-v4 --num-episodes 1 --checkpoint models/PongNoFrameskip-v4.dat --epsilon-start 0.0 > results.txt 34 | result=`grep "Best:" results.txt | awk '{print $7}' | cut -d "." -f1` 35 | echo "Result: $result" 36 | if [ $result -lt 19 ]; then 37 | echo "Result was below threshold. Marking failed." 38 | exit -1 39 | else 40 | echo "Result above threshold. Marking passed." 41 | fi 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Robert Clark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenAI Gym PyTorch 2 |

3 | 4 |

5 | 6 | OpenAI's Gym is an open source toolkit containing several environments which can 7 | be used to compare reinforcement learning algorithms and techniques in a 8 | consistent and repeatable manner, easily allowing developers to benchmark their 9 | solutions. 10 | 11 | This repository aims to create a simple one-stop location for testing 12 | reinforcement learning models without worrying about configuring or maintaining 13 | the environment. Featuring extensive command-line parameters, various tweaks to 14 | settings can easily be made to determine an optimal configuration for a 15 | particular environment and model. 16 | 17 | ## Setting up the repository 18 | 19 | ### Creating a virtual environment 20 | After cloning the repository, it is highly recommended to install a virtual 21 | environment (such as `virtualenv`) or Anaconda to isolate the dependencies of 22 | this project with other system dependencies. 23 | 24 | To install `virtualenv`, simply run 25 | 26 | ``` 27 | pip install virtualenv 28 | ``` 29 | 30 | Once installed, a new virtual environment can be created by running 31 | 32 | ``` 33 | virtualenv env 34 | ``` 35 | 36 | This will create a virtual environment in the `env` directory in the current 37 | working directory. To change the location and/or name of the environment 38 | directory, change `env` to the desired path in the command above. 39 | 40 | To enter the virtual environment, run 41 | 42 | ``` 43 | source env/bin/activate 44 | ``` 45 | 46 | You should see `(env)` at the beginning of the terminal prompt, indicating the 47 | environment is active. Again, replace `env` with your desired directory name. 48 | 49 | To get out of the environment, simply run 50 | 51 | ``` 52 | deactivate 53 | ``` 54 | 55 | ### Installing Dependencies 56 | While the virtual environment is active, install the required dependencies by 57 | running 58 | 59 | ``` 60 | pip -r requirements.txt 61 | ``` 62 | 63 | This will install all of the dependencies at specific versions to ensure they 64 | are compatible with one another. 65 | 66 | ## Training a model 67 | 68 | To train a model, use the `train.py` script and specify any parameters that need 69 | to be changed, such as the environment or epsilon decay factors. A list of the 70 | default values for every parameters can be found by running 71 | 72 | ``` 73 | python train.py --help 74 | ``` 75 | 76 | If you desire to run with the default settings, execute the script directly with 77 | 78 | ``` 79 | python train.py 80 | ``` 81 | 82 | The script will train the default environment over a set number of episodes and 83 | display the training progress after the conclusion of every episode. The updates 84 | indicate the episode number, the reward for the current episode, the best reward 85 | the model has achieved so far, a rolling average of the previous 100 episode 86 | rewards, and the current value for epsilon. 87 | 88 | Any time the model reaches a new best rolling average, the current model weights 89 | are saved as a `.dat` file with the environment's name (such as 90 | `PongNoFrameskip-v4.dat`). This saved model will overwrite any existing model 91 | weight files for the same environment. 92 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roclark/openai-gym-pytorch/0a8fbd94070877d6dbb14c1733c57af27905df37/core/__init__.py -------------------------------------------------------------------------------- /core/argparser.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from core.constants import (BATCH_SIZE, 4 | ENVIRONMENT, 5 | EPSILON_START, 6 | EPSILON_FINAL, 7 | EPSILON_DECAY, 8 | GAMMA, 9 | INITIAL_LEARNING, 10 | LEARNING_RATE, 11 | MEMORY_CAPACITY, 12 | NUM_EPISODES, 13 | TARGET_UPDATE_FREQUENCY) 14 | from core.helpers import Range 15 | 16 | 17 | def parse_args(): 18 | parser = ArgumentParser(description='') 19 | parser.add_argument('--batch-size', type=int, help='Specify the batch ' 20 | 'size to use when updating the replay buffer. ' 21 | 'Default: %s' % BATCH_SIZE, default=BATCH_SIZE) 22 | parser.add_argument('--buffer-capacity', type=int, help='The capacity to ' 23 | 'use in the experience replay buffer. Default: %s' 24 | % MEMORY_CAPACITY, default=MEMORY_CAPACITY) 25 | parser.add_argument('--checkpoint', type=str, help='Specify a .dat file ' 26 | 'to be used as a checkpoint to initialize weights for ' 27 | 'a new training run. Defaults to no checkpoint.') 28 | parser.add_argument('--environment', type=str, help='The OpenAI gym ' 29 | 'environment to use. Default: %s' % ENVIRONMENT, 30 | default=ENVIRONMENT) 31 | parser.add_argument('--epsilon-start', type=float, help='The initial ' 32 | 'value for epsilon to be used in the epsilon-greedy ' 33 | 'algorithm. Default: %s' % EPSILON_START, 34 | choices=[Range(0.0, 1.0)], default=EPSILON_START, 35 | metavar='EPSILON_START') 36 | parser.add_argument('--epsilon-final', type=float, help='The final value ' 37 | 'for epislon to be used in the epsilon-greedy ' 38 | 'algorithm. Default: %s' % EPSILON_FINAL, 39 | choices=[Range(0.0, 1.0)], default=EPSILON_FINAL, 40 | metavar='EPSILON_FINAL') 41 | parser.add_argument('--epsilon-decay', type=int, help='The decay factor ' 42 | 'for epsilon in the epsilon-greedy algorithm. ' 43 | 'Default: %s' % EPSILON_DECAY, default=EPSILON_DECAY) 44 | parser.add_argument('--force-cpu', action='store_true', help='By default, ' 45 | 'the program will run on the first supported GPU ' 46 | 'identified by the system, if applicable. If a ' 47 | 'supported GPU is installed, but all computations are ' 48 | 'desired to run on the CPU only, specify this ' 49 | 'parameter to ignore the GPUs. All actions will run ' 50 | 'on the CPU if no supported GPUs are found. Default: ' 51 | 'False') 52 | parser.add_argument('--gamma', type=float, help='Specify the discount ' 53 | 'factor, gamma, to use in the Q-table formula. ' 54 | 'Default: %s' % GAMMA, choices=[Range(0.0, 1.0)], 55 | default=GAMMA, metavar='GAMMA') 56 | parser.add_argument('--initial-learning', type=int, help='The number of ' 57 | 'iterations to explore prior to updating the model ' 58 | 'and begin the learning process. Default: %s' 59 | % INITIAL_LEARNING, default=INITIAL_LEARNING) 60 | parser.add_argument('--learning-rate', type=float, help='The learning ' 61 | 'rate to use for the optimizer. Default: %s' 62 | % LEARNING_RATE, default=LEARNING_RATE) 63 | parser.add_argument('--num-episodes', type=int, help='The number of ' 64 | 'episodes to run in the given environment. Default: ' 65 | '%s' % NUM_EPISODES, default=NUM_EPISODES) 66 | parser.add_argument('--render', action='store_true', help='Specify to ' 67 | 'render a visualization in another window of the ' 68 | 'learning process. Note that a Desktop Environment is ' 69 | 'required for visualization. Rendering scenes will ' 70 | 'lower the learning speed. Default: False') 71 | parser.add_argument('--target-update-frequency', type=int, help='Specify ' 72 | 'the number of iterations to run prior to updating ' 73 | 'target network with the primary network\'s weights. ' 74 | 'Default: %s' % TARGET_UPDATE_FREQUENCY, 75 | default=TARGET_UPDATE_FREQUENCY) 76 | return parser.parse_args() 77 | -------------------------------------------------------------------------------- /core/constants.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 32 2 | ENVIRONMENT = 'PongNoFrameskip-v4' 3 | EPSILON_START = 1.0 4 | EPSILON_FINAL = 0.01 5 | EPSILON_DECAY = 100000 6 | GAMMA = 0.99 7 | INITIAL_LEARNING = 10000 8 | LEARNING_RATE = 1e-4 9 | MEMORY_CAPACITY = 20000 10 | NUM_EPISODES = 10000 11 | TARGET_UPDATE_FREQUENCY = 1000 12 | -------------------------------------------------------------------------------- /core/helpers.py: -------------------------------------------------------------------------------- 1 | import atari_py as ap 2 | import math 3 | import numpy as np 4 | import re 5 | import torch 6 | from .model import CNNDQN 7 | from torch import FloatTensor, LongTensor 8 | from torch.autograd import Variable 9 | 10 | 11 | class Range: 12 | def __init__(self, start, end): 13 | self._start = start 14 | self._end = end 15 | 16 | def __eq__(self, input_num): 17 | return self._start <= input_num <= self._end 18 | 19 | 20 | def compute_td_loss(model, target_net, batch, gamma, device): 21 | state, action, reward, next_state, done = batch 22 | 23 | state = Variable(FloatTensor(np.float32(state))).to(device) 24 | next_state = Variable(FloatTensor(np.float32(next_state))).to(device) 25 | action = Variable(LongTensor(action)).to(device) 26 | reward = Variable(FloatTensor(reward)).to(device) 27 | done = Variable(FloatTensor(done)).to(device) 28 | 29 | q_values = model(state) 30 | next_q_values = target_net(next_state) 31 | 32 | q_value = q_values.gather(1, action.unsqueeze(-1)).squeeze(-1) 33 | next_q_value = next_q_values.max(1)[0] 34 | expected_q_value = reward + gamma * next_q_value * (1 - done) 35 | 36 | loss = (q_value - Variable(expected_q_value.data).to(device)).pow(2).mean() 37 | loss.backward() 38 | 39 | 40 | def update_epsilon(episode, args): 41 | eps_final = args.epsilon_final 42 | eps_start = args.epsilon_start 43 | decay = args.epsilon_decay 44 | epsilon = eps_final + (eps_start - eps_final) * \ 45 | math.exp(-1 * ((episode + 1) / decay)) 46 | return epsilon 47 | 48 | 49 | def set_device(force_cpu): 50 | device = torch.device('cpu') 51 | if not force_cpu and torch.cuda.is_available(): 52 | device = torch.device('cuda') 53 | return device 54 | 55 | 56 | def load_model(checkpoint, model, target_model, device): 57 | model.load_state_dict(torch.load(checkpoint, map_location=device)) 58 | target_model.load_state_dict(model.state_dict()) 59 | return model, target_model 60 | 61 | 62 | def initialize_models(env, device, checkpoint): 63 | model = CNNDQN(env.observation_space.shape, 64 | env.action_space.n).to(device) 65 | target_model = CNNDQN(env.observation_space.shape, 66 | env.action_space.n).to(device) 67 | if checkpoint: 68 | model, target_model = load_model(checkpoint, model, target_model, 69 | device) 70 | return model, target_model 71 | 72 | 73 | def camel_to_snake_case(string): 74 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string) 75 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 76 | 77 | 78 | def is_atari(environment): 79 | for field in ['ramDeterministic', 'ramNoFrameSkip', 'NoFrameskip', 80 | 'Deterministic', 'ram']: 81 | environment = environment.replace(field, '') 82 | environment = re.sub(r'-v\d+', '', environment) 83 | environment = camel_to_snake_case(environment) 84 | if environment in ap.list_games(): 85 | return True 86 | else: 87 | return False 88 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from random import random, randrange 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CNNDQN(nn.Module): 9 | def __init__(self, input_shape, num_actions): 10 | super(CNNDQN, self).__init__() 11 | self._input_shape = input_shape 12 | self._num_actions = num_actions 13 | 14 | self.features = nn.Sequential( 15 | nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), 16 | nn.ReLU(), 17 | nn.Conv2d(32, 64, kernel_size=4, stride=2), 18 | nn.ReLU(), 19 | nn.Conv2d(64, 64, kernel_size=3, stride=1), 20 | nn.ReLU() 21 | ) 22 | 23 | self.fc = nn.Sequential( 24 | nn.Linear(self.feature_size, 512), 25 | nn.ReLU(), 26 | nn.Linear(512, num_actions) 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.features(x).view(x.size()[0], -1) 31 | return self.fc(x) 32 | 33 | @property 34 | def feature_size(self): 35 | x = self.features(torch.zeros(1, *self._input_shape)) 36 | return x.view(1, -1).size(1) 37 | 38 | def act(self, state, device, epsilon=0.0): 39 | if random() > epsilon: 40 | state = torch.FloatTensor(np.float32(state)) \ 41 | .unsqueeze(0).to(device) 42 | q_value = self.forward(state) 43 | action = q_value.max(1)[1].item() 44 | else: 45 | action = randrange(self._num_actions) 46 | return action 47 | -------------------------------------------------------------------------------- /core/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from random import sample 3 | from collections import deque 4 | 5 | 6 | class ReplayBuffer: 7 | def __init__(self, capacity): 8 | self._buffer = deque(maxlen=capacity) 9 | 10 | def push(self, state, action, reward, next_state, done): 11 | self._buffer.append((state, action, reward, next_state, done)) 12 | 13 | def sample(self, batch_size): 14 | indices = np.random.choice(len(self._buffer), 15 | batch_size, 16 | replace=False) 17 | batch = zip(*[self._buffer[i] for i in indices]) 18 | state, action, reward, next_state, done = batch 19 | return (np.array(state), 20 | np.array(action), 21 | np.array(reward, dtype=np.float32), 22 | np.array(next_state), 23 | np.array(done, dtype=np.uint8)) 24 | 25 | def __len__(self): 26 | return len(self._buffer) 27 | -------------------------------------------------------------------------------- /core/train_information.py: -------------------------------------------------------------------------------- 1 | class TrainInformation: 2 | def __init__(self): 3 | self._average = 0.0 4 | self._best_reward = -float('inf') 5 | self._best_average = -float('inf') 6 | self._rewards = [] 7 | self._average_range = 100 8 | self._index = 0 9 | 10 | @property 11 | def best_reward(self): 12 | return self._best_reward 13 | 14 | @property 15 | def best_average(self): 16 | return self._best_average 17 | 18 | @property 19 | def average(self): 20 | avg_range = self._average_range * -1 21 | return sum(self._rewards[avg_range:]) / len(self._rewards[avg_range:]) 22 | 23 | @property 24 | def index(self): 25 | return self._index 26 | 27 | def _update_best_reward(self, episode_reward): 28 | if episode_reward > self.best_reward: 29 | self._best_reward = episode_reward 30 | 31 | def _update_best_average(self): 32 | if self.average > self.best_average: 33 | self._best_average = self.average 34 | return True 35 | return False 36 | 37 | def update_rewards(self, episode_reward): 38 | self._rewards.append(episode_reward) 39 | self._update_best_reward(episode_reward) 40 | return self._update_best_average() 41 | 42 | def update_index(self): 43 | self._index += 1 44 | -------------------------------------------------------------------------------- /core/wrappers.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from collections import deque 4 | from core.helpers import is_atari 5 | from gym import make, ObservationWrapper, wrappers, Wrapper 6 | from gym.spaces import Box 7 | 8 | 9 | class ClassicControl(Wrapper): 10 | def __init__(self, env, atari): 11 | super(ClassicControl, self).__init__(env) 12 | self._atari = atari 13 | 14 | def reset(self): 15 | if not self._atari: 16 | self.env.reset() 17 | return self.env.render(mode='rgb_array') 18 | else: 19 | return self.env.reset() 20 | 21 | 22 | class FrameDownsample(ObservationWrapper): 23 | def __init__(self, env): 24 | super(FrameDownsample, self).__init__(env) 25 | self.observation_space = Box(low=0, 26 | high=255, 27 | shape=(84, 84, 1), 28 | dtype=np.uint8) 29 | self._width = 84 30 | self._height = 84 31 | 32 | def observation(self, observation): 33 | frame = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) 34 | frame = cv2.resize(frame, 35 | (self._width, self._height), 36 | interpolation=cv2.INTER_AREA) 37 | return frame[:, :, None] 38 | 39 | 40 | class MaxAndSkipEnv(Wrapper): 41 | def __init__(self, env, atari, skip=4): 42 | super(MaxAndSkipEnv, self).__init__(env) 43 | self._obs_buffer = deque(maxlen=2) 44 | self._skip = skip 45 | self._atari = atari 46 | 47 | def step(self, action): 48 | total_reward = 0.0 49 | done = None 50 | for _ in range(self._skip): 51 | obs, reward, done, info = self.env.step(action) 52 | if not self._atari: 53 | obs = self.env.render(mode='rgb_array') 54 | self._obs_buffer.append(obs) 55 | total_reward += reward 56 | if done: 57 | break 58 | max_frame = np.max(np.stack(self._obs_buffer), axis=0) 59 | return max_frame, total_reward, done, info 60 | 61 | def reset(self): 62 | self._obs_buffer.clear() 63 | obs = self.env.reset() 64 | self._obs_buffer.append(obs) 65 | return obs 66 | 67 | 68 | class FireResetEnv(Wrapper): 69 | def __init__(self, env): 70 | Wrapper.__init__(self, env) 71 | if len(env.unwrapped.get_action_meanings()) < 3: 72 | raise ValueError('Expected an action space of at least 3!') 73 | 74 | def reset(self, **kwargs): 75 | self.env.reset(**kwargs) 76 | obs, _, done, _ = self.env.step(1) 77 | if done: 78 | self.env.reset(**kwargs) 79 | obs, _, done, _ = self.env.step(2) 80 | if done: 81 | self.env.reset(**kwargs) 82 | return obs 83 | 84 | def step(self, action): 85 | return self.env.step(action) 86 | 87 | 88 | class FrameBuffer(ObservationWrapper): 89 | def __init__(self, env, num_steps, dtype=np.float32): 90 | super(FrameBuffer, self).__init__(env) 91 | obs_space = env.observation_space 92 | self._dtype = dtype 93 | self.observation_space = Box(obs_space.low.repeat(num_steps, axis=0), 94 | obs_space.high.repeat(num_steps, axis=0), 95 | dtype=self._dtype) 96 | 97 | def reset(self): 98 | self.buffer = np.zeros_like(self.observation_space.low, 99 | dtype=self._dtype) 100 | return self.observation(self.env.reset()) 101 | 102 | def observation(self, observation): 103 | self.buffer[:-1] = self.buffer[1:] 104 | self.buffer[-1] = observation 105 | return self.buffer 106 | 107 | 108 | class ImageToPyTorch(ObservationWrapper): 109 | def __init__(self, env): 110 | super(ImageToPyTorch, self).__init__(env) 111 | obs_shape = self.observation_space.shape 112 | self.observation_space = Box(low=0.0, 113 | high=1.0, 114 | shape=(obs_shape[::-1]), 115 | dtype=np.float32) 116 | 117 | def observation(self, observation): 118 | return np.moveaxis(observation, 2, 0) 119 | 120 | 121 | class NormalizeFloats(ObservationWrapper): 122 | def observation(self, obs): 123 | return np.array(obs).astype(np.float32) / 255.0 124 | 125 | 126 | def wrap_environment(environment, monitor=False): 127 | env = make(environment) 128 | atari = is_atari(environment) 129 | env = ClassicControl(env, atari) 130 | env = MaxAndSkipEnv(env, atari) 131 | try: 132 | if 'FIRE' in env.unwrapped.get_action_meanings(): 133 | env = FireResetEnv(env) 134 | except AttributeError: 135 | # Some environments, such as the classic control environments, don't 136 | # have a get_action_meanings method. Since these environments don't 137 | # contain a 'FIRE' action, this wrapper is irrelevant and can be safely 138 | # ignored if the attribute doesn't exist. 139 | pass 140 | env = FrameDownsample(env) 141 | env = ImageToPyTorch(env) 142 | env = FrameBuffer(env, 4) 143 | env = NormalizeFloats(env) 144 | if monitor: 145 | env = wrappers.Monitor(env, 'videos', force=True) 146 | return env 147 | -------------------------------------------------------------------------------- /media/pongnoframeskip-v4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roclark/openai-gym-pytorch/0a8fbd94070877d6dbb14c1733c57af27905df37/media/pongnoframeskip-v4.gif -------------------------------------------------------------------------------- /models/PongNoFrameskip-v4.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roclark/openai-gym-pytorch/0a8fbd94070877d6dbb14c1733c57af27905df37/models/PongNoFrameskip-v4.dat -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atari-py==0.2.6 2 | cloudpickle==1.6.0 3 | future==0.18.2 4 | gym==0.19.0 5 | numpy==1.21.2 6 | opencv-python==4.5.3.56 7 | Pillow==8.3.2 8 | pyglet==1.5.15 9 | scipy==1.7.1 10 | six==1.16.0 11 | torch==1.9.0 12 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from core.helpers import (initialize_models, 3 | set_device) 4 | from core.wrappers import wrap_environment 5 | 6 | 7 | def parse_args(): 8 | parser = ArgumentParser() 9 | parser.add_argument('--checkpoint', type=str, help='Specify the trained ' 10 | 'model to test.') 11 | parser.add_argument('--environment', type=str, help='Specify the ' 12 | 'environment to test against.', 13 | default='PongNoFrameskip-v4') 14 | parser.add_argument('--force-cpu', action='store_true', help='Force ' 15 | 'computation to be done on the CPU. This may result ' 16 | 'in longer processing time.') 17 | return parser.parse_args() 18 | 19 | 20 | def main(): 21 | args = parse_args() 22 | env = wrap_environment(args.environment, monitor=True) 23 | device = set_device(args.force_cpu) 24 | model, target_model = initialize_models(env, device, args.checkpoint) 25 | 26 | done = False 27 | state = env.reset() 28 | episode_reward = 0.0 29 | 30 | while not done: 31 | action = model.act(state, device) 32 | next_state, reward, done, _ = env.step(action) 33 | episode_reward += reward 34 | state = next_state 35 | 36 | print(f'Episode Reward: {round(episode_reward, 3)}') 37 | env.close() 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from core.argparser import parse_args 2 | from core.helpers import (compute_td_loss, 3 | initialize_models, 4 | is_atari, 5 | set_device, 6 | update_epsilon) 7 | from core.replay_buffer import ReplayBuffer 8 | from core.train_information import TrainInformation 9 | from core.wrappers import wrap_environment 10 | 11 | from torch import save 12 | from torch.optim import Adam 13 | 14 | 15 | def update_graph(model, target_model, optimizer, replay_buffer, args, device, 16 | info): 17 | if len(replay_buffer) > args.initial_learning: 18 | if not info.index % args.target_update_frequency: 19 | target_model.load_state_dict(model.state_dict()) 20 | optimizer.zero_grad() 21 | batch = replay_buffer.sample(args.batch_size) 22 | compute_td_loss(model, target_model, batch, args.gamma, device) 23 | optimizer.step() 24 | 25 | 26 | def complete_episode(model, environment, info, episode_reward, episode, 27 | epsilon): 28 | new_best = info.update_rewards(episode_reward) 29 | if new_best: 30 | print('New best average reward of %s! Saving model' 31 | % round(info.best_average, 3)) 32 | save(model.state_dict(), '%s.dat' % environment) 33 | print('Episode %s - Reward: %s, Best: %s, Average: %s ' 34 | 'Epsilon: %s' % (episode, episode_reward, info.best_reward, 35 | round(info.average, 3), round(epsilon, 4))) 36 | 37 | 38 | def run_episode(env, model, target_model, optimizer, replay_buffer, args, 39 | device, info, episode): 40 | episode_reward = 0.0 41 | state = env.reset() 42 | 43 | while True: 44 | epsilon = update_epsilon(info.index, args) 45 | action = model.act(state, device, epsilon) 46 | if args.render: 47 | env.render() 48 | next_state, reward, done, _ = env.step(action) 49 | replay_buffer.push(state, action, reward, next_state, done) 50 | state = next_state 51 | episode_reward += reward 52 | info.update_index() 53 | update_graph(model, target_model, optimizer, replay_buffer, args, 54 | device, info) 55 | if done: 56 | complete_episode(model, args.environment, info, episode_reward, 57 | episode, epsilon) 58 | break 59 | 60 | 61 | def train(env, model, target_model, optimizer, replay_buffer, args, device): 62 | info = TrainInformation() 63 | 64 | for episode in range(args.num_episodes): 65 | run_episode(env, model, target_model, optimizer, replay_buffer, args, 66 | device, info, episode) 67 | 68 | 69 | def main(): 70 | args = parse_args() 71 | env = wrap_environment(args.environment) 72 | device = set_device(args.force_cpu) 73 | model, target_model = initialize_models(env, device, args.checkpoint) 74 | optimizer = Adam(model.parameters(), lr=args.learning_rate) 75 | replay_buffer = ReplayBuffer(args.buffer_capacity) 76 | train(env, model, target_model, optimizer, replay_buffer, args, device) 77 | env.close() 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | --------------------------------------------------------------------------------