├── requirements.txt ├── q_table.pickle ├── utils.py ├── README.md ├── evaluate.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.17.1 2 | click -------------------------------------------------------------------------------- /q_table.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satwikkansal/q-learning-taxi-v3/HEAD/q_table.pickle -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def select_optimal_action(q_table, state): 2 | max_q_value_action = None 3 | max_q_value = -100000 4 | 5 | if q_table[state]: 6 | for action, action_q_value in q_table[state].items(): 7 | if action_q_value >= max_q_value: 8 | max_q_value = action_q_value 9 | max_q_value_action = action 10 | 11 | return max_q_value_action 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # q-learning-taxi-v3 2 | 3 | Table based q-learning implementation for taxi-v3 environment of Open AI gym. 4 | 5 | Read the tutorial here [https://www.learndatasci.com/tutorials/reinforcement-q-learning-scratch-python-openai-gym/](https://www.learndatasci.com/tutorials/reinforcement-q-learning-scratch-python-openai-gym/) 6 | 7 | ## Instructions to run 8 | 9 | ```shell script 10 | $ pip install -r requirements.txt 11 | ``` 12 | 13 | ### Training 14 | ```shell script 15 | $ python train.py --help 16 | Usage: train.py [OPTIONS] 17 | 18 | Options: 19 | --num-episodes INTEGER Number of episodes to train on [default: 100000] 20 | --save-path TEXT Path to save the Q-table dump [default: 21 | q_table.pickle] 22 | --help Show this message and exit. 23 | ``` 24 | 25 | ### Evaluation 26 | 27 | ```shell script 28 | $ python evaluate.py --help 29 | Usage: evaluate.py [OPTIONS] 30 | 31 | Options: 32 | --num-episodes INTEGER Number of episodes to train on [default: 100] 33 | --q-path TEXT Path to read the q-table values from [default: 34 | q_table.pickle] 35 | --help Show this message and exit. 36 | ``` 37 | 38 | ## Similar projects 39 | 40 | - [https://github.com/satwikkansal/smartcab](https://github.com/satwikkansal/smartcab) 41 | - [https://github.com/satwikkansal/snakepy](https://github.com/satwikkansal/snakepy) 42 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | import click 5 | import gym 6 | 7 | from utils import select_optimal_action 8 | 9 | 10 | NUM_EPISODES = 100 11 | 12 | 13 | def evaluate_agent(q_table, env, num_trials): 14 | total_epochs, total_penalties = 0, 0 15 | 16 | print("Running episodes...") 17 | for _ in range(num_trials): 18 | state = env.reset() 19 | epochs, num_penalties, reward = 0, 0, 0 20 | 21 | while reward != 20: 22 | next_action = select_optimal_action(q_table,state) 23 | state, reward, _, _ = env.step(next_action) 24 | 25 | if reward == -10: 26 | num_penalties += 1 27 | 28 | epochs += 1 29 | 30 | total_penalties += num_penalties 31 | total_epochs += epochs 32 | 33 | average_time = total_epochs / float(num_trials) 34 | average_penalties = total_penalties / float(num_trials) 35 | print("Evaluation results after {} trials".format(num_trials)) 36 | print("Average time steps taken: {}".format(average_time)) 37 | print("Average number of penalties incurred: {}".format(average_penalties)) 38 | 39 | 40 | @click.command() 41 | @click.option('--num-episodes', default=NUM_EPISODES, help='Number of episodes to train on', show_default=True) 42 | @click.option('--q-path', default="q_table.pickle", help='Path to read the q-table values from', show_default=True) 43 | def main(num_episodes, q_path): 44 | env = gym.make("Taxi-v3") 45 | with open(q_path, 'rb') as f: 46 | q_table = pickle.load(f) 47 | evaluate_agent(q_table, env, num_episodes) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import pickle 3 | import random 4 | 5 | import click 6 | import gym 7 | 8 | from utils import select_optimal_action 9 | 10 | # The hyperparameters 11 | alpha = 0.1 12 | gamma = 0.6 13 | epsilon = 0.1 14 | 15 | NUM_EPISODES = 100000 16 | 17 | 18 | def update(q_table, env, state): 19 | if random.uniform(0, 1) < epsilon: 20 | action = env.action_space.sample() 21 | else: 22 | action = select_optimal_action(q_table, state) 23 | 24 | next_state, reward, _, _ = env.step(action) 25 | old_q_value = q_table[state][action] 26 | 27 | # Check if next_state has q values already 28 | if not q_table[next_state]: 29 | q_table[next_state] = {action: 0 for action in range(env.action_space.n)} 30 | 31 | # Maximum q_value for the actions in next state 32 | next_max = max(q_table[next_state].values()) 33 | 34 | # Calculate the new q_value 35 | new_q_value = (1 - alpha) * old_q_value + alpha * (reward + gamma * next_max) 36 | 37 | # Finally, update the q_value 38 | q_table[state][action] = new_q_value 39 | 40 | return next_state, reward 41 | 42 | 43 | def train_agent(q_table, env, num_episodes): 44 | for i in range(num_episodes): 45 | state = env.reset() 46 | if not q_table[state]: 47 | q_table[state] = { 48 | action: 0 for action in range(env.action_space.n)} 49 | 50 | epochs = 0 51 | num_penalties, reward, total_reward = 0, 0, 0 52 | while reward != 20: 53 | state, reward = update(q_table, env, state) 54 | total_reward += reward 55 | 56 | if reward == -10: 57 | num_penalties += 1 58 | 59 | epochs += 1 60 | print("\nTraining episode {}".format(i + 1)) 61 | print("Time steps: {}, Penalties: {}, Reward: {}".format(epochs, 62 | num_penalties, 63 | total_reward)) 64 | 65 | print("Training finished.\n") 66 | 67 | return q_table 68 | 69 | 70 | @click.command() 71 | @click.option('--num-episodes', default=NUM_EPISODES, help='Number of episodes to train on', show_default=True) 72 | @click.option('--save-path', default="q_table.pickle", help='Path to save the Q-table dump', show_default=True) 73 | def main(num_episodes, save_path): 74 | env = gym.make("Taxi-v3") 75 | q_table = defaultdict(int, {}) 76 | q_table = train_agent(q_table, env, num_episodes) 77 | # save the table for future use 78 | with open(save_path, "wb") as f: 79 | pickle.dump(dict(q_table), f) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | --------------------------------------------------------------------------------