├── .gitignore ├── LICENSE ├── README.md ├── amtlb ├── __init__.py ├── common.py └── transfer_benchmark.py ├── runtests ├── setup.py └── tests └── test_common.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | .cache/ 4 | requirements.txt 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Josh Kuhn 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This repository is deprecated! 2 | Further development is continuing at [AI-ON/Multitask-and-Transfer-Learning](https://github.com/AI-ON/Multitask-and-Transfer-Learning) (we just moved the code to the AMTLB subdirectory). Please join us there! 3 | 4 | # Atari Multitask & Transfer Learning Benchmark (AMTLB) 5 | 6 | [![Join the chat at https://gitter.im/ai-open-network/multitask_and_transfer_learning](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/ai-open-network/multitask_and_transfer_learning) 7 | 8 | This is a library to test how a reinforcement learning architecture 9 | performs on all Atari games in OpenAI's gym. It performs two kinds of 10 | tests, one for transfer learning and one for multitask 11 | learning. Crucially, this benchmark tests how an *architecture* 12 | performs. Training the architecture on games is part of the test, so 13 | it does not aim to test how well a pre-trained network does, but 14 | rather how quickly an architecture can learn to play the games (but 15 | see note below for details). 16 | 17 | Throughout this document, we'll refer to **architecture** as the 18 | system being tested irrespective of individual weight values. An 19 | **instance** will refer to the architecture instantiated with a 20 | particular set of weights (either trained or untrained). The benchmark 21 | trains several instances of the architecture to infer how well the 22 | architecture itself learns. 23 | 24 | ## Transfer learning benchmark 25 | 26 | The goal of the transfer learning benchmark is to see quickly an 27 | architecture can learn a new game it has never seen before, just using 28 | what it's learned from other games (so, how much knowledge is 29 | transferred from one game to another). 30 | 31 | The way it works is first it creates a fresh instance of the 32 | architecture (call it instance `F`), and then measures its score over 33 | time as it learns on ten million frames of a random Atari game (call 34 | it game `X`). Next, we create another fresh instance of the 35 | architecture, but this one we train on bunch of other Atari games (but 36 | not on game `X` itself), we'll call it instance `F_b`. Finally, we let 37 | `F_b` play ten million frames of game `X` and measure its score over 38 | time. 39 | 40 | For each time frame, we take the cumulative score of `A` and the 41 | cumulative score of `B` and get the ratio `r = 1 - B / A`. 42 | 43 | * If `r` negative, then the architecture actually got worse from seeing other Atari games. 44 | * If `r` is about 0, then the architecture didn't really transfer knowledge well from having seen the other Atari games. 45 | * If `r` positive, then we're in the sweet spot and the architecture is successfully learning to play a new Atari game from other games. 46 | 47 | We're not quite done though, because really this is just a measure of 48 | how well the architecture did on game `X`. Some games may transfer 49 | knowledge well, and other games may be so unlike other Atari games 50 | that it's hard to transfer much knowledge. What we could do to get 51 | around this is to then do the process above for each game in the 52 | entire collection and average the scores. 53 | 54 | This would take a really long time though, so as a compromise, instead 55 | of just holding out one game in the above process, we hold out about 56 | 30% of all games as tests, and keep 70% of games for training. We then 57 | do the above process to test, except we create a fresh instance for 58 | each test game, and we save the state of network after it's been 59 | trained on the training set of games. We reset it to that "freshly 60 | trained" state before each test game (so it doesn't learn from the 61 | other testing games). Then we shuffle the training and testing sets up 62 | randomly and do this a few more times from scratch. 63 | 64 | As an example, lets say there are five games `S`, `U`, `V`, `X`, and `Y`. 65 | 66 | We'll measure the performance of a fresh instance on each of the games 67 | for 10 million frames, getting `F(S)`, `F(U)`, `F(V)`, `F(X)`, and `F(Y)` 68 | (`F` is for "fresh"). 69 | 70 | Then for the first trial, we'll randomly select `X` and `Y` as the test games. 71 | We'll train a new instance `F` on `S`, `U`, and `V` and save its weights as `F_suv`. 72 | Then we train `F_suv` on `X` for ten million frames, getting `F_suv(X)`. 73 | Then we train `F_suv` on `Y` for ten million frames, getting `F_suv(Y)`. 74 | 75 | To get the score for the first trial, we average their ratios: 76 | 77 | r_1 = (F_suv(X)/F(X) + F_suv(Y)/F(Y)) / 2 78 | 79 | Now we do a couple more trials, maybe using `S` and `V` as the test 80 | games, then maybe for the third trial `U` and `S` as the tests. 81 | 82 | r_2 = avg(F_uxy(S)/F(S) , F_uxy(V)/F(V)) 83 | r_3 = avg(F_vxy(U)/F(U) , F_vxy(S)/F(S)) 84 | 85 | Finally, we average the scores from all three trials: 86 | 87 | r = avg(r_1, r_2, r_3) 88 | 89 | `r(t)` is the final transfer learning score for the architecture 90 | for each time step, though we may simply use `r(t_max)` as a summary. 91 | 92 | ## Multitask learning benchmark 93 | 94 | The multitask learning benchmark is most similar to existing 95 | benchmarking that's been done on Atari games, in that we are concerned 96 | with an absolute score on the games. Since absolute scores aren't 97 | comparable across games, we have to keep each game's score separate in 98 | the results rather than aggregating them. 99 | 100 | How it works is we once again train a fresh instance of the 101 | architecture for 10 million frames on each game separately, obtaining 102 | baseline scores for each game. These instances we call the 103 | "specialists" since they're trained on only one game a piece. 104 | 105 | Then we train an instance of the architecture on every game in random 106 | order, so that the new architecture has seen 10 million frames of 107 | every game. This instance we call the "generalist". 108 | 109 | We then compare the generalist's cumulative scores for each frame 110 | against the specialists' scores for the same game and time step. On 111 | the multitask benchmark, we're looking for the generalist to match the 112 | scores of the specialists in the best case. Since the architecture is 113 | the same, the presumption is that the specialist will nearly always 114 | have better performance, and we can only minimize how much the 115 | generalist loses. Though in practice, if the architecture transfers 116 | knowledge well, the generalist may actually outperform the 117 | specialists in some cases. 118 | 119 | On the multitask benchmark, we also output the absolute scores for 120 | comparison with other benchmarks etc, though note that it should only 121 | be compared with scores that were obtained under the 10 million frame 122 | limit. 123 | 124 | ## Note on pre-training 125 | 126 | The benchmark doesn't have a strong opinion about how the weights are 127 | initialized in a fresh instance of an architecture. It's reasonable to 128 | not initialize the weights randomly, instead opting to come with some 129 | prior training so that (for example) a deep convolutional network 130 | doesn't waste part of its precious 10 million frames learning to 131 | recognize edges and shapes etc. 132 | 133 | The transfer learning benchmark is somewhat robust to this kind of 134 | pre-training since it relies on measuring the amount of improvement in 135 | the architecture before and after it is able to see other games. If a 136 | "fresh" instance already has extensive training on Atari games 137 | beforehand, we should expect this to simply eat into the improvement. 138 | 139 | Nevertheless, as a rule of thumb, it's best if a fresh instance of an 140 | architecture does not include any prior training on Atari games or 141 | images from Atari games, to eliminate confusion. 142 | -------------------------------------------------------------------------------- /amtlb/__init__.py: -------------------------------------------------------------------------------- 1 | from common import * 2 | import transfer_benchmark 3 | -------------------------------------------------------------------------------- /amtlb/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | from abc import ABCMeta, abstractmethod 4 | 5 | GAMES = ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 6 | 'asteroids', 'atlantis', 'bank_heist', 'battle_zone', 7 | 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 8 | 'carnival', 'centipede', 'chopper_command', 'crazy_climber', 9 | 'demon_attack', 'double_dunk', 'elevator_action', 'enduro', 10 | 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar', 11 | 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 12 | 'kung_fu_master', 'montezuma_revenge', 'ms_pacman', 13 | 'name_this_game', 'phoenix', 'pitfall', 'pong', 'pooyan', 14 | 'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 15 | 'seaquest', 'skiing', 'solaris', 'space_invaders', 'star_gunner', 16 | 'tennis', 'time_pilot', 'tutankham', 'up_n_down', 'venture', 17 | 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon'] 18 | 19 | NUM_GAMES = len(GAMES) 20 | 21 | # The names above are easy to modify and read if we keep them 22 | # separate, but to load the gym you need the name in camelcase with 23 | # the version 24 | def game_name(raw_name): 25 | return ''.join([g.capitalize() for g in game.split('_')]) + '-v0' 26 | 27 | GAME_NAMES = [game_name(game) for game in GAMES] 28 | 29 | 30 | 31 | class Agent(object): 32 | 33 | __metaclass__ = ABCMeta # ensures subclasses implement everything 34 | 35 | @abstractmethod 36 | def __call__(self, observation, reward): 37 | '''Called every time a new observation is available. 38 | Should return an integer from 0 to 17 inclusive 39 | ''' 40 | 41 | @abstractmethod 42 | def clone(self): 43 | '''Returns a deep copy of the agent and its weights.''' 44 | 45 | @classmethod 46 | @abstractmethod 47 | def load(cls, filename): 48 | '''Loads an agent (with weights) from a filename''' 49 | 50 | @abstractmethod 51 | def save(self, filename): 52 | '''Saves the agent (with weights) to a filename''' 53 | 54 | 55 | class RandomAgent(Agent): 56 | '''Simple random agent that has no state''' 57 | 58 | def __call__(self, observation, reward): 59 | # The benchmark maps invalid actions to No-op (action 0) 60 | return random.randint(0, 17) 61 | 62 | def clone(self): 63 | return self # RandomAgent has no state 64 | 65 | @classmethod 66 | def load(cls, filename): 67 | return cls() 68 | 69 | def save(self, filename): 70 | pass 71 | 72 | 73 | class BenchmarkParms(object): 74 | def __init__(self, 75 | num_folds=5, 76 | max_rounds_w_no_reward=10000, 77 | seed=None, 78 | max_rounds_per_game=100000, 79 | game_names=GAME_NAMES, 80 | ): 81 | self.num_folds = num_folds 82 | self.max_rounds_w_no_reward = max_rounds_w_no_reward 83 | self.seed = random.randint(0, 2**64-1) if seed is None else seed 84 | self.max_rounds_per_game = max_rounds_per_game 85 | self.game_names = game_names 86 | 87 | num_games = len(game_names) 88 | 89 | games = set(game_names) 90 | fold_size = num_games // num_folds 91 | remainder = num_games % num_folds 92 | self.folds = [None] * num_folds 93 | 94 | for i in range(num_folds): 95 | if i < remainder: 96 | # distribute the remainder games evenly among the folds 97 | self.folds[i] = random.sample(games, fold_size + 1) 98 | else: 99 | self.folds[i] = random.sample(games, fold_size) 100 | games -= set(self.folds[i]) 101 | 102 | assert(len(games) == 0) 103 | 104 | def save(self, filename=None): 105 | '''Save the TestPlan to a file''' 106 | filedata = { 107 | 'num_folds': self.num_folds, 108 | 'folds': self.folds, 109 | 'seed': self.seed, 110 | 'max_rounds_w_no_reward': self.max_rounds_w_no_reward, 111 | 'max_rounds_per_game': self.max_rounds_per_game, 112 | 'game_names': self.game_names, 113 | } 114 | with open(filename, 'w') as savefile: 115 | json.dump(filedata, savefile, sort_keys=True, indent=True) 116 | 117 | @staticmethod 118 | def load_from_file(filename): 119 | '''Load a BenchmarkParms from a file''' 120 | with open(filename, 'r') as savefile: 121 | filedata = json.load(savefile) 122 | parms = BenchmarkParms() 123 | # Just overwrite the original fields. A little wasteful but w/e 124 | parms.folds = filedata['folds'] 125 | parms.num_folds = len(parms.folds) 126 | parms.max_rounds_w_no_reward = filedata['max_rounds_w_no_reward'] 127 | parms.seed = filedata['seed'] 128 | parms.max_rounds_per_game = filedata['max_rounds_per_game'] 129 | parms.game_names = filedata['game_names'] 130 | return parms 131 | -------------------------------------------------------------------------------- /amtlb/transfer_benchmark.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import random 4 | import json 5 | from math import ceil 6 | import string 7 | from datetime import datetime 8 | from collections import defaultdict 9 | 10 | import gym 11 | 12 | from common import Agent, RandomAgent, GAME_NAMES, NUM_GAMES, BenchmarkParms 13 | 14 | 15 | def fold_name(num): 16 | name = string.ascii_uppercase[num % 26] 17 | if num > 25: 18 | name += str(num - 25) 19 | return name 20 | 21 | 22 | class BenchmarkResult(object): 23 | def __init__(self, agent, game=None): 24 | self.agent = agent 25 | self.rewards = [] # list of int rewards, index is round # 26 | self.dones = [] # list of rounds where the agent died 27 | self.games = [] # list of what games were played each round 28 | if game is not None: 29 | self.record_game(game, 0) 30 | # TODO: make serializable 31 | 32 | def record_reward(self, reward): 33 | self.rewards.append(reward) 34 | 35 | def record_done(self, round): 36 | self.dones.append(round) 37 | 38 | def record_game(self, game_name, round): 39 | self.games.append((game_name, round)) 40 | 41 | 42 | class EnvMaker(object): 43 | '''Mixin class''' 44 | def create_env(self, game_name): 45 | env = gym.make(game_name) 46 | # Ensure all being tested on this game get the same seed to 47 | # reduce variability. Crucially, this means an agent can't see 48 | # the same environment more than once! 49 | env.seed(self.parms.seed) 50 | return env 51 | 52 | 53 | class TestRun(EnvMaker): 54 | def __init__(self, agent, game_name, parms): 55 | self.agent = agent 56 | self.parms = parms 57 | self.game = self.create_env(game_name) 58 | self.result = BenchmarkResult(agent, game_name) 59 | 60 | def __call__(self): 61 | '''Test an agent on the given game''' 62 | observation = self.game.reset() 63 | reward = 0 64 | for round_num in xrange(self.parms.max_rounds_per_game): 65 | action = self.agent(observation, reward) 66 | if action >= self.game.action_space.n: 67 | action = 0 # Map invalid actions to no-op 68 | observation, reward, done, _ = self.game.step(action) 69 | self.result.record_reward(reward) 70 | if done: 71 | observation = self.game.reset() 72 | self.result.record_done(round_num) 73 | return self.result 74 | 75 | 76 | class TransferBenchmark(object): 77 | '''Benchmark for testing knowledge transfer. 78 | 79 | Uses k-fold cross-validation to test an agent's performance on a 80 | new game. Each fold is a set of games held out from the training 81 | set. All other folds except that one are the training 82 | set. Performance is compared as a ratio between the cumulative 83 | score over time of a fresh agent vs. the cumulative score over 84 | time of an agent trained on the training set. Both agents will be 85 | seeing the game in the test set for the first time, but one of the 86 | agents will have preparation, and the other will not. The ratio 87 | measures how much that preparation helps. 88 | 89 | For the purposes of this class, the term `game_agent` is used to 90 | denote a fresh agent who has no preparation, but has trained on a 91 | particular game. Game agents are identified by the name of their 92 | game. A `fold_agent` is an agent who has been trained up on all 93 | the games in the training set, i.e. everything but the test game 94 | fold. The folds are indexed by an integer, so the fold_agents are 95 | also indexed by an integer indicating which fold is their test 96 | set (all other folds are implicitly their training set). 97 | 98 | In practice, we only care about the results of the game agents, 99 | and don't need to keep the trained agent around once we have their 100 | scores, since they are intended to be fresh. We do need to keep 101 | around the fold agents' trained agent since we should start from 102 | the same baseline on each test game. In other words, we shouldn't 103 | allow the agent to train on the other test games first, since then 104 | results would be dependent on the order in which we did tested 105 | games in the fold. So we checkpoint the agent at the time it 106 | finishes all of its training, and reset to that point before 107 | testing on each game in the test fold. 108 | 109 | ''' 110 | def __init__(self, parms, AgentClass, dir=None): 111 | self.parms = parms # BenchmarkParms 112 | self.AgentClass = AgentClass 113 | self.untrained_agent = AgentClass() 114 | self.game_agents = {} # indexed by game name, since 1 per game 115 | self.game_results = {} # benchmarks for each game agent 116 | self.fold_agents = [] # indexed by fold number 117 | self.fold_results = [] # Each fold gets a result for each game 118 | self.dir = dir or self.default_dir() 119 | # TODO: make dir recursively 120 | # TODO: add checkpoints, benchmark is ephemeral now 121 | 122 | def test_set(self, test_index): 123 | '''Copy the test set of game names from the BenchmarkParms''' 124 | return set(self.parms.folds[test_index]) 125 | 126 | def training_set(self, test_index): 127 | '''Aggregate the training set from the current BenchmarParms''' 128 | return {x 129 | for i, fold in enumerate(self.parms.folds) 130 | for x in fold if i != test_index} 131 | 132 | def default_dir(self): 133 | '''A reasonable default directory to store benchmark results in''' 134 | time = datetime.now().strftime('%Y-%m-%d_%H.%M') 135 | return 'benchmarks/' + time + '/' 136 | 137 | def game_agent_filename(self, game_name): 138 | '''Constructs a save filename for a game agent''' 139 | return self.dir + 'game_agent_' + game_name 140 | 141 | def fold_agent_filename(self, fold_num): 142 | '''Constructs a save filename for a fold agent''' 143 | return self.dir + 'fold_agent_' + str(fold_num) 144 | 145 | def tested_agent_filename(self, fold_num, game_name): 146 | '''Constructs a filename for an fold agent tested on a particular 147 | game''' 148 | return self.dir + 'tested_agent_' + str(fold_num) + '_' + game_name 149 | 150 | def ensure_game_agents(self): 151 | '''Ensures there is a game agent for every game. 152 | 153 | Game agents don't care about folds, and this function checks 154 | if an agent already exists, so there's no harm in running this 155 | function multiple times. 156 | ''' 157 | for game_name in GAME_NAMES: 158 | if game_name not in self.game_agents: 159 | game_agent = self.untrained_agent.clone() 160 | self.game_results[game_name] = self.test(game_agent, game_name) 161 | game_agent.save(self.game_agent_filename(game_name)) 162 | self.game_agents[game_name] = game_agent 163 | 164 | def test(self, game_agent, game_name): 165 | return TestRun(game_agent, game_name, self.parms)() 166 | 167 | def train(self, fold_agent, training_set): 168 | return TrainingRun(fold_agent, training_set, self.parms)() 169 | 170 | 171 | def do_folds(self): 172 | '''Runs through each fold, and trains an agent for each set. 173 | 174 | The outer loop of this function should run k times. 175 | ''' 176 | self.ensure_game_agents() 177 | 178 | for fold_num in range(len(self.parms.folds)): 179 | training_set = self.training_set(fold_num) 180 | fold_agent = self.untrained_agent.clone() 181 | self.fold_agents[fold_num] = fold_agent 182 | self.train(fold_agent, training_set) 183 | fold_agent.save(self.fold_agent_filename(fold_num)) 184 | 185 | test_set = self.test_set(fold_num) 186 | self.fold_results[fold_num] = fold_results = {} 187 | 188 | for game_name in test_set: 189 | tested_agent = fold_agent.clone() 190 | fold_results[game_name] = self.test(tested_agent, game_name) 191 | tested_agent.save( 192 | self.tested_agent_filename(fold_num, game_name)) 193 | 194 | 195 | class TrainingRun(EnvMaker): 196 | def __init__(self, agent, training_set, parms): 197 | self.agent = agent 198 | self.training_set = training_set 199 | self.parms = parms 200 | self.envs = defaultdict(self.create_env) 201 | self.game_rounds_left = {game: self.max_test_game_rounds 202 | for game in training_set} 203 | self.trace_result = BenchmarkResult(agent) 204 | 205 | def total_rounds_left(self): 206 | return sum(self.game_rounds_left.values()) 207 | 208 | def sample_env(self): 209 | '''Samples from remaining games in the training set, inversely 210 | proportional to how much they've already been played. It has 211 | the added benefit of not sampling games that have already 212 | exhausted their round quota. It also doesn't care what order 213 | the items of self.game_rounds_left arrive in. 214 | ''' 215 | 216 | total_rounds_left = self.total_rounds_left() 217 | threshold = random.randint(0, total_rounds_left) 218 | cumulative_sum = 0 219 | for game, rounds_left in self.game_rounds_left.items(): 220 | cumulative_sum += rounds_left 221 | if cumulative_sum >= threshold: 222 | break 223 | return game, self.envs[game] # Will lazily initialize env 224 | 225 | def keep_playing(self, game_name, done, no_reward_turns): 226 | return 227 | 228 | def __call__(self): 229 | round = 0 230 | while self.total_rounds_left() > 0: 231 | game_name, game = self.sample_env() 232 | self.result.record_game(game_name, round) 233 | observation = game.reset() 234 | reward = 0 235 | done = False 236 | no_reward_turns = 0 237 | while (self.game_rounds_left[game_name] > 0 and 238 | not done and 239 | no_reward_turns < self.parms.max_rounds_w_no_reward): 240 | round += 1 241 | action = self.agent(observation, reward) 242 | if action >= game.action_space.n: 243 | action = 0 # Map invalid actions to no-op 244 | observation, reward, done, _ = game.step(action) 245 | self.result.record_reward(reward) 246 | no_reward_turns = 0 if reward > 0 else no_reward_turns + 1 247 | self.game_rounds_left[game_name] -= 1 248 | self.result.record_done(round) 249 | return self.result 250 | -------------------------------------------------------------------------------- /runtests: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python -m pytest tests/ 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | VERSION = 0.1 4 | README = open('README.md') 5 | 6 | setup( 7 | name="amtlb", 8 | version=VERSION, 9 | description="Atari Multitask and Transfer Learning Benchmark", 10 | author="Josh Kuhn", 11 | author_email="deontologician@gmail.com", 12 | url="ai-on.org/projects/multitask-and-transfer-learning.html", 13 | packages=find_packages('amtlb'), 14 | package_dir={'': 'amtlb'}, 15 | license="Apache License", 16 | tests_require=[ 17 | 'pytest == 3.0.6', 18 | 'mock' 19 | ], 20 | long_description=README, 21 | install_requires=[ 22 | 'gym >= 0.5.6', 23 | ], 24 | scripts=[ 25 | 'runtests', 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from random import randint, sample 3 | from string import uppercase, lowercase 4 | import json 5 | import os.path 6 | 7 | import pytest 8 | from mock import MagicMock, patch 9 | 10 | import amtlb 11 | 12 | 13 | @pytest.fixture 14 | def game_names(): 15 | return list(sample(uppercase, randint(14, 26))) 16 | 17 | @pytest.fixture(params=[1, 3, 5, 7, 11, 13]) 18 | def num_folds(request): 19 | return request.param 20 | 21 | @pytest.fixture 22 | def benchmark_parms(game_names, num_folds): 23 | return partial( 24 | amtlb.BenchmarkParms, 25 | game_names=game_names, 26 | num_folds=num_folds, 27 | ) 28 | 29 | @pytest.fixture 30 | def random_benchmark_parms(game_names): 31 | return partial(amtlb.BenchmarkParms, 32 | num_folds=randint(0, 11), 33 | max_rounds_w_no_reward=randint(0, 1000), 34 | seed=randint(0, 10000), 35 | game_names=game_names, 36 | ) 37 | 38 | @pytest.fixture 39 | def random_filename(tmpdir): 40 | return os.path.join(str(tmpdir), ''.join(sample(lowercase, 10))) 41 | 42 | 43 | def folds_equal(A, B): 44 | _A = {frozenset(a) for a in A} 45 | _B = {frozenset(b) for b in B} 46 | return _A == _B 47 | 48 | class TestBenchmarkParms(object): 49 | 50 | def test_creates_right_num_folds(self, num_folds, benchmark_parms): 51 | bp = benchmark_parms() 52 | assert len(bp.folds) == num_folds 53 | 54 | def test_folds_are_all_close_in_size( 55 | self, game_names, num_folds, benchmark_parms): 56 | bp = benchmark_parms() 57 | 58 | fold_div = len(game_names) // num_folds 59 | fold_rem = len(game_names) % num_folds 60 | 61 | for fold in bp.folds: 62 | assert len(fold) in [fold_div, fold_div + 1] 63 | 64 | def test_all_games_go_in_a_fold(self, game_names, benchmark_parms): 65 | bp = benchmark_parms() 66 | 67 | all_games_in_folds = set() 68 | for fold in bp.folds: 69 | all_games_in_folds.update(set(fold)) 70 | assert set(game_names) == all_games_in_folds 71 | 72 | def test_save_parms(self, random_benchmark_parms, random_filename): 73 | bp = random_benchmark_parms() 74 | bp.save(random_filename) 75 | 76 | with open(random_filename, 'r') as fileobj: 77 | j = json.load(fileobj) 78 | 79 | assert j.pop('num_folds') == bp.num_folds 80 | assert j.pop('max_rounds_w_no_reward') == bp.max_rounds_w_no_reward 81 | assert j.pop('seed') == bp.seed 82 | assert j.pop('max_rounds_per_game') == bp.max_rounds_per_game 83 | assert j.pop('game_names') == bp.game_names 84 | assert folds_equal(j.pop('folds'), bp.folds) 85 | assert not j 86 | 87 | def test_load_parms(self, random_benchmark_parms, random_filename): 88 | bp = random_benchmark_parms() 89 | test_data = { 90 | "num_folds": bp.num_folds, 91 | "max_rounds_w_no_reward": bp.max_rounds_w_no_reward, 92 | "seed": bp.seed, 93 | "max_rounds_per_game": bp.max_rounds_per_game, 94 | "game_names": bp.game_names, 95 | "folds": bp.folds, 96 | } 97 | with open(random_filename, 'w') as fileobj: 98 | json.dump(test_data, fileobj) 99 | 100 | bp2 = amtlb.BenchmarkParms.load_from_file(random_filename) 101 | 102 | assert bp.num_folds == bp2.num_folds 103 | assert bp.max_rounds_w_no_reward == bp2.max_rounds_w_no_reward 104 | assert bp.seed == bp2.seed 105 | assert bp.max_rounds_per_game == bp2.max_rounds_per_game 106 | assert bp.game_names == bp2.game_names 107 | assert folds_equal(bp.folds, bp2.folds) 108 | --------------------------------------------------------------------------------