├── tests
├── __init__.py
├── fixtures
│ ├── feature-1-1.json
│ ├── feature-2-4.json
│ ├── feature-4-6.json
│ ├── feature-5-7.json
│ ├── feature-7-9.json
│ ├── feature-10-12.json
│ ├── feature-10-13.json
│ ├── feature-10-15.json
│ ├── feature-10-16.json
│ ├── feature-10-17.json
│ ├── feature-10-18.json
│ ├── feature-10-20.json
│ ├── feature-10-21.json
│ ├── feature-9-10.json
│ ├── feature-1-1-hit.json
│ ├── feature-1-1-stick.json
│ ├── feature-10-15-hit.json
│ ├── feature-7-12-hit.json
│ └── feature-9-10-stick.json
├── environment_tests.py
└── function_approximation_tests.py
├── Q_opt.pkl
├── vis
├── V_mc_1000000_episodes.pdf
├── V_mc_1000000_episodes.png
├── lambda_mse_lfa_gamma_1.0_episodes_20000.pdf
├── lambda_mse_lfa_gamma_1.0_episodes_20000.png
├── policy_gradient_rewards_episodes_100000.pdf
├── policy_gradient_rewards_episodes_100000.png
├── V_lfa_lambda_0.0_gamma_1.0_episodes_20000.pdf
├── V_lfa_lambda_0.0_gamma_1.0_episodes_20000.png
├── V_lfa_lambda_1.0_gamma_1.0_episodes_20000.pdf
├── V_lfa_lambda_1.0_gamma_1.0_episodes_20000.png
├── lambda_mse_sarsa_gamma_1.0_episodes_20000.pdf
├── lambda_mse_sarsa_gamma_1.0_episodes_20000.png
├── V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.pdf
├── V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.png
├── V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.pdf
├── V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.png
├── V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf
├── V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.png
├── V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf
└── V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.png
├── requirements.txt
├── agents
├── __init__.py
├── monte_carlo.py
├── sarsa.py
├── policy_gradient.py
└── function_approximation.py
├── utils.py
├── .gitignore
├── vis.py
├── environment.py
├── easy21.py
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Q_opt.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/Q_opt.pkl
--------------------------------------------------------------------------------
/vis/V_mc_1000000_episodes.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_mc_1000000_episodes.pdf
--------------------------------------------------------------------------------
/vis/V_mc_1000000_episodes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_mc_1000000_episodes.png
--------------------------------------------------------------------------------
/vis/lambda_mse_lfa_gamma_1.0_episodes_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_lfa_gamma_1.0_episodes_20000.pdf
--------------------------------------------------------------------------------
/vis/lambda_mse_lfa_gamma_1.0_episodes_20000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_lfa_gamma_1.0_episodes_20000.png
--------------------------------------------------------------------------------
/vis/policy_gradient_rewards_episodes_100000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/policy_gradient_rewards_episodes_100000.pdf
--------------------------------------------------------------------------------
/vis/policy_gradient_rewards_episodes_100000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/policy_gradient_rewards_episodes_100000.png
--------------------------------------------------------------------------------
/vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.pdf
--------------------------------------------------------------------------------
/vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_0.0_gamma_1.0_episodes_20000.png
--------------------------------------------------------------------------------
/vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.pdf
--------------------------------------------------------------------------------
/vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_lambda_1.0_gamma_1.0_episodes_20000.png
--------------------------------------------------------------------------------
/vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.pdf
--------------------------------------------------------------------------------
/vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/lambda_mse_sarsa_gamma_1.0_episodes_20000.png
--------------------------------------------------------------------------------
/vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.pdf
--------------------------------------------------------------------------------
/vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_0.0_gamma_1.0_episodes_20000.png
--------------------------------------------------------------------------------
/vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.pdf
--------------------------------------------------------------------------------
/vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_sarsa_lambda_1.0_gamma_1.0_episodes_20000.png
--------------------------------------------------------------------------------
/vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf
--------------------------------------------------------------------------------
/vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_dynamic_alpha_lambda_0.0_gamma_1.0_episodes_50000.png
--------------------------------------------------------------------------------
/vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.pdf
--------------------------------------------------------------------------------
/vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hartikainen/easy21/HEAD/vis/V_lfa_identity_features_static_alpha_lambda_0.0_gamma_1.0_episodes_50000.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | appdirs==1.4.3
2 | cycler==0.10.0
3 | matplotlib==2.0.0
4 | nose==1.3.7
5 | numpy==1.12.1
6 | packaging==16.8
7 | pyparsing==2.2.0
8 | python-dateutil==2.6.0
9 | pytz==2017.2
10 | six==1.10.0
11 |
--------------------------------------------------------------------------------
/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from agents.monte_carlo import MonteCarloAgent
2 | from agents.sarsa import SarsaAgent
3 | from agents.function_approximation import FunctionApproximationAgent
4 | from agents.policy_gradient import PolicyGradientAgent
5 |
--------------------------------------------------------------------------------
/tests/fixtures/feature-1-1.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 1,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 0,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 1,
31 | 1
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-2-4.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 1,
6 | 1,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 0,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 2,
31 | 4
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-4-6.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 1,
6 | 1,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 1,
14 | 1,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 0,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 4,
31 | 6
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-5-7.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 1,
15 | 1,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 0,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 5,
31 | 7
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-7-9.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 1,
15 | 1,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 1,
23 | 1,
24 | 0,
25 | 0,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 7,
31 | 9
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-12.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 1,
24 | 1,
25 | 0,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 12
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-13.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 1,
25 | 1,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 13
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-15.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 1,
26 | 1
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 16
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 1,
26 | 1
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 16
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-17.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 1,
26 | 1
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 17
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-18.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 1,
26 | 1
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 18
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-20.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 0,
26 | 1
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 20
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-21.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 0,
24 | 0,
25 | 0,
26 | 1
27 | ]
28 | ],
29 | "state": [
30 | 10,
31 | 21
32 | ]
33 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-9-10.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": null,
3 | "expected_feats": [
4 | [
5 | 0,
6 | 0,
7 | 0,
8 | 0,
9 | 0,
10 | 0
11 | ],
12 | [
13 | 0,
14 | 0,
15 | 0,
16 | 0,
17 | 0,
18 | 0
19 | ],
20 | [
21 | 0,
22 | 0,
23 | 1,
24 | 1,
25 | 0,
26 | 0
27 | ]
28 | ],
29 | "state": [
30 | 9,
31 | 10
32 | ]
33 | }
--------------------------------------------------------------------------------
/agents/monte_carlo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from utils import epsilon_greedy_policy
4 | from environment import TERMINAL_STATE, STATE_SPACE_SHAPE
5 |
6 |
7 | class MonteCarloAgent:
8 | def __init__(self, env, num_episodes=1000, **kwargs):
9 | self.num_episodes = num_episodes
10 | self.env = env
11 | self.reset()
12 |
13 |
14 | def reset(self):
15 | self.Q = np.zeros(STATE_SPACE_SHAPE)
16 |
17 |
18 | def learn(self):
19 | env = self.env
20 | Q = self.Q
21 | N = np.zeros(STATE_SPACE_SHAPE)
22 |
23 | for episode in range(1, self.num_episodes+1):
24 | env.reset()
25 | state = env.observe()
26 | E = [] # experience from the episode
27 |
28 | while state != TERMINAL_STATE:
29 | action = epsilon_greedy_policy(Q, N, state)
30 | state_, reward = env.step(action)
31 |
32 | E.append([state, action, reward])
33 | state = state_
34 |
35 | for (dealer, player), action, reward in E:
36 | idx = dealer-1, player-1, action
37 | N[idx] += 1
38 | alpha = 1.0 / N[idx]
39 | Q[idx] += alpha * (reward - Q[idx])
40 |
41 | return Q
42 |
--------------------------------------------------------------------------------
/tests/fixtures/feature-1-1-hit.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": "hit",
3 | "expected_feats": [
4 | [
5 | [
6 | 1,
7 | 0
8 | ],
9 | [
10 | 0,
11 | 0
12 | ],
13 | [
14 | 0,
15 | 0
16 | ],
17 | [
18 | 0,
19 | 0
20 | ],
21 | [
22 | 0,
23 | 0
24 | ],
25 | [
26 | 0,
27 | 0
28 | ]
29 | ],
30 | [
31 | [
32 | 0,
33 | 0
34 | ],
35 | [
36 | 0,
37 | 0
38 | ],
39 | [
40 | 0,
41 | 0
42 | ],
43 | [
44 | 0,
45 | 0
46 | ],
47 | [
48 | 0,
49 | 0
50 | ],
51 | [
52 | 0,
53 | 0
54 | ]
55 | ],
56 | [
57 | [
58 | 0,
59 | 0
60 | ],
61 | [
62 | 0,
63 | 0
64 | ],
65 | [
66 | 0,
67 | 0
68 | ],
69 | [
70 | 0,
71 | 0
72 | ],
73 | [
74 | 0,
75 | 0
76 | ],
77 | [
78 | 0,
79 | 0
80 | ]
81 | ]
82 | ],
83 | "state": [
84 | 1,
85 | 1
86 | ]
87 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-1-1-stick.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": "stick",
3 | "expected_feats": [
4 | [
5 | [
6 | 0,
7 | 1
8 | ],
9 | [
10 | 0,
11 | 0
12 | ],
13 | [
14 | 0,
15 | 0
16 | ],
17 | [
18 | 0,
19 | 0
20 | ],
21 | [
22 | 0,
23 | 0
24 | ],
25 | [
26 | 0,
27 | 0
28 | ]
29 | ],
30 | [
31 | [
32 | 0,
33 | 0
34 | ],
35 | [
36 | 0,
37 | 0
38 | ],
39 | [
40 | 0,
41 | 0
42 | ],
43 | [
44 | 0,
45 | 0
46 | ],
47 | [
48 | 0,
49 | 0
50 | ],
51 | [
52 | 0,
53 | 0
54 | ]
55 | ],
56 | [
57 | [
58 | 0,
59 | 0
60 | ],
61 | [
62 | 0,
63 | 0
64 | ],
65 | [
66 | 0,
67 | 0
68 | ],
69 | [
70 | 0,
71 | 0
72 | ],
73 | [
74 | 0,
75 | 0
76 | ],
77 | [
78 | 0,
79 | 0
80 | ]
81 | ]
82 | ],
83 | "state": [
84 | 1,
85 | 1
86 | ]
87 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-10-15-hit.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": "hit",
3 | "expected_feats": [
4 | [
5 | [
6 | 0,
7 | 0
8 | ],
9 | [
10 | 0,
11 | 0
12 | ],
13 | [
14 | 0,
15 | 0
16 | ],
17 | [
18 | 0,
19 | 0
20 | ],
21 | [
22 | 0,
23 | 0
24 | ],
25 | [
26 | 0,
27 | 0
28 | ]
29 | ],
30 | [
31 | [
32 | 0,
33 | 0
34 | ],
35 | [
36 | 0,
37 | 0
38 | ],
39 | [
40 | 0,
41 | 0
42 | ],
43 | [
44 | 0,
45 | 0
46 | ],
47 | [
48 | 0,
49 | 0
50 | ],
51 | [
52 | 0,
53 | 0
54 | ]
55 | ],
56 | [
57 | [
58 | 0,
59 | 0
60 | ],
61 | [
62 | 0,
63 | 0
64 | ],
65 | [
66 | 0,
67 | 0
68 | ],
69 | [
70 | 1,
71 | 0
72 | ],
73 | [
74 | 1,
75 | 0
76 | ],
77 | [
78 | 0,
79 | 0
80 | ]
81 | ]
82 | ],
83 | "state": [
84 | 10,
85 | 15
86 | ]
87 | }
--------------------------------------------------------------------------------
/tests/fixtures/feature-7-12-hit.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": "hit",
3 | "expected_feats": [
4 | [
5 | [
6 | 0,
7 | 0
8 | ],
9 | [
10 | 0,
11 | 0
12 | ],
13 | [
14 | 0,
15 | 0
16 | ],
17 | [
18 | 0,
19 | 0
20 | ],
21 | [
22 | 0,
23 | 0
24 | ],
25 | [
26 | 0,
27 | 0
28 | ]
29 | ],
30 | [
31 | [
32 | 0,
33 | 0
34 | ],
35 | [
36 | 0,
37 | 0
38 | ],
39 | [
40 | 1,
41 | 0
42 | ],
43 | [
44 | 1,
45 | 0
46 | ],
47 | [
48 | 0,
49 | 0
50 | ],
51 | [
52 | 0,
53 | 0
54 | ]
55 | ],
56 | [
57 | [
58 | 0,
59 | 0
60 | ],
61 | [
62 | 0,
63 | 0
64 | ],
65 | [
66 | 1,
67 | 0
68 | ],
69 | [
70 | 1,
71 | 0
72 | ],
73 | [
74 | 0,
75 | 0
76 | ],
77 | [
78 | 0,
79 | 0
80 | ]
81 | ]
82 | ],
83 | "state": [
84 | 7,
85 | 12
86 | ]
87 | }
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from functools import reduce
3 | from operator import mul
4 | from collections import Iterable
5 |
6 | from environment import ACTIONS
7 |
8 |
9 | def mse(A, B):
10 | return np.sum((A - B) ** 2) / np.size(A)
11 |
12 |
13 | def get_step_size(N):
14 | non_zeros = np.where(N > 0)
15 | steps = np.zeros_like(N)
16 | steps[non_zeros] = 1.0 / N[non_zeros]
17 |
18 | return steps
19 |
20 |
21 | N_0 = 100
22 | def get_epsilon(N):
23 | return N_0 / (N_0 + N)
24 |
25 |
26 | def epsilon_greedy_policy(Q, N, state):
27 | dealer, player = state
28 | epsilon = get_epsilon(np.sum(N[dealer-1, player-1, :]))
29 | if np.random.rand() < (1 - epsilon):
30 | action = np.argmax(Q[dealer-1, player-1, :])
31 | else:
32 | action = np.random.choice(ACTIONS)
33 | return action
34 |
35 |
36 | def policy_wrapper(Q, N):
37 | def policy(state):
38 | dealer, player = state
39 | assert(0 < dealer < 11 and 0 < player < 22)
40 | eps = get_epsilon(np.sum(N[dealer-1, player-1, :]))
41 |
42 | if np.random.rand() < (1 - eps):
43 | action = np.argmax(Q[dealer-1, player-1, :])
44 | else:
45 | action = np.random.choice(ACTIONS)
46 | return action
47 | return policy
48 |
--------------------------------------------------------------------------------
/tests/fixtures/feature-9-10-stick.json:
--------------------------------------------------------------------------------
1 | {
2 | "action": "stick",
3 | "expected_feats": [
4 | [
5 | [
6 | 0,
7 | 0
8 | ],
9 | [
10 | 0,
11 | 0
12 | ],
13 | [
14 | 0,
15 | 0
16 | ],
17 | [
18 | 0,
19 | 0
20 | ],
21 | [
22 | 0,
23 | 0
24 | ],
25 | [
26 | 0,
27 | 0
28 | ]
29 | ],
30 | [
31 | [
32 | 0,
33 | 0
34 | ],
35 | [
36 | 0,
37 | 0
38 | ],
39 | [
40 | 0,
41 | 0
42 | ],
43 | [
44 | 0,
45 | 0
46 | ],
47 | [
48 | 0,
49 | 0
50 | ],
51 | [
52 | 0,
53 | 0
54 | ]
55 | ],
56 | [
57 | [
58 | 0,
59 | 0
60 | ],
61 | [
62 | 0,
63 | 0
64 | ],
65 | [
66 | 0,
67 | 1
68 | ],
69 | [
70 | 0,
71 | 1
72 | ],
73 | [
74 | 0,
75 | 0
76 | ],
77 | [
78 | 0,
79 | 0
80 | ]
81 | ]
82 | ],
83 | "state": [
84 | 9,
85 | 10
86 | ]
87 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | .venv/
84 | ENV/
85 |
86 | # Spyder project settings
87 | .spyderproject
88 |
89 | # Rope project settings
90 | .ropeproject
91 |
92 | .DS_Store
--------------------------------------------------------------------------------
/tests/environment_tests.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 | from nose.tools import assert_equal
3 |
4 | from environment import (
5 | Easy21Env, TERMINAL_STATE, ACTIONS
6 | )
7 |
8 | HIT, STICK = ACTIONS
9 |
10 | def mock_draw_card(result):
11 | def fn(color=None):
12 | return result
13 | return fn
14 |
15 |
16 | class TestEnvironment():
17 | def setUp(self):
18 | self.env = Easy21Env()
19 |
20 |
21 | def tearDown(self):
22 | self.env = None
23 |
24 | def test_step_player_should_bust_if_sum_exceeds_max(self):
25 | CARD = { 'value': 8, 'color': "black" }
26 | card_mock = mock_draw_card(CARD)
27 | PLAYER_START = 0
28 | self.env.reset(dealer=5, player=PLAYER_START)
29 |
30 | with patch("environment.draw_card", card_mock):
31 | state = self.env.observe()
32 | for i in range(3):
33 | player = state[1]
34 | assert_equal(self.env.player, PLAYER_START + i * CARD["value"])
35 | state, reward = self.env.step(HIT)
36 |
37 | assert_equal(state, TERMINAL_STATE)
38 | assert_equal(reward, -1)
39 |
40 | def test_step_player_should_bust_if_sum_below_min(self):
41 | CARD = { 'value': 8, 'color': "red" }
42 | card_mock = mock_draw_card(CARD)
43 | PLAYER_START = 10
44 | self.env.reset(dealer=5, player=PLAYER_START)
45 |
46 | with patch("environment.draw_card", card_mock):
47 | state = self.env.observe()
48 | for i in range(2):
49 | assert_equal(self.env.player, PLAYER_START - i * CARD["value"])
50 | state, reward = self.env.step(HIT)
51 |
52 | assert_equal(state, TERMINAL_STATE)
53 | assert_equal(reward, -1)
54 |
55 | def test_step_dealer_finishes_between_17_21(self):
56 | CARD = { 'value': 8, 'color': "black" }
57 | pass
58 |
--------------------------------------------------------------------------------
/tests/function_approximation_tests.py:
--------------------------------------------------------------------------------
1 | from nose.tools import assert_equal
2 | import glob
3 | import json
4 | import os
5 |
6 | from environment import ACTIONS
7 | from agents.function_approximation import phi
8 |
9 | HIT, STICK = ACTIONS
10 | NAME_TO_ACTION = {
11 | "hit": HIT,
12 | "stick": STICK
13 | }
14 | ACTION_TO_NAME = {
15 | v: k for k, v in NAME_TO_ACTION.items()
16 | }
17 |
18 | FEATURE_PATH_TEMPLATE = "./tests/fixtures/feature-{}.json"
19 |
20 | def write_fixture(filepath, result):
21 | with open(filepath, "w") as f:
22 | json.dump(result, f, separators=(',', ': '), sort_keys=True, indent=2)
23 |
24 |
25 | class TestFunctionApproximationAgent:
26 | def verify_features(self, filepath):
27 | with open(filepath) as f:
28 | test_case = json.load(f)
29 |
30 | state = test_case["state"]
31 |
32 | action = None
33 | if test_case.get("action", None) is not None:
34 | action = NAME_TO_ACTION[test_case["action"]]
35 |
36 | feats = phi(test_case["state"], action).tolist()
37 | expected = test_case["expected_feats"]
38 |
39 | if expected != feats and os.environ.get("TESTS_UPDATE", False):
40 | result = {
41 | "state": test_case["state"],
42 | "action": ACTION_TO_NAME.get(action, None),
43 | "expected_feats": feats
44 | }
45 | write_fixture(filepath, result)
46 |
47 | assert_equal(expected, feats)
48 |
49 |
50 | def test_features(self):
51 | features_path = FEATURE_PATH_TEMPLATE.format("*")
52 | for filepath in glob.glob(features_path):
53 | yield self.verify_features, filepath
54 |
55 |
56 | def create_feature_fixtures(self, state_actions):
57 | for state, action in state_actions:
58 | fixture = {
59 | "state": state,
60 | "action": action,
61 | "expected_feats": []
62 | }
63 |
64 | filename = "{}-{}".format(state[0], state[1])
65 | if action is not None:
66 | filename += "-{}".format(action)
67 |
68 | filepath = FEATURE_PATH_TEMPLATE.format(filename)
69 | write_fixture(filepath, fixture)
70 | pass
71 |
--------------------------------------------------------------------------------
/agents/sarsa.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 |
4 | from utils import epsilon_greedy_policy, mse
5 | from vis import plot_V
6 | from environment import Easy21Env, TERMINAL_STATE, STATE_SPACE_SHAPE
7 |
8 |
9 | GAMMA = 1.0
10 | LAMBDA = 0
11 |
12 | class SarsaAgent:
13 | def __init__(self, env, num_episodes=1000,
14 | gamma=GAMMA, lmbd=LAMBDA,
15 | save_error_history=False,
16 | **kwargs):
17 | self.num_episodes = num_episodes
18 | self.env = env
19 | self.gamma = gamma
20 | self.lmbd = lmbd
21 |
22 | self.save_error_history = save_error_history
23 | if self.save_error_history:
24 | with open("./Q_opt.pkl", "rb") as f:
25 | self.opt_Q = pickle.load(f)
26 |
27 | self.reset()
28 |
29 |
30 | def reset(self):
31 | self.Q = np.zeros(STATE_SPACE_SHAPE)
32 |
33 | if self.save_error_history:
34 | self.error_history = []
35 |
36 |
37 | def learn(self):
38 | env = self.env
39 | Q = self.Q
40 | N = np.zeros(STATE_SPACE_SHAPE)
41 |
42 | for episode in range(1, self.num_episodes+1):
43 | env.reset()
44 | state1 = env.observe()
45 | E = np.zeros(STATE_SPACE_SHAPE) # eligibility traces
46 |
47 | while state1 != TERMINAL_STATE:
48 | action1 = epsilon_greedy_policy(Q, N, state1)
49 | state2, reward = env.step(action1)
50 |
51 | dealer1, player1 = state1
52 | idx1 = (dealer1-1, player1-1, action1)
53 | Q1 = Q[idx1]
54 |
55 | if state2 == TERMINAL_STATE:
56 | Q2 = 0.0
57 | else:
58 | action2 = epsilon_greedy_policy(Q, N, state2)
59 | dealer2, player2 = state2
60 | idx2 = (dealer2-1, player2-1, action2)
61 | Q2 = Q[idx2]
62 |
63 | N[idx1] += 1
64 | E[idx1] += 1
65 |
66 | alpha = 1.0 / N[idx1]
67 | delta = reward + self.gamma * Q2 - Q1
68 | Q += alpha * delta * E
69 | E *= self.gamma * self.lmbd
70 |
71 | state1 = state2
72 |
73 | if self.save_error_history:
74 | self.error_history.append((episode, mse(self.Q, self.opt_Q)))
75 |
76 | return Q
77 |
--------------------------------------------------------------------------------
/vis.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use("TkAgg")
3 |
4 | from mpl_toolkits.mplot3d import Axes3D
5 | from matplotlib import cm, rc
6 | rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
7 | rc('text', usetex=True)
8 | import matplotlib.pyplot as plt
9 |
10 | import numpy as np
11 | from pprint import pprint
12 |
13 |
14 | def create_surf_plot(X, Y, Z, fig_idx=1):
15 | fig = plt.figure(fig_idx)
16 | ax = fig.add_subplot(111, projection="3d")
17 |
18 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
19 | linewidth=0, antialiased=False)
20 | # surf = ax.plot_wireframe(X, Y, Z)
21 |
22 | return surf
23 |
24 |
25 | from environment import DEALER_RANGE, PLAYER_RANGE
26 | def plot_V(Q, save=None, fig_idx=0):
27 | V = np.max(Q, axis=2)
28 | X, Y = np.mgrid[DEALER_RANGE, PLAYER_RANGE]
29 |
30 | surf = create_surf_plot(X, Y, V)
31 |
32 | plt.title("V*")
33 | plt.ylabel('player sum', size=18)
34 | plt.xlabel('dealer', size=18)
35 |
36 | if save is not None:
37 | plt.savefig(save, format='pdf', transparent=True)
38 | else:
39 | plt.show()
40 |
41 | plt.clf()
42 |
43 |
44 | def plot_learning_curve(learning_curves, save=None, agent_args={}, fig_idx=2):
45 | fig = plt.figure(fig_idx)
46 |
47 | plt.title("Mean-squared error vs. 'true' Q values against episode number")
48 | plt.ylabel(r'$\frac{1}{|S||A|}\sum_{s,a}{(Q(s,a) - Q^{*}(s,a))^2}$', size=18)
49 | plt.xlabel(r'$episode$', size=18)
50 |
51 | colors = iter(cm.rainbow(np.linspace(0, 1, len(learning_curves))))
52 | for lmbd, D in learning_curves.items():
53 | X, Y = zip(*D)
54 | plt.plot(X, Y, label="lambda={:.1f}".format(lmbd),
55 | linewidth=1.0, color=next(colors))
56 |
57 | plt.legend()
58 |
59 | if save is not None:
60 | plt.savefig(save, format='pdf', transparent=True)
61 | else:
62 | plt.show()
63 |
64 | plt.clf()
65 |
66 | def plot_pg_rewards(mean_rewards, save=None, fig_idx=3):
67 | fig = plt.figure(fig_idx)
68 |
69 | plt.title("Policy Gradient running average rewards (1000 epsiode window)")
70 | plt.ylabel(r'average reward', size=18)
71 | plt.xlabel(r'episode', size=18)
72 |
73 | Y = mean_rewards
74 | X = range(1, len(Y)+1)
75 |
76 | plt.plot(X, Y, linewidth=1.0)
77 |
78 | if save is not None:
79 | plt.savefig(save, format='pdf', transparent=True)
80 | else:
81 | plt.show()
82 |
83 | plt.clf()
84 |
--------------------------------------------------------------------------------
/agents/policy_gradient.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import numpy as np
3 |
4 | from utils import epsilon_greedy_policy
5 | from environment import ACTIONS, TERMINAL_STATE, STATE_SPACE_SHAPE
6 |
7 |
8 | class TwoLayerNet:
9 | def __init__(self, input_size, hidden_size, output_size, std=1e-4):
10 | self.params = {}
11 | self.params['W1'] = std * np.random.randn(input_size, hidden_size)
12 | self.params['b1'] = np.zeros(hidden_size)
13 | self.params['W2'] = std * np.random.randn(hidden_size, output_size)
14 | self.params['b2'] = np.zeros(output_size)
15 |
16 |
17 | def forward(self, X):
18 | if len(X.shape) == 1: X = X.reshape(-1, X.shape[0])
19 |
20 | W1, b1 = self.params['W1'], self.params['b1']
21 | W2, b2 = self.params['W2'], self.params['b2']
22 | N, D = 1, len(X)
23 |
24 | z = np.dot(X, W1) + b1
25 | h1 = np.maximum(z, 0)
26 | scores = np.dot(h1, W2) + b2
27 |
28 | exp_scores = np.exp(scores)
29 | sum_exp_scores = np.sum(exp_scores, axis=1, keepdims=True)
30 | probs = exp_scores / sum_exp_scores
31 |
32 | cache = (probs, h1, z)
33 |
34 | return probs, cache
35 |
36 |
37 | def backward(self, X, y, cache):
38 | if len(X.shape) == 1: X = X.reshape(-1, X.shape[0])
39 |
40 | probs, h1, z = cache
41 |
42 | W1, b1 = self.params['W1'], self.params['b1']
43 | W2, b2 = self.params['W2'], self.params['b2']
44 | N, D = 1, len(X)
45 |
46 | dscores = probs.copy()
47 | dscores[np.arange(N), y] -= 1.0
48 | dscores /= float(N)
49 |
50 | dW2 = np.dot(h1.T, dscores) + W2
51 | db2 = np.sum(dscores, axis=0)
52 |
53 | dh1 = np.dot(dscores, W2.T)
54 |
55 | dz = (z > 0.0) * dh1
56 |
57 | dW1 = np.dot(X.T, dz) + W1
58 | db1 = np.sum(dz, axis=0)
59 |
60 | grads = {
61 | 'W2': dW2, 'b2': db2,
62 | 'W1': dW1, 'b1': db1
63 | }
64 |
65 | return grads
66 |
67 |
68 | def update_params(self, grads):
69 | learning_rate = 5e-4
70 |
71 | for param in self.params.keys():
72 | self.params[param] -= learning_rate * grads[param]
73 |
74 |
75 | GAMMA = 1.0
76 | LAMBDA = 0
77 |
78 |
79 | class PolicyGradientAgent:
80 | def __init__(self, env, num_episodes=1000,
81 | gamma=GAMMA, lmbd=LAMBDA,
82 | save_error_history=False,
83 | **kwargs):
84 | self.num_episodes = num_episodes
85 | self.env = env
86 | self.gamma = gamma
87 | self.lmbd = lmbd
88 |
89 | self.reset()
90 |
91 |
92 | def reset(self):
93 | self.reward_history = []
94 |
95 |
96 | def learn(self):
97 | env = self.env
98 | net = TwoLayerNet(2, 20, 2)
99 | reward_window = deque([0] * 1000, 1000)
100 |
101 | for episode in range(1, self.num_episodes+1):
102 | env.reset()
103 | state = env.observe()
104 | E = [] # experiences
105 |
106 | while state != TERMINAL_STATE:
107 | (probs,), cache = net.forward(state)
108 | action = np.random.choice(ACTIONS, p=probs)
109 | state_, reward = env.step(action)
110 |
111 | E.append([state, action, reward, cache])
112 | state = state_
113 |
114 | G = np.cumsum([e[2] for e in reversed(E)])[::-1]
115 | for (state, action, reward, probs), G_t in zip(E, G):
116 | grads = net.backward(state, action, cache)
117 | grads = { k: v * G_t for k, v in grads.items() }
118 | net.update_params(grads)
119 |
120 | reward_window.append(sum(e[2] for e in E))
121 | self.reward_history.append(np.mean(reward_window))
122 |
123 | return self.reward_history
124 |
--------------------------------------------------------------------------------
/environment.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 |
4 | DECK = range(1, 11)
5 | ACTIONS = (HIT, STICK) = (0, 1)
6 |
7 | DEALER_RANGE = range(1, 11)
8 | PLAYER_RANGE = range(1, 22)
9 | STATE_SPACE_SHAPE = (len(DEALER_RANGE), len(PLAYER_RANGE), len(ACTIONS))
10 |
11 | TERMINAL_STATE = "TERMINAL"
12 | COLOR_PROBS = { 'red': 1/3, 'black': 2/3 }
13 | COLOR_COEFFS = { 'red': -1, 'black': 1 }
14 |
15 |
16 | def draw_card(color=None):
17 | value = np.random.choice(DECK)
18 | if color is None:
19 | colors, probs = zip(*COLOR_PROBS.items())
20 | color = np.random.choice(colors, p=probs)
21 | return { 'value': value, 'color': color }
22 |
23 |
24 | def bust(x):
25 | return (x < 1 or 21 < x)
26 |
27 |
28 | class Easy21Env:
29 | """ Easy21 environment
30 |
31 | Easy21 is a simple card game similar to Blackjack The rules of the game are as
32 | follows:
33 |
34 | - The game is played with an infinite deck of cards (i.e. cards are sampled
35 | with replacement)
36 | - Each draw from the deck results in a value between 1 and 10 (uniformly
37 | distributed) with a colour of red (probability 1/3) or black (probability
38 | 2/3).
39 | - There are no aces or picture (face) cards in this game
40 | - At the start of the game both the player and the dealer draw one black
41 | card (fully observed)
42 | - Each turn the player may either stick or hit
43 | - If the player hits then she draws another card from the deck
44 | - If the player sticks she receives no further cards
45 | - The values of the player's cards are added (black cards) or subtracted (red
46 | cards)
47 | - If the player's sum exceeds 21, or becomes less than 1, then she "goes
48 | bust" and loses the game (reward -1)
49 | - If the player sticks then the dealer starts taking turns. The dealer always
50 | sticks on any sum of 17 or greater, and hits otherwise. If the dealer goes
51 | bust, then the player wins; otherwise, the outcome - win (reward +1),
52 | lose (reward -1), or draw (reward 0) - is the player with the largest sum.
53 | """
54 | def __init__(self):
55 | self.reset()
56 |
57 |
58 | def reset(self, dealer=None, player=None):
59 | if dealer is None: dealer = draw_card()['value']
60 | self.dealer = dealer
61 | if player is None: player = draw_card()['value']
62 | self.player = player
63 |
64 |
65 | def observe(self):
66 | if not (self.dealer in DEALER_RANGE and self.player in PLAYER_RANGE):
67 | return TERMINAL_STATE
68 | return np.array((self.dealer, self.player))
69 |
70 |
71 | def step(self, action):
72 | """ Step function
73 |
74 | Inputs:
75 | - action: hit or stick
76 |
77 | Returns:
78 | - next_state: a sample of the next state (which may be terminal if the
79 | game is finished)
80 | - reward
81 | """
82 |
83 | if action == HIT:
84 | card = draw_card()
85 | self.player += COLOR_COEFFS[card['color']] * card['value']
86 |
87 | if bust(self.player):
88 | next_state, reward = TERMINAL_STATE, -1
89 | else:
90 | next_state, reward = (self.dealer, self.player), 0
91 | elif action == STICK:
92 | while 0 < self.dealer < 17:
93 | card = draw_card()
94 | self.dealer += COLOR_COEFFS[card['color']] * card['value']
95 |
96 | next_state = TERMINAL_STATE
97 | if bust(self.dealer):
98 | reward = 1
99 | else:
100 | reward = int(self.player > self.dealer) - int(self.player < self.dealer)
101 | else:
102 | raise ValueError("Action not in action space")
103 |
104 | return np.array(next_state), reward
105 |
--------------------------------------------------------------------------------
/agents/function_approximation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 |
4 | from utils import epsilon_greedy_policy, mse
5 | from vis import plot_V
6 | from environment import (
7 | Easy21Env, TERMINAL_STATE, STATE_SPACE_SHAPE, ACTIONS,
8 | DEALER_RANGE, PLAYER_RANGE
9 | )
10 |
11 |
12 | HIT, STICK = ACTIONS
13 |
14 | GAMMA = 1
15 | LAMBDA = 0
16 | EPSILON = 0.05
17 | ALPHA = 0.01
18 |
19 | CUBOID_INTERVALS = {
20 | "dealer": ((1, 4), (4, 7), (7, 10)),
21 | "player": ((1, 6), (4, 9), (7, 12), (10, 15), (13, 18), (16, 21)),
22 | "action": ((HIT,), (STICK,))
23 | }
24 |
25 | FEATS_SHAPE = tuple(
26 | len(CUBOID_INTERVALS[key]) for key in ("dealer", "player", "action")
27 | )
28 |
29 |
30 | def phi(state, action=None):
31 | if state == TERMINAL_STATE: return 0
32 |
33 | dealer, player = state
34 |
35 | state_features = np.array([
36 | (di[0] <= dealer <= di[1]) and (pi[0] <= player <= pi[1])
37 | for di in CUBOID_INTERVALS['dealer']
38 | for pi in CUBOID_INTERVALS['player']
39 | ]).astype(int).reshape(FEATS_SHAPE[:2])
40 |
41 | if action is None: return state_features
42 |
43 | features = np.zeros(FEATS_SHAPE)
44 | for i, ai in enumerate(CUBOID_INTERVALS['action']):
45 | if action in ai:
46 | features[:, :, i] = state_features
47 |
48 | return features.astype(int)
49 |
50 |
51 | def expand_Q(w):
52 | Q = np.zeros(STATE_SPACE_SHAPE)
53 |
54 | for dealer in DEALER_RANGE:
55 | for player in PLAYER_RANGE:
56 | for action in ACTIONS:
57 | state = (dealer, player)
58 | feats = phi(state, action)
59 | Q[dealer-1, player-1][action] = np.sum(feats * w)
60 |
61 | return Q
62 |
63 |
64 | class FunctionApproximationAgent:
65 | def __init__(self, env, num_episodes=1000,
66 | gamma=GAMMA, lmbd=LAMBDA,
67 | epsilon=EPSILON, alpha=ALPHA,
68 | save_error_history=False,
69 | **kwargs):
70 | self.num_episodes = num_episodes
71 | self.env = env
72 |
73 | self.gamma = gamma
74 | self.lmbd = lmbd
75 | self.epsilon = epsilon
76 | self.alpha = alpha
77 |
78 | self.save_error_history = save_error_history
79 | if self.save_error_history:
80 | with open("./Q_opt.pkl", "rb") as f:
81 | self.opt_Q = pickle.load(f)
82 |
83 | self.reset()
84 |
85 |
86 | def reset(self):
87 | self.Q = np.zeros(STATE_SPACE_SHAPE)
88 | self.w = (np.random.rand(*FEATS_SHAPE) - 0.5) * 0.001
89 |
90 | if self.save_error_history:
91 | self.error_history = []
92 |
93 | def policy(self, state):
94 | if state == TERMINAL_STATE:
95 | return 0.0, None
96 |
97 | if np.random.rand() < (1 - self.epsilon):
98 | Qhat, action = max(
99 | # same as dotproduct in our case
100 | ((np.sum(phi(state, a) * self.w), a) for a in ACTIONS),
101 | key=lambda x: x[0]
102 | )
103 | else:
104 | action = np.random.choice(ACTIONS)
105 | Qhat = np.sum(phi(state, action) * self.w)
106 |
107 | return Qhat, action
108 |
109 |
110 | def learn(self):
111 | env = self.env
112 |
113 | for episode in range(1, self.num_episodes+1):
114 | env.reset()
115 | state1 = env.observe()
116 | E = np.zeros_like(self.w)
117 |
118 | while state1 != TERMINAL_STATE:
119 | Qhat1, action1 = self.policy(state1)
120 | state2, reward = env.step(action1)
121 | Qhat2, action2 = self.policy(state2)
122 |
123 | feats1 = phi(state1, action1)
124 | grad_w_Qhat1 = feats1
125 |
126 | delta = reward + self.gamma * Qhat2 - Qhat1
127 | E = self.gamma * self.lmbd * E + grad_w_Qhat1
128 | dw = self.alpha * delta * E
129 |
130 | self.w += dw
131 | state1 = state2
132 |
133 | if self.save_error_history:
134 | self.Q = expand_Q(self.w)
135 | self.error_history.append((episode, mse(self.Q, self.opt_Q)))
136 |
137 | self.Q = expand_Q(self.w)
138 | return self.Q
139 |
--------------------------------------------------------------------------------
/easy21.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | import numpy as np
4 | from distutils.util import strtobool
5 |
6 | from environment import Easy21Env, ACTIONS, DEALER_RANGE, PLAYER_RANGE
7 | from agents import (
8 | MonteCarloAgent, SarsaAgent, FunctionApproximationAgent, PolicyGradientAgent
9 | )
10 | from vis import plot_V, plot_learning_curve, plot_pg_rewards
11 | from utils import mse
12 |
13 |
14 | def range_float_type(s):
15 | """ Custom range float type for arg parser
16 | """
17 | try:
18 | parts = list(map(float, s.split(",")))
19 | if len(parts) == 1:
20 | return parts
21 | elif len(parts) == 3:
22 | return np.arange(*parts)
23 | except:
24 | raise argparse.ArgumentTypeError(
25 | "range_float must be a string that, when split and parts then mapped to "
26 | "floats, can be passed to np.arange as arguments. E.g. '0,1.1,0.1'."
27 | )
28 |
29 |
30 | def bool_type(x):
31 | return bool(strtobool(x))
32 |
33 |
34 | parser = argparse.ArgumentParser(
35 | description="Simple Reinforcement Learning Environment")
36 |
37 | parser.add_argument("-v", "--verbose", default=False, type=bool_type,
38 | help="Verbose")
39 |
40 | parser.add_argument("-a", "--agent", default="mc",
41 | choices=['mc', 'sarsa', 'lfa', 'pg'],
42 | help=("Agent Type: "
43 | "mc (monte carlo), "
44 | "sarsa, "
45 | "lfa (linear function approximation)"))
46 | parser.add_argument("--num-episodes", default=1000, type=int,
47 | help="Number of episodes")
48 | parser.add_argument("--lmbd", default=[1.0], type=range_float_type, help="Lambda")
49 | parser.add_argument("--gamma", default=1, type=float, help="Gamma")
50 |
51 | parser.add_argument("--plot-v", default=False, type=bool_type,
52 | help="Plot the value function")
53 | parser.add_argument("--dump-q", default=False, type=bool_type,
54 | help="Dump the Q values to file")
55 | parser.add_argument("--plot-lambda-mse", default=False, type=bool_type,
56 | help=("Plot mean-squared error compared to the 'true' Q "
57 | "values obtained with monte-carlo"))
58 | parser.add_argument("--plot-learning-curve", default=False, type=bool_type,
59 | help=("Plot the learning curve of mean-squared error "
60 | "compared to the 'true' Q values obtained from "
61 | "monte-carlo against episode number"))
62 |
63 |
64 | AGENTS = {
65 | "mc": MonteCarloAgent,
66 | "sarsa": SarsaAgent,
67 | "lfa": FunctionApproximationAgent,
68 | "pg": PolicyGradientAgent
69 | }
70 |
71 | Q_DUMP_BASE_NAME = "Q_dump"
72 | def dump_Q(Q, args):
73 | filename = ("./{}_{}_lambda_{}_gamma_{}_episodes_{}.pkl"
74 | "".format(Q_DUMP_BASE_NAME,
75 | args["agent_type"], args["lmbd"],
76 | args.get("gamma", None), args["num_episodes"]))
77 |
78 | print("dumping Q: ", filename)
79 |
80 | with open(filename, "wb") as f:
81 | pickle.dump(Q, f)
82 |
83 |
84 | def get_agent_args(args):
85 | agent_type = args.agent
86 | agent_args = {
87 | "agent_type": agent_type,
88 | "num_episodes": args.num_episodes
89 | }
90 |
91 | if agent_type == "mc":
92 | return agent_args
93 | elif agent_type == "sarsa" or agent_type == "lfa":
94 | agent_args.update({
95 | key: getattr(args, key) for key in ["gamma"]
96 | if key in args
97 | })
98 | agent_args["save_error_history"] = getattr(
99 | args, "plot_learning_curve", False )
100 |
101 | return agent_args
102 |
103 |
104 | Q_OPT_FILE = "./Q_opt.pkl"
105 | def main(args):
106 | env = Easy21Env()
107 |
108 | if args.plot_learning_curve:
109 | learning_curves = {}
110 |
111 | for i, lmbd in enumerate(args.lmbd):
112 | agent_args = get_agent_args(args)
113 | agent_args["lmbd"] = lmbd
114 | agent = AGENTS[args.agent](env, **agent_args)
115 |
116 | agent.learn()
117 |
118 | if agent_args["agent_type"] == "pg":
119 | plot_file = ("./vis/policy_gradient_rewards_episodes_{}.pdf"
120 | "".format(args.num_episodes))
121 | plot_pg_rewards(agent.reward_history, save=plot_file)
122 | return
123 |
124 | if args.dump_q:
125 | dump_Q(agent.Q, agent_args)
126 |
127 | if args.plot_v:
128 | plot_file = ("./vis/V_{}_lambda_{}_gamma_{}_episodes_{}.pdf"
129 | "".format(agent_args["agent_type"],
130 | lmbd,
131 | args.gamma,
132 | args.num_episodes))
133 | plot_V(agent.Q, save=plot_file)
134 |
135 | if args.plot_learning_curve:
136 | learning_curves[lmbd] = agent.error_history
137 |
138 | if args.plot_learning_curve:
139 | plot_file = ("./vis/lambda_mse_{}_gamma_{}_episodes_{}.pdf"
140 | "".format(agent_args["agent_type"],
141 | args.gamma, args.num_episodes))
142 | plot_learning_curve(learning_curves, save=plot_file)
143 |
144 | if __name__ == "__main__":
145 | args = parser.parse_args()
146 | main(args)
147 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # easy21
2 | This repository implements the assignment requirements for the reinforcement learning course given by David Silver [1]. It implements a reinforcement learning environment and four different agents, namely monte-carlo, sarsa lambda, linear value function approximation, and neural network policy gradient, for simple card game called Easy21, presented in [2].
3 |
4 |
5 | ## Setting up the environment
6 | To setup the python environment you need Python v3.x, pip, and virtualenv:
7 | ```
8 | git clone https://github.com/hartikainen/easy21.git
9 | cd easy21
10 |
11 | pip install virtualenv
12 | virtualenv .venv
13 | source .venv/bin/activate
14 |
15 | pip install -r requirements.txt
16 | ```
17 |
18 |
19 | ## Running the game
20 | To run the game and test the agents, run the `easy21.py` file as follows:
21 | ```
22 | python easy21.py [-h] [-v VERBOSE] [-a {mc,sarsa,lfa}]
23 | [--num-episodes NUM_EPISODES] [--lmbd LMBD] [--gamma GAMMA]
24 | [--plot-v PLOT_V] [--dump-q DUMP_Q]
25 | [--plot-lambda-mse PLOT_LAMBDA_MSE]
26 | [--plot-learning-curve PLOT_LEARNING_CURVE]
27 | ```
28 | See `easy21.py` for more information about the running the game and testing the agents. All the agents are found in the `/agents` folder.
29 |
30 |
31 | ## Easy21 Environment
32 | The Easy21 environment is implemented `Easy21Env` class found in `environment.py`. The environment keeps track of the game state (dealer card and sum of player cards), and exposes a `step` method, which, given an action (hit or stick), updates its state, and returns the observed state (in our case observation is equivalent to the game state) and reward.
33 |
34 |
35 | ## Monte-Carlo Control in Easy21
36 | Monte-Carlo control for Easy21 is implemented in file `agents/monte_carlo.py`. The default implementation uses a time-varying scalar step-size αt = 1/N(st, at) and ε-greedy exploration strategy with εt = N0 / (N0 + N(St)), where N0 = 100, N(s) is the number of times that state s has been visited, N(s,a) is the number of times that action a has been selected from state s, and t is the time-step.
37 | The figure below presents the optimal value function V\* against the game state (player sum and dealer hand).
38 |
39 | 
40 |
41 |
42 | ## TD Learning in Easy21
43 | File `agents/sarsa.py` implements a Sarsa(λ) control for Easy21. It uses the same step-size and exploration schedules as the Monte-Carlo agent described in the previous section. The agent is tested with parameter values λ ∈ {0, 0.1, 0.2, ..., 1}, each ran for 20000 episodes. The first figure below present the learning curve, i.e. the mean-squared error vs. 'true' Q values against episode number, for each lambda. The next two figures plot the function V\* (same as in Monte-Carlo section) for λ=0.0 and λ=1.0.
44 |
45 | 
46 |
47 | 
48 |
49 | 
50 |
51 |
52 | ### Bootstrapping in Easy21
53 |
54 | As with any other situation, also in Easy21, bootstrapping reduces the variance of the learned policy, with the cost of increased bias. The Sarsa(λ) agent learns reasonable policy faster (i.e. in smaller number of episodes) than Monte-Carlo agent. I a game as simple as Easy21, however, it is feasible to run enough episodes for the Monte-Carlo agent to converge to the optimal unbiased policy.
55 |
56 | The episodes in Easy21 on average last longer than in traditional Blackjack game because of the subtractive effect of red cards. Because of this, boostrapping is likely to be more useful in Easy21 than it would be in traditional Blackjack game.
57 |
58 |
59 | ## Linear Function Approximation in Easy21
60 | File `agents/function_approximation.py` implements a value function approximator with coarse coding for Easy21, using binary feature vector φ(state,action) with 36 (3\*6\*2) features. Each binary feature takes value 1 if (state, action) lies in the cuboid of state-space corresponding to that feature, and the action corresponding to that feature, and 0 otherwise. The cuboids are defined in the variable `CUBOID_INTERVALS` in `agents/function_approximation.py`.
61 |
62 | Similarly to the Sarsa(λ) in the previous section, we run tests with 20000 episodes for parameter values λ ∈ {0, 0.1, 0.2, ..., 1}, with constant step-size α=0.01 and exploration value ε=0.05. The figures below plot the learning curve for each lambda, and the function V\* for λ=0.0 and λ=1.0.
63 |
64 | 
65 |
66 | 
67 |
68 | 
69 |
70 |
71 | ### Notes for function approximation
72 | Using approximation for the state and action space reduces the time and space complexity of the algorithm\*\*, due to the reduced number of variables, corresponding to the states and actions, needed to learn by the agent. However, this comes with the cost of reduced accuracy of the learned state-value function Q (and thus value function V and policy π). It seems like the overlapping regions in the cuboid intervals result in more extreme values in some states. This happens because each state in the expanded Q function approximation can be affected by multiple states and actions through the weights of function approximation.
73 |
74 |
75 | ### The effect of constant step-size α
76 | One thing to notice is the effect of constant step-size used for the linear function approximation. Because the step-size is kept constant in the learning, some regions of the value function receive much less training than other parts. This results in incorrect value function in the extreme regions where the number of visits is small. This effect is tested with cuboid intervals corresponding to "identity" features, i.e. features that should result in exactly the same Q as learning with sarsa lambda. Even after 50k runs, the agent seems to have spots in the state space (low dealer and low player; low dealer and high player), where it doesn't match the sarsa results. This effect is presented in the figure below (for λ=0).
77 |
78 | 
79 |
80 | When using dynamic step-size α, as in the Sarsa(λ) section above, we can see that this effect disappears, and the function approximation results in expected approximation, as shown in the figure below (again, for λ=0).
81 |
82 | 
83 |
84 | \*\* the actual running time is actually worse than with Sarsa(λ) because my function approximation implementation does not fully utilize numpy vectorization
85 |
86 | ## Policy Gradient in Easy21
87 | File `agents/policy_gradient.py` implements a simple two layer feed forward neural network and a policy gradient agent that utilizes the network. The network used by the agent uses 20 hidden neurons with ReLU non-linearities, and learning rate α=5*10-4. The weights of the network are initialized from gaussian distribution with mean μ=0 and standard deviation σ=10-4, and biases are initially set to 0. The figure below represents the running average (with 1000 episode window) of the reward received by the policy gradient agent over total of 100,000 episodes.
88 |
89 | 
90 |
91 | [1]http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html
92 | [2]http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching_files/Easy21-Johannes.pdf
93 |
--------------------------------------------------------------------------------