├── src ├── __init__.py ├── config.py ├── pirate.py ├── island.py ├── dataset.py ├── model_chunks.py ├── hyperopt_trainer.py └── ship.py ├── tests ├── __init__.py ├── profiler_tester.py ├── test_pirate.py ├── memory_leak_tester.py ├── test_model_chunks.py └── test_ship.py ├── notebooks ├── __init__.py ├── pirate_analyzer.ipynb └── data_creator.ipynb ├── docs └── images │ ├── arch.PNG │ ├── top5.png │ ├── island.PNG │ ├── pirate.PNG │ ├── treasure.PNG │ ├── arch_table.png │ ├── data_table.png │ ├── opt_table.png │ ├── val_acc_tensorboard.png │ └── val_loss_tensorboard.png ├── clean_local.sh ├── .gitignore ├── run_tests.sh ├── LICENSE ├── run.py └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/arch.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/arch.PNG -------------------------------------------------------------------------------- /docs/images/top5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/top5.png -------------------------------------------------------------------------------- /docs/images/island.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/island.PNG -------------------------------------------------------------------------------- /docs/images/pirate.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/pirate.PNG -------------------------------------------------------------------------------- /docs/images/treasure.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/treasure.PNG -------------------------------------------------------------------------------- /docs/images/arch_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/arch_table.png -------------------------------------------------------------------------------- /docs/images/data_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/data_table.png -------------------------------------------------------------------------------- /docs/images/opt_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/opt_table.png -------------------------------------------------------------------------------- /clean_local.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rm -rf local/logs/* 4 | rm -rf local/models/* 5 | rm -rf local/ships/* -------------------------------------------------------------------------------- /docs/images/val_acc_tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/val_acc_tensorboard.png -------------------------------------------------------------------------------- /docs/images/val_loss_tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hu-po/pirateAI/HEAD/docs/images/val_loss_tensorboard.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Exclude local files (data, models, logs, config) 2 | /local/datasets/* 3 | /local/models/* 4 | /local/logs/* 5 | /local/ships/* 6 | /local/unityenv/* 7 | 8 | # Generated files 9 | .idea* 10 | __pycache__* 11 | .ipynb_checkpoints* -------------------------------------------------------------------------------- /tests/profiler_tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cProfile 4 | 5 | mod_path = os.path.abspath(os.path.join('..')) 6 | sys.path.append(mod_path) 7 | import src.config as config 8 | from run import train 9 | 10 | """ 11 | This python file is used to profile training 12 | """ 13 | 14 | # Create a test ship and run training 15 | test_ship_name = 'TestShip' 16 | 17 | cProfile.run('train(ship_name=test_ship_name, full_cycles = 1, maroon_cycles = 1, max_pirates_in_ship = 2, ' 18 | 'min_pirates_in_ship = 2)') 19 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run python tests in tests module 4 | python -m unittest discover 5 | 6 | # Individual tests 7 | #python -m unittest tests.test_pirate 8 | #python -m unittest tests.test_ship 9 | #python -m unittest tests.test_island 10 | #python -m unittest tests.test_model_chunks 11 | 12 | # Might need context file in tests directory 13 | ## Context allows testing of modules without building solution 14 | #import os 15 | #import sys 16 | #mod_path = os.path.abspath(os.path.join('..')) 17 | #sys.path.append(mod_path) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /tests/test_pirate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from uuid import uuid4 4 | from keras.models import Sequential 5 | from keras.layers import Dense, Flatten 6 | from src.pirate import Pirate 7 | 8 | class TestPirate(unittest.TestCase): 9 | 10 | def test_initialization(self): 11 | # No model will be found for a blank pirate 12 | with self.assertRaises(FileNotFoundError): 13 | blank_pirate = Pirate() 14 | 15 | # Make a blank keras model to try and load 16 | test_dna = str(uuid4()) 17 | model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'local', 'models')) 18 | test_model = Sequential() 19 | test_model.add(Flatten(input_shape=(128, 128, 3))) 20 | test_model.add(Dense(1)) 21 | test_model.save(model_path + '/' + test_dna + '.h5') 22 | test_pirate = Pirate(dna=test_dna) 23 | self.assertEqual(test_pirate.dna, test_dna) 24 | self.assertTrue(isinstance(test_pirate.name, str)) 25 | 26 | # Delete the test model 27 | os.remove(model_path + '/' + test_dna + '.h5') 28 | 29 | def test_act(self): 30 | pass 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /tests/memory_leak_tester.py: -------------------------------------------------------------------------------- 1 | import linecache 2 | import os 3 | import tracemalloc 4 | import sys 5 | mod_path = os.path.abspath(os.path.join('..')) 6 | sys.path.append(mod_path) 7 | import src.config as config 8 | from run import train 9 | 10 | """ 11 | This python file is used to test for memory leaks using 12 | tracemalloc while running training. 13 | """ 14 | 15 | def display_top(snapshot, key_type='lineno', limit=10): 16 | snapshot = snapshot.filter_traces(( 17 | tracemalloc.Filter(False, ""), 18 | tracemalloc.Filter(False, ""), 19 | )) 20 | top_stats = snapshot.statistics(key_type) 21 | 22 | print("Top %s lines" % limit) 23 | for index, stat in enumerate(top_stats[:limit], 1): 24 | frame = stat.traceback[0] 25 | # replace "/path/to/module/file.py" with "module/file.py" 26 | filename = os.sep.join(frame.filename.split(os.sep)[-2:]) 27 | print("#%s: %s:%s: %.1f KiB" 28 | % (index, filename, frame.lineno, stat.size / 1024)) 29 | line = linecache.getline(frame.filename, frame.lineno).strip() 30 | if line: 31 | print(' %s' % line) 32 | 33 | other = top_stats[limit:] 34 | if other: 35 | size = sum(stat.size for stat in other) 36 | print("%s other: %.1f KiB" % (len(other), size / 1024)) 37 | total = sum(stat.size for stat in top_stats) 38 | print("Total allocated size: %.1f KiB" % (total / 1024)) 39 | 40 | # use tracemalloc to find memory leaks 41 | tracemalloc.start() 42 | 43 | test_ship_name = 'TestShip' 44 | 45 | # Change some configs to only train a single model 46 | config.NUM_PIRATES_PER_TRAIN = 1 47 | config.MAX_TRAIN_TRIES = 1 48 | train(ship_name=test_ship_name, 49 | full_cycles=1, 50 | maroon_cycles=1, 51 | max_pirates_in_ship=2, 52 | min_pirates_in_ship=2) 53 | 54 | snapshot_before = tracemalloc.take_snapshot() 55 | import keras 56 | keras.backend.clear_session() 57 | snapshot_after = tracemalloc.take_snapshot() 58 | stats = snapshot_after.compare_to(snapshot_before, 'lineno') 59 | 60 | print('TOP FOR AFTER SNAPSHOT --------------') 61 | display_top(snapshot_after) 62 | 63 | print('TOP FOR BEFORE SNAPSHOT --------------') 64 | display_top(snapshot_before) 65 | 66 | print("[ Top 10 ]") 67 | for stat in stats[:10]: 68 | print(stat) 69 | 70 | try: 71 | os.remove(os.path.join(config.SHIP_DIR, test_ship_name)) 72 | except FileNotFoundError: 73 | pass -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from src.ship import Ship 3 | from src.island import Island 4 | import src.config as config 5 | 6 | 7 | # PLEASE GO TO config.py FILE TO CHANGE CONFIGS 8 | 9 | def train(ship_name=config.SHIP_NAME, 10 | full_cycles=config.FULL_CYCLES, 11 | maroon_cycles=config.MAROON_CYCLES, 12 | max_pirates_in_ship=config.MAX_PIRATES_IN_SHIP, 13 | min_pirates_in_ship=config.MIN_PIRATES_IN_SHIP): 14 | """ 15 | Train mode: runs a local UnityEnvironment, general loop is: 16 | - Generate pirates until you have some minimum amount in the ship 17 | - Run matches to rank pirates 18 | - Cull the worse performing pirates in the ship 19 | :param ship_name: (str) name of the ship 20 | :param full_cycles: (int) number of full pirateAI cycles to run 21 | :param maroon_cycles: (int) number of maroonings to run for each hyperopt training 22 | :param max_pirates_in_ship: (int) max number of pirates in ship 23 | :param min_pirates_in_ship: (int) min number of pirates in ship 24 | """ 25 | with Ship(ship_name=ship_name) as ship: 26 | for _ in range(full_cycles): 27 | pirates_in_ship = ship.headcount() 28 | while pirates_in_ship < max_pirates_in_ship: 29 | ship.more_pirates() # Not enough hands on deck 30 | pirates_in_ship = ship.headcount() 31 | with Island(brain='PirateBrain', file_name='local/unityenv/Island') as island: 32 | for _ in range(maroon_cycles): 33 | island.maroon(ship=ship) 34 | if pirates_in_ship > min_pirates_in_ship: 35 | ship.less_pirates() 36 | 37 | 38 | def test(ship_name=config.SHIP_NAME, 39 | full_cycles=config.FULL_CYCLES, 40 | maroon_cycles=config.MAROON_CYCLES): 41 | """ 42 | Test Mode: runs an external UnityEnvironment, general loop is: 43 | - Run matches to rank pirates 44 | :param ship_name: (str) name of the ship 45 | :param full_cycles: (int) number of full pirateAI cycles to run 46 | :param maroon_cycles: (int) number of maroonings to run for each hyperopt training 47 | """ 48 | with Ship(ship_name=ship_name) as ship: 49 | for _ in range(full_cycles): 50 | with Island(host_ip=config.WINDOWS_IP, 51 | host_port=config.WINDOWS_PORT, 52 | brain='PirateBrain') as island: 53 | for _ in range(maroon_cycles): 54 | island.maroon(ship=ship) 55 | 56 | 57 | if __name__ == '__main__': 58 | 59 | logger = logging.getLogger(__name__) 60 | 61 | if config.MODE == 'train': 62 | logger.info("Running in training mode") 63 | train() 64 | 65 | if config.MODE == 'test': 66 | logger.info("Running in Testing mode") 67 | test() 68 | -------------------------------------------------------------------------------- /tests/test_model_chunks.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from src.hyperopt_trainer import HyperoptTrainer 3 | 4 | class TestModelChunks(unittest.TestCase): 5 | def setUp(self): 6 | self.trainer = HyperoptTrainer() 7 | self.hyperparams = {'dataset' : 'test.pickle', 8 | 'dataset_size' : 80, 9 | 'head' : 'fc', 10 | 'dim_reduction': 'flatten', 11 | 'fc params' : 12 | { 13 | 'dense_layers' : [32], 14 | 'dense_activations' : 'relu', 15 | 'dropout_percentage': 0.3, 16 | }, 17 | 'batch_size' : 16, 18 | 'epochs' : 1, 19 | 'optimizer' : 'rmsprop', 20 | 'learning_rate': 0.0005, 21 | 'decay' : 0.0001, 22 | 'clipnorm' : 1.0 23 | } 24 | 25 | def tearDown(self): 26 | del self.trainer, self.hyperparams 27 | 28 | def test_custom_model(self): 29 | pass # Gets tested in functions below 30 | 31 | def test_model_chunk(self): 32 | pass # Gets tested in functions below 33 | 34 | def test_a3c(self): 35 | self.hyperparams['base'] = 'a3c' 36 | self.trainer.model(self.hyperparams, test_mode=True) 37 | 38 | def test_a3c_sepconv(self): 39 | self.hyperparams['base'] = 'a3c_sepconv' 40 | self.trainer.model(self.hyperparams, test_mode=True) 41 | 42 | def test_simpleconv(self): 43 | self.hyperparams['base'] = 'simpleconv' 44 | self.trainer.model(self.hyperparams, test_mode=True) 45 | 46 | def test_minires(self): 47 | self.hyperparams['base'] = 'minires' 48 | self.trainer.model(self.hyperparams, test_mode=True) 49 | 50 | def test_tall_kernel(self): 51 | self.hyperparams['base'] = 'tall_kernel' 52 | self.trainer.model(self.hyperparams, test_mode=True) 53 | 54 | def test_inception_res_v2(self): 55 | self.hyperparams['base'] = 'inception_res_v2' 56 | self.hyperparams['inception_res_v2 params'] = \ 57 | { 58 | 'trainable' : True, 59 | 'pre_trained': True, 60 | 'input_shape': (128, 128, 3) 61 | } 62 | self.trainer.model(self.hyperparams, test_mode=True) 63 | 64 | def test_res_net_50(self): 65 | self.hyperparams['base'] = 'res_net_50' 66 | self.hyperparams['res_net_50 params'] = \ 67 | { 68 | 'trainable' : True, 69 | 'pre_trained': True, 70 | 'input_shape': (128, 128, 3) 71 | } 72 | self.trainer.model(self.hyperparams, test_mode=True) 73 | 74 | def test_xception(self): 75 | self.hyperparams['base'] = 'xception' 76 | self.hyperparams['xception params'] = \ 77 | { 78 | 'trainable' : True, 79 | 'pre_trained': True, 80 | 'input_shape': (128, 128, 3) 81 | } 82 | self.trainer.model(self.hyperparams, test_mode=True) 83 | 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tests/test_ship.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import logging 4 | from uuid import uuid4 5 | from keras.models import Sequential 6 | from keras.layers import Dense, Flatten 7 | from src.pirate import Pirate 8 | from src.ship import Ship 9 | 10 | class TestShip(unittest.TestCase): 11 | def setUp(self): 12 | self.ship = Ship(ship_name='TestBoat') 13 | self.test_dnas = [str(uuid4()) for _ in range(10)] 14 | 15 | # Make a blank keras model for test pirates 16 | self.model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'local', 'models')) 17 | test_model = Sequential() 18 | test_model.add(Flatten(input_shape=(128, 128, 3))) 19 | test_model.add(Dense(1)) 20 | test_model.compile(optimizer='rmsprop', loss='mse') 21 | 22 | # Save models and create pirates 23 | self.test_pirates = [] 24 | for i, dna in enumerate(self.test_dnas): 25 | test_model.save(self.model_path + '/' + dna + '.h5') 26 | self.test_pirates.append(Pirate(dna=dna)) 27 | if i > len(self.test_dnas) - 5: 28 | break # Leave some models un-saved 29 | 30 | # Add pirates to the ship 31 | # Add some pirates 32 | for i in range(2): 33 | self.ship._add_pirate(dna=self.test_dnas[i]) 34 | 35 | def tearDown(self): 36 | # Delete all the test models we created 37 | for i, pirate in enumerate(self.test_pirates): 38 | os.remove(self.model_path + '/' + pirate.dna + '.h5') 39 | self.ship.sink() 40 | del self.ship 41 | 42 | def test_get_set(self): 43 | with self.assertRaises(ValueError): 44 | self.ship._get_prop(dna=None, prop=['test']) # dna should be string 45 | self.ship._get_prop(dna=self.test_dnas[0], prop=[None]) 46 | self.ship._set_prop(dna=None, prop=[]) # dna should be string 47 | self.ship._set_prop(dna=self.test_dnas[0], prop=[]) # prop should be a dict 48 | self.ship._set_prop(dna=self.test_dnas[0], prop={1: 'a'}) # prop dict keys should be strings 49 | 50 | with self.assertLogs(level=logging.WARNING): 51 | # Column does not exist 52 | self.assertTrue(self.ship._set_prop(dna=self.test_dnas[0], prop={'test': 4})) 53 | err, _ = self.ship._get_prop(dna=self.test_dnas[0], prop=['test']) 54 | self.assertTrue(err) 55 | 56 | # Valid conditions 57 | self.assertFalse(self.ship._set_prop(dna=self.test_dnas[0], prop={'saltyness': 2, 'loss': 1})) 58 | self.assertFalse(self.ship._set_prop(dna=self.test_dnas[0], prop={'win': 'win + 1'})) 59 | err, _ = self.ship._get_prop(dna=self.test_dnas[0], prop=['win']) 60 | self.assertFalse(err) 61 | 62 | def test_add_remove(self): 63 | with self.assertRaises(ValueError): 64 | self.ship._add_pirate(dna=None) 65 | self.ship._walk_the_plank(dna=None) 66 | 67 | def test_create_pirate(self): 68 | with self.assertRaises(ValueError): 69 | self.ship.create_pirate(dna=None) 70 | 71 | with self.assertLogs(level=logging.WARNING): 72 | # Give it a pirate with no model 73 | err, _ = self.ship.create_pirate(dna=self.test_dnas[-2]) 74 | self.assertTrue(err) 75 | # Give it a pirate not on the ship 76 | err, _ = self.ship.create_pirate(dna=str(uuid4())) 77 | self.assertTrue(err) 78 | 79 | err, _ = self.ship.create_pirate(dna=self.test_dnas[0]) 80 | self.assertFalse(err) 81 | 82 | def test_get_best_pirates(self): 83 | # Should only return the 2 pirates in ship 84 | pirates = self.ship.get_best_pirates(n=3) 85 | self.assertEqual(len(pirates), 2) 86 | self.assertTrue(all(isinstance(pirate, Pirate) for pirate in pirates)) 87 | 88 | 89 | if __name__ == '__main__': 90 | unittest.main() 91 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from hyperopt import hp 2 | import logging 3 | import os 4 | 5 | # This file contains all the parameters associated with pirateAI. 6 | # Many of the classes and files import from this. 7 | # TODO: Make this a JSON or YAML? 8 | 9 | # MODE = 'test' 10 | MODE = 'train' 11 | SHIP_NAME = 'Krusty Krab' 12 | 13 | # Training (hyperopt) 14 | NUM_PIRATES_PER_TRAIN = 3 15 | MAX_TRAIN_TRIES = 3 16 | NUM_PIRATES_PER_CULLING = 2 17 | SALT_PER_WIN = 8 18 | SALT_PER_LOSS = 1 19 | STARTING_SALT = 100 20 | TRAIN_PATIENCE = 3 21 | LABEL_DICT = {0: 'W', 1: 'S', 2: 'A', 3: 'D'} 22 | 23 | # Run duration 24 | FULL_CYCLES = 100 25 | 26 | # Ship 27 | MAX_PIRATES_IN_SHIP = 20 28 | MIN_PIRATES_IN_SHIP = 10 29 | 30 | # Evaluation (marooning) 31 | MAROON_CYCLES = 10 32 | N_BEST_PIRATES = 10 33 | BOUNTY = 1 34 | MAX_ROUNDS = 3 35 | 36 | # Connection information (for an external unity environment) 37 | WINDOWS_IP = '192.168.2.3' 38 | WINDOWS_PORT = 5008 39 | 40 | # 0-1 percent of data to have in training set 41 | TRAIN_TEST_SPLIT = 0.9 42 | 43 | # Debug Tools 44 | INPUT_DEBUG = False # Blocking plot showing sample training image and histogram 45 | 46 | # Logging 47 | logging.getLogger("tensorflow").setLevel(logging.ERROR) 48 | logging.basicConfig(level=logging.INFO) 49 | 50 | # Directories 51 | base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 52 | MODEL_DIR = os.path.abspath(os.path.join(base_path, 'local', 'models')) 53 | LOGS_DIR = os.path.abspath(os.path.join(base_path, 'local', 'logs')) 54 | DATA_DIR = os.path.abspath(os.path.join(base_path, 'local', 'datasets')) 55 | SHIP_DIR = os.path.abspath(os.path.join(base_path, 'local', 'ships')) 56 | 57 | # Hyperparameter space to explore 58 | # TODO: For now comment out the larger nets as they are unfair 59 | 60 | 61 | SPACE = {'dataset' : hp.choice('dataset', ['1203_d1.pickle', 62 | '1203_d2.pickle', 63 | '1203_d3.pickle', 64 | '1203_d4.pickle']), 65 | 'dataset_size' : hp.choice('dataset_size', [2000, 3000, 4000]), 66 | 'base' : hp.choice('base', [ 67 | 'xception', 68 | 'inception_res_v2', 69 | 'res_net_50', 70 | 'a3c', 71 | 'tall_kernel', 72 | 'a3c_sepconv', 73 | 'minires', 74 | 'simpleconv']), 75 | 'head' : hp.choice('head', ['fc']), 76 | 'dim_reduction': hp.choice('dim_reduction', ['global_average', 77 | 'global_max', 78 | 'flatten']), 79 | 'xception params' : { 80 | 'trainable' : hp.choice('xception trainable', [True, False]), 81 | 'pre_trained': hp.choice('xception pre_trained', [True, False]), 82 | }, 83 | 'inception_res_v2 params': { 84 | 'trainable' : hp.choice('iresv2 trainable', [True, False]), 85 | 'pre_trained': hp.choice('iresv2 pre_trained', [True, False]), 86 | }, 87 | 'res_net_50 params' : { 88 | 'trainable' : hp.choice('rn50 trainable', [True, False]), 89 | 'pre_trained': hp.choice('rn50 pre_trained', [True, False]), 90 | }, 91 | 'fc params' : { 92 | 'dense_layers' : hp.choice('dense_layers', 93 | [[64, 32], [256], [256, 64], [64], [32, 16], [256, 32]]), 94 | 'dense_activations' : hp.choice('dense_activations', ['relu']), 95 | 'dropout_percentage': hp.uniform('dropout_percentage', 0, 0.5), 96 | }, 97 | 98 | 'batch_size' : hp.choice('batch_size', [16, 32]), 99 | 'epochs' : hp.choice('epochs', [25]), 100 | 'optimizer' : hp.choice('optimizer', ['rmsprop', 'sgd', 'adam', 'nadam']), 101 | 'learning_rate': hp.choice('learning_rate', [0.0001, 0.00005]), 102 | 'decay' : hp.choice('decay', [0.0, 0.004, 0.0001]), 103 | 'clipnorm' : hp.choice('clipnorm', [0., 1.]), 104 | } 105 | -------------------------------------------------------------------------------- /notebooks/pirate_analyzer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Pirate Analyzer\n", 8 | "\n", 9 | "The aim of this notebook is to analyze an individual pirate, or the top pirates in a ship. Mainly comparing hyperparameters and looking at the output of the pirate model on some test images\n", 10 | "\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "import sys\n", 21 | "import logging\n", 22 | "import json\n", 23 | "# Allow for importing Pirate class\n", 24 | "cwd = os.getcwd()\n", 25 | "sys.path.append(os.path.join(cwd, '..'))\n", 26 | "from src.dataset import Dataset, plot_images\n", 27 | "from src.pirate import Pirate\n", 28 | "from src.ship import Ship\n", 29 | "from src.config import *\n", 30 | "\n", 31 | "# Define logger\n", 32 | "logging.basicConfig(level=logging.ERROR)\n", 33 | "logger = logging.getLogger(__name__)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### Individual Pirate Analysis\n", 41 | "\n", 42 | "This section looks at a single pirate. You will need to provide the dna for a given pirate. The easiest way to do this is to run\n", 43 | "\n", 44 | "`sqlitebrowser local/ships/yourshipname.db`\n", 45 | "\n", 46 | "scroll through the list of pirates and pick one.\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# We need the dna of the pirate in question\n", 56 | "DNA = '60e43a29-2c95-4342-8b5d-b86015a96ac3'\n", 57 | "\n", 58 | "# Note that initializing a pirate this way means you lose the name/stats :(\n", 59 | "pirate = Pirate(dna=DNA)\n", 60 | "\n", 61 | "# Make a quick test dataset\n", 62 | "test_path = os.path.join(DATA_DIR, '0312cleaner')\n", 63 | "test = Dataset.from_path(test_path, max_n = 100)\n", 64 | "images, true_labels = test.load_data_sample(n = 18)\n", 65 | "\n", 66 | "# Get the pirate action associated with each image\n", 67 | "pred_labels = []\n", 68 | "for image in images:\n", 69 | " label = pirate.act(image)\n", 70 | " pred_labels.append(label)\n", 71 | " \n", 72 | "# Plot the images along with true and predicted labels\n", 73 | "plot_images(images, true_labels, pred_label=pred_labels)\n", 74 | "\n", 75 | "# Print out the hyperparameters for this model\n", 76 | "print('Hyperparams %s' % json.dumps(pirate.description(), indent=4))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "### Ship Analysis\n", 84 | "\n", 85 | "This section will analyze the pirates within an entire ship. The main interest here is to find the hyperparameters that seem to be producing the best pirates. " 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "SHIP_NAME = 'The Black Pearl'\n", 95 | "N_BEST_PIRATES = 6\n", 96 | "\n", 97 | "# Get the best pirates\n", 98 | "with Ship(ship_name=SHIP_NAME) as ship:\n", 99 | " pirates = ship.get_best_pirates(N_BEST_PIRATES)\n", 100 | "\n", 101 | "# Aggregate dictionary of hyperparameters:\n", 102 | "agg_params = {}\n", 103 | "for pirate in pirates:\n", 104 | " for key, value in pirate.description().items():\n", 105 | " # Add the pirate hyperparameter to aggregate\n", 106 | " if key not in agg_params.keys():\n", 107 | " agg_params[key] = []\n", 108 | " # Add parameter value\n", 109 | " agg_params[key].append(value)\n", 110 | "\n", 111 | "# Print out the hyperparameters for this model\n", 112 | "print('Aggregate Hyperparameters %s' % json.dumps(agg_params, indent=4))" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.6.2" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 1 137 | } 138 | -------------------------------------------------------------------------------- /src/pirate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import numpy as np 5 | import pickle 6 | from uuid import uuid4 7 | from keras.models import load_model 8 | from src.dataset import image_analysis 9 | import src.config as config 10 | 11 | class Pirate(object): 12 | """ 13 | Pirates are the agents on the island. When instantiated, pirates load a model 14 | into GPU memory. 15 | """ 16 | 17 | def __init__(self, dna=None, name='Unborn', win=0, loss=0, saltyness=0, rank=None): 18 | """ 19 | :param dna: (string) identifier uuid4 string for a pirate 20 | :param name: (string) the pirate's name 21 | :param win: (int) number of wins 22 | :param loss: (int) number of losses 23 | :param saltyness: (int) an estimate of a Pirate's performance (think ELO) 24 | :param rank: (int) rank of this pirate from their training batch 25 | :raises FileNotFoundError: Can't load pirate model 26 | :raises ValueError: no input given 27 | """ 28 | self.log = logging.getLogger(__name__) 29 | self.dna = dna or str(uuid4()) 30 | if name == 'Unborn': 31 | self.name = self._generate_name(rank=rank) 32 | else: 33 | self.name = name 34 | self.win = win 35 | self.loss = loss 36 | self.saltyness = saltyness 37 | # Model contains weights and graph 38 | self._model = self._load_model() 39 | 40 | def act(self, input, visualize=False): 41 | """ 42 | Runs the pirate model on the given input (image, etc). 43 | :param input: input tensor, format matches model 44 | :param visualize: (bool) display incoming image and metadata 45 | :return: (int) action resulting from model. 46 | :raises ValueError: no input given 47 | """ 48 | if input is None: 49 | raise ValueError("Please provide an input image to generate an action") 50 | if len(input.shape) == 3: 51 | input = np.expand_dims(input, axis=0) 52 | norm_input_image = input 53 | output = self._model.predict(norm_input_image) 54 | # Classification model outputs action probabilities 55 | action = np.argmax(output) 56 | if visualize or config.INPUT_DEBUG: # This blocks the GIL to visualize 57 | image_analysis(image=input[0, :, :, :], label=action) 58 | return action 59 | 60 | def description(self): 61 | """ 62 | Finds and returns the info in hyperparameter text file 63 | :return: (string) or None 64 | :raises FileNotFoundError: Can't find hyperparameter text file in path 65 | """ 66 | for dirpath, _, files in os.walk(config.MODEL_DIR): 67 | if "{dna}.pickle".format(dna=self.dna) in files: 68 | with open(os.path.join(dirpath, self.dna + '.pickle'), 'rb') as file: 69 | data = pickle.load(file) 70 | assert isinstance(data, dict), 'Pirate description is corrupted' 71 | # Pirate description printed out to logger 72 | self.log.info('--- Pirate %s (dna: %s) ---' % (self.name, self.dna)) 73 | model_summary = data.pop('model_summary', None) 74 | for line in model_summary: 75 | self.log.info(line) 76 | for key, val in data.items(): 77 | self.log.info('%s : %s' % (str(key), str(val))) 78 | return data 79 | raise FileNotFoundError('Could not find description in path using given dna string') 80 | 81 | def _load_model(self): 82 | """ 83 | Tries to find pirate model in the model path 84 | :return: (keras.model) or None 85 | :raises FileNotFoundError: Can't find model in path 86 | """ 87 | for dirpath, _, files in os.walk(config.MODEL_DIR): 88 | if self.dna + '.h5' in files: 89 | return load_model(os.path.join(dirpath, self.dna + '.h5')) 90 | raise FileNotFoundError('Could not find model in path using given dna string') 91 | 92 | def _generate_name(self, rank=-1): 93 | """ 94 | Generates a proper pirate name 95 | :param rank: (int) rank with respect to training batch 96 | :return:(string) name 97 | """ 98 | name = '' 99 | # Titles are ordered based on rank 100 | titles = ['Salty ', 'Admiral ', 'Captain ', 'Don ', 'First Mate ', 'Gunmaster ', 101 | 'Sailor ', 'Deckhand ', 'Mc', 'Cookie ', 'Lil', ''] 102 | if rank in range(len(titles)): 103 | name += titles[rank] 104 | # The real part of the name is chosen randomly 105 | real_names = ['Jack', 'Haddock', 'Blackbeard', 'Will', 'Long', 'Simon', 'Barbossa'] 106 | name += random.choice(real_names) 107 | self.log.debug('The Pirate %s has been created' % name) 108 | return name 109 | 110 | def __eq__(self, pirate): 111 | """ 112 | Compare pirates using their dna 113 | :param pirate: (Pirate) the other pirate 114 | :return: (bool) True if dna matches 115 | """ 116 | return self.dna == pirate.dna 117 | -------------------------------------------------------------------------------- /notebooks/data_creator.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Data Creator\n", 8 | "\n", 9 | "This notebook reads, modifies, and concatenates datasets into pickle files of image paths and corresponding targets that can be used for training models." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "### Preliminary Data Analysis\n", 17 | "\n", 18 | "Plot some samples from each dataset, as well as the action histograms. Do the samples look reasonable? Is the histogram too lopsided? " 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "import os\n", 30 | "import sys\n", 31 | "mod_path = os.path.abspath(os.path.join('..'))\n", 32 | "sys.path.append(mod_path)\n", 33 | "\n", 34 | "from src.dataset import Dataset\n", 35 | "from src.config import DATA_DIR" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# 12-3-2017\n", 45 | "\n", 46 | "#beach_messy\n", 47 | "beach_messy_path = os.path.join(DATA_DIR, '0312beaches_messy')\n", 48 | "beach_messy = Dataset.from_path(beach_messy_path)\n", 49 | "\n", 50 | "#forest_messy\n", 51 | "forest_messy_path = os.path.join(DATA_DIR, '0312forest_messy')\n", 52 | "forest_messy = Dataset.from_path(forest_messy_path)\n", 53 | "\n", 54 | "#clean_mix\n", 55 | "clean_mix_path = os.path.join(DATA_DIR, '0312cleaner')\n", 56 | "clean_mix = Dataset.from_path(clean_mix_path)\n", 57 | "\n", 58 | "#test_dataset\n", 59 | "test_path = os.path.join(DATA_DIR, '0312cleaner')\n", 60 | "test = Dataset.from_path(test_path, max_n = 100)\n", 61 | "\n", 62 | "# Plot sample images and histograms\n", 63 | "beach_messy.analyze()\n", 64 | "forest_messy.analyze()\n", 65 | "clean_mix.analyze()\n", 66 | "test.analyze()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "### Mix and Match\n", 74 | "\n", 75 | "Here we combine our raw datasets into custom datasets we will use in training." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "# 12-3-2017\n", 85 | "\n", 86 | "# Dataset 1: beach messy + forest messy\n", 87 | "d1 = Dataset.from_datasets([beach_messy,\n", 88 | " forest_messy])\n", 89 | "\n", 90 | "# Dataset 2: messy but with better balance\n", 91 | "d2 = Dataset.from_datasets([beach_messy.only_label('A'),\n", 92 | " beach_messy.only_label('D'),\n", 93 | " forest_messy.only_label('A'),\n", 94 | " forest_messy.only_label('D'),\n", 95 | " beach_messy.only_label('A'),\n", 96 | " beach_messy.only_label('D'),\n", 97 | " forest_messy.only_label('A'),\n", 98 | " forest_messy.only_label('D'),\n", 99 | " beach_messy,\n", 100 | " forest_messy])\n", 101 | "\n", 102 | "# Dataset 3: clean only\n", 103 | "d3 = Dataset.from_datasets([clean_mix])\n", 104 | "\n", 105 | "# Dataset 4: clean with better balance\n", 106 | "d4 = Dataset.from_datasets([clean_mix.only_label('A'),\n", 107 | " clean_mix.only_label('D'),\n", 108 | " clean_mix.only_label('A'),\n", 109 | " clean_mix.only_label('D'),\n", 110 | " clean_mix])\n", 111 | "\n", 112 | "# Plot sample images and histograms\n", 113 | "d1.analyze()\n", 114 | "d2.analyze()\n", 115 | "d3.analyze()\n", 116 | "d4.analyze()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## Dump to pickle\n", 124 | "\n", 125 | "For our keras hyperopt model, lets dump the image paths and labels to a pickle file." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "collapsed": true 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "import pickle\n", 137 | "\n", 138 | "# Paths to local directories containing data\n", 139 | "d1_path = os.path.join(DATA_DIR, '1203_d1.pickle')\n", 140 | "d2_path = os.path.join(DATA_DIR, '1203_d2.pickle')\n", 141 | "d3_path = os.path.join(DATA_DIR, '1203_d3.pickle')\n", 142 | "d4_path = os.path.join(DATA_DIR, '1203_d4.pickle')\n", 143 | "test_path = os.path.join(DATA_DIR, 'test.pickle')\n", 144 | "\n", 145 | "d1.save_to_pickle(d1_path)\n", 146 | "d2.save_to_pickle(d2_path)\n", 147 | "d3.save_to_pickle(d3_path)\n", 148 | "d4.save_to_pickle(d4_path)\n", 149 | "test.save_to_pickle(test_path)" 150 | ] 151 | } 152 | ], 153 | "metadata": { 154 | "anaconda-cloud": {}, 155 | "kernelspec": { 156 | "display_name": "Python 3", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.6.2" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 1 175 | } 176 | -------------------------------------------------------------------------------- /src/island.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import random 4 | import keras 5 | from statistics import median 6 | from unityagents import UnityEnvironment 7 | from src.pirate import Pirate 8 | import src.config as config 9 | 10 | 11 | class Island(object): 12 | """ 13 | The Island is where pirates are marooned, testing their saltyness. 14 | Holds the unity environment. Used as a context manager. 15 | """ 16 | 17 | def __init__(self, host_ip=None, host_port=None, brain=None, file_name=None): 18 | """ 19 | :param host_ip: (string) host ip, if not provided runs locally 20 | :param host_port: (string) host port, if not provided runs locally 21 | :param brain: (string) name of the external brain in unity environment 22 | :param file_name: (string) name of the unity environment executable 23 | """ 24 | self.log = logging.getLogger(__name__) 25 | if not host_ip or not host_port: 26 | self.log.info('No host ip or port provided, running in local training mode') 27 | self._train_mode = True 28 | else: 29 | self.log.info('Running in external testing mode') 30 | self._train_mode = False 31 | self._host_ip = host_ip 32 | self._host_port = host_port 33 | self._brain_name = brain 34 | self.file_name = file_name 35 | 36 | def __enter__(self): 37 | # Connect to the Unity environment 38 | self.unity_env = UnityEnvironment(file_name=self.file_name, 39 | host_ip=self._host_ip, 40 | base_port=self._host_port) 41 | return self 42 | 43 | def __exit__(self, exception_type, exception_value, traceback): 44 | # Kill the Unity environment 45 | self.unity_env.close() 46 | del self.unity_env 47 | 48 | def maroon(self, ship=None, num_best_pirates=config.N_BEST_PIRATES): 49 | """ 50 | Maroon some pirates. Figure out which one is truly saltiest. 51 | :param ship: (Ship) the ship is where pirates live 52 | :param num_best_pirates: number of pirates to select before choosing randomly 53 | :return: 54 | """ 55 | assert ship, "No ship specified when marooning" 56 | # Randomly select 2 from N best pirates 57 | pirates = random.sample(ship.get_best_pirates(num_best_pirates), 2) 58 | # Run the marooning rounds 59 | self.log.info('Marooning the pirates %s' % ', '.join([pirate.name for pirate in pirates])) 60 | err, winners, losers = self._run_rounds(pirates=pirates) 61 | if not err: # If no error occurred during the marooning 62 | # Update the ship accordingly 63 | ship.marooning_update(winners, losers) 64 | # Delete the session to prevent GPU memory from getting full 65 | keras.backend.clear_session() 66 | 67 | def _run_rounds(self, pirates=None, bounty=config.BOUNTY, max_rounds=config.MAX_ROUNDS): 68 | """ 69 | Runs rounds between a list of pirates 70 | :param pirates: [pirates] list of N pirates to maroon 71 | :param bounty: (int) how many wins to be the winner 72 | :param max_rounds: (int) maximum number of rounds in one marooning 73 | :return: (bool),(string),[string,] error, winning pirate dna, losing pirates dna 74 | """ 75 | if any([not isinstance(pirate, Pirate) for pirate in pirates]): 76 | raise ValueError('Some of the pirates you provided are not pirates') 77 | # tracking variables for the match 78 | score = [0] * len(pirates) 79 | round_idx = 0 80 | winner = False # Is there a winning pirate? 81 | while round_idx < max_rounds: 82 | self.log.info("-------------- Round %s" % str(round_idx + 1)) 83 | try: 84 | winner_idx, times = self._round(pirates) 85 | # times contains execution times for each step 86 | self.log.info("%d steps taken." % len(times)) 87 | self.log.info("python execution time [median: %ds, max: %ds, min: %ds] " 88 | % (median(times), max(times), min(times))) 89 | score[winner_idx] += 1 90 | except ValueError: 91 | self.log.warning('Bad values passed within a round, discarding results...') 92 | except TimeoutError: 93 | self.log.info('Round Complete! But no clear winner') 94 | round_idx += 1 95 | if any(score[i] >= bounty for i in score): 96 | winner = True 97 | break # Break when a pirate reaches the max score 98 | if winner: 99 | winning_idx = score.index(max(score)) 100 | self.log.info('Match complete! %s claims victory' % pirates[winning_idx].name) 101 | winning_pirate = pirates.pop(winning_idx) 102 | return False, winning_pirate.dna, [pirate.dna for pirate in pirates] 103 | else: 104 | self.log.info('Match complete! No pirate was able to demonstrate superior saltyness') 105 | return False, '', [pirate.dna for pirate in pirates] 106 | 107 | def _round(self, pirates=None, max_steps=10000): 108 | """ 109 | Carries out a single round of pirate on pirate action 110 | :param pirates: [pirates] list of N pirates in the round 111 | :param max_steps: (int) maximum number of steps in round 112 | :return: (int),[int] index of winner, list of step execution times 113 | :raises TimeoutError: no done signal, max steps reached 114 | :raises ValueError: unity agents logic is having trouble 115 | """ 116 | # Reset the environment 117 | env_info = self.unity_env.reset(train_mode=self._train_mode) 118 | # Time python code each step, interesting and a good sanity checker 119 | py_t0, py_t1 = None, None 120 | episode_times = [] 121 | # Execute steps until environment sends done signal 122 | while True: 123 | if len(episode_times) > max_steps: 124 | raise TimeoutError('Unity environment never sent done signal, perhaps it disconnected?') 125 | # TODO: [0] index works because we only have one camera per pirate 126 | observation = env_info[self._brain_name].observations[0] 127 | agents_done = env_info[self._brain_name].local_done 128 | if all(agents_done): # environment finished first 129 | raise TimeoutError('Neither pirate was able to find treasure') 130 | actions = [] 131 | for i, pirate in enumerate(pirates): 132 | if agents_done[i]: 133 | self.log.info("Round complete! %s got to the treasure first!" % pirate.name) 134 | return i, episode_times 135 | # Get the action for each pirate based on its observation 136 | actions.append(pirate.act(observation[i, :, :, :])) 137 | if py_t0: # timing 138 | episode_times.append(py_t1 - py_t0) # timing 139 | py_t1 = time.time() # timing 140 | env_info = self.unity_env.step(actions) # Step in unity environment 141 | py_t0 = time.time() # timing 142 | 143 | 144 | if __name__ == '__main__': 145 | island = Island(brain='PirateBrain', file_name='local/unityenv/BootyFind') 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ** NO LONGER MAINTAINED, USE AT YOUR OWN RISK ** 2 | 3 | # PirateAI 4 | 5 | PirateAI is a personal project that trains autonomous agents (pirates) in a simulated environment (island). 6 | This repo runs a training pipeline that alternates between a game (find the treasure) and model training 7 | sessions (Keras + hyperopt). 8 | 9 | [YouTube Video](https://youtu.be/P3grJ5LHb8k) 10 | 11 | ## Environment 12 | 13 | ![alt text](docs/images/pirate.PNG "Get to the treasure first") 14 | 15 | The agents in this island are pirates. Pirates are marooned (randomly spawned) within the island 16 | in pairs (pirate vs pirate). The goal of each pirate when marooned is to find the treasure (walk up and touch it), 17 | after which the game will end. Winning pirates rise in the ranks within their Ship, and slowly but surely the best 18 | pirates bubble to the top. 19 | 20 | The Ship is a local sqlite database that holds pirates. It keeps track of wins, losses, ranking, randomly generated 21 | pirate names, etc. Each pirate has a specific uuid4 string called its dna, which is used to connect it with saved 22 | models, logs, and database entries. 23 | 24 | ![alt text](docs/images/treasure.PNG "Keep your eye on the prize") 25 | 26 | Pirates have a camera in the center of their head that records 128 x 128 x 3 images. These images are used as input 27 | for the pirate models, which will output an instantaneous action for a given image. Data sets of image and action 28 | pairs are created using a heuristic or through a human player. 29 | 30 | ![alt text](docs/images/island.PNG "The island environment") 31 | 32 | The island environment is created using [Unity](https://unity3d.com/), a popular game-development engine, and Google 33 | [Poly](https://poly.google.com/). In particular, I use [Unity ML Agents](https://github.com/Unity-Technologies/ml-agents) 34 | to interface and communicate with the unity environment. Note I use a custom [fork](https://github.com/HugoCMU/ml-agents), 35 | which I do not keep up to date with the latest repo. 36 | 37 | ## Learning 38 | 39 | ![alt text](docs/images/arch.PNG "Encoder section") 40 | 41 | Pirates make decisions by evaluating inputs with deep learning models. I focus on smaller, 42 | simpler models since my GPU resources are limited. 43 | 44 | Some models are custom, and some are loosely based on things I read in papers. For example the _ff_a3c_ model is 45 | inspired by the encoder in _Learning to Navigate Complex Environments_. I also include larger pre-trained model bases 46 | such as _Xception_ and _Resnet50_. 47 | 48 | The model architechture is just one of many different possible hyperparameters when it comes to training 49 | a model. Others include optimizer, learning rate, batch size, number of fc layers, batchnorm, etc. The space of 50 | possible hyperparameters is huge, and searching it is expensive. To best utilize the low GPU resources 51 | available, pirateAI uses a two-part training process: 52 | 53 | 54 | ***Step 1*** Make use of [hyperopt](https://github.com/hyperopt/hyperopt) to train a series models while searching 55 | through 56 | the hyperparameter space. I use validation loss (cross entropy between the true label and predicted label for an input image) 57 | as the optimization metric for the hyperopt training sessions. Graduate the best models (lowest validation loss) to 58 | full pirates, and add them to the ship (db). 59 | 60 | ***Step 2*** Select two pirates from the ship and maroon them on the island. This starts a marooning game, where the 61 | first 62 | pirate to get to the tresure wins. Pirates accumulate saltyness (rating system) every win, and the pirates with the 63 | lowest saltyness are permanently removed from the ship. Running models in inference is quicker than training and 64 | environment performance is a better validation metric than validation loss. 65 | 66 | By alternating between training models and evaluating them in a competitive game within the environment, you can 67 | quickly arrive at good performing model and set of hyperparameters. 68 | 69 | ## Results 70 | 71 | For the final experiment, all the models had relatively low capacity (weights) and only a limited amount of training 72 | data. A true test of resourcefulness. The results here are from running the pirateAI loop for 12 hours. Looking into 73 | our ship database we get the following top 5 pirates (lots of Wills apparently): 74 | 75 | ![alt text](docs/images/top5.png "The saltiest") 76 | 77 | The associated Tensorboard plots for validation accuracy and validation loss. The shortness of the lines is due to 78 | aggressive early stopping (to further speed up the training process). From the plots, we can see that the 79 | metrics are fairly spread out, confirming our intuition that validation loss & accuracy are inappropriate for 80 | judging final agent performance. 81 | 82 | ![alt text](docs/images/val_acc_tensorboard.png "Validation Accuracy") 83 | ![alt text](docs/images/val_loss_tensorboard.png "Validation Loss") 84 | 85 | 86 | ***Architecture*** 87 | 88 | ![alt text](docs/images/arch_table.png "Top 5 Model Architectures") 89 | 90 | The hyperparameter space for model architechture boiled down to: 91 | 92 | - 5 different convolution bases. 93 | 94 | - 3 different dimensionality reduction choices. This is the part of the model that takes the 3D tensors coming 95 | out of the conv block and reduces it to a 2D tensor which we can feed to the model head. 96 | 97 | - 6 different head choices (with and without dropout whose percentage also varied). These are the fully connected 98 | layers right before the final output layer of the net. 99 | 100 | You can look up the exact makeup of the layers for the model parts above in [`model_chunks.py`](src/model_chunks.py). 101 | It seemed 102 | like the _256, 32_ head (256x1 fully connected layer followed by a 32x1 fc layer) was the most popular. Due to the 103 | black box 104 | nature of deep learning, its hard to give a reason other than _it seems best for this problem_. 105 | 106 | The interesting result for me is that the number of parameters, or model capacity, did not seem to be important when 107 | determining agent performance. In fact, the top pirate actually had the _least_ capacity. 108 | 109 | ***Dataset*** 110 | 111 | ![alt text](docs/images/data_table.png "Top 5 Data Parameters") 112 | 113 | The first thing to talk about is the size of the training data. We see large variability here in the top 5. To me 114 | this suggests that for the given problem of finding a treasure, you don't actually need that much data. Having more 115 | data thus doesn't really help you. 116 | 117 | Another hyperparameter tested on the dataset was the composition of each data set. There were four dataset _flavors_ 118 | available for training, these were: 119 | 120 | 1. Noisy data (some wrong actions) with a bias towards moving forward (W action) 121 | 2. Noisy data with a balanced distribution of actions 122 | 3. Clean data with a bias towards moving forward 123 | 4. Clean data with a balanced distribution of actions 124 | 125 | Every single one of the top 5 models was trained using dataset 3 above. Clean data being better than noisy data makes 126 | sense, since there will be less examples that push the gradient in the wrong direction. As for why the forward bias 127 | matters, my speculation is that pirates that move forward slightly more often get to the treasure slightly quicker. 128 | 129 | ***Optimization*** 130 | 131 | ![alt text](docs/images/opt_table.png "Top 5 Optimization Parameters") 132 | 133 | There were a couple different hyperparameters related to optimization and training that were tested. The results 134 | here were the least interesting, with lots of variability in the top 5. This makes sense intuitively, as the best 135 | optimization parameters will be dependent on the model architechture and dataset. 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import random 4 | import pickle 5 | import csv 6 | import json 7 | from matplotlib.image import imread 8 | import matplotlib.pyplot as plt 9 | import src.config as config 10 | 11 | 12 | def histogram_of_labels(labels, label_dict=None): 13 | """ 14 | Creates and displays a histogram plot for the given labels 15 | :param labels: [int] labels 16 | :param label_dict: {} dictionary of possible labels for human readability 17 | """ 18 | if label_dict is None: 19 | label_dict = config.LABEL_DICT 20 | plt.hist(labels, bins=len(label_dict), normed=True) 21 | plt.xticks(range(len(label_dict)), label_dict.values()) 22 | plt.ylabel('% of data') 23 | plt.show() 24 | 25 | 26 | def plot_images(images, label, pred_label=None, label_dict=None): 27 | """ 28 | Creates and displays a 3x6 plot of sample images with their labels. 29 | :param images: [ndarray] input image 30 | :param label: [ndarray] true label 31 | :param pred_label: [ndarray] predicted label 32 | :param label_dict: {} dictionary of possible targets for human readability 33 | """ 34 | if label_dict is None: 35 | label_dict = config.LABEL_DICT 36 | assert len(images) == len(label), "Dimension mismatch between images and labels given" 37 | fig, axes = plt.subplots(3, 6, figsize=(12, 5)) 38 | fig.subplots_adjust(hspace=0.3, wspace=0.1) 39 | for i, ax in enumerate(axes.flat): 40 | if i < len(images): # Less than 9 images 41 | ax.imshow(images[i], cmap='binary') 42 | true_label = label_dict[label[i]] if label_dict else label[i] 43 | if pred_label is None: 44 | xlabel = "True: %s" % true_label 45 | else: 46 | predict_label = label_dict[label[i]] if label_dict else label[i] 47 | xlabel = "True: %s, Pred: %s" % (true_label, predict_label) 48 | ax.set_xlabel(xlabel) 49 | # Remove x and y ticks 50 | ax.set_xticks([]) 51 | ax.set_yticks([]) 52 | plt.show() 53 | 54 | 55 | def image_analysis(image, label): 56 | """ 57 | Displays image and label along with histogram of values 58 | :param images: [ndarray] input image 59 | :param label: [ndarray] true label 60 | """ 61 | assert len(image.shape) == 3, 'Wrong size image given to image_analysis' 62 | fig, axes = plt.subplots(1, 2) 63 | axes[0].imshow(image) 64 | axes[1].hist(image.flatten()) 65 | axes[0].set_xlabel('Label %s' % str(label)) 66 | plt.show() 67 | 68 | 69 | class Dataset(object): 70 | """ 71 | TThe dataset class provides a way to 72 | """ 73 | 74 | def __init__(self, label_dict=None): 75 | if label_dict is None: 76 | label_dict = config.LABEL_DICT 77 | self.log = logging.getLogger(__name__) 78 | self.log.info('New dataset, label dictionary: %s' % json.dumps(config.LABEL_DICT)) 79 | self.label_dict = label_dict 80 | self.input_paths = [] 81 | self.label_list = [] 82 | 83 | @classmethod 84 | def from_datasets(cls, datasets=None, label_dict=None, max_n=None): 85 | """ 86 | Create dataset from list of other datasets 87 | :param datasets: [Dataset,] list of datasets 88 | :param max_n: (int) maximum number of datapoints 89 | :return: (Dataset) Dataset object 90 | """ 91 | if label_dict is None: 92 | label_dict = config.LABEL_DICT 93 | assert all([isinstance(d, Dataset) for d in datasets]), "Please provide only datasets" 94 | d = cls(label_dict) 95 | d.input_paths, d.label_list = d.combine_datasets(datasets) 96 | if max_n: 97 | d.input_paths = d.input_paths[:max_n] 98 | d.label_list = d.label_list[:max_n] 99 | return d 100 | 101 | @classmethod 102 | def from_path(cls, dir_path=None, label_dict=None, max_n=None): 103 | """ 104 | Create dataset from folder 105 | :param dir_path: (string) path to data folder 106 | :param max_n: (int) maximum number of datapoints 107 | :return: (Dataset) Dataset object 108 | """ 109 | if label_dict is None: 110 | label_dict = config.LABEL_DICT 111 | assert dir_path, "Please provide a path to a folder of data" 112 | d = cls(label_dict) 113 | d.dir_path = dir_path 114 | d.input_paths, d.label_list = d._get_data_paths() 115 | if max_n: 116 | d.input_paths = d.input_paths[:max_n] 117 | d.label_list = d.label_list[:max_n] 118 | return d 119 | 120 | def _get_data_paths(self): 121 | """ 122 | Puts together path for image and corresponding target. 123 | :param filepath: directory containing data 124 | :return:[string],[int] image paths, labels 125 | """ 126 | assert self.dir_path is not None, "No filepath for the dataset" 127 | input_paths = [] 128 | targets = [] 129 | # Open CSV file with targets and image filenames 130 | with open(self.dir_path + '/targets.csv', newline='\n') as csvfile: 131 | spamreader = csv.reader(csvfile, delimiter=',') 132 | for row in spamreader: 133 | try: 134 | input_paths.append(self.dir_path + '/0_' + row[0] + '.png') 135 | targets.append(int(row[1])) 136 | except FileNotFoundError: 137 | pass 138 | return input_paths, targets 139 | 140 | def load_data_sample(self, n=1): 141 | """ 142 | Random sample of images and targets as np arrays 143 | :param n: (bool) True to return a sample of 9 144 | :return:[ndarray],[ndarray] image paths, labels 145 | """ 146 | if n: # Get random sample of size n first 147 | idx = random.sample(range(len(self.input_paths)), n) 148 | image_paths = [self.input_paths[i] for i in idx] 149 | targets = [self.label_list[i] for i in idx] 150 | # Load images and targets as ndarrays 151 | images = np.asarray([imread(path) for path in image_paths]) 152 | labels = np.asarray(targets) 153 | return images, labels 154 | 155 | def only_label(self, label=''): 156 | """ 157 | Returns dataset with only instances of this label 158 | :param label: (string) label 159 | :return: {} 160 | """ 161 | assert label in self.label_dict.values(), "Label not in label dictionary" 162 | # Get indices for labels corresponding to each label in label_dict 163 | label_idx = [i for i, x in enumerate(self.label_list) if self.label_dict[x] == label] 164 | input_paths = [self.input_paths[i] for i in label_idx] 165 | label_list = [self.label_list[i] for i in label_idx] 166 | # Create a new dataset 167 | d = Dataset(label_dict=self.label_dict) 168 | d.input_paths = input_paths 169 | d.label_list = label_list 170 | return d 171 | 172 | @staticmethod 173 | def combine_datasets(datasets=None): 174 | """ 175 | Combine datasets to get list of paths and labels 176 | :param datasets: [Dataset,] list of datasets 177 | :return:[string],[int] image paths, labels 178 | """ 179 | if datasets is None: 180 | datasets = [] 181 | input_paths = [] 182 | targets = [] 183 | for d in datasets: 184 | input_paths += d.input_paths 185 | targets += d.label_list 186 | return input_paths, targets 187 | 188 | def analyze(self): 189 | """ 190 | Plots some sample images and their label. Histogram of labels. 191 | """ 192 | print("Size of dataset is %s" % len(self.label_list)) 193 | images, labels = self.load_data_sample(n=18) 194 | plot_images(images, labels, label_dict=self.label_dict) 195 | histogram_of_labels(self.label_list, label_dict=self.label_dict) 196 | 197 | def save_to_pickle(self, path=None): 198 | """ 199 | Saves dataset image and label lists to the given path 200 | :param path: (string) path to save location 201 | """ 202 | assert path, "Please provide a path when saving to pickle file" 203 | with open(path, 'wb') as f: 204 | pickle.dump((self.input_paths, self.label_list), f) 205 | -------------------------------------------------------------------------------- /src/model_chunks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import keras 3 | from keras import layers, applications 4 | 5 | # Define logger 6 | logger = logging.getLogger(__name__) 7 | 8 | """ 9 | This file contains chunks of models. Functions that define different 10 | model bases (or encoders), different final dim reduction layers, and 11 | different model heads. These are put in this separate file to prevent cluttering. 12 | """ 13 | 14 | 15 | def custom_model(x, params, run_doc): 16 | """ 17 | Returns the output tensor after going through a custom base 18 | :param x: input tensor 19 | :param params: dictionary of params 20 | :param run_doc: {Ordered dict} run documentation 21 | :return: output tensor 22 | """ 23 | # Get parameters from params dictionary, throw errors if not found 24 | base = params.get('base', None) 25 | dimreduc = params.get('dim_reduction', None) 26 | head = params.get('head', None) 27 | assert base, "No base parameter provided" 28 | assert dimreduc, "No dim reduction layer parameter provided" 29 | assert head, "No head parameter provided " 30 | # Get all possible base functions from globals 31 | possible = globals().copy() 32 | possible.update(locals()) 33 | # Base -> Dim Reduction -> Head 34 | x, run_doc = model_chunk(base, x, params, possible, run_doc) 35 | x, run_doc = model_chunk(dimreduc, x, params, possible, run_doc) 36 | x, run_doc = model_chunk(head, x, params, possible, run_doc) 37 | return x, run_doc 38 | 39 | 40 | def model_chunk(chunk, x, params, possible, run_doc): 41 | """ 42 | Runs input through a model chunk (function) 43 | :param chunk: (string) name of model chunk 44 | :param x: input tensor 45 | :param params: {dict} hyperparams (sub-selection) 46 | :param possible: {dict} local function names 47 | :param run_doc: {Ordered dict} run documentation 48 | :return: output tensor, run documentation 49 | :raises ValueError: could not find parameter 50 | """ 51 | func = possible.get(chunk) 52 | sub_param = params.get(chunk + ' params', None) 53 | if not func: 54 | raise ValueError('Could not find %s layer' % chunk) 55 | x = func(x, sub_param) 56 | run_doc[chunk] = sub_param 57 | return x, run_doc 58 | 59 | 60 | def _big_base(func): 61 | """ 62 | Decorator for big bases (xception, inception, etc) 63 | :raise ValueError: missing parameters 64 | """ 65 | 66 | def wrapper(x, params): 67 | assert params, 'Model chunk needs params' 68 | pre_trained = params.get('pre_trained', None) 69 | trainable = params.get('trainable', None) 70 | input_shape = params.get('input_shape', None) 71 | if any(p is None for p in [pre_trained, trainable, input_shape]): 72 | raise ValueError('xception missing argument') 73 | # Option for pre-trained weights from imagenet 74 | weights = 'imagenet' if pre_trained else None 75 | base = func(weights, input_shape) 76 | base.trainable = trainable # optionally freeze the base 77 | return base(x) 78 | 79 | return wrapper 80 | 81 | 82 | @_big_base 83 | def xception(weights, input_shape): 84 | """ 85 | Chollet's Xception architechture, supposedly better than InceptionV4 86 | :param weights: (bool) 87 | :param input_shape: tuple(3) 88 | :return: base 89 | """ 90 | base = applications.xception.Xception(weights=weights, 91 | input_shape=input_shape, 92 | include_top=False) 93 | return base 94 | 95 | 96 | @_big_base 97 | def inception_res_v2(weights, input_shape): 98 | """ 99 | Google's Inception Resnet V2 100 | :param weights: (bool) 101 | :param input_shape: tuple(3) 102 | :return: base 103 | """ 104 | # Option for pre-trained weights from imagenet 105 | base = applications.inception_resnet_v2.InceptionResNetV2(weights=weights, 106 | input_shape=input_shape, 107 | include_top=False) 108 | return base 109 | 110 | 111 | @_big_base 112 | def res_net_50(weights, input_shape): 113 | """ 114 | The Resnet 50. Residual Connections. 115 | :param weights: (bool) 116 | :param input_shape: tuple(3) 117 | :return: base 118 | """ 119 | # Option for pre-trained weights from imagenet 120 | base = applications.resnet50.ResNet50(weights=weights, 121 | input_shape=input_shape, 122 | include_top=False) 123 | return base 124 | 125 | 126 | def a3c(x, params): 127 | """ 128 | Feed forward model used in a3c paper 129 | :param x: input tensor 130 | :param params: {dict} hyperparams (sub-selection) 131 | :return: output tensor 132 | :raises ValueError: could not find parameter 133 | """ 134 | x = layers.Conv2D(filters=16, kernel_size=8, strides=4, activation='relu')(x) 135 | x = layers.Conv2D(filters=32, kernel_size=4, strides=2, activation='relu')(x) 136 | return x 137 | 138 | 139 | def a3c_sepconv(x, params): 140 | """ 141 | Feed forward model used in a3c paper but with seperable convolutions 142 | :param x: input tensor 143 | :param params: {dict} hyperparams (sub-selection) 144 | :return: output tensor 145 | :raises ValueError: could not find parameter 146 | """ 147 | x = layers.SeparableConv2D(filters=16, kernel_size=8, strides=4, activation='relu')(x) 148 | x = layers.SeparableConv2D(filters=32, kernel_size=4, strides=2, activation='relu')(x) 149 | return x 150 | 151 | 152 | def minires(x, params): 153 | """ 154 | Small net with residual connections 155 | :param x: input tensor 156 | :param params: {dict} hyperparams (sub-selection) 157 | :return: output tensor 158 | :raises ValueError: could not find parameter 159 | """ 160 | x_a = layers.Conv2D(16, 4, activation='relu', padding='same')(x) 161 | x_b = layers.Conv2D(16, 8, activation='relu', padding='same')(x) 162 | x_1 = keras.layers.concatenate([x_a, x_b]) 163 | x_c = layers.Conv2D(32, 4, activation='relu', padding='same')(x_1) 164 | x_d = layers.Conv2D(32, 8, activation='relu', padding='same')(x_1) 165 | x = keras.layers.concatenate([x_c, x_d, x_1]) 166 | return x 167 | 168 | 169 | def tall_kernel(x, params): 170 | """ 171 | Small net with residual connections and tall kernels 172 | :param x: input tensor 173 | :param params: {dict} hyperparams (sub-selection) 174 | :return: output tensor 175 | :raises ValueError: could not find parameter 176 | """ 177 | x_a = layers.Conv2D(16, kernel_size=(2, 8), activation='relu', padding='same')(x) 178 | x_b = layers.Conv2D(16, kernel_size=(8, 2), activation='relu', padding='same')(x) 179 | x_1 = keras.layers.concatenate([x_a, x_b]) 180 | x_c = layers.Conv2D(32, kernel_size=(2, 8), activation='relu', padding='same')(x_1) 181 | x_d = layers.Conv2D(32, kernel_size=(8, 2), activation='relu', padding='same')(x_1) 182 | x = keras.layers.concatenate([x_c, x_d, x_1]) 183 | return x 184 | 185 | 186 | def simpleconv(x, params): 187 | """ 188 | Simple CNN base. 189 | :param x: input tensor 190 | :param params: {dict} hyperparams (sub-selection) 191 | :return: output tensor 192 | :raises ValueError: could not find parameter 193 | """ 194 | x = layers.Conv2D(16, 3, activation='relu')(x) 195 | x = layers.Conv2D(32, 3, activation='relu')(x) 196 | return x 197 | 198 | 199 | def fc(x, params): 200 | """ 201 | Fully connected net head 202 | :param x: input tensor 203 | :param params: {dict} hyperparams (sub-selection) 204 | :return: output tensor 205 | """ 206 | assert params, 'Model chunk needs params' 207 | # Head consists of a couple dense layers with dropout 208 | for i in range(0, len(params["dense_layers"])): 209 | x = layers.Dense(params["dense_layers"][i], activation=params["dense_activations"])(x) 210 | x = layers.Dropout(params['dropout_percentage'])(x) 211 | return x 212 | 213 | 214 | def global_average(x, params): 215 | """ 216 | Global average pooling 217 | :param x: input tensor 218 | :param params: {dict} hyperparams (sub-selection) 219 | :return: output tensor 220 | """ 221 | return layers.GlobalAveragePooling2D()(x) 222 | 223 | 224 | def global_max(x, params): 225 | """ 226 | Global max pooling 227 | :param x: input tensor 228 | :param params: {dict} hyperparams (sub-selection) 229 | :return: output tensor 230 | """ 231 | return layers.GlobalMaxPool2D()(x) 232 | 233 | 234 | def flatten(x, params): 235 | """ 236 | Plain ol' 2D flatten 237 | :param x: input tensor 238 | :param params: {dict} hyperparams (sub-selection) 239 | :return: output tensor 240 | """ 241 | return layers.Flatten()(x) 242 | -------------------------------------------------------------------------------- /src/hyperopt_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import time 5 | from datetime import datetime 6 | from uuid import uuid4 7 | from PIL import Image 8 | import random 9 | import pickle 10 | from collections import OrderedDict 11 | import keras 12 | from hyperopt import STATUS_OK 13 | from hyperopt import fmin, hp, tpe 14 | from keras import layers, Input, optimizers 15 | from keras.models import Model 16 | import src.config as config 17 | 18 | from .model_chunks import custom_model 19 | from .dataset import image_analysis 20 | 21 | class HyperoptTrainer(object): 22 | """ 23 | This class uses Hyperopt and Keras to generate pirates. 24 | """ 25 | 26 | def __enter__(self): 27 | self.log = logging.getLogger(__name__) 28 | # Timing and record keeping 29 | self._start_time = time.time() 30 | self._results = {} 31 | self._max_eval = config.MAX_TRAIN_TRIES 32 | self._eval_idx = 0 33 | # Create directory for the logs and models 34 | folder_name = datetime.now().strftime('%Y%m%d') 35 | self._logs_dir = os.path.join(config.LOGS_DIR, folder_name) 36 | self._models_dir = os.path.join(config.MODEL_DIR, folder_name) 37 | os.makedirs(self._logs_dir, exist_ok=True) 38 | os.makedirs(self._models_dir, exist_ok=True) 39 | return self 40 | 41 | def __exit__(self, exception_type, exception_value, traceback): 42 | pass 43 | 44 | def data_loader(self, dataset=None, shuffle=True, size=None): 45 | """ 46 | Data providing function. Reads from local datasets folder 47 | :param dataset: (string) path to folder containing raw images and csv file with targets 48 | :param shuffle: (bool) shuffle training data 49 | :param size: (int) size of final dataset 50 | :return: x_train, y_train, x_test, y_test 51 | """ 52 | self.log.info("Loading data") 53 | assert dataset is not None, "Please provide a dataset (folder name)" 54 | data_path = os.path.join(config.DATA_DIR, dataset) 55 | # Load dataset using pickle 56 | with open(data_path, 'rb') as file: 57 | image_paths, labels = pickle.load(file) 58 | one_hot_labels = [] 59 | images = [] 60 | for i, (path, label) in enumerate(zip(image_paths, labels)): 61 | try: 62 | # One-hot encode the vectors 63 | one_hot = [0, 0, 0, 0] 64 | one_hot[label] = 1.0 65 | # Clean up image, normalize (mean to 0, std to 1) 66 | image = Image.open(path) 67 | image = np.asarray(image, dtype=np.float32) / 255 68 | except: # If there is some issue reading in data, skip datapoint 69 | continue 70 | one_hot_labels.append(np.asarray(one_hot, dtype=np.float32)) 71 | images.append(image) 72 | # Shuffle data before cutting it into test and src 73 | if shuffle: 74 | x = list(zip(images, one_hot_labels)) 75 | random.shuffle(x) 76 | images, one_hot_labels = zip(*x) 77 | self.log.info("Separating data into test and src.") 78 | split_idx = int(config.TRAIN_TEST_SPLIT * len(one_hot_labels)) 79 | train_input = images[:split_idx] 80 | train_target = one_hot_labels[:split_idx] 81 | test_input = images[split_idx:] 82 | test_target = one_hot_labels[split_idx:] 83 | if size: 84 | assert size < len(train_input), "Final dataset size too big, not enough data" 85 | train_input = train_input[:size] 86 | train_target = train_target[:size] 87 | self.log.info(" -- test : {}".format(len(test_target))) 88 | self.log.info(" -- src: {}".format(len(train_target))) 89 | # Convert to nparray before sending over 90 | return np.array(train_input), \ 91 | np.array(train_target), \ 92 | np.array(test_input), \ 93 | np.array(test_target) 94 | 95 | def model(self, hyperparams, test_mode=False): 96 | """ 97 | Builds and runs a model given a dictionary of hyperparameters 98 | :return: {dict} 99 | - loss: validation loss (to be minimized) 100 | - status: STATUS_OK (see hyperopt documentation) 101 | """ 102 | run_doc = OrderedDict() # Document important hyperparameters 103 | run_start_time = time.time() 104 | run_id = str(uuid4()) 105 | # TODO: Not ideal: Loads from memory every time. Use generator? 106 | train_data, train_targets, test_data, test_targets = \ 107 | self.data_loader(dataset=hyperparams['dataset'], size=hyperparams['dataset_size']) 108 | run_doc['dataset'] = hyperparams['dataset'] 109 | run_doc['data_size'] = len(train_targets) 110 | # Visualization tools 111 | if config.INPUT_DEBUG: 112 | image_analysis(image=train_data[0, :, :, :], label=train_targets[0, :]) 113 | # Input shape comes from image shape 114 | img_width = train_data[0].shape[0] 115 | img_height = train_data[0].shape[1] 116 | num_channels = train_data[0].shape[2] 117 | input_shape = (img_width, img_height, num_channels) 118 | run_doc['input_shape'] = '(%d, %d, %d)' % input_shape 119 | input_tensor = Input(shape=input_shape, dtype='float32', name='input_image') 120 | try: # Model creation is in separate file 121 | x, run_doc = custom_model(input_tensor, params=hyperparams, run_doc=run_doc) 122 | except ValueError as e: 123 | if not test_mode: # If not testing, ignore error causing models 124 | return {'loss': 100, 'status': STATUS_OK} 125 | else: 126 | raise e 127 | # Final layer classifies into 4 possible actions 128 | output = layers.Dense(4, activation='softmax')(x) 129 | # File names for the model and logs 130 | log_file = os.path.join(self._logs_dir, run_id) 131 | model_file = os.path.join(self._models_dir, run_id + '.h5') 132 | # Add some callbacks so we can track progress using Tensorboard 133 | callbacks = [keras.callbacks.EarlyStopping('val_loss', patience=config.TRAIN_PATIENCE, mode="min")] 134 | if not test_mode: # Don't save models/logs if in testing mode 135 | callbacks += [keras.callbacks.TensorBoard(log_dir=log_file), 136 | keras.callbacks.ModelCheckpoint(model_file, save_best_only=True)] 137 | # Choice of optimizer and optimization parameters 138 | if hyperparams['optimizer'] == 'sgd': 139 | optimizer = optimizers.SGD(lr=hyperparams["learning_rate"], 140 | decay=hyperparams["decay"], 141 | clipnorm=hyperparams["clipnorm"]) 142 | elif hyperparams['optimizer'] == 'rmsprop': 143 | optimizer = optimizers.RMSprop(lr=hyperparams["learning_rate"], 144 | decay=hyperparams["decay"], 145 | clipnorm=hyperparams["clipnorm"]) 146 | elif hyperparams['optimizer'] == 'nadam': 147 | optimizer = optimizers.Nadam(lr=hyperparams["learning_rate"], 148 | schedule_decay=hyperparams["decay"], 149 | clipnorm=hyperparams["clipnorm"]) 150 | elif hyperparams['optimizer'] == 'adam': 151 | optimizer = optimizers.Adam(lr=hyperparams["learning_rate"], 152 | decay=hyperparams["decay"], 153 | clipnorm=hyperparams["clipnorm"]) 154 | # Save optimizer parameters to run doc 155 | run_doc['optimizer'] = hyperparams['optimizer'] 156 | run_doc['opt_learning_rate'] = hyperparams["learning_rate"] 157 | run_doc['opt_decay'] = hyperparams["decay"] 158 | run_doc['opt_clipnorm'] = hyperparams["clipnorm"] 159 | # Create and compile the model 160 | model = Model(input_tensor, output) 161 | model.compile(loss='categorical_crossentropy', 162 | optimizer=optimizer, 163 | metrics=['accuracy']) 164 | # Print out model summary and store inside run documentation as list of strings 165 | model.summary() 166 | run_doc['model_summary'] = [] 167 | model.summary(print_fn=(lambda a: run_doc['model_summary'].append(a))) 168 | # Fit the model to the datasets 169 | self.log.info("Fitting model (eval %d of %d) ..." % (self._eval_idx + 1, self._max_eval)) 170 | self._eval_idx += 1 171 | model.fit(x=train_data, y=train_targets, 172 | batch_size=hyperparams['batch_size'], 173 | epochs=hyperparams['epochs'], 174 | validation_data=(test_data, test_targets), 175 | callbacks=callbacks, 176 | verbose=1) 177 | val_loss, val_acc = model.evaluate(x=test_data, y=test_targets, verbose=2) 178 | self.log.info(" .... Completed!") 179 | self.log.info(" -- Evaluation time %ds" % (time.time() - run_start_time)) 180 | self.log.info(" -- Total time %ds" % (time.time() - self._start_time)) 181 | # Save training parameters to run doc 182 | run_doc['batch_size'] = hyperparams['batch_size'] 183 | run_doc['epochs'] = hyperparams['epochs'] 184 | run_doc['val_loss'] = val_loss 185 | run_doc['val_acc'] = val_acc 186 | # Results are used to pick best pirate 187 | self._results[run_id] = val_loss 188 | # Save run_doc to pickle file in model directory 189 | run_doc_file_name = run_id + '.pickle' 190 | if not test_mode: # Don't save docs if in testing mode 191 | with open(os.path.join(self._models_dir, run_doc_file_name), 'wb') as f: 192 | pickle.dump(run_doc, f) 193 | self.log.info('Run Dictionary %s' % str(run_doc)) 194 | # Delete the session to prevent GPU memory from getting full 195 | keras.backend.clear_session() 196 | # Optimizer minimizes validation loss 197 | return {'loss': val_loss, 'status': STATUS_OK} 198 | 199 | def run_hyperopt(self, max_eval, space): 200 | """ 201 | Runs the hyperopt trainer 202 | :param max_eval: (int) max evaluations to carry out when running hyperopt 203 | :param space: {dict} }dictionary of hyperparameter space to explore 204 | :return: dictionary of best fit models by dna 205 | """ 206 | # Reset run parameters 207 | self._max_eval = max_eval 208 | self._results = {} 209 | self._eval_idx = 0 210 | 211 | # Hyperopt is picky about the function handle 212 | def model_handle(params): 213 | return self.model(params) 214 | 215 | # Run the hyperparameter optimization 216 | _ = fmin(fn=model_handle, space=space, algo=tpe.suggest, max_evals=max_eval) 217 | return self._results 218 | -------------------------------------------------------------------------------- /src/ship.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import sqlite3 5 | from src.hyperopt_trainer import HyperoptTrainer 6 | from src.pirate import Pirate 7 | import src.config as config 8 | 9 | 10 | def _check_dna(func): 11 | """ 12 | Decorator makes sure dna is a string and not none 13 | :raises ValueError: if dna is not string 14 | """ 15 | 16 | def wrapper(*args, **kwargs): 17 | dna = kwargs.get('dna', None) 18 | # TODO: Check if it is actually a uuid4, not just if its a string 19 | if not isinstance(dna, str): 20 | raise ValueError('dna must be a string UUID4') 21 | return func(*args, **kwargs) 22 | 23 | return wrapper 24 | 25 | 26 | class Ship(object): 27 | """ 28 | The Ship is where the Pirates live. It contains the interface to the sqlite database 29 | which is where scores and meta-data for each pirate is kept. 30 | """ 31 | 32 | def __init__(self, ship_name='Boat'): 33 | """ 34 | :param ship_name: (string) name this vessel! 35 | """ 36 | self.log = logging.getLogger(__name__) 37 | self.name = ship_name 38 | # Database connection params 39 | self._c = None 40 | self._conn = None 41 | self._database_path = os.path.abspath(os.path.join(config.SHIP_DIR, self.name + '.db')) 42 | 43 | def __enter__(self): 44 | """ 45 | Set off to sea! Connects to local sqlite db. Creates table if database does not yet exist 46 | """ 47 | # Connect to the database, set up cursor 48 | self.log.debug('Starting up database at %s' % self._database_path) 49 | self._conn = sqlite3.connect(self._database_path) 50 | self._conn.row_factory = sqlite3.Row 51 | self._c = self._conn.cursor() 52 | # Create the table 53 | with self._conn: 54 | try: 55 | self._c.execute("""CREATE TABLE pirates ( 56 | dna TEXT, 57 | name TEXT DEFAULT 'Unborn', 58 | rank INTEGER DEFAULT 0, 59 | win INTEGER DEFAULT 0, 60 | loss INTEGER DEFAULT 0, 61 | saltyness INTEGER DEFAULT 0 62 | )""") 63 | except sqlite3.OperationalError: 64 | pass # Table was already created 65 | return self 66 | 67 | def __exit__(self, exception_type, exception_value, traceback): 68 | pass 69 | 70 | @_check_dna 71 | def _add_pirate(self, dna=None): 72 | """ 73 | Adds a pirate to the ship 74 | :param dna: (string) identifier for the pirate 75 | :raises ValueError: if dna is not string 76 | """ 77 | with self._conn: 78 | try: 79 | self._c.execute('INSERT INTO pirates(dna) VALUES (:dna)', {'dna': dna}) 80 | except sqlite3.Error as e: 81 | self.log.warning('Could not add pirate to ship. Error: %s' % e) 82 | 83 | @_check_dna 84 | def _walk_the_plank(self, dna=None): 85 | """ 86 | Removes a pirate from the ship 87 | :param dna: (string) identifier for the pirate 88 | :raises ValueError: if dna is not string 89 | """ 90 | with self._conn: 91 | try: 92 | self._c.execute('DELETE FROM pirates WHERE dna=:dna', {'dna': dna}) 93 | except sqlite3.Error as e: 94 | self.log.warning('Could not remove pirate from ship. Error: %s' % e) 95 | 96 | @_check_dna 97 | def _set_prop(self, dna=None, prop=None): 98 | """ 99 | Updates properties of pirate on ship 100 | :param dna: (string) identifier for the pirate 101 | :param prop: {string:val,} name:value of the properties 102 | :return: (bool) error 103 | """ 104 | # TODO: take a list of attributes to update 105 | if not isinstance(prop, dict) and not all(isinstance(p, str) for p in prop.keys()): 106 | raise ValueError('Must give a dictionary of properties with string keys to find') 107 | with self._conn: 108 | try: 109 | prop_str = '' 110 | for key, value in prop.items(): 111 | prop_str += key + ' = ' + str(value) + ' , ' 112 | # TODO: This is unsafe SQL practices 113 | query = 'UPDATE pirates SET ' + prop_str[:-2] + 'WHERE dna = \'' + dna + '\'' 114 | self._c.execute(query) 115 | return False 116 | except sqlite3.Error as e: 117 | self.log.warning('Could not set pirate properties. Error: %s' % e) 118 | return True 119 | 120 | @_check_dna 121 | def _get_prop(self, dna=None, prop=None): 122 | """ 123 | Returns properties of pirate on ship 124 | :param dna: (string) identifier for the pirate 125 | :param prop: [string,] name(s) of the property 126 | :return: (bool), [val,] error, name:value of the properties 127 | """ 128 | if not isinstance(prop, list) and not all(isinstance(p, str) for p in prop): 129 | raise ValueError('Must give a list of string properties to find') 130 | with self._conn: 131 | try: 132 | query = 'SELECT ' + ','.join(prop) + ' FROM pirates WHERE dna = \'' + dna + '\'' 133 | self._c.execute(query) 134 | sql_row = [dict(a) for a in self._c.fetchall()] # TODO: clean up b2b list comprehension 135 | return False, [row[key] for key, row in zip(prop, sql_row)] 136 | except (TypeError, sqlite3.Error) as e: 137 | self.log.warning('Could not get pirate properties. Error: %s' % e) 138 | return True, None 139 | 140 | @_check_dna 141 | def create_pirate(self, dna=None): 142 | """ 143 | Creates a pirate on the ship. Watch out: this loads pirate model to memory. 144 | :param dna: (string) identifier for the pirate 145 | :return: (bool), (Pirate) error, the pirate object 146 | :raises ValueError: if dna is not string 147 | """ 148 | with self._conn: 149 | try: 150 | self._c.execute('SELECT * FROM pirates WHERE dna=:dna', {'dna': dna}) 151 | pirate_info = dict(self._c.fetchone()) 152 | except (TypeError, sqlite3.Error) as e: 153 | self.log.warning('Could not find pirate in ship. Error: %s' % e) 154 | return True, None 155 | try: 156 | pirate = Pirate(dna=pirate_info.get('dna', None), 157 | name=pirate_info.get('name', None), 158 | rank=pirate_info.get('rank', None), 159 | win=pirate_info.get('win', None), 160 | loss=pirate_info.get('loss', None), 161 | saltyness=pirate_info.get('saltyness', None)) 162 | # Update the name for the pirate 163 | self._set_prop(dna=dna, prop={'name': '\'' + pirate.name + '\''}) 164 | except FileNotFoundError: 165 | self.log.warning('Could not create pirate. Could not find model associated with it') 166 | return True, None 167 | return False, pirate 168 | 169 | def get_best_pirates(self, n=1): 170 | """ 171 | The (up-to) N saltiest pirates on board the ship. 172 | :param n: (int) up to this number of pirates, less if not many pirates in db 173 | :return: [pirates+] list of pirates 174 | """ 175 | with self._conn: 176 | self._c.execute('SELECT dna FROM pirates ORDER BY saltyness DESC LIMIT 50') 177 | sql_row = [dict(a) for a in self._c.fetchall()] 178 | pirates = [] 179 | for i, d in enumerate(sql_row): 180 | if i >= n: 181 | break 182 | err, pirate = self.create_pirate(dna=d['dna']) 183 | if not err: # Don't add pirates that throw an error on creation 184 | pirates.append(pirate) 185 | return pirates 186 | 187 | def marooning_update(self, winner, losers): 188 | """ 189 | Updates the ship with the results from a marooning 190 | :param winners: [string,] list of string dnas for winning pirates 191 | :param losers: [string,] list of string dnas for losing pirates 192 | :return: (bool) error 193 | """ 194 | # Update wins, losses, and saltyness for the winner and the losers 195 | if winner: # Empty string is False in python 196 | # We can use the +1 formulation because we use string concatentation 197 | self._set_prop(dna=winner, prop={'win': 'win + 1'}) 198 | self._set_prop(dna=winner, prop={'saltyness': 'saltyness + ' + str(config.SALT_PER_WIN)}) 199 | for dna in losers: 200 | self._set_prop(dna=dna, prop={'loss': 'loss + 1'}) 201 | self._set_prop(dna=dna, prop={'saltyness': 'saltyness - ' + str(config.SALT_PER_LOSS)}) 202 | return True # Not yet implemented 203 | 204 | def headcount(self): 205 | """ 206 | How many pirates are on this ship? 207 | :return: (int) number of pirates on this ship (or 0 if error) 208 | """ 209 | with self._conn: 210 | try: 211 | self._c.execute('SELECT count() FROM pirates') 212 | sql_row = [dict(a) for a in self._c.fetchall()] 213 | num_pirates = sql_row[0]['count()'] 214 | self.log.info('There are currently %s pirates on the ship' % num_pirates) 215 | return num_pirates 216 | except (TypeError, sqlite3.Error) as e: 217 | self.log.warning('Failed to perform headcount ship. Error: %s' % e) 218 | return 0 219 | 220 | @_check_dna 221 | def delete_local_pirate_files(self, dna=None): 222 | """ 223 | Deletes local files associated with a pirate (model, docs, logs). 224 | :param dna: (string) identifier for the pirate 225 | :raise FileNotFoundError: can't find the local files 226 | """ 227 | removed = {'model': False, 'doc': False, 'log': False} 228 | for dirpath, _, files in os.walk(config.MODEL_DIR): 229 | if dna + '.h5' in files: 230 | os.remove(os.path.join(dirpath, dna + '.h5')) 231 | removed['model'] = True 232 | if dna + '.pickle' in files: 233 | os.remove(os.path.join(dirpath, dna + '.pickle')) 234 | removed['doc'] = True 235 | for dirpath, dirs, files in os.walk(config.LOGS_DIR): 236 | if dna in dirs: 237 | # Tensorboard logs are a folder 238 | shutil.rmtree(os.path.join(dirpath, dna)) 239 | removed['log'] = True 240 | if not all(removed.values()): # All of the files should be removed 241 | self.log.warning('When removing local files for %s, could not find %s' % (dna, removed)) 242 | 243 | def less_pirates(self, n=config.NUM_PIRATES_PER_CULLING): 244 | """ 245 | Removes the N pirates with lowest saltyness from the ship (and associated local files) 246 | :param n: (int) how many pirates to be removed 247 | """ 248 | with self._conn: 249 | self._c.execute('SELECT dna FROM pirates ORDER BY saltyness ASC LIMIT ?', (str(n),)) 250 | sql_row = [dict(a) for a in self._c.fetchall()] 251 | for d in sql_row: 252 | self.delete_local_pirate_files(dna=d['dna']) 253 | self._walk_the_plank(dna=d['dna']) 254 | 255 | def more_pirates(self, num_pirates=config.NUM_PIRATES_PER_TRAIN, max_tries=config.MAX_TRAIN_TRIES, space=config.SPACE): 256 | """ 257 | Create pirates using hyperopt, adds them to the ship 258 | :param num_pirates: (int) number of pirates to generate 259 | :param max_tries: (int) max number of hyperopt runs before choosing best pirates 260 | """ 261 | assert space, 'Please provide a hyperparameter space for creating pirate models' 262 | with HyperoptTrainer() as trainer: 263 | results = trainer.run_hyperopt(max_tries, space) 264 | # Sort results by highest validation accuracy 265 | top = sorted(results.items(), key=lambda e: e[1]) 266 | self.log.info('Making %s more pirates' % num_pirates) 267 | for idx, (dna, _) in enumerate(top): 268 | if idx < num_pirates: # Only add the best N pirates 269 | self._add_pirate(dna=dna) 270 | self._set_prop(dna=dna, prop={'rank': idx}) 271 | self._set_prop(dna=dna, prop={'saltyness': config.STARTING_SALT}) 272 | else: 273 | self.delete_local_pirate_files(dna=dna) # Delete pirate model from memory --------------------------------------------------------------------------------