├── tests ├── __init__.py ├── fixtures │ ├── feature-1-1.json │ ├── feature-2-4.json │ ├── feature-4-6.json │ ├── feature-5-7.json │ ├── feature-7-9.json │ ├── feature-10-12.json │ ├── feature-10-13.json │ ├── feature-10-15.json │ ├── feature-10-16.json │ ├── feature-10-17.json │ ├── feature-10-18.json │ ├── feature-10-20.json │ ├── feature-10-21.json │ ├── feature-9-10.json │ ├── feature-1-1-hit.json │ ├── feature-1-1-stick.json │ ├── feature-10-15-hit.json │ ├── feature-7-12-hit.json │ └── feature-9-10-stick.json ├── environment_tests.py └── function_approximation_tests.py ├── Q_opt.pkl ├── vis ├── V_mc_1000000_episodes.pdf ├── V_mc_1000000_episodes.png ├── lambda_mse_lfa_gamma_1.0_episodes_20000.pdf ├── lambda_mse_lfa_gamma_1.0_episodes_20000.png ├── policy_gradient_rewards_episodes_100000.pdf ├── policy_gradient_rewards_episodes_100000.png ├── V_lfa_lambda_0.0_gamma_1.0_episodes_20000.pdf ├── V_lfa_lambda_0.0_gamma_1.0_episodes_20000.png ├── V_lfa_lambda_1.0_gamma_1.0_episodes_20000.pdf ├── V_lfa_lambda_1.0_gamma_1.0_episodes_20000.png ├── lambda_mse_sarsa_gamma_1.0_episodes_20000.pdf ├── lambda_mse_sarsa_gamma_1.0_episodes_20000.png ├── V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.pdf ├── V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.png ├── V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.pdf ├── V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.png ├── V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf ├── V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.png ├── V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf └── V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.png ├── requirements.txt ├── agents ├── __init__.py ├── monte_carlo.py ├── sarsa.py ├── policy_gradient.py └── function_approximation.py ├── utils.py ├── .gitignore ├── vis.py ├── environment.py ├── easy21.py └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Q_opt.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/Q_opt.pkl -------------------------------------------------------------------------------- /vis/V_mc_1000000_episodes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_mc_1000000_episodes.pdf -------------------------------------------------------------------------------- /vis/V_mc_1000000_episodes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_mc_1000000_episodes.png -------------------------------------------------------------------------------- /vis/lambda_mse_lfa_gamma_1.0_episodes_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_lfa_gamma_1.0_episodes_20000.pdf -------------------------------------------------------------------------------- /vis/lambda_mse_lfa_gamma_1.0_episodes_20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_lfa_gamma_1.0_episodes_20000.png -------------------------------------------------------------------------------- /vis/policy_gradient_rewards_episodes_100000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/policy_gradient_rewards_episodes_100000.pdf -------------------------------------------------------------------------------- /vis/policy_gradient_rewards_episodes_100000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/policy_gradient_rewards_episodes_100000.png -------------------------------------------------------------------------------- /vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.pdf -------------------------------------------------------------------------------- /vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.png -------------------------------------------------------------------------------- /vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.pdf -------------------------------------------------------------------------------- /vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.png -------------------------------------------------------------------------------- /vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.pdf -------------------------------------------------------------------------------- /vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.png -------------------------------------------------------------------------------- /vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.pdf -------------------------------------------------------------------------------- /vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.png -------------------------------------------------------------------------------- /vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.pdf -------------------------------------------------------------------------------- /vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.png -------------------------------------------------------------------------------- /vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf -------------------------------------------------------------------------------- /vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.png -------------------------------------------------------------------------------- /vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf -------------------------------------------------------------------------------- /vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.3 2 | cycler==0.10.0 3 | matplotlib==2.0.0 4 | nose==1.3.7 5 | numpy==1.12.1 6 | packaging==16.8 7 | pyparsing==2.2.0 8 | python-dateutil==2.6.0 9 | pytz==2017.2 10 | six==1.10.0 11 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | from agents.monte_carlo import MonteCarloAgent 2 | from agents.sarsa import SarsaAgent 3 | from agents.function_approximation import FunctionApproximationAgent 4 | from agents.policy_gradient import PolicyGradientAgent 5 | -------------------------------------------------------------------------------- /tests/fixtures/feature-1-1.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 1, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 0, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 1, 31 | 1 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-2-4.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 1, 6 | 1, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 0, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 2, 31 | 4 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-4-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 1, 6 | 1, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 1, 14 | 1, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 0, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 4, 31 | 6 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-5-7.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 1, 15 | 1, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 0, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 5, 31 | 7 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-7-9.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 1, 15 | 1, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 1, 23 | 1, 24 | 0, 25 | 0, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 7, 31 | 9 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-12.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 1, 24 | 1, 25 | 0, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 12 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-13.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 1, 25 | 1, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 13 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-15.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 1, 26 | 1 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 16 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 1, 26 | 1 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 16 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-17.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 1, 26 | 1 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 17 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-18.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 1, 26 | 1 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 18 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-20.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 0, 26 | 1 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 20 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-21.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | 0, 26 | 1 27 | ] 28 | ], 29 | "state": [ 30 | 10, 31 | 21 32 | ] 33 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-9-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": null, 3 | "expected_feats": [ 4 | [ 5 | 0, 6 | 0, 7 | 0, 8 | 0, 9 | 0, 10 | 0 11 | ], 12 | [ 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0 19 | ], 20 | [ 21 | 0, 22 | 0, 23 | 1, 24 | 1, 25 | 0, 26 | 0 27 | ] 28 | ], 29 | "state": [ 30 | 9, 31 | 10 32 | ] 33 | } -------------------------------------------------------------------------------- /agents/monte_carlo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils import epsilon_greedy_policy 4 | from environment import TERMINAL_STATE, STATE_SPACE_SHAPE 5 | 6 | 7 | class MonteCarloAgent: 8 | def __init__(self, env, num_episodes=1000, **kwargs): 9 | self.num_episodes = num_episodes 10 | self.env = env 11 | self.reset() 12 | 13 | 14 | def reset(self): 15 | self.Q = np.zeros(STATE_SPACE_SHAPE) 16 | 17 | 18 | def learn(self): 19 | env = self.env 20 | Q = self.Q 21 | N = np.zeros(STATE_SPACE_SHAPE) 22 | 23 | for episode in range(1, self.num_episodes+1): 24 | env.reset() 25 | state = env.observe() 26 | E = [] # experience from the episode 27 | 28 | while state != TERMINAL_STATE: 29 | action = epsilon_greedy_policy(Q, N, state) 30 | state_, reward = env.step(action) 31 | 32 | E.append([state, action, reward]) 33 | state = state_ 34 | 35 | for (dealer, player), action, reward in E: 36 | idx = dealer-1, player-1, action 37 | N[idx] += 1 38 | alpha = 1.0 / N[idx] 39 | Q[idx] += alpha * (reward - Q[idx]) 40 | 41 | return Q 42 | -------------------------------------------------------------------------------- /tests/fixtures/feature-1-1-hit.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": "hit", 3 | "expected_feats": [ 4 | [ 5 | [ 6 | 1, 7 | 0 8 | ], 9 | [ 10 | 0, 11 | 0 12 | ], 13 | [ 14 | 0, 15 | 0 16 | ], 17 | [ 18 | 0, 19 | 0 20 | ], 21 | [ 22 | 0, 23 | 0 24 | ], 25 | [ 26 | 0, 27 | 0 28 | ] 29 | ], 30 | [ 31 | [ 32 | 0, 33 | 0 34 | ], 35 | [ 36 | 0, 37 | 0 38 | ], 39 | [ 40 | 0, 41 | 0 42 | ], 43 | [ 44 | 0, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0 50 | ], 51 | [ 52 | 0, 53 | 0 54 | ] 55 | ], 56 | [ 57 | [ 58 | 0, 59 | 0 60 | ], 61 | [ 62 | 0, 63 | 0 64 | ], 65 | [ 66 | 0, 67 | 0 68 | ], 69 | [ 70 | 0, 71 | 0 72 | ], 73 | [ 74 | 0, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 0 80 | ] 81 | ] 82 | ], 83 | "state": [ 84 | 1, 85 | 1 86 | ] 87 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-1-1-stick.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": "stick", 3 | "expected_feats": [ 4 | [ 5 | [ 6 | 0, 7 | 1 8 | ], 9 | [ 10 | 0, 11 | 0 12 | ], 13 | [ 14 | 0, 15 | 0 16 | ], 17 | [ 18 | 0, 19 | 0 20 | ], 21 | [ 22 | 0, 23 | 0 24 | ], 25 | [ 26 | 0, 27 | 0 28 | ] 29 | ], 30 | [ 31 | [ 32 | 0, 33 | 0 34 | ], 35 | [ 36 | 0, 37 | 0 38 | ], 39 | [ 40 | 0, 41 | 0 42 | ], 43 | [ 44 | 0, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0 50 | ], 51 | [ 52 | 0, 53 | 0 54 | ] 55 | ], 56 | [ 57 | [ 58 | 0, 59 | 0 60 | ], 61 | [ 62 | 0, 63 | 0 64 | ], 65 | [ 66 | 0, 67 | 0 68 | ], 69 | [ 70 | 0, 71 | 0 72 | ], 73 | [ 74 | 0, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 0 80 | ] 81 | ] 82 | ], 83 | "state": [ 84 | 1, 85 | 1 86 | ] 87 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-10-15-hit.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": "hit", 3 | "expected_feats": [ 4 | [ 5 | [ 6 | 0, 7 | 0 8 | ], 9 | [ 10 | 0, 11 | 0 12 | ], 13 | [ 14 | 0, 15 | 0 16 | ], 17 | [ 18 | 0, 19 | 0 20 | ], 21 | [ 22 | 0, 23 | 0 24 | ], 25 | [ 26 | 0, 27 | 0 28 | ] 29 | ], 30 | [ 31 | [ 32 | 0, 33 | 0 34 | ], 35 | [ 36 | 0, 37 | 0 38 | ], 39 | [ 40 | 0, 41 | 0 42 | ], 43 | [ 44 | 0, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0 50 | ], 51 | [ 52 | 0, 53 | 0 54 | ] 55 | ], 56 | [ 57 | [ 58 | 0, 59 | 0 60 | ], 61 | [ 62 | 0, 63 | 0 64 | ], 65 | [ 66 | 0, 67 | 0 68 | ], 69 | [ 70 | 1, 71 | 0 72 | ], 73 | [ 74 | 1, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 0 80 | ] 81 | ] 82 | ], 83 | "state": [ 84 | 10, 85 | 15 86 | ] 87 | } -------------------------------------------------------------------------------- /tests/fixtures/feature-7-12-hit.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": "hit", 3 | "expected_feats": [ 4 | [ 5 | [ 6 | 0, 7 | 0 8 | ], 9 | [ 10 | 0, 11 | 0 12 | ], 13 | [ 14 | 0, 15 | 0 16 | ], 17 | [ 18 | 0, 19 | 0 20 | ], 21 | [ 22 | 0, 23 | 0 24 | ], 25 | [ 26 | 0, 27 | 0 28 | ] 29 | ], 30 | [ 31 | [ 32 | 0, 33 | 0 34 | ], 35 | [ 36 | 0, 37 | 0 38 | ], 39 | [ 40 | 1, 41 | 0 42 | ], 43 | [ 44 | 1, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0 50 | ], 51 | [ 52 | 0, 53 | 0 54 | ] 55 | ], 56 | [ 57 | [ 58 | 0, 59 | 0 60 | ], 61 | [ 62 | 0, 63 | 0 64 | ], 65 | [ 66 | 1, 67 | 0 68 | ], 69 | [ 70 | 1, 71 | 0 72 | ], 73 | [ 74 | 0, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 0 80 | ] 81 | ] 82 | ], 83 | "state": [ 84 | 7, 85 | 12 86 | ] 87 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import reduce 3 | from operator import mul 4 | from collections import Iterable 5 | 6 | from environment import ACTIONS 7 | 8 | 9 | def mse(A, B): 10 | return np.sum((A - B) ** 2) / np.size(A) 11 | 12 | 13 | def get_step_size(N): 14 | non_zeros = np.where(N > 0) 15 | steps = np.zeros_like(N) 16 | steps[non_zeros] = 1.0 / N[non_zeros] 17 | 18 | return steps 19 | 20 | 21 | N_0 = 100 22 | def get_epsilon(N): 23 | return N_0 / (N_0 + N) 24 | 25 | 26 | def epsilon_greedy_policy(Q, N, state): 27 | dealer, player = state 28 | epsilon = get_epsilon(np.sum(N[dealer-1, player-1, :])) 29 | if np.random.rand() < (1 - epsilon): 30 | action = np.argmax(Q[dealer-1, player-1, :]) 31 | else: 32 | action = np.random.choice(ACTIONS) 33 | return action 34 | 35 | 36 | def policy_wrapper(Q, N): 37 | def policy(state): 38 | dealer, player = state 39 | assert(0 < dealer < 11 and 0 < player < 22) 40 | eps = get_epsilon(np.sum(N[dealer-1, player-1, :])) 41 | 42 | if np.random.rand() < (1 - eps): 43 | action = np.argmax(Q[dealer-1, player-1, :]) 44 | else: 45 | action = np.random.choice(ACTIONS) 46 | return action 47 | return policy 48 | -------------------------------------------------------------------------------- /tests/fixtures/feature-9-10-stick.json: -------------------------------------------------------------------------------- 1 | { 2 | "action": "stick", 3 | "expected_feats": [ 4 | [ 5 | [ 6 | 0, 7 | 0 8 | ], 9 | [ 10 | 0, 11 | 0 12 | ], 13 | [ 14 | 0, 15 | 0 16 | ], 17 | [ 18 | 0, 19 | 0 20 | ], 21 | [ 22 | 0, 23 | 0 24 | ], 25 | [ 26 | 0, 27 | 0 28 | ] 29 | ], 30 | [ 31 | [ 32 | 0, 33 | 0 34 | ], 35 | [ 36 | 0, 37 | 0 38 | ], 39 | [ 40 | 0, 41 | 0 42 | ], 43 | [ 44 | 0, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0 50 | ], 51 | [ 52 | 0, 53 | 0 54 | ] 55 | ], 56 | [ 57 | [ 58 | 0, 59 | 0 60 | ], 61 | [ 62 | 0, 63 | 0 64 | ], 65 | [ 66 | 0, 67 | 1 68 | ], 69 | [ 70 | 0, 71 | 1 72 | ], 73 | [ 74 | 0, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 0 80 | ] 81 | ] 82 | ], 83 | "state": [ 84 | 9, 85 | 10 86 | ] 87 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | .venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | .DS_Store -------------------------------------------------------------------------------- /tests/environment_tests.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | from nose.tools import assert_equal 3 | 4 | from environment import ( 5 | Easy21Env, TERMINAL_STATE, ACTIONS 6 | ) 7 | 8 | HIT, STICK = ACTIONS 9 | 10 | def mock_draw_card(result): 11 | def fn(color=None): 12 | return result 13 | return fn 14 | 15 | 16 | class TestEnvironment(): 17 | def setUp(self): 18 | self.env = Easy21Env() 19 | 20 | 21 | def tearDown(self): 22 | self.env = None 23 | 24 | def test_step_player_should_bust_if_sum_exceeds_max(self): 25 | CARD = { 'value': 8, 'color': "black" } 26 | card_mock = mock_draw_card(CARD) 27 | PLAYER_START = 0 28 | self.env.reset(dealer=5, player=PLAYER_START) 29 | 30 | with patch("environment.draw_card", card_mock): 31 | state = self.env.observe() 32 | for i in range(3): 33 | player = state[1] 34 | assert_equal(self.env.player, PLAYER_START + i * CARD["value"]) 35 | state, reward = self.env.step(HIT) 36 | 37 | assert_equal(state, TERMINAL_STATE) 38 | assert_equal(reward, -1) 39 | 40 | def test_step_player_should_bust_if_sum_below_min(self): 41 | CARD = { 'value': 8, 'color': "red" } 42 | card_mock = mock_draw_card(CARD) 43 | PLAYER_START = 10 44 | self.env.reset(dealer=5, player=PLAYER_START) 45 | 46 | with patch("environment.draw_card", card_mock): 47 | state = self.env.observe() 48 | for i in range(2): 49 | assert_equal(self.env.player, PLAYER_START - i * CARD["value"]) 50 | state, reward = self.env.step(HIT) 51 | 52 | assert_equal(state, TERMINAL_STATE) 53 | assert_equal(reward, -1) 54 | 55 | def test_step_dealer_finishes_between_17_21(self): 56 | CARD = { 'value': 8, 'color': "black" } 57 | pass 58 | -------------------------------------------------------------------------------- /tests/function_approximation_tests.py: -------------------------------------------------------------------------------- 1 | from nose.tools import assert_equal 2 | import glob 3 | import json 4 | import os 5 | 6 | from environment import ACTIONS 7 | from agents.function_approximation import phi 8 | 9 | HIT, STICK = ACTIONS 10 | NAME_TO_ACTION = { 11 | "hit": HIT, 12 | "stick": STICK 13 | } 14 | ACTION_TO_NAME = { 15 | v: k for k, v in NAME_TO_ACTION.items() 16 | } 17 | 18 | FEATURE_PATH_TEMPLATE = "./tests/fixtures/feature-{}.json" 19 | 20 | def write_fixture(filepath, result): 21 | with open(filepath, "w") as f: 22 | json.dump(result, f, separators=(',', ': '), sort_keys=True, indent=2) 23 | 24 | 25 | class TestFunctionApproximationAgent: 26 | def verify_features(self, filepath): 27 | with open(filepath) as f: 28 | test_case = json.load(f) 29 | 30 | state = test_case["state"] 31 | 32 | action = None 33 | if test_case.get("action", None) is not None: 34 | action = NAME_TO_ACTION[test_case["action"]] 35 | 36 | feats = phi(test_case["state"], action).tolist() 37 | expected = test_case["expected_feats"] 38 | 39 | if expected != feats and os.environ.get("TESTS_UPDATE", False): 40 | result = { 41 | "state": test_case["state"], 42 | "action": ACTION_TO_NAME.get(action, None), 43 | "expected_feats": feats 44 | } 45 | write_fixture(filepath, result) 46 | 47 | assert_equal(expected, feats) 48 | 49 | 50 | def test_features(self): 51 | features_path = FEATURE_PATH_TEMPLATE.format("*") 52 | for filepath in glob.glob(features_path): 53 | yield self.verify_features, filepath 54 | 55 | 56 | def create_feature_fixtures(self, state_actions): 57 | for state, action in state_actions: 58 | fixture = { 59 | "state": state, 60 | "action": action, 61 | "expected_feats": [] 62 | } 63 | 64 | filename = "{}-{}".format(state[0], state[1]) 65 | if action is not None: 66 | filename += "-{}".format(action) 67 | 68 | filepath = FEATURE_PATH_TEMPLATE.format(filename) 69 | write_fixture(filepath, fixture) 70 | pass 71 | -------------------------------------------------------------------------------- /agents/sarsa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | from utils import epsilon_greedy_policy, mse 5 | from vis import plot_V 6 | from environment import Easy21Env, TERMINAL_STATE, STATE_SPACE_SHAPE 7 | 8 | 9 | GAMMA = 1.0 10 | LAMBDA = 0 11 | 12 | class SarsaAgent: 13 | def __init__(self, env, num_episodes=1000, 14 | gamma=GAMMA, lmbd=LAMBDA, 15 | save_error_history=False, 16 | **kwargs): 17 | self.num_episodes = num_episodes 18 | self.env = env 19 | self.gamma = gamma 20 | self.lmbd = lmbd 21 | 22 | self.save_error_history = save_error_history 23 | if self.save_error_history: 24 | with open("./Q_opt.pkl", "rb") as f: 25 | self.opt_Q = pickle.load(f) 26 | 27 | self.reset() 28 | 29 | 30 | def reset(self): 31 | self.Q = np.zeros(STATE_SPACE_SHAPE) 32 | 33 | if self.save_error_history: 34 | self.error_history = [] 35 | 36 | 37 | def learn(self): 38 | env = self.env 39 | Q = self.Q 40 | N = np.zeros(STATE_SPACE_SHAPE) 41 | 42 | for episode in range(1, self.num_episodes+1): 43 | env.reset() 44 | state1 = env.observe() 45 | E = np.zeros(STATE_SPACE_SHAPE) # eligibility traces 46 | 47 | while state1 != TERMINAL_STATE: 48 | action1 = epsilon_greedy_policy(Q, N, state1) 49 | state2, reward = env.step(action1) 50 | 51 | dealer1, player1 = state1 52 | idx1 = (dealer1-1, player1-1, action1) 53 | Q1 = Q[idx1] 54 | 55 | if state2 == TERMINAL_STATE: 56 | Q2 = 0.0 57 | else: 58 | action2 = epsilon_greedy_policy(Q, N, state2) 59 | dealer2, player2 = state2 60 | idx2 = (dealer2-1, player2-1, action2) 61 | Q2 = Q[idx2] 62 | 63 | N[idx1] += 1 64 | E[idx1] += 1 65 | 66 | alpha = 1.0 / N[idx1] 67 | delta = reward + self.gamma * Q2 - Q1 68 | Q += alpha * delta * E 69 | E *= self.gamma * self.lmbd 70 | 71 | state1 = state2 72 | 73 | if self.save_error_history: 74 | self.error_history.append((episode, mse(self.Q, self.opt_Q))) 75 | 76 | return Q 77 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("TkAgg") 3 | 4 | from mpl_toolkits.mplot3d import Axes3D 5 | from matplotlib import cm, rc 6 | rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) 7 | rc('text', usetex=True) 8 | import matplotlib.pyplot as plt 9 | 10 | import numpy as np 11 | from pprint import pprint 12 | 13 | 14 | def create_surf_plot(X, Y, Z, fig_idx=1): 15 | fig = plt.figure(fig_idx) 16 | ax = fig.add_subplot(111, projection="3d") 17 | 18 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm, 19 | linewidth=0, antialiased=False) 20 | # surf = ax.plot_wireframe(X, Y, Z) 21 | 22 | return surf 23 | 24 | 25 | from environment import DEALER_RANGE, PLAYER_RANGE 26 | def plot_V(Q, save=None, fig_idx=0): 27 | V = np.max(Q, axis=2) 28 | X, Y = np.mgrid[DEALER_RANGE, PLAYER_RANGE] 29 | 30 | surf = create_surf_plot(X, Y, V) 31 | 32 | plt.title("V*") 33 | plt.ylabel('player sum', size=18) 34 | plt.xlabel('dealer', size=18) 35 | 36 | if save is not None: 37 | plt.savefig(save, format='pdf', transparent=True) 38 | else: 39 | plt.show() 40 | 41 | plt.clf() 42 | 43 | 44 | def plot_learning_curve(learning_curves, save=None, agent_args={}, fig_idx=2): 45 | fig = plt.figure(fig_idx) 46 | 47 | plt.title("Mean-squared error vs. 'true' Q values against episode number") 48 | plt.ylabel(r'$\frac{1}{|S||A|}\sum_{s,a}{(Q(s,a) - Q^{*}(s,a))^2}$', size=18) 49 | plt.xlabel(r'$episode$', size=18) 50 | 51 | colors = iter(cm.rainbow(np.linspace(0, 1, len(learning_curves)))) 52 | for lmbd, D in learning_curves.items(): 53 | X, Y = zip(*D) 54 | plt.plot(X, Y, label="lambda={:.1f}".format(lmbd), 55 | linewidth=1.0, color=next(colors)) 56 | 57 | plt.legend() 58 | 59 | if save is not None: 60 | plt.savefig(save, format='pdf', transparent=True) 61 | else: 62 | plt.show() 63 | 64 | plt.clf() 65 | 66 | def plot_pg_rewards(mean_rewards, save=None, fig_idx=3): 67 | fig = plt.figure(fig_idx) 68 | 69 | plt.title("Policy Gradient running average rewards (1000 epsiode window)") 70 | plt.ylabel(r'average reward', size=18) 71 | plt.xlabel(r'episode', size=18) 72 | 73 | Y = mean_rewards 74 | X = range(1, len(Y)+1) 75 | 76 | plt.plot(X, Y, linewidth=1.0) 77 | 78 | if save is not None: 79 | plt.savefig(save, format='pdf', transparent=True) 80 | else: 81 | plt.show() 82 | 83 | plt.clf() 84 | -------------------------------------------------------------------------------- /agents/policy_gradient.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import numpy as np 3 | 4 | from utils import epsilon_greedy_policy 5 | from environment import ACTIONS, TERMINAL_STATE, STATE_SPACE_SHAPE 6 | 7 | 8 | class TwoLayerNet: 9 | def __init__(self, input_size, hidden_size, output_size, std=1e-4): 10 | self.params = {} 11 | self.params['W1'] = std * np.random.randn(input_size, hidden_size) 12 | self.params['b1'] = np.zeros(hidden_size) 13 | self.params['W2'] = std * np.random.randn(hidden_size, output_size) 14 | self.params['b2'] = np.zeros(output_size) 15 | 16 | 17 | def forward(self, X): 18 | if len(X.shape) == 1: X = X.reshape(-1, X.shape[0]) 19 | 20 | W1, b1 = self.params['W1'], self.params['b1'] 21 | W2, b2 = self.params['W2'], self.params['b2'] 22 | N, D = 1, len(X) 23 | 24 | z = np.dot(X, W1) + b1 25 | h1 = np.maximum(z, 0) 26 | scores = np.dot(h1, W2) + b2 27 | 28 | exp_scores = np.exp(scores) 29 | sum_exp_scores = np.sum(exp_scores, axis=1, keepdims=True) 30 | probs = exp_scores / sum_exp_scores 31 | 32 | cache = (probs, h1, z) 33 | 34 | return probs, cache 35 | 36 | 37 | def backward(self, X, y, cache): 38 | if len(X.shape) == 1: X = X.reshape(-1, X.shape[0]) 39 | 40 | probs, h1, z = cache 41 | 42 | W1, b1 = self.params['W1'], self.params['b1'] 43 | W2, b2 = self.params['W2'], self.params['b2'] 44 | N, D = 1, len(X) 45 | 46 | dscores = probs.copy() 47 | dscores[np.arange(N), y] -= 1.0 48 | dscores /= float(N) 49 | 50 | dW2 = np.dot(h1.T, dscores) + W2 51 | db2 = np.sum(dscores, axis=0) 52 | 53 | dh1 = np.dot(dscores, W2.T) 54 | 55 | dz = (z > 0.0) * dh1 56 | 57 | dW1 = np.dot(X.T, dz) + W1 58 | db1 = np.sum(dz, axis=0) 59 | 60 | grads = { 61 | 'W2': dW2, 'b2': db2, 62 | 'W1': dW1, 'b1': db1 63 | } 64 | 65 | return grads 66 | 67 | 68 | def update_params(self, grads): 69 | learning_rate = 5e-4 70 | 71 | for param in self.params.keys(): 72 | self.params[param] -= learning_rate * grads[param] 73 | 74 | 75 | GAMMA = 1.0 76 | LAMBDA = 0 77 | 78 | 79 | class PolicyGradientAgent: 80 | def __init__(self, env, num_episodes=1000, 81 | gamma=GAMMA, lmbd=LAMBDA, 82 | save_error_history=False, 83 | **kwargs): 84 | self.num_episodes = num_episodes 85 | self.env = env 86 | self.gamma = gamma 87 | self.lmbd = lmbd 88 | 89 | self.reset() 90 | 91 | 92 | def reset(self): 93 | self.reward_history = [] 94 | 95 | 96 | def learn(self): 97 | env = self.env 98 | net = TwoLayerNet(2, 20, 2) 99 | reward_window = deque([0] * 1000, 1000) 100 | 101 | for episode in range(1, self.num_episodes+1): 102 | env.reset() 103 | state = env.observe() 104 | E = [] # experiences 105 | 106 | while state != TERMINAL_STATE: 107 | (probs,), cache = net.forward(state) 108 | action = np.random.choice(ACTIONS, p=probs) 109 | state_, reward = env.step(action) 110 | 111 | E.append([state, action, reward, cache]) 112 | state = state_ 113 | 114 | G = np.cumsum([e[2] for e in reversed(E)])[::-1] 115 | for (state, action, reward, probs), G_t in zip(E, G): 116 | grads = net.backward(state, action, cache) 117 | grads = { k: v * G_t for k, v in grads.items() } 118 | net.update_params(grads) 119 | 120 | reward_window.append(sum(e[2] for e in E)) 121 | self.reward_history.append(np.mean(reward_window)) 122 | 123 | return self.reward_history 124 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | 4 | DECK = range(1, 11) 5 | ACTIONS = (HIT, STICK) = (0, 1) 6 | 7 | DEALER_RANGE = range(1, 11) 8 | PLAYER_RANGE = range(1, 22) 9 | STATE_SPACE_SHAPE = (len(DEALER_RANGE), len(PLAYER_RANGE), len(ACTIONS)) 10 | 11 | TERMINAL_STATE = "TERMINAL" 12 | COLOR_PROBS = { 'red': 1/3, 'black': 2/3 } 13 | COLOR_COEFFS = { 'red': -1, 'black': 1 } 14 | 15 | 16 | def draw_card(color=None): 17 | value = np.random.choice(DECK) 18 | if color is None: 19 | colors, probs = zip(*COLOR_PROBS.items()) 20 | color = np.random.choice(colors, p=probs) 21 | return { 'value': value, 'color': color } 22 | 23 | 24 | def bust(x): 25 | return (x < 1 or 21 < x) 26 | 27 | 28 | class Easy21Env: 29 | """ Easy21 environment 30 | 31 | Easy21 is a simple card game similar to Blackjack The rules of the game are as 32 | follows: 33 | 34 | - The game is played with an infinite deck of cards (i.e. cards are sampled 35 | with replacement) 36 | - Each draw from the deck results in a value between 1 and 10 (uniformly 37 | distributed) with a colour of red (probability 1/3) or black (probability 38 | 2/3). 39 | - There are no aces or picture (face) cards in this game 40 | - At the start of the game both the player and the dealer draw one black 41 | card (fully observed) 42 | - Each turn the player may either stick or hit 43 | - If the player hits then she draws another card from the deck 44 | - If the player sticks she receives no further cards 45 | - The values of the player's cards are added (black cards) or subtracted (red 46 | cards) 47 | - If the player's sum exceeds 21, or becomes less than 1, then she "goes 48 | bust" and loses the game (reward -1) 49 | - If the player sticks then the dealer starts taking turns. The dealer always 50 | sticks on any sum of 17 or greater, and hits otherwise. If the dealer goes 51 | bust, then the player wins; otherwise, the outcome - win (reward +1), 52 | lose (reward -1), or draw (reward 0) - is the player with the largest sum. 53 | """ 54 | def __init__(self): 55 | self.reset() 56 | 57 | 58 | def reset(self, dealer=None, player=None): 59 | if dealer is None: dealer = draw_card()['value'] 60 | self.dealer = dealer 61 | if player is None: player = draw_card()['value'] 62 | self.player = player 63 | 64 | 65 | def observe(self): 66 | if not (self.dealer in DEALER_RANGE and self.player in PLAYER_RANGE): 67 | return TERMINAL_STATE 68 | return np.array((self.dealer, self.player)) 69 | 70 | 71 | def step(self, action): 72 | """ Step function 73 | 74 | Inputs: 75 | - action: hit or stick 76 | 77 | Returns: 78 | - next_state: a sample of the next state (which may be terminal if the 79 | game is finished) 80 | - reward 81 | """ 82 | 83 | if action == HIT: 84 | card = draw_card() 85 | self.player += COLOR_COEFFS[card['color']] * card['value'] 86 | 87 | if bust(self.player): 88 | next_state, reward = TERMINAL_STATE, -1 89 | else: 90 | next_state, reward = (self.dealer, self.player), 0 91 | elif action == STICK: 92 | while 0 < self.dealer < 17: 93 | card = draw_card() 94 | self.dealer += COLOR_COEFFS[card['color']] * card['value'] 95 | 96 | next_state = TERMINAL_STATE 97 | if bust(self.dealer): 98 | reward = 1 99 | else: 100 | reward = int(self.player > self.dealer) - int(self.player < self.dealer) 101 | else: 102 | raise ValueError("Action not in action space") 103 | 104 | return np.array(next_state), reward 105 | -------------------------------------------------------------------------------- /agents/function_approximation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | from utils import epsilon_greedy_policy, mse 5 | from vis import plot_V 6 | from environment import ( 7 | Easy21Env, TERMINAL_STATE, STATE_SPACE_SHAPE, ACTIONS, 8 | DEALER_RANGE, PLAYER_RANGE 9 | ) 10 | 11 | 12 | HIT, STICK = ACTIONS 13 | 14 | GAMMA = 1 15 | LAMBDA = 0 16 | EPSILON = 0.05 17 | ALPHA = 0.01 18 | 19 | CUBOID_INTERVALS = { 20 | "dealer": ((1, 4), (4, 7), (7, 10)), 21 | "player": ((1, 6), (4, 9), (7, 12), (10, 15), (13, 18), (16, 21)), 22 | "action": ((HIT,), (STICK,)) 23 | } 24 | 25 | FEATS_SHAPE = tuple( 26 | len(CUBOID_INTERVALS[key]) for key in ("dealer", "player", "action") 27 | ) 28 | 29 | 30 | def phi(state, action=None): 31 | if state == TERMINAL_STATE: return 0 32 | 33 | dealer, player = state 34 | 35 | state_features = np.array([ 36 | (di[0] <= dealer <= di[1]) and (pi[0] <= player <= pi[1]) 37 | for di in CUBOID_INTERVALS['dealer'] 38 | for pi in CUBOID_INTERVALS['player'] 39 | ]).astype(int).reshape(FEATS_SHAPE[:2]) 40 | 41 | if action is None: return state_features 42 | 43 | features = np.zeros(FEATS_SHAPE) 44 | for i, ai in enumerate(CUBOID_INTERVALS['action']): 45 | if action in ai: 46 | features[:, :, i] = state_features 47 | 48 | return features.astype(int) 49 | 50 | 51 | def expand_Q(w): 52 | Q = np.zeros(STATE_SPACE_SHAPE) 53 | 54 | for dealer in DEALER_RANGE: 55 | for player in PLAYER_RANGE: 56 | for action in ACTIONS: 57 | state = (dealer, player) 58 | feats = phi(state, action) 59 | Q[dealer-1, player-1][action] = np.sum(feats * w) 60 | 61 | return Q 62 | 63 | 64 | class FunctionApproximationAgent: 65 | def __init__(self, env, num_episodes=1000, 66 | gamma=GAMMA, lmbd=LAMBDA, 67 | epsilon=EPSILON, alpha=ALPHA, 68 | save_error_history=False, 69 | **kwargs): 70 | self.num_episodes = num_episodes 71 | self.env = env 72 | 73 | self.gamma = gamma 74 | self.lmbd = lmbd 75 | self.epsilon = epsilon 76 | self.alpha = alpha 77 | 78 | self.save_error_history = save_error_history 79 | if self.save_error_history: 80 | with open("./Q_opt.pkl", "rb") as f: 81 | self.opt_Q = pickle.load(f) 82 | 83 | self.reset() 84 | 85 | 86 | def reset(self): 87 | self.Q = np.zeros(STATE_SPACE_SHAPE) 88 | self.w = (np.random.rand(*FEATS_SHAPE) - 0.5) * 0.001 89 | 90 | if self.save_error_history: 91 | self.error_history = [] 92 | 93 | def policy(self, state): 94 | if state == TERMINAL_STATE: 95 | return 0.0, None 96 | 97 | if np.random.rand() < (1 - self.epsilon): 98 | Qhat, action = max( 99 | # same as dotproduct in our case 100 | ((np.sum(phi(state, a) * self.w), a) for a in ACTIONS), 101 | key=lambda x: x[0] 102 | ) 103 | else: 104 | action = np.random.choice(ACTIONS) 105 | Qhat = np.sum(phi(state, action) * self.w) 106 | 107 | return Qhat, action 108 | 109 | 110 | def learn(self): 111 | env = self.env 112 | 113 | for episode in range(1, self.num_episodes+1): 114 | env.reset() 115 | state1 = env.observe() 116 | E = np.zeros_like(self.w) 117 | 118 | while state1 != TERMINAL_STATE: 119 | Qhat1, action1 = self.policy(state1) 120 | state2, reward = env.step(action1) 121 | Qhat2, action2 = self.policy(state2) 122 | 123 | feats1 = phi(state1, action1) 124 | grad_w_Qhat1 = feats1 125 | 126 | delta = reward + self.gamma * Qhat2 - Qhat1 127 | E = self.gamma * self.lmbd * E + grad_w_Qhat1 128 | dw = self.alpha * delta * E 129 | 130 | self.w += dw 131 | state1 = state2 132 | 133 | if self.save_error_history: 134 | self.Q = expand_Q(self.w) 135 | self.error_history.append((episode, mse(self.Q, self.opt_Q))) 136 | 137 | self.Q = expand_Q(self.w) 138 | return self.Q 139 | -------------------------------------------------------------------------------- /easy21.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import numpy as np 4 | from distutils.util import strtobool 5 | 6 | from environment import Easy21Env, ACTIONS, DEALER_RANGE, PLAYER_RANGE 7 | from agents import ( 8 | MonteCarloAgent, SarsaAgent, FunctionApproximationAgent, PolicyGradientAgent 9 | ) 10 | from vis import plot_V, plot_learning_curve, plot_pg_rewards 11 | from utils import mse 12 | 13 | 14 | def range_float_type(s): 15 | """ Custom range float type for arg parser 16 | """ 17 | try: 18 | parts = list(map(float, s.split(","))) 19 | if len(parts) == 1: 20 | return parts 21 | elif len(parts) == 3: 22 | return np.arange(*parts) 23 | except: 24 | raise argparse.ArgumentTypeError( 25 | "range_float must be a string that, when split and parts then mapped to " 26 | "floats, can be passed to np.arange as arguments. E.g. '0,1.1,0.1'." 27 | ) 28 | 29 | 30 | def bool_type(x): 31 | return bool(strtobool(x)) 32 | 33 | 34 | parser = argparse.ArgumentParser( 35 | description="Simple Reinforcement Learning Environment") 36 | 37 | parser.add_argument("-v", "--verbose", default=False, type=bool_type, 38 | help="Verbose") 39 | 40 | parser.add_argument("-a", "--agent", default="mc", 41 | choices=['mc', 'sarsa', 'lfa', 'pg'], 42 | help=("Agent Type: " 43 | "mc (monte carlo), " 44 | "sarsa, " 45 | "lfa (linear function approximation)")) 46 | parser.add_argument("--num-episodes", default=1000, type=int, 47 | help="Number of episodes") 48 | parser.add_argument("--lmbd", default=[1.0], type=range_float_type, help="Lambda") 49 | parser.add_argument("--gamma", default=1, type=float, help="Gamma") 50 | 51 | parser.add_argument("--plot-v", default=False, type=bool_type, 52 | help="Plot the value function") 53 | parser.add_argument("--dump-q", default=False, type=bool_type, 54 | help="Dump the Q values to file") 55 | parser.add_argument("--plot-lambda-mse", default=False, type=bool_type, 56 | help=("Plot mean-squared error compared to the 'true' Q " 57 | "values obtained with monte-carlo")) 58 | parser.add_argument("--plot-learning-curve", default=False, type=bool_type, 59 | help=("Plot the learning curve of mean-squared error " 60 | "compared to the 'true' Q values obtained from " 61 | "monte-carlo against episode number")) 62 | 63 | 64 | AGENTS = { 65 | "mc": MonteCarloAgent, 66 | "sarsa": SarsaAgent, 67 | "lfa": FunctionApproximationAgent, 68 | "pg": PolicyGradientAgent 69 | } 70 | 71 | Q_DUMP_BASE_NAME = "Q_dump" 72 | def dump_Q(Q, args): 73 | filename = ("./{}_{}_lambda_{}_gamma_{}_episodes_{}.pkl" 74 | "".format(Q_DUMP_BASE_NAME, 75 | args["agent_type"], args["lmbd"], 76 | args.get("gamma", None), args["num_episodes"])) 77 | 78 | print("dumping Q: ", filename) 79 | 80 | with open(filename, "wb") as f: 81 | pickle.dump(Q, f) 82 | 83 | 84 | def get_agent_args(args): 85 | agent_type = args.agent 86 | agent_args = { 87 | "agent_type": agent_type, 88 | "num_episodes": args.num_episodes 89 | } 90 | 91 | if agent_type == "mc": 92 | return agent_args 93 | elif agent_type == "sarsa" or agent_type == "lfa": 94 | agent_args.update({ 95 | key: getattr(args, key) for key in ["gamma"] 96 | if key in args 97 | }) 98 | agent_args["save_error_history"] = getattr( 99 | args, "plot_learning_curve", False ) 100 | 101 | return agent_args 102 | 103 | 104 | Q_OPT_FILE = "./Q_opt.pkl" 105 | def main(args): 106 | env = Easy21Env() 107 | 108 | if args.plot_learning_curve: 109 | learning_curves = {} 110 | 111 | for i, lmbd in enumerate(args.lmbd): 112 | agent_args = get_agent_args(args) 113 | agent_args["lmbd"] = lmbd 114 | agent = AGENTS[args.agent](env, **agent_args) 115 | 116 | agent.learn() 117 | 118 | if agent_args["agent_type"] == "pg": 119 | plot_file = ("./vis/policy_gradient_rewards_episodes_{}.pdf" 120 | "".format(args.num_episodes)) 121 | plot_pg_rewards(agent.reward_history, save=plot_file) 122 | return 123 | 124 | if args.dump_q: 125 | dump_Q(agent.Q, agent_args) 126 | 127 | if args.plot_v: 128 | plot_file = ("./vis/V_{}_lambda_{}_gamma_{}_episodes_{}.pdf" 129 | "".format(agent_args["agent_type"], 130 | lmbd, 131 | args.gamma, 132 | args.num_episodes)) 133 | plot_V(agent.Q, save=plot_file) 134 | 135 | if args.plot_learning_curve: 136 | learning_curves[lmbd] = agent.error_history 137 | 138 | if args.plot_learning_curve: 139 | plot_file = ("./vis/lambda_mse_{}_gamma_{}_episodes_{}.pdf" 140 | "".format(agent_args["agent_type"], 141 | args.gamma, args.num_episodes)) 142 | plot_learning_curve(learning_curves, save=plot_file) 143 | 144 | if __name__ == "__main__": 145 | args = parser.parse_args() 146 | main(args) 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # easy21 2 | This repository implements the assignment requirements for the reinforcement learning course given by David Silver [1]. It implements a reinforcement learning environment and four different agents, namely monte-carlo, sarsa lambda, linear value function approximation, and neural network policy gradient, for simple card game called Easy21, presented in [2]. 3 | 4 | 5 | ## Setting up the environment 6 | To setup the python environment you need Python v3.x, pip, and virtualenv: 7 | ``` 8 | git clone https://github.com/hartikainen/easy21.git 9 | cd easy21 10 | 11 | pip install virtualenv 12 | virtualenv .venv 13 | source .venv/bin/activate 14 | 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | 19 | ## Running the game 20 | To run the game and test the agents, run the `easy21.py` file as follows: 21 | ``` 22 | python easy21.py [-h] [-v VERBOSE] [-a {mc,sarsa,lfa}] 23 | [--num-episodes NUM_EPISODES] [--lmbd LMBD] [--gamma GAMMA] 24 | [--plot-v PLOT_V] [--dump-q DUMP_Q] 25 | [--plot-lambda-mse PLOT_LAMBDA_MSE] 26 | [--plot-learning-curve PLOT_LEARNING_CURVE] 27 | ``` 28 | See `easy21.py` for more information about the running the game and testing the agents. All the agents are found in the `/agents` folder. 29 | 30 | 31 | ## Easy21 Environment 32 | The Easy21 environment is implemented `Easy21Env` class found in `environment.py`. The environment keeps track of the game state (dealer card and sum of player cards), and exposes a `step` method, which, given an action (hit or stick), updates its state, and returns the observed state (in our case observation is equivalent to the game state) and reward. 33 | 34 | 35 | ## Monte-Carlo Control in Easy21 36 | Monte-Carlo control for Easy21 is implemented in file `agents/monte_carlo.py`. The default implementation uses a time-varying scalar step-size αt = 1/N(st, at) and ε-greedy exploration strategy with εt = N0 / (N0 + N(St)), where N0 = 100, N(s) is the number of times that state s has been visited, N(s,a) is the number of times that action a has been selected from state s, and t is the time-step. 37 | The figure below presents the optimal value function V\* against the game state (player sum and dealer hand). 38 | 39 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/V_mc_1000000_episodes.png) 40 | 41 | 42 | ## TD Learning in Easy21 43 | File `agents/sarsa.py` implements a Sarsa(λ) control for Easy21. It uses the same step-size and exploration schedules as the Monte-Carlo agent described in the previous section. The agent is tested with parameter values λ ∈ {0, 0.1, 0.2, ..., 1}, each ran for 20000 episodes. The first figure below present the learning curve, i.e. the mean-squared error vs. 'true' Q values against episode number, for each lambda. The next two figures plot the function V\* (same as in Monte-Carlo section) for λ=0.0 and λ=1.0. 44 | 45 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.png) 46 | 47 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.png) 48 | 49 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.png) 50 | 51 | 52 | ### Bootstrapping in Easy21 53 | 54 | As with any other situation, also in Easy21, bootstrapping reduces the variance of the learned policy, with the cost of increased bias. The Sarsa(λ) agent learns reasonable policy faster (i.e. in smaller number of episodes) than Monte-Carlo agent. I a game as simple as Easy21, however, it is feasible to run enough episodes for the Monte-Carlo agent to converge to the optimal unbiased policy. 55 | 56 | The episodes in Easy21 on average last longer than in traditional Blackjack game because of the subtractive effect of red cards. Because of this, boostrapping is likely to be more useful in Easy21 than it would be in traditional Blackjack game. 57 | 58 | 59 | ## Linear Function Approximation in Easy21 60 | File `agents/function_approximation.py` implements a value function approximator with coarse coding for Easy21, using binary feature vector φ(state,action) with 36 (3\*6\*2) features. Each binary feature takes value 1 if (state, action) lies in the cuboid of state-space corresponding to that feature, and the action corresponding to that feature, and 0 otherwise. The cuboids are defined in the variable `CUBOID_INTERVALS` in `agents/function_approximation.py`. 61 | 62 | Similarly to the Sarsa(λ) in the previous section, we run tests with 20000 episodes for parameter values λ ∈ {0, 0.1, 0.2, ..., 1}, with constant step-size α=0.01 and exploration value ε=0.05. The figures below plot the learning curve for each lambda, and the function V\* for λ=0.0 and λ=1.0. 63 | 64 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/lambda_mse_lfa_gamma_1.0_episodes_20000.png) 65 | 66 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.png) 67 | 68 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.png) 69 | 70 | 71 | ### Notes for function approximation 72 | Using approximation for the state and action space reduces the time and space complexity of the algorithm\*\*, due to the reduced number of variables, corresponding to the states and actions, needed to learn by the agent. However, this comes with the cost of reduced accuracy of the learned state-value function Q (and thus value function V and policy π). It seems like the overlapping regions in the cuboid intervals result in more extreme values in some states. This happens because each state in the expanded Q function approximation can be affected by multiple states and actions through the weights of function approximation. 73 | 74 | 75 | ### The effect of constant step-size α 76 | One thing to notice is the effect of constant step-size used for the linear function approximation. Because the step-size is kept constant in the learning, some regions of the value function receive much less training than other parts. This results in incorrect value function in the extreme regions where the number of visits is small. This effect is tested with cuboid intervals corresponding to "identity" features, i.e. features that should result in exactly the same Q as learning with sarsa lambda. Even after 50k runs, the agent seems to have spots in the state space (low dealer and low player; low dealer and high player), where it doesn't match the sarsa results. This effect is presented in the figure below (for λ=0). 77 | 78 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.png) 79 | 80 | When using dynamic step-size α, as in the Sarsa(λ) section above, we can see that this effect disappears, and the function approximation results in expected approximation, as shown in the figure below (again, for λ=0). 81 | 82 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.png) 83 | 84 | \*\* the actual running time is actually worse than with Sarsa(λ) because my function approximation implementation does not fully utilize numpy vectorization 85 | 86 | ## Policy Gradient in Easy21 87 | File `agents/policy_gradient.py` implements a simple two layer feed forward neural network and a policy gradient agent that utilizes the network. The network used by the agent uses 20 hidden neurons with ReLU non-linearities, and learning rate α=5*10-4. The weights of the network are initialized from gaussian distribution with mean μ=0 and standard deviation σ=10-4, and biases are initially set to 0. The figure below represents the running average (with 1000 episode window) of the reward received by the policy gradient agent over total of 100,000 episodes. 88 | 89 | ![alt text](https://github.com/hartikainen/easy21/blob/master/vis/policy_gradient_rewards_episodes_100000.png) 90 | 91 | [1]http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html 92 | [2]http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching_files/Easy21-Johannes.pdf 93 | --------------------------------------------------------------------------------