├── .github
├── dependabot.yml
└── workflows
│ └── model-sanity-test.yml
├── .gitignore
├── LICENSE
├── README.md
├── core
├── __init__.py
├── argparser.py
├── constants.py
├── helpers.py
├── model.py
├── replay_buffer.py
├── train_information.py
└── wrappers.py
├── media
└── pongnoframeskip-v4.gif
├── models
└── PongNoFrameskip-v4.dat
├── requirements.txt
├── test.py
└── train.py
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 |
--------------------------------------------------------------------------------
/.github/workflows/model-sanity-test.yml:
--------------------------------------------------------------------------------
1 | # This workflow ensures no destructive changes made to the application which would
2 | # prevent the model from learning Pong. The latest trained model is used and should
3 | # get a score of at least 19 in Pong for a single episode.
4 | name: Pong Retraining Sanity Test
5 |
6 | on: [push, pull_request]
7 |
8 | jobs:
9 | build:
10 |
11 | runs-on: ubuntu-latest
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | python-version: [3.7, 3.8, 3.9]
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install -r requirements.txt
27 | - name: Lint with pycodestyle
28 | run: |
29 | python -m pip install pycodestyle
30 | pycodestyle core *.py
31 | - name: Test transfer learning
32 | run: |
33 | python train.py --environment PongNoFrameskip-v4 --num-episodes 1 --checkpoint models/PongNoFrameskip-v4.dat --epsilon-start 0.0 > results.txt
34 | result=`grep "Best:" results.txt | awk '{print $7}' | cut -d "." -f1`
35 | echo "Result: $result"
36 | if [ $result -lt 19 ]; then
37 | echo "Result was below threshold. Marking failed."
38 | exit -1
39 | else
40 | echo "Result above threshold. Marking passed."
41 | fi
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__*
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Robert Clark
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenAI Gym PyTorch
2 |
3 |
4 |
5 |
6 | OpenAI's Gym is an open source toolkit containing several environments which can
7 | be used to compare reinforcement learning algorithms and techniques in a
8 | consistent and repeatable manner, easily allowing developers to benchmark their
9 | solutions.
10 |
11 | This repository aims to create a simple one-stop location for testing
12 | reinforcement learning models without worrying about configuring or maintaining
13 | the environment. Featuring extensive command-line parameters, various tweaks to
14 | settings can easily be made to determine an optimal configuration for a
15 | particular environment and model.
16 |
17 | ## Setting up the repository
18 |
19 | ### Creating a virtual environment
20 | After cloning the repository, it is highly recommended to install a virtual
21 | environment (such as `virtualenv`) or Anaconda to isolate the dependencies of
22 | this project with other system dependencies.
23 |
24 | To install `virtualenv`, simply run
25 |
26 | ```
27 | pip install virtualenv
28 | ```
29 |
30 | Once installed, a new virtual environment can be created by running
31 |
32 | ```
33 | virtualenv env
34 | ```
35 |
36 | This will create a virtual environment in the `env` directory in the current
37 | working directory. To change the location and/or name of the environment
38 | directory, change `env` to the desired path in the command above.
39 |
40 | To enter the virtual environment, run
41 |
42 | ```
43 | source env/bin/activate
44 | ```
45 |
46 | You should see `(env)` at the beginning of the terminal prompt, indicating the
47 | environment is active. Again, replace `env` with your desired directory name.
48 |
49 | To get out of the environment, simply run
50 |
51 | ```
52 | deactivate
53 | ```
54 |
55 | ### Installing Dependencies
56 | While the virtual environment is active, install the required dependencies by
57 | running
58 |
59 | ```
60 | pip -r requirements.txt
61 | ```
62 |
63 | This will install all of the dependencies at specific versions to ensure they
64 | are compatible with one another.
65 |
66 | ## Training a model
67 |
68 | To train a model, use the `train.py` script and specify any parameters that need
69 | to be changed, such as the environment or epsilon decay factors. A list of the
70 | default values for every parameters can be found by running
71 |
72 | ```
73 | python train.py --help
74 | ```
75 |
76 | If you desire to run with the default settings, execute the script directly with
77 |
78 | ```
79 | python train.py
80 | ```
81 |
82 | The script will train the default environment over a set number of episodes and
83 | display the training progress after the conclusion of every episode. The updates
84 | indicate the episode number, the reward for the current episode, the best reward
85 | the model has achieved so far, a rolling average of the previous 100 episode
86 | rewards, and the current value for epsilon.
87 |
88 | Any time the model reaches a new best rolling average, the current model weights
89 | are saved as a `.dat` file with the environment's name (such as
90 | `PongNoFrameskip-v4.dat`). This saved model will overwrite any existing model
91 | weight files for the same environment.
92 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/roclark/openai-gym-pytorch/0a8fbd94070877d6dbb14c1733c57af27905df37/core/__init__.py
--------------------------------------------------------------------------------
/core/argparser.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from core.constants import (BATCH_SIZE,
4 | ENVIRONMENT,
5 | EPSILON_START,
6 | EPSILON_FINAL,
7 | EPSILON_DECAY,
8 | GAMMA,
9 | INITIAL_LEARNING,
10 | LEARNING_RATE,
11 | MEMORY_CAPACITY,
12 | NUM_EPISODES,
13 | TARGET_UPDATE_FREQUENCY)
14 | from core.helpers import Range
15 |
16 |
17 | def parse_args():
18 | parser = ArgumentParser(description='')
19 | parser.add_argument('--batch-size', type=int, help='Specify the batch '
20 | 'size to use when updating the replay buffer. '
21 | 'Default: %s' % BATCH_SIZE, default=BATCH_SIZE)
22 | parser.add_argument('--buffer-capacity', type=int, help='The capacity to '
23 | 'use in the experience replay buffer. Default: %s'
24 | % MEMORY_CAPACITY, default=MEMORY_CAPACITY)
25 | parser.add_argument('--checkpoint', type=str, help='Specify a .dat file '
26 | 'to be used as a checkpoint to initialize weights for '
27 | 'a new training run. Defaults to no checkpoint.')
28 | parser.add_argument('--environment', type=str, help='The OpenAI gym '
29 | 'environment to use. Default: %s' % ENVIRONMENT,
30 | default=ENVIRONMENT)
31 | parser.add_argument('--epsilon-start', type=float, help='The initial '
32 | 'value for epsilon to be used in the epsilon-greedy '
33 | 'algorithm. Default: %s' % EPSILON_START,
34 | choices=[Range(0.0, 1.0)], default=EPSILON_START,
35 | metavar='EPSILON_START')
36 | parser.add_argument('--epsilon-final', type=float, help='The final value '
37 | 'for epislon to be used in the epsilon-greedy '
38 | 'algorithm. Default: %s' % EPSILON_FINAL,
39 | choices=[Range(0.0, 1.0)], default=EPSILON_FINAL,
40 | metavar='EPSILON_FINAL')
41 | parser.add_argument('--epsilon-decay', type=int, help='The decay factor '
42 | 'for epsilon in the epsilon-greedy algorithm. '
43 | 'Default: %s' % EPSILON_DECAY, default=EPSILON_DECAY)
44 | parser.add_argument('--force-cpu', action='store_true', help='By default, '
45 | 'the program will run on the first supported GPU '
46 | 'identified by the system, if applicable. If a '
47 | 'supported GPU is installed, but all computations are '
48 | 'desired to run on the CPU only, specify this '
49 | 'parameter to ignore the GPUs. All actions will run '
50 | 'on the CPU if no supported GPUs are found. Default: '
51 | 'False')
52 | parser.add_argument('--gamma', type=float, help='Specify the discount '
53 | 'factor, gamma, to use in the Q-table formula. '
54 | 'Default: %s' % GAMMA, choices=[Range(0.0, 1.0)],
55 | default=GAMMA, metavar='GAMMA')
56 | parser.add_argument('--initial-learning', type=int, help='The number of '
57 | 'iterations to explore prior to updating the model '
58 | 'and begin the learning process. Default: %s'
59 | % INITIAL_LEARNING, default=INITIAL_LEARNING)
60 | parser.add_argument('--learning-rate', type=float, help='The learning '
61 | 'rate to use for the optimizer. Default: %s'
62 | % LEARNING_RATE, default=LEARNING_RATE)
63 | parser.add_argument('--num-episodes', type=int, help='The number of '
64 | 'episodes to run in the given environment. Default: '
65 | '%s' % NUM_EPISODES, default=NUM_EPISODES)
66 | parser.add_argument('--render', action='store_true', help='Specify to '
67 | 'render a visualization in another window of the '
68 | 'learning process. Note that a Desktop Environment is '
69 | 'required for visualization. Rendering scenes will '
70 | 'lower the learning speed. Default: False')
71 | parser.add_argument('--target-update-frequency', type=int, help='Specify '
72 | 'the number of iterations to run prior to updating '
73 | 'target network with the primary network\'s weights. '
74 | 'Default: %s' % TARGET_UPDATE_FREQUENCY,
75 | default=TARGET_UPDATE_FREQUENCY)
76 | return parser.parse_args()
77 |
--------------------------------------------------------------------------------
/core/constants.py:
--------------------------------------------------------------------------------
1 | BATCH_SIZE = 32
2 | ENVIRONMENT = 'PongNoFrameskip-v4'
3 | EPSILON_START = 1.0
4 | EPSILON_FINAL = 0.01
5 | EPSILON_DECAY = 100000
6 | GAMMA = 0.99
7 | INITIAL_LEARNING = 10000
8 | LEARNING_RATE = 1e-4
9 | MEMORY_CAPACITY = 20000
10 | NUM_EPISODES = 10000
11 | TARGET_UPDATE_FREQUENCY = 1000
12 |
--------------------------------------------------------------------------------
/core/helpers.py:
--------------------------------------------------------------------------------
1 | import atari_py as ap
2 | import math
3 | import numpy as np
4 | import re
5 | import torch
6 | from .model import CNNDQN
7 | from torch import FloatTensor, LongTensor
8 | from torch.autograd import Variable
9 |
10 |
11 | class Range:
12 | def __init__(self, start, end):
13 | self._start = start
14 | self._end = end
15 |
16 | def __eq__(self, input_num):
17 | return self._start <= input_num <= self._end
18 |
19 |
20 | def compute_td_loss(model, target_net, batch, gamma, device):
21 | state, action, reward, next_state, done = batch
22 |
23 | state = Variable(FloatTensor(np.float32(state))).to(device)
24 | next_state = Variable(FloatTensor(np.float32(next_state))).to(device)
25 | action = Variable(LongTensor(action)).to(device)
26 | reward = Variable(FloatTensor(reward)).to(device)
27 | done = Variable(FloatTensor(done)).to(device)
28 |
29 | q_values = model(state)
30 | next_q_values = target_net(next_state)
31 |
32 | q_value = q_values.gather(1, action.unsqueeze(-1)).squeeze(-1)
33 | next_q_value = next_q_values.max(1)[0]
34 | expected_q_value = reward + gamma * next_q_value * (1 - done)
35 |
36 | loss = (q_value - Variable(expected_q_value.data).to(device)).pow(2).mean()
37 | loss.backward()
38 |
39 |
40 | def update_epsilon(episode, args):
41 | eps_final = args.epsilon_final
42 | eps_start = args.epsilon_start
43 | decay = args.epsilon_decay
44 | epsilon = eps_final + (eps_start - eps_final) * \
45 | math.exp(-1 * ((episode + 1) / decay))
46 | return epsilon
47 |
48 |
49 | def set_device(force_cpu):
50 | device = torch.device('cpu')
51 | if not force_cpu and torch.cuda.is_available():
52 | device = torch.device('cuda')
53 | return device
54 |
55 |
56 | def load_model(checkpoint, model, target_model, device):
57 | model.load_state_dict(torch.load(checkpoint, map_location=device))
58 | target_model.load_state_dict(model.state_dict())
59 | return model, target_model
60 |
61 |
62 | def initialize_models(env, device, checkpoint):
63 | model = CNNDQN(env.observation_space.shape,
64 | env.action_space.n).to(device)
65 | target_model = CNNDQN(env.observation_space.shape,
66 | env.action_space.n).to(device)
67 | if checkpoint:
68 | model, target_model = load_model(checkpoint, model, target_model,
69 | device)
70 | return model, target_model
71 |
72 |
73 | def camel_to_snake_case(string):
74 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string)
75 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
76 |
77 |
78 | def is_atari(environment):
79 | for field in ['ramDeterministic', 'ramNoFrameSkip', 'NoFrameskip',
80 | 'Deterministic', 'ram']:
81 | environment = environment.replace(field, '')
82 | environment = re.sub(r'-v\d+', '', environment)
83 | environment = camel_to_snake_case(environment)
84 | if environment in ap.list_games():
85 | return True
86 | else:
87 | return False
88 |
--------------------------------------------------------------------------------
/core/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from random import random, randrange
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class CNNDQN(nn.Module):
9 | def __init__(self, input_shape, num_actions):
10 | super(CNNDQN, self).__init__()
11 | self._input_shape = input_shape
12 | self._num_actions = num_actions
13 |
14 | self.features = nn.Sequential(
15 | nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
16 | nn.ReLU(),
17 | nn.Conv2d(32, 64, kernel_size=4, stride=2),
18 | nn.ReLU(),
19 | nn.Conv2d(64, 64, kernel_size=3, stride=1),
20 | nn.ReLU()
21 | )
22 |
23 | self.fc = nn.Sequential(
24 | nn.Linear(self.feature_size, 512),
25 | nn.ReLU(),
26 | nn.Linear(512, num_actions)
27 | )
28 |
29 | def forward(self, x):
30 | x = self.features(x).view(x.size()[0], -1)
31 | return self.fc(x)
32 |
33 | @property
34 | def feature_size(self):
35 | x = self.features(torch.zeros(1, *self._input_shape))
36 | return x.view(1, -1).size(1)
37 |
38 | def act(self, state, device, epsilon=0.0):
39 | if random() > epsilon:
40 | state = torch.FloatTensor(np.float32(state)) \
41 | .unsqueeze(0).to(device)
42 | q_value = self.forward(state)
43 | action = q_value.max(1)[1].item()
44 | else:
45 | action = randrange(self._num_actions)
46 | return action
47 |
--------------------------------------------------------------------------------
/core/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from random import sample
3 | from collections import deque
4 |
5 |
6 | class ReplayBuffer:
7 | def __init__(self, capacity):
8 | self._buffer = deque(maxlen=capacity)
9 |
10 | def push(self, state, action, reward, next_state, done):
11 | self._buffer.append((state, action, reward, next_state, done))
12 |
13 | def sample(self, batch_size):
14 | indices = np.random.choice(len(self._buffer),
15 | batch_size,
16 | replace=False)
17 | batch = zip(*[self._buffer[i] for i in indices])
18 | state, action, reward, next_state, done = batch
19 | return (np.array(state),
20 | np.array(action),
21 | np.array(reward, dtype=np.float32),
22 | np.array(next_state),
23 | np.array(done, dtype=np.uint8))
24 |
25 | def __len__(self):
26 | return len(self._buffer)
27 |
--------------------------------------------------------------------------------
/core/train_information.py:
--------------------------------------------------------------------------------
1 | class TrainInformation:
2 | def __init__(self):
3 | self._average = 0.0
4 | self._best_reward = -float('inf')
5 | self._best_average = -float('inf')
6 | self._rewards = []
7 | self._average_range = 100
8 | self._index = 0
9 |
10 | @property
11 | def best_reward(self):
12 | return self._best_reward
13 |
14 | @property
15 | def best_average(self):
16 | return self._best_average
17 |
18 | @property
19 | def average(self):
20 | avg_range = self._average_range * -1
21 | return sum(self._rewards[avg_range:]) / len(self._rewards[avg_range:])
22 |
23 | @property
24 | def index(self):
25 | return self._index
26 |
27 | def _update_best_reward(self, episode_reward):
28 | if episode_reward > self.best_reward:
29 | self._best_reward = episode_reward
30 |
31 | def _update_best_average(self):
32 | if self.average > self.best_average:
33 | self._best_average = self.average
34 | return True
35 | return False
36 |
37 | def update_rewards(self, episode_reward):
38 | self._rewards.append(episode_reward)
39 | self._update_best_reward(episode_reward)
40 | return self._update_best_average()
41 |
42 | def update_index(self):
43 | self._index += 1
44 |
--------------------------------------------------------------------------------
/core/wrappers.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from collections import deque
4 | from core.helpers import is_atari
5 | from gym import make, ObservationWrapper, wrappers, Wrapper
6 | from gym.spaces import Box
7 |
8 |
9 | class ClassicControl(Wrapper):
10 | def __init__(self, env, atari):
11 | super(ClassicControl, self).__init__(env)
12 | self._atari = atari
13 |
14 | def reset(self):
15 | if not self._atari:
16 | self.env.reset()
17 | return self.env.render(mode='rgb_array')
18 | else:
19 | return self.env.reset()
20 |
21 |
22 | class FrameDownsample(ObservationWrapper):
23 | def __init__(self, env):
24 | super(FrameDownsample, self).__init__(env)
25 | self.observation_space = Box(low=0,
26 | high=255,
27 | shape=(84, 84, 1),
28 | dtype=np.uint8)
29 | self._width = 84
30 | self._height = 84
31 |
32 | def observation(self, observation):
33 | frame = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
34 | frame = cv2.resize(frame,
35 | (self._width, self._height),
36 | interpolation=cv2.INTER_AREA)
37 | return frame[:, :, None]
38 |
39 |
40 | class MaxAndSkipEnv(Wrapper):
41 | def __init__(self, env, atari, skip=4):
42 | super(MaxAndSkipEnv, self).__init__(env)
43 | self._obs_buffer = deque(maxlen=2)
44 | self._skip = skip
45 | self._atari = atari
46 |
47 | def step(self, action):
48 | total_reward = 0.0
49 | done = None
50 | for _ in range(self._skip):
51 | obs, reward, done, info = self.env.step(action)
52 | if not self._atari:
53 | obs = self.env.render(mode='rgb_array')
54 | self._obs_buffer.append(obs)
55 | total_reward += reward
56 | if done:
57 | break
58 | max_frame = np.max(np.stack(self._obs_buffer), axis=0)
59 | return max_frame, total_reward, done, info
60 |
61 | def reset(self):
62 | self._obs_buffer.clear()
63 | obs = self.env.reset()
64 | self._obs_buffer.append(obs)
65 | return obs
66 |
67 |
68 | class FireResetEnv(Wrapper):
69 | def __init__(self, env):
70 | Wrapper.__init__(self, env)
71 | if len(env.unwrapped.get_action_meanings()) < 3:
72 | raise ValueError('Expected an action space of at least 3!')
73 |
74 | def reset(self, **kwargs):
75 | self.env.reset(**kwargs)
76 | obs, _, done, _ = self.env.step(1)
77 | if done:
78 | self.env.reset(**kwargs)
79 | obs, _, done, _ = self.env.step(2)
80 | if done:
81 | self.env.reset(**kwargs)
82 | return obs
83 |
84 | def step(self, action):
85 | return self.env.step(action)
86 |
87 |
88 | class FrameBuffer(ObservationWrapper):
89 | def __init__(self, env, num_steps, dtype=np.float32):
90 | super(FrameBuffer, self).__init__(env)
91 | obs_space = env.observation_space
92 | self._dtype = dtype
93 | self.observation_space = Box(obs_space.low.repeat(num_steps, axis=0),
94 | obs_space.high.repeat(num_steps, axis=0),
95 | dtype=self._dtype)
96 |
97 | def reset(self):
98 | self.buffer = np.zeros_like(self.observation_space.low,
99 | dtype=self._dtype)
100 | return self.observation(self.env.reset())
101 |
102 | def observation(self, observation):
103 | self.buffer[:-1] = self.buffer[1:]
104 | self.buffer[-1] = observation
105 | return self.buffer
106 |
107 |
108 | class ImageToPyTorch(ObservationWrapper):
109 | def __init__(self, env):
110 | super(ImageToPyTorch, self).__init__(env)
111 | obs_shape = self.observation_space.shape
112 | self.observation_space = Box(low=0.0,
113 | high=1.0,
114 | shape=(obs_shape[::-1]),
115 | dtype=np.float32)
116 |
117 | def observation(self, observation):
118 | return np.moveaxis(observation, 2, 0)
119 |
120 |
121 | class NormalizeFloats(ObservationWrapper):
122 | def observation(self, obs):
123 | return np.array(obs).astype(np.float32) / 255.0
124 |
125 |
126 | def wrap_environment(environment, monitor=False):
127 | env = make(environment)
128 | atari = is_atari(environment)
129 | env = ClassicControl(env, atari)
130 | env = MaxAndSkipEnv(env, atari)
131 | try:
132 | if 'FIRE' in env.unwrapped.get_action_meanings():
133 | env = FireResetEnv(env)
134 | except AttributeError:
135 | # Some environments, such as the classic control environments, don't
136 | # have a get_action_meanings method. Since these environments don't
137 | # contain a 'FIRE' action, this wrapper is irrelevant and can be safely
138 | # ignored if the attribute doesn't exist.
139 | pass
140 | env = FrameDownsample(env)
141 | env = ImageToPyTorch(env)
142 | env = FrameBuffer(env, 4)
143 | env = NormalizeFloats(env)
144 | if monitor:
145 | env = wrappers.Monitor(env, 'videos', force=True)
146 | return env
147 |
--------------------------------------------------------------------------------
/media/pongnoframeskip-v4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/roclark/openai-gym-pytorch/0a8fbd94070877d6dbb14c1733c57af27905df37/media/pongnoframeskip-v4.gif
--------------------------------------------------------------------------------
/models/PongNoFrameskip-v4.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/roclark/openai-gym-pytorch/0a8fbd94070877d6dbb14c1733c57af27905df37/models/PongNoFrameskip-v4.dat
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | atari-py==0.2.6
2 | cloudpickle==1.6.0
3 | future==0.18.2
4 | gym==0.19.0
5 | numpy==1.21.2
6 | opencv-python==4.5.3.56
7 | Pillow==8.3.2
8 | pyglet==1.5.15
9 | scipy==1.7.1
10 | six==1.16.0
11 | torch==1.9.0
12 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from core.helpers import (initialize_models,
3 | set_device)
4 | from core.wrappers import wrap_environment
5 |
6 |
7 | def parse_args():
8 | parser = ArgumentParser()
9 | parser.add_argument('--checkpoint', type=str, help='Specify the trained '
10 | 'model to test.')
11 | parser.add_argument('--environment', type=str, help='Specify the '
12 | 'environment to test against.',
13 | default='PongNoFrameskip-v4')
14 | parser.add_argument('--force-cpu', action='store_true', help='Force '
15 | 'computation to be done on the CPU. This may result '
16 | 'in longer processing time.')
17 | return parser.parse_args()
18 |
19 |
20 | def main():
21 | args = parse_args()
22 | env = wrap_environment(args.environment, monitor=True)
23 | device = set_device(args.force_cpu)
24 | model, target_model = initialize_models(env, device, args.checkpoint)
25 |
26 | done = False
27 | state = env.reset()
28 | episode_reward = 0.0
29 |
30 | while not done:
31 | action = model.act(state, device)
32 | next_state, reward, done, _ = env.step(action)
33 | episode_reward += reward
34 | state = next_state
35 |
36 | print(f'Episode Reward: {round(episode_reward, 3)}')
37 | env.close()
38 |
39 |
40 | if __name__ == '__main__':
41 | main()
42 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from core.argparser import parse_args
2 | from core.helpers import (compute_td_loss,
3 | initialize_models,
4 | is_atari,
5 | set_device,
6 | update_epsilon)
7 | from core.replay_buffer import ReplayBuffer
8 | from core.train_information import TrainInformation
9 | from core.wrappers import wrap_environment
10 |
11 | from torch import save
12 | from torch.optim import Adam
13 |
14 |
15 | def update_graph(model, target_model, optimizer, replay_buffer, args, device,
16 | info):
17 | if len(replay_buffer) > args.initial_learning:
18 | if not info.index % args.target_update_frequency:
19 | target_model.load_state_dict(model.state_dict())
20 | optimizer.zero_grad()
21 | batch = replay_buffer.sample(args.batch_size)
22 | compute_td_loss(model, target_model, batch, args.gamma, device)
23 | optimizer.step()
24 |
25 |
26 | def complete_episode(model, environment, info, episode_reward, episode,
27 | epsilon):
28 | new_best = info.update_rewards(episode_reward)
29 | if new_best:
30 | print('New best average reward of %s! Saving model'
31 | % round(info.best_average, 3))
32 | save(model.state_dict(), '%s.dat' % environment)
33 | print('Episode %s - Reward: %s, Best: %s, Average: %s '
34 | 'Epsilon: %s' % (episode, episode_reward, info.best_reward,
35 | round(info.average, 3), round(epsilon, 4)))
36 |
37 |
38 | def run_episode(env, model, target_model, optimizer, replay_buffer, args,
39 | device, info, episode):
40 | episode_reward = 0.0
41 | state = env.reset()
42 |
43 | while True:
44 | epsilon = update_epsilon(info.index, args)
45 | action = model.act(state, device, epsilon)
46 | if args.render:
47 | env.render()
48 | next_state, reward, done, _ = env.step(action)
49 | replay_buffer.push(state, action, reward, next_state, done)
50 | state = next_state
51 | episode_reward += reward
52 | info.update_index()
53 | update_graph(model, target_model, optimizer, replay_buffer, args,
54 | device, info)
55 | if done:
56 | complete_episode(model, args.environment, info, episode_reward,
57 | episode, epsilon)
58 | break
59 |
60 |
61 | def train(env, model, target_model, optimizer, replay_buffer, args, device):
62 | info = TrainInformation()
63 |
64 | for episode in range(args.num_episodes):
65 | run_episode(env, model, target_model, optimizer, replay_buffer, args,
66 | device, info, episode)
67 |
68 |
69 | def main():
70 | args = parse_args()
71 | env = wrap_environment(args.environment)
72 | device = set_device(args.force_cpu)
73 | model, target_model = initialize_models(env, device, args.checkpoint)
74 | optimizer = Adam(model.parameters(), lr=args.learning_rate)
75 | replay_buffer = ReplayBuffer(args.buffer_capacity)
76 | train(env, model, target_model, optimizer, replay_buffer, args, device)
77 | env.close()
78 |
79 |
80 | if __name__ == '__main__':
81 | main()
82 |
--------------------------------------------------------------------------------