├── .gitignore ├── README.md ├── gail ├── __init__.py ├── config │ ├── CartPole-v0 │ │ ├── config_gail.json │ │ ├── config_ppo.json │ │ └── config_traj.json │ └── LunarLander-v2 │ │ ├── config_gail.json │ │ ├── config_ppo.json │ │ └── config_traj.json ├── main.py ├── model_args.py ├── ppo.py ├── traj.py ├── utils.py └── visualize.py ├── pretrained ├── CartPole-v0 │ ├── gail │ │ ├── discriminator.ckpt │ │ ├── policy.ckpt │ │ ├── record1.pkl │ │ ├── record2.pkl │ │ ├── record3.pkl │ │ ├── record4.pkl │ │ ├── record5.pkl │ │ └── rewards_gail.png │ ├── ppo │ │ ├── policy.ckpt │ │ ├── record1.pkl │ │ ├── record2.pkl │ │ ├── record3.pkl │ │ ├── record4.pkl │ │ ├── record5.pkl │ │ └── rewards_ppo.png │ └── trajectory │ │ ├── actions.csv │ │ └── states.csv └── LunarLander-v2 │ └── ppo │ └── policy.ckpt ├── report.pdf ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/wandb/ 2 | *.egg-info/ 3 | **/__pycache__/ 4 | .idea/ 5 | out/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Adversarial Imitation Learning 2 | 3 | PyTorch implementation of the paper: 4 | 5 | Ho, Jonathan, and Stefano Ermon. "[Generative adversarial imitation learning.](https://arxiv.org/pdf/1606.03476.pdf)" Proceedings of the 30th International Conference on Neural Information Processing Systems. 2016. 6 | 7 | We also present a [report](report.pdf) with theoretical and empirical studies based on our understanding of the paper and other related works. 8 | ## Installation 9 | ```commandline 10 | pip install -r requirements.txt 11 | pip install -e . 12 | 13 | [optional] conda install swig 14 | [optional] pip install box2d-py 15 | ``` 16 | 17 | Note: ```swig``` and ```box2d-py``` are required only by ```LunarLander-v2``` environment. 18 | 19 | ## Run Setup 20 | 21 | Have a look at the parameters set in the corresponding run config files before executing these commands. We provide some 22 | example pretrained models and sampled expert trajectories to directly work with as well. 23 | 24 | ### Train PPO to learn expert policy 25 | 26 | ```shell script 27 | python ppo.py --config config/CartPole-v0/config_ppo.json 28 | ``` 29 | 30 | ### Sample expert trajectories 31 | 32 | ```shell script 33 | python traj.py --config config/CartPole-v0/config_traj.json 34 | ``` 35 | 36 | ### Train GAIL for imitation learning 37 | 38 | ```shell script 39 | python main.py --config config/CartPole-v0/config_gail.json 40 | ``` 41 | 42 | ### Generate training graphs 43 | 44 | ```shell script 45 | python visualize.py --env_id CartPole-v0 --out_dir ../pretrained --model_name ppo 46 | python visualize.py --env_id CartPole-v0 --out_dir ../pretrained --model_name gail 47 | ``` 48 | 49 | ## Cartpole-v0 Experiment 50 | 51 | 52 | 53 | ## References 54 | 55 | 1. [GitHub: nav74neet/gail_gym](https://github.com/nav74neet/gail_gym) 56 | 2. [GitHub: nikhilbarhate99/PPO-PyTorch](https://github.com/nikhilbarhate99/PPO-PyTorch) 57 | 3. [Medium: Article on GAIL](https://medium.com/@sanketgujar95/generative-adversarial-imitation-learning-266f45634e60) 58 | 4. [Blog post on PPO algorithm](https://towardsdatascience.com/proximal-policy-optimization-tutorial-part-1-actor-critic-method-d53f9afffbf6) 59 | 5. [White Paper on MCE IRL](https://apps.dtic.mil/sti/pdfs/AD1090741.pdf) 60 | 6. [Blog post on PPO](https://jonathan-hui.medium.com/rl-proximal-policy-optimization-ppo-explained-77f014ec3f12) 61 | 7. [Blog post on TRPO](https://jonathan-hui.medium.com/rl-trust-region-policy-optimization-trpo-explained-a6ee04eeeee9) 62 | 63 | ## Acknowledgements 64 | 65 | This work has been completed as a course project for [CS498: Reinforcement Learning](https://nanjiang.cs.illinois.edu/cs498/) 66 | course taught by [Professor Nan Jiang](https://nanjiang.cs.illinois.edu/). I thank our instructor and course teaching 67 | assistants for their guidance and support throughout the course. 68 | 69 | ## Contact 70 | 71 | Jatin Arora 72 | 73 | University Mail: [jatin2@illinois.edu](mailto:jatin2@illinois.edu) 74 | 75 | External Mail: [jatinarora2702@gmail.com](mailto:jatinarora2702@gmail.com) 76 | 77 | LinkedIn: [linkedin.com/in/jatinarora2702](https://www.linkedin.com/in/jatinarora2702) 78 | 79 | -------------------------------------------------------------------------------- /gail/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/gail/__init__.py -------------------------------------------------------------------------------- /gail/config/CartPole-v0/config_gail.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_id": "CartPole-v0", 3 | "model_name": "gail", 4 | "resume": null, 5 | "train": false, 6 | "reward_threshold": 195.0, 7 | "discount_factor": 0.99, 8 | "clip_eps": 0.2, 9 | "lr_actor": 0.0003, 10 | "lr_critic": 0.001, 11 | "lr_discriminator": 0.001, 12 | "train_steps": 1e5, 13 | "max_episode_len": 400, 14 | "update_steps": 1600, 15 | "checkpoint_steps": 1600, 16 | "num_epochs": 40, 17 | "num_d_epochs": 2, 18 | "out_root": "../pretrained", 19 | "wandb_mode": "dryrun", 20 | "seed": 42, 21 | "device": "cuda:0" 22 | } 23 | -------------------------------------------------------------------------------- /gail/config/CartPole-v0/config_ppo.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_id": "CartPole-v0", 3 | "model_name": "ppo", 4 | "resume": null, 5 | "train": false, 6 | "reward_threshold": 195.0, 7 | "discount_factor": 0.99, 8 | "clip_eps": 0.2, 9 | "lr_actor": 0.0003, 10 | "lr_critic": 0.001, 11 | "train_steps": 1e5, 12 | "max_episode_len": 400, 13 | "update_steps": 1600, 14 | "checkpoint_steps": 1600, 15 | "num_epochs": 40, 16 | "out_root": "../pretrained", 17 | "expert_trajectory_dir": "../trajectory", 18 | "wandb_mode": "dryrun", 19 | "seed": 42, 20 | "device": "cuda:0" 21 | } 22 | -------------------------------------------------------------------------------- /gail/config/CartPole-v0/config_traj.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_id": "CartPole-v0", 3 | "model_name": "ppo", 4 | "train": false, 5 | "num_trajectories": 25, 6 | "max_episode_len": 400, 7 | "out_root": "../out", 8 | "wandb_mode": "dryrun", 9 | "seed": 42, 10 | "device": "cuda:0" 11 | } 12 | -------------------------------------------------------------------------------- /gail/config/LunarLander-v2/config_gail.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_id": "LunarLander-v2", 3 | "model_name": "gail", 4 | "resume": null, 5 | "train": true, 6 | "reward_threshold": 200.0, 7 | "discount_factor": 0.99, 8 | "clip_eps": 0.2, 9 | "lr_actor": 0.0003, 10 | "lr_critic": 0.001, 11 | "lr_discriminator": 0.001, 12 | "train_steps": 1e6, 13 | "max_episode_len": 300, 14 | "update_steps": 900, 15 | "checkpoint_steps": 900, 16 | "num_epochs": 30, 17 | "num_d_epochs": 2, 18 | "out_root": "../out", 19 | "expert_trajectory_dir": "../trajectory", 20 | "wandb_mode": "dryrun", 21 | "seed": 42, 22 | "device": "cuda:0" 23 | } 24 | -------------------------------------------------------------------------------- /gail/config/LunarLander-v2/config_ppo.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_id": "LunarLander-v2", 3 | "model_name": "ppo", 4 | "resume": null, 5 | "train": true, 6 | "reward_threshold": 200.0, 7 | "discount_factor": 0.99, 8 | "clip_eps": 0.2, 9 | "lr_actor": 0.0003, 10 | "lr_critic": 0.001, 11 | "train_steps": 1e6, 12 | "max_episode_len": 300, 13 | "update_steps": 900, 14 | "checkpoint_steps": 900, 15 | "num_epochs": 30, 16 | "out_root": "../out", 17 | "expert_trajectory_dir": "../trajectory", 18 | "wandb_mode": "dryrun", 19 | "seed": 42, 20 | "device": "cuda:0" 21 | } 22 | -------------------------------------------------------------------------------- /gail/config/LunarLander-v2/config_traj.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_id": "LunarLander-v2", 3 | "model_name": "ppo", 4 | "train": false, 5 | "num_trajectories": 25, 6 | "max_episode_len": 300, 7 | "out_root": "../out", 8 | "expert_trajectory_dir": "../trajectory", 9 | "wandb_mode": "dryrun", 10 | "seed": 42, 11 | "device": "cuda:0" 12 | } 13 | -------------------------------------------------------------------------------- /gail/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | 6 | import gym 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.distributions import Categorical 11 | 12 | from gail.model_args import ModelArguments 13 | from gail.utils import parse_config, setup_logging, set_wandb, set_all_seeds 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class PolicyModel(nn.Module): 19 | def __init__(self, args): 20 | super(PolicyModel, self).__init__() 21 | self.args = args 22 | self.actor = nn.Sequential( 23 | nn.Linear(self.args.state_dim, 64), 24 | nn.Tanh(), 25 | nn.Linear(64, 64), 26 | nn.Tanh(), 27 | nn.Linear(64, self.args.num_actions), 28 | nn.Softmax(dim=-1) 29 | ) 30 | self.critic = nn.Sequential( 31 | nn.Linear(self.args.state_dim, 64), 32 | nn.Tanh(), 33 | nn.Linear(64, 64), 34 | nn.Tanh(), 35 | nn.Linear(64, 1) 36 | ) 37 | 38 | def act(self, state): 39 | action_prob = self.actor(state) 40 | action_dist = Categorical(action_prob) 41 | action = action_dist.sample() 42 | action_log_prob = action_dist.log_prob(action) 43 | return action, action_log_prob 44 | 45 | def evaluate(self, state, action): 46 | action_prob = self.actor(state) 47 | action_dist = Categorical(action_prob) 48 | action_log_prob = action_dist.log_prob(action) 49 | entropy = action_dist.entropy() 50 | value = self.critic(state) 51 | return value, action_log_prob, entropy 52 | 53 | def forward(self): 54 | raise NotImplementedError 55 | 56 | 57 | class Discriminator(nn.Module): 58 | def __init__(self, args): 59 | super(Discriminator, self).__init__() 60 | self.args = args 61 | self.model = nn.Sequential( 62 | nn.Linear(self.args.state_dim + self.args.num_actions, 64), 63 | nn.Tanh(), 64 | nn.Linear(64, 64), 65 | nn.Tanh(), 66 | nn.Linear(64, 1), 67 | nn.Sigmoid() 68 | ) 69 | 70 | def forward(self, state_action): 71 | reward = self.model(state_action) 72 | return reward 73 | 74 | 75 | class GailExecutor: 76 | def __init__(self, args): 77 | self.args = args 78 | 79 | os.environ["WANDB_MODE"] = self.args.wandb_mode 80 | set_wandb(self.args.wandb_dir) 81 | logger.info("args: {0}".format(self.args.to_json_string())) 82 | set_all_seeds(self.args.seed) 83 | 84 | self.env = gym.make(self.args.env_id) 85 | self.env.seed(self.args.seed) 86 | self.args.state_dim = self.env.observation_space.shape[0] 87 | self.args.num_actions = self.env.action_space.n 88 | 89 | self.policy = PolicyModel(self.args).to(self.args.device) 90 | self.policy_old = PolicyModel(self.args).to(self.args.device) 91 | self.policy_old.load_state_dict(self.policy.state_dict()) 92 | self.discriminator = Discriminator(self.args).to(self.args.device) 93 | 94 | self.optimizer = torch.optim.Adam([ 95 | {"params": self.policy.actor.parameters(), "lr": self.args.lr_actor}, 96 | {"params": self.policy.critic.parameters(), "lr": self.args.lr_critic} 97 | ]) 98 | self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.args.lr_discriminator) 99 | self.mse_loss = nn.MSELoss() 100 | self.bce_loss = nn.BCELoss() 101 | 102 | expert_states = np.genfromtxt("{0}/trajectory/states.csv".format(self.args.env_root)) 103 | expert_states = torch.tensor(expert_states, dtype=torch.float32, device=self.args.device) 104 | expert_actions = np.genfromtxt("{0}/trajectory/actions.csv".format(self.args.env_root), dtype=np.int32) 105 | expert_actions = torch.tensor(expert_actions, dtype=torch.int64, device=self.args.device) 106 | expert_actions = torch.eye(self.args.num_actions)[expert_actions].to(self.args.device) 107 | self.expert_state_actions = torch.cat([expert_states, expert_actions], dim=1) 108 | 109 | self.states = [] 110 | self.actions = [] 111 | self.log_prob_actions = [] 112 | self.rewards = [] 113 | self.is_terminal = [] 114 | 115 | def reset_buffers(self): 116 | self.states = [] 117 | self.actions = [] 118 | self.log_prob_actions = [] 119 | self.rewards = [] 120 | self.is_terminal = [] 121 | 122 | def take_action(self, state): 123 | state = torch.tensor(state, dtype=torch.float32, device=self.args.device) 124 | with torch.no_grad(): 125 | action, action_log_prob = self.policy_old.act(state) 126 | self.states.append(state.detach()) 127 | self.actions.append(action.detach()) 128 | self.log_prob_actions.append(action_log_prob.detach()) 129 | 130 | action = action.detach().item() 131 | next_state, reward, done, info = self.env.step(action) 132 | self.rewards.append(reward) 133 | self.is_terminal.append(done) 134 | 135 | return next_state, reward, done 136 | 137 | def update(self): 138 | prev_states = torch.stack(self.states, dim=0).to(self.args.device) 139 | prev_actions = torch.stack(self.actions, dim=0).to(self.args.device) 140 | prev_log_prob_actions = torch.stack(self.log_prob_actions, dim=0).to(self.args.device) 141 | prev_actions_one_hot = torch.eye(self.args.num_actions)[prev_actions.long()].to(self.args.device) 142 | agent_state_actions = torch.cat([prev_states, prev_actions_one_hot], dim=1) 143 | 144 | for ep in range(self.args.num_d_epochs): 145 | expert_prob = self.discriminator(self.expert_state_actions) 146 | agent_prob = self.discriminator(agent_state_actions) 147 | term1 = self.bce_loss(agent_prob, torch.ones((agent_state_actions.shape[0], 1), device=self.args.device)) 148 | term2 = self.bce_loss(expert_prob, torch.zeros((self.expert_state_actions.shape[0], 1), 149 | device=self.args.device)) 150 | loss = term1 + term2 151 | self.d_optimizer.zero_grad() 152 | loss.backward() 153 | self.d_optimizer.step() 154 | 155 | with torch.no_grad(): 156 | d_rewards = torch.log(self.discriminator(agent_state_actions)) 157 | 158 | rewards = [] 159 | cumulative_discounted_reward = 0. 160 | for i in range(len(d_rewards) - 1, -1, -1): 161 | cumulative_discounted_reward = d_rewards[i] + self.args.discount_factor * cumulative_discounted_reward 162 | rewards.append(cumulative_discounted_reward) 163 | 164 | rewards = torch.tensor(rewards, dtype=torch.float32, device=self.args.device) 165 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7) 166 | 167 | for ep in range(self.args.num_epochs): 168 | values, log_prob_actions, entropy = self.policy.evaluate(prev_states, prev_actions) 169 | advantages = rewards - values.detach() 170 | imp_ratios = torch.exp(log_prob_actions - prev_log_prob_actions) 171 | clamped_imp_ratio = torch.clamp(imp_ratios, 1 - self.args.clip_eps, 1 + self.args.clip_eps) 172 | term1 = -torch.min(imp_ratios, clamped_imp_ratio) * advantages 173 | term2 = 0.5 * self.mse_loss(values, rewards) 174 | term3 = -0.01 * entropy 175 | loss = term1 + term2 + term3 176 | self.optimizer.zero_grad() 177 | loss.mean().backward() 178 | self.optimizer.step() 179 | 180 | self.policy_old.load_state_dict(self.policy.state_dict()) 181 | self.reset_buffers() 182 | 183 | def run(self): 184 | t = 1 185 | success_count = 0 186 | finish = False 187 | record = [] 188 | while t <= self.args.train_steps: 189 | state = self.env.reset() 190 | total_reward = 0 191 | done = False 192 | ep_len = 0 193 | while ep_len < self.args.max_episode_len: 194 | state, reward, done = self.take_action(state) 195 | total_reward += reward 196 | if self.args.train and t % self.args.update_steps == 0: 197 | logger.info("updating model") 198 | self.update() 199 | if self.args.train and t % self.args.checkpoint_steps == 0: 200 | logger.info("saving checkpoint") 201 | self.save(self.args.checkpoint_dir) 202 | t += 1 203 | ep_len += 1 204 | if done: 205 | if total_reward >= self.args.reward_threshold: 206 | success_count += 1 207 | if success_count >= 100: 208 | logger.info("model trained. saving checkpoint") 209 | self.save(self.args.checkpoint_dir) 210 | finish = True 211 | else: 212 | success_count = 0 213 | logger.info("iter: {0} | reward: {1:.1f}".format(t, total_reward)) 214 | if not self.args.train: 215 | self.reset_buffers() 216 | break 217 | record.append((ep_len, total_reward)) 218 | if not done: 219 | logger.info("truncated at horizon") 220 | if finish: 221 | break 222 | with open("{0}/record.pkl".format(self.args.checkpoint_dir), "wb") as handle: 223 | pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL) 224 | 225 | def save(self, checkpoint_dir): 226 | torch.save(self.policy_old.state_dict(), "{0}/policy.ckpt".format(checkpoint_dir)) 227 | torch.save(self.discriminator.state_dict(), "{0}/discriminator.ckpt".format(checkpoint_dir)) 228 | 229 | def load(self, checkpoint_dir): 230 | policy_model_path = "{0}/policy.ckpt".format(checkpoint_dir) 231 | self.policy_old.load_state_dict(torch.load(policy_model_path, map_location=lambda x, y: x)) 232 | self.policy.load_state_dict(self.policy_old.state_dict()) 233 | discriminator_model_path = "{0}/discriminator.ckpt".format(checkpoint_dir) 234 | self.discriminator.load_state_dict(torch.load(discriminator_model_path, map_location=lambda x, y: x)) 235 | 236 | 237 | def main(args): 238 | setup_logging() 239 | model_args = parse_config(ModelArguments, args.config) 240 | executor = GailExecutor(model_args) 241 | if not executor.args.train: 242 | executor.load(executor.args.checkpoint_dir) 243 | executor.run() 244 | 245 | 246 | if __name__ == "__main__": 247 | ap = argparse.ArgumentParser(description="GAIL model") 248 | ap.add_argument("--config", default="config/CartPole-v0/config_gail.json", help="config json file") 249 | ap = ap.parse_args() 250 | main(ap) 251 | -------------------------------------------------------------------------------- /gail/model_args.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from enum import Enum 4 | from typing import Optional 5 | 6 | import dataclasses 7 | from dataclasses import dataclass, field 8 | 9 | 10 | @dataclass 11 | class ModelArguments: 12 | model_name: str = field(default="gail", metadata={"help": "model identifier"}) 13 | train: bool = field(default=True, metadata={"help": "do training. If set to False, we load existing model"}) 14 | resume: Optional[str] = field(default=None, metadata={"help": "checkpoint to resume. Starts from scratch, if None"}) 15 | reward_threshold: float = field(default=195.0, metadata={"help": "cumulative reward threshold for win"}) 16 | discount_factor: float = field(default=0.99, metadata={"help": "discount factor"}) 17 | clip_eps: float = field(default=0.2, metadata={"help": "clipping epsilon in PPO loss"}) 18 | lr_actor: float = field(default=0.0003, metadata={"help": "actor model learning rate"}) 19 | lr_critic: float = field(default=0.001, metadata={"help": "critic model learning rate"}) 20 | lr_discriminator: float = field(default=0.001, metadata={"help": "discriminator model learning rate"}) 21 | num_trajectories: int = field(default=10, metadata={"help": "# expert traj.s to sample from pretrained PPO model"}) 22 | train_steps: int = field(default=1e5, metadata={"help": "maximum training time steps"}) 23 | max_episode_len: int = field(default=400, metadata={"help": "maximum episode length"}) 24 | update_steps: int = field(default=1600, metadata={"help": "frequency of model update"}) 25 | checkpoint_steps: int = field(default=2e4, metadata={"help": "frequency of model saving"}) 26 | num_epochs: int = field(default=40, metadata={"help": "training epochs of PPO model"}) 27 | num_d_epochs: int = field(default=2, metadata={"help": "training epochs of discriminator model"}) 28 | 29 | out_root: str = field(default="../out", metadata={"help": "outputs root directory"}) 30 | env_id: str = field(default="CartPole-v0", metadata={"help": "simulation environment identifier"}) 31 | wandb_mode: str = field(default="run", metadata={"help": "can enable/disable wandb online sync (run/dryrun)"}) 32 | seed: int = field(default=42, metadata={"help": "random seed for reproducibility of results"}) 33 | device: str = field(default="cuda:0", metadata={"help": "device (cpu|cuda:0)"}) 34 | 35 | def __post_init__(self): 36 | self.env_root = os.path.join(self.out_root, self.env_id) 37 | self.checkpoint_dir = os.path.join(self.env_root, self.model_name) 38 | self.wandb_dir = self.checkpoint_dir 39 | 40 | def to_dict(self): 41 | """ 42 | Serializes this instance while replace `Enum` by their values (for JSON serialization support). 43 | """ 44 | d = dataclasses.asdict(self) 45 | for k, v in d.items(): 46 | if isinstance(v, Enum): 47 | d[k] = v.value 48 | return d 49 | 50 | def to_json_string(self): 51 | """ 52 | Serializes this instance to a JSON string. 53 | """ 54 | return json.dumps(self.to_dict(), indent=2) 55 | -------------------------------------------------------------------------------- /gail/ppo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | 6 | import gym 7 | import torch 8 | import torch.nn as nn 9 | from torch.distributions import Categorical 10 | 11 | from gail.model_args import ModelArguments 12 | from gail.utils import parse_config, setup_logging, set_wandb, set_all_seeds 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class PolicyModel(nn.Module): 18 | def __init__(self, args): 19 | super(PolicyModel, self).__init__() 20 | self.args = args 21 | self.actor = nn.Sequential( 22 | nn.Linear(self.args.state_dim, 64), 23 | nn.Tanh(), 24 | nn.Linear(64, 64), 25 | nn.Tanh(), 26 | nn.Linear(64, self.args.num_actions), 27 | nn.Softmax(dim=-1) 28 | ) 29 | self.critic = nn.Sequential( 30 | nn.Linear(self.args.state_dim, 64), 31 | nn.Tanh(), 32 | nn.Linear(64, 64), 33 | nn.Tanh(), 34 | nn.Linear(64, 1) 35 | ) 36 | 37 | def act(self, state): 38 | action_prob = self.actor(state) 39 | action_dist = Categorical(action_prob) 40 | action = action_dist.sample() 41 | action_log_prob = action_dist.log_prob(action) 42 | return action, action_log_prob 43 | 44 | def evaluate(self, state, action): 45 | action_prob = self.actor(state) 46 | action_dist = Categorical(action_prob) 47 | action_log_prob = action_dist.log_prob(action) 48 | entropy = action_dist.entropy() 49 | value = self.critic(state) 50 | return value, action_log_prob, entropy 51 | 52 | def forward(self): 53 | raise NotImplementedError 54 | 55 | 56 | class PpoExecutor: 57 | def __init__(self, args): 58 | self.args = args 59 | 60 | os.environ["WANDB_MODE"] = self.args.wandb_mode 61 | set_wandb(self.args.wandb_dir) 62 | logger.info("args: {0}".format(self.args.to_json_string())) 63 | set_all_seeds(self.args.seed) 64 | 65 | self.env = gym.make(self.args.env_id) 66 | self.env.seed(self.args.seed) 67 | self.args.state_dim = self.env.observation_space.shape[0] 68 | self.args.num_actions = self.env.action_space.n 69 | 70 | self.policy = PolicyModel(self.args).to(self.args.device) 71 | self.policy_old = PolicyModel(self.args).to(self.args.device) 72 | self.policy_old.load_state_dict(self.policy.state_dict()) 73 | 74 | self.optimizer = torch.optim.Adam([ 75 | {"params": self.policy.actor.parameters(), "lr": self.args.lr_actor}, 76 | {"params": self.policy.critic.parameters(), "lr": self.args.lr_critic} 77 | ]) 78 | self.mse_loss = nn.MSELoss() 79 | 80 | self.states = [] 81 | self.actions = [] 82 | self.log_prob_actions = [] 83 | self.rewards = [] 84 | self.is_terminal = [] 85 | 86 | def reset_buffers(self): 87 | self.states = [] 88 | self.actions = [] 89 | self.log_prob_actions = [] 90 | self.rewards = [] 91 | self.is_terminal = [] 92 | 93 | def take_action(self, state): 94 | state = torch.tensor(state, dtype=torch.float32, device=self.args.device) 95 | with torch.no_grad(): 96 | action, action_log_prob = self.policy_old.act(state) 97 | self.states.append(state.detach()) 98 | self.actions.append(action.detach()) 99 | self.log_prob_actions.append(action_log_prob.detach()) 100 | 101 | action = action.detach().item() 102 | next_state, reward, done, info = self.env.step(action) 103 | self.rewards.append(reward) 104 | self.is_terminal.append(done) 105 | 106 | return next_state, reward, done 107 | 108 | def update(self): 109 | prev_states = torch.stack(self.states, dim=0).to(self.args.device) 110 | prev_actions = torch.stack(self.actions, dim=0).to(self.args.device) 111 | prev_log_prob_actions = torch.stack(self.log_prob_actions, dim=0).to(self.args.device) 112 | 113 | rewards = [] 114 | cumulative_discounted_reward = 0. 115 | for i in range(len(self.rewards) - 1, -1, -1): 116 | if self.is_terminal[i]: 117 | cumulative_discounted_reward = 0. 118 | cumulative_discounted_reward = self.rewards[i] + self.args.discount_factor * cumulative_discounted_reward 119 | rewards.append(cumulative_discounted_reward) 120 | 121 | rewards = torch.tensor(rewards, dtype=torch.float32, device=self.args.device) 122 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7) 123 | 124 | for ep in range(self.args.num_epochs): 125 | values, log_prob_actions, entropy = self.policy.evaluate(prev_states, prev_actions) 126 | advantages = rewards - values.detach() 127 | imp_ratios = torch.exp(log_prob_actions - prev_log_prob_actions) 128 | clamped_imp_ratio = torch.clamp(imp_ratios, 1 - self.args.clip_eps, 1 + self.args.clip_eps) 129 | term1 = -torch.min(imp_ratios, clamped_imp_ratio) * advantages 130 | term2 = 0.5 * self.mse_loss(values, rewards) 131 | term3 = -0.01 * entropy 132 | loss = term1 + term2 + term3 133 | self.optimizer.zero_grad() 134 | loss.mean().backward() 135 | self.optimizer.step() 136 | 137 | self.policy_old.load_state_dict(self.policy.state_dict()) 138 | self.reset_buffers() 139 | 140 | def run(self): 141 | t = 1 142 | success_count = 0 143 | finish = False 144 | record = [] 145 | while t <= self.args.train_steps: 146 | state = self.env.reset() 147 | total_reward = 0 148 | done = False 149 | ep_len = 0 150 | while ep_len < self.args.max_episode_len: 151 | state, reward, done = self.take_action(state) 152 | total_reward += reward 153 | if self.args.train and t % self.args.update_steps == 0: 154 | logger.info("updating model") 155 | self.update() 156 | if self.args.train and t % self.args.checkpoint_steps == 0: 157 | logger.info("saving checkpoint") 158 | self.save(self.args.checkpoint_dir) 159 | t += 1 160 | ep_len += 1 161 | if done: 162 | if total_reward >= self.args.reward_threshold: 163 | success_count += 1 164 | if success_count >= 100: 165 | logger.info("model trained. saving checkpoint") 166 | self.save(self.args.checkpoint_dir) 167 | finish = True 168 | else: 169 | success_count = 0 170 | logger.info("iter: {0} | reward: {1:.1f}".format(t, total_reward)) 171 | if not self.args.train: 172 | self.reset_buffers() 173 | break 174 | record.append((ep_len, total_reward)) 175 | if not done: 176 | logger.info("truncated at horizon") 177 | if finish: 178 | break 179 | with open("{0}/record.pkl".format(self.args.checkpoint_dir), "wb") as handle: 180 | pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL) 181 | 182 | def save(self, checkpoint_dir): 183 | torch.save(self.policy_old.state_dict(), "{0}/policy.ckpt".format(checkpoint_dir)) 184 | 185 | def load(self, checkpoint_dir): 186 | policy_model_path = "{0}/policy.ckpt".format(checkpoint_dir) 187 | self.policy_old.load_state_dict(torch.load(policy_model_path, map_location=lambda x, y: x)) 188 | self.policy.load_state_dict(self.policy_old.state_dict()) 189 | 190 | 191 | def main(args): 192 | setup_logging() 193 | model_args = parse_config(ModelArguments, args.config) 194 | executor = PpoExecutor(model_args) 195 | if not executor.args.train: 196 | executor.load(executor.args.checkpoint_dir) 197 | executor.run() 198 | 199 | 200 | if __name__ == "__main__": 201 | ap = argparse.ArgumentParser(description="PPO for sampling expert trajectories") 202 | ap.add_argument("--config", default="config/CartPole-v0/config_ppo.json", help="config json file") 203 | ap = ap.parse_args() 204 | main(ap) 205 | -------------------------------------------------------------------------------- /gail/traj.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import gym 6 | import numpy as np 7 | import torch 8 | 9 | from gail.model_args import ModelArguments 10 | from gail.ppo import PolicyModel 11 | from gail.utils import parse_config, setup_logging, set_wandb, set_all_seeds 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class PpoExecutor: 17 | def __init__(self, args): 18 | self.args = args 19 | 20 | os.environ["WANDB_MODE"] = self.args.wandb_mode 21 | set_wandb(self.args.wandb_dir) 22 | logger.info("args: {0}".format(self.args.to_json_string())) 23 | set_all_seeds(self.args.seed) 24 | 25 | self.env = gym.make(self.args.env_id) 26 | self.env.seed(self.args.seed) 27 | self.args.state_dim = self.env.observation_space.shape[0] 28 | self.args.num_actions = self.env.action_space.n 29 | 30 | self.policy = PolicyModel(self.args).to(self.args.device) 31 | self.policy_old = PolicyModel(self.args).to(self.args.device) 32 | self.policy_old.load_state_dict(self.policy.state_dict()) 33 | 34 | self.states = [] 35 | self.actions = [] 36 | 37 | def reset_buffers(self): 38 | self.states = [] 39 | self.actions = [] 40 | 41 | def take_action(self, state): 42 | state_tensor = torch.tensor(state, dtype=torch.float32, device=self.args.device) 43 | with torch.no_grad(): 44 | action, action_log_prob = self.policy_old.act(state_tensor) 45 | action = action.detach().item() 46 | 47 | self.states.append(state) 48 | self.actions.append(action) 49 | next_state, reward, done, info = self.env.step(action) 50 | 51 | return next_state, reward, done 52 | 53 | def run(self): 54 | for t in range(self.args.num_trajectories): 55 | state = self.env.reset() 56 | for ep in range(self.args.max_episode_len): 57 | state, reward, done = self.take_action(state) 58 | if done: 59 | break 60 | 61 | PpoExecutor.save_to_file(self.states, "{0}/trajectory/states.csv".format(self.args.env_root)) 62 | PpoExecutor.save_to_file(self.actions, "{0}/trajectory/actions.csv".format(self.args.env_root)) 63 | self.reset_buffers() 64 | 65 | def load(self, checkpoint_dir): 66 | policy_model_path = "{0}/policy.ckpt".format(checkpoint_dir) 67 | self.policy_old.load_state_dict(torch.load(policy_model_path, map_location=lambda x, y: x)) 68 | self.policy.load_state_dict(self.policy_old.state_dict()) 69 | 70 | @staticmethod 71 | def save_to_file(data, file_path): 72 | try: 73 | with open(file_path, "ab") as handle: 74 | np.savetxt(handle, data, fmt="%s") 75 | except FileNotFoundError: 76 | with open(file_path, "wb") as handle: 77 | np.savetxt(handle, data, fmt="%s") 78 | 79 | 80 | def main(args): 81 | setup_logging() 82 | model_args = parse_config(ModelArguments, args.config) 83 | executor = PpoExecutor(model_args) 84 | executor.load(executor.args.checkpoint_dir) 85 | executor.run() 86 | 87 | 88 | if __name__ == "__main__": 89 | ap = argparse.ArgumentParser(description="sample trajectories from pretrained PPO model") 90 | ap.add_argument("--config", default="config/CartPole-v0/config_traj.json", help="config json file") 91 | ap = ap.parse_args() 92 | main(ap) 93 | -------------------------------------------------------------------------------- /gail/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import random 5 | from pathlib import Path 6 | 7 | import dataclasses 8 | import numpy as np 9 | import torch 10 | import wandb 11 | 12 | 13 | def set_all_seeds(seed=42): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | 18 | os.environ['PYTHONHASHSEED'] = str(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 21 | torch.backends.cudnn.benchmark = False 22 | torch.backends.cudnn.deterministic = True 23 | 24 | 25 | def setup_logging(): 26 | logging.basicConfig( 27 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 28 | datefmt="%Y-%m-%d %H:%M:%S", 29 | level=logging.INFO) 30 | 31 | 32 | def set_wandb(wandb_dir): 33 | os.environ["WANDB_WATCH"] = "all" 34 | os.makedirs(os.path.join(wandb_dir, "wandb"), exist_ok=True) 35 | wandb.init(project=os.getenv("WANDB_PROJECT", "gail-pytorch"), dir=wandb_dir) 36 | 37 | 38 | def parse_config(args_class, json_file): 39 | data = json.loads(Path(json_file).read_text()) 40 | 41 | # curr_run_output_dir = os.path.join(data["out_root"], data["dataset_dir"], data["model_name"]) 42 | # data["output_dir"] = os.path.join(curr_run_output_dir, "checkpoints") 43 | # data["logging_dir"] = os.path.join(curr_run_output_dir, default_logdir()) 44 | 45 | keys = {f.name for f in dataclasses.fields(args_class)} 46 | inputs = {k: v for k, v in data.items() if k in keys} 47 | return args_class(**inputs) 48 | -------------------------------------------------------------------------------- /gail/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from scipy import stats 8 | 9 | 10 | def smooth_fn(a, window=100): 11 | smooth_a = [] 12 | running_sum = 0 13 | for i in range(window): 14 | running_sum += a[i] 15 | smooth_a.append(running_sum / (i + 1)) 16 | for i in range(window, len(a)): 17 | running_sum += a[i] - a[i - window] 18 | smooth_a.append(running_sum / window) 19 | return smooth_a 20 | 21 | 22 | def main(args): 23 | root_dir = os.path.join(args.out_dir, args.env_id, args.model_name) 24 | 25 | record_map = dict() 26 | max_cnt = 0 27 | for i in range(1, 6): 28 | pickle_path = os.path.join(root_dir, "record{}.pkl".format(i)) 29 | with open(pickle_path, "rb") as handle: 30 | record = pickle.load(handle) 31 | max_cnt = max(max_cnt, len(record)) 32 | episode_rewards = smooth_fn([tup[1] for tup in record]) 33 | for it in range(len(episode_rewards)): 34 | if it not in record_map: 35 | record_map[it] = [] 36 | record_map[it].append(episode_rewards[it]) 37 | 38 | values = [] 39 | errors = [] 40 | for i in range(max_cnt): 41 | m = np.mean(record_map[i]) 42 | e = 2.0 * stats.sem(record_map[i]) # 95% confidence interval 43 | # e = 2.0 * np.std(record_map[i]) # 95% confidence interval 44 | values.append(m) 45 | errors.append(e) 46 | 47 | iteration_counts = list(range(max_cnt)) 48 | 49 | fig = plt.figure() 50 | markers, caps, bars = plt.errorbar(iteration_counts, values, errors, ecolor='lightskyblue') 51 | [bar.set_alpha(0.1) for bar in bars] 52 | [cap.set_alpha(0.1) for cap in caps] 53 | 54 | plt.ylabel("Episode Rewards") 55 | plt.xlabel("Training Iteration") 56 | fig.suptitle(args.model_name.upper()) 57 | fig.savefig(os.path.join(root_dir, "rewards.png"), dpi=600) 58 | 59 | 60 | if __name__ == "__main__": 61 | ap = argparse.ArgumentParser(description="Visualize Training") 62 | ap.add_argument("--env_id", type=str, default="CartPole-v0", help="gym env identifier") 63 | ap.add_argument("--model_name", type=str, default="ppo", help="algorithm name (ppo|gail)") 64 | ap.add_argument("--out_dir", type=str, default="../out", help="outputs parent directory") 65 | ap = ap.parse_args() 66 | main(ap) 67 | -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/discriminator.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/discriminator.ckpt -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/policy.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/policy.ckpt -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/record1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/record1.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/record2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/record2.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/record3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/record3.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/record4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/record4.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/record5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/record5.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/gail/rewards_gail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/gail/rewards_gail.png -------------------------------------------------------------------------------- /pretrained/CartPole-v0/ppo/policy.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/ppo/policy.ckpt -------------------------------------------------------------------------------- /pretrained/CartPole-v0/ppo/record1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/ppo/record1.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/ppo/record2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/ppo/record2.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/ppo/record3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/ppo/record3.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/ppo/record4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/ppo/record4.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/ppo/record5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/ppo/record5.pkl -------------------------------------------------------------------------------- /pretrained/CartPole-v0/ppo/rewards_ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/CartPole-v0/ppo/rewards_ppo.png -------------------------------------------------------------------------------- /pretrained/CartPole-v0/trajectory/actions.csv: -------------------------------------------------------------------------------- 1 | 1 2 | 0 3 | 0 4 | 0 5 | 1 6 | 1 7 | 1 8 | 0 9 | 0 10 | 1 11 | 0 12 | 1 13 | 1 14 | 0 15 | 1 16 | 1 17 | 0 18 | 1 19 | 1 20 | 0 21 | 0 22 | 1 23 | 0 24 | 0 25 | 0 26 | 0 27 | 1 28 | 1 29 | 1 30 | 0 31 | 0 32 | 1 33 | 0 34 | 1 35 | 1 36 | 1 37 | 0 38 | 0 39 | 1 40 | 1 41 | 0 42 | 0 43 | 0 44 | 0 45 | 1 46 | 0 47 | 1 48 | 1 49 | 1 50 | 1 51 | 1 52 | 0 53 | 0 54 | 0 55 | 1 56 | 0 57 | 1 58 | 0 59 | 1 60 | 1 61 | 1 62 | 0 63 | 0 64 | 0 65 | 1 66 | 1 67 | 0 68 | 0 69 | 0 70 | 1 71 | 0 72 | 1 73 | 0 74 | 1 75 | 1 76 | 0 77 | 1 78 | 0 79 | 1 80 | 1 81 | 0 82 | 0 83 | 0 84 | 1 85 | 1 86 | 0 87 | 1 88 | 0 89 | 0 90 | 1 91 | 0 92 | 1 93 | 0 94 | 0 95 | 1 96 | 0 97 | 1 98 | 1 99 | 1 100 | 0 101 | 0 102 | 0 103 | 0 104 | 1 105 | 1 106 | 1 107 | 0 108 | 0 109 | 0 110 | 1 111 | 0 112 | 1 113 | 0 114 | 1 115 | 1 116 | 1 117 | 1 118 | 0 119 | 0 120 | 0 121 | 1 122 | 1 123 | 0 124 | 1 125 | 1 126 | 0 127 | 0 128 | 1 129 | 0 130 | 1 131 | 1 132 | 0 133 | 0 134 | 1 135 | 1 136 | 0 137 | 1 138 | 1 139 | 0 140 | 0 141 | 0 142 | 0 143 | 1 144 | 0 145 | 1 146 | 0 147 | 1 148 | 1 149 | 0 150 | 1 151 | 0 152 | 0 153 | 1 154 | 0 155 | 1 156 | 0 157 | 1 158 | 0 159 | 0 160 | 1 161 | 1 162 | 0 163 | 0 164 | 1 165 | 1 166 | 0 167 | 1 168 | 1 169 | 0 170 | 0 171 | 0 172 | 1 173 | 0 174 | 1 175 | 0 176 | 0 177 | 1 178 | 0 179 | 0 180 | 1 181 | 1 182 | 0 183 | 0 184 | 1 185 | 1 186 | 1 187 | 0 188 | 1 189 | 0 190 | 1 191 | 0 192 | 1 193 | 0 194 | 1 195 | 0 196 | 1 197 | 0 198 | 0 199 | 1 200 | 0 201 | 1 202 | 0 203 | 0 204 | 1 205 | 1 206 | 1 207 | 0 208 | 1 209 | 0 210 | 1 211 | 0 212 | 0 213 | 0 214 | 1 215 | 0 216 | 1 217 | 1 218 | 0 219 | 0 220 | 1 221 | 0 222 | 0 223 | 0 224 | 1 225 | 1 226 | 0 227 | 1 228 | 1 229 | 0 230 | 0 231 | 0 232 | 1 233 | 1 234 | 0 235 | 0 236 | 0 237 | 1 238 | 0 239 | 1 240 | 0 241 | 1 242 | 0 243 | 1 244 | 1 245 | 0 246 | 1 247 | 0 248 | 1 249 | 0 250 | 0 251 | 0 252 | 0 253 | 1 254 | 1 255 | 0 256 | 0 257 | 1 258 | 1 259 | 1 260 | 0 261 | 1 262 | 0 263 | 1 264 | 0 265 | 1 266 | 1 267 | 0 268 | 0 269 | 0 270 | 1 271 | 0 272 | 0 273 | 1 274 | 1 275 | 1 276 | 1 277 | 0 278 | 1 279 | 0 280 | 1 281 | 0 282 | 1 283 | 1 284 | 0 285 | 0 286 | 0 287 | 1 288 | 1 289 | 0 290 | 0 291 | 1 292 | 1 293 | 1 294 | 1 295 | 0 296 | 0 297 | 1 298 | 0 299 | 1 300 | 0 301 | 0 302 | 1 303 | 0 304 | 1 305 | 1 306 | 0 307 | 1 308 | 1 309 | 0 310 | 1 311 | 0 312 | 1 313 | 1 314 | 0 315 | 0 316 | 0 317 | 1 318 | 0 319 | 1 320 | 0 321 | 1 322 | 1 323 | 0 324 | 1 325 | 0 326 | 1 327 | 0 328 | 1 329 | 0 330 | 1 331 | 0 332 | 1 333 | 0 334 | 0 335 | 1 336 | 1 337 | 1 338 | 0 339 | 1 340 | 0 341 | 1 342 | 1 343 | 0 344 | 0 345 | 0 346 | 0 347 | 1 348 | 1 349 | 0 350 | 1 351 | 0 352 | 1 353 | 0 354 | 0 355 | 1 356 | 1 357 | 0 358 | 1 359 | 0 360 | 0 361 | 1 362 | 1 363 | 1 364 | 0 365 | 0 366 | 0 367 | 1 368 | 1 369 | 1 370 | 0 371 | 1 372 | 1 373 | 0 374 | 1 375 | 1 376 | 0 377 | 0 378 | 1 379 | 1 380 | 0 381 | 0 382 | 1 383 | 1 384 | 1 385 | 0 386 | 0 387 | 1 388 | 1 389 | 0 390 | 1 391 | 0 392 | 0 393 | 0 394 | 1 395 | 1 396 | 1 397 | 1 398 | 0 399 | 0 400 | 1 401 | 0 402 | 1 403 | 1 404 | 0 405 | 0 406 | 1 407 | 1 408 | 0 409 | 0 410 | 1 411 | 0 412 | 1 413 | 0 414 | 0 415 | 1 416 | 0 417 | 1 418 | 1 419 | 0 420 | 1 421 | 1 422 | 0 423 | 1 424 | 0 425 | 0 426 | 1 427 | 1 428 | 0 429 | 1 430 | 0 431 | 0 432 | 0 433 | 0 434 | 0 435 | 0 436 | 1 437 | 1 438 | 1 439 | 1 440 | 0 441 | 0 442 | 0 443 | 1 444 | 1 445 | 1 446 | 0 447 | 1 448 | 1 449 | 0 450 | 0 451 | 0 452 | 1 453 | 0 454 | 1 455 | 1 456 | 0 457 | 0 458 | 0 459 | 1 460 | 1 461 | 0 462 | 1 463 | 0 464 | 1 465 | 0 466 | 0 467 | 1 468 | 1 469 | 1 470 | 0 471 | 0 472 | 1 473 | 1 474 | 0 475 | 0 476 | 1 477 | 0 478 | 1 479 | 1 480 | 0 481 | 1 482 | 0 483 | 1 484 | 1 485 | 1 486 | 0 487 | 0 488 | 1 489 | 1 490 | 0 491 | 0 492 | 0 493 | 1 494 | 1 495 | 1 496 | 0 497 | 1 498 | 1 499 | 0 500 | 1 501 | 0 502 | 1 503 | 0 504 | 0 505 | 1 506 | 0 507 | 0 508 | 1 509 | 0 510 | 1 511 | 0 512 | 0 513 | 0 514 | 1 515 | 1 516 | 1 517 | 1 518 | 1 519 | 0 520 | 1 521 | 0 522 | 0 523 | 1 524 | 0 525 | 1 526 | 0 527 | 0 528 | 1 529 | 1 530 | 0 531 | 0 532 | 1 533 | 1 534 | 0 535 | 1 536 | 0 537 | 1 538 | 0 539 | 1 540 | 0 541 | 0 542 | 1 543 | 0 544 | 1 545 | 0 546 | 0 547 | 1 548 | 0 549 | 1 550 | 1 551 | 0 552 | 0 553 | 0 554 | 1 555 | 1 556 | 0 557 | 0 558 | 1 559 | 0 560 | 1 561 | 0 562 | 1 563 | 1 564 | 1 565 | 0 566 | 0 567 | 0 568 | 0 569 | 1 570 | 1 571 | 0 572 | 1 573 | 0 574 | 1 575 | 0 576 | 0 577 | 1 578 | 0 579 | 0 580 | 0 581 | 1 582 | 1 583 | 1 584 | 1 585 | 0 586 | 1 587 | 0 588 | 1 589 | 0 590 | 0 591 | 0 592 | 1 593 | 1 594 | 0 595 | 0 596 | 1 597 | 1 598 | 0 599 | 1 600 | 0 601 | 0 602 | 1 603 | 1 604 | 1 605 | 0 606 | 1 607 | 1 608 | 0 609 | 0 610 | 1 611 | 0 612 | 0 613 | 1 614 | 1 615 | 1 616 | 0 617 | 0 618 | 0 619 | 0 620 | 0 621 | 1 622 | 1 623 | 0 624 | 1 625 | 0 626 | 1 627 | 1 628 | 0 629 | 0 630 | 1 631 | 0 632 | 1 633 | 0 634 | 1 635 | 1 636 | 1 637 | 0 638 | 0 639 | 0 640 | 1 641 | 1 642 | 0 643 | 0 644 | 0 645 | 1 646 | 0 647 | 1 648 | 0 649 | 0 650 | 1 651 | 1 652 | 1 653 | 0 654 | 1 655 | 1 656 | 1 657 | 0 658 | 0 659 | 0 660 | 0 661 | 1 662 | 0 663 | 1 664 | 1 665 | 0 666 | 0 667 | 1 668 | 0 669 | 0 670 | 1 671 | 0 672 | 1 673 | 1 674 | 0 675 | 1 676 | 0 677 | 1 678 | 1 679 | 0 680 | 0 681 | 1 682 | 1 683 | 0 684 | 0 685 | 1 686 | 0 687 | 0 688 | 1 689 | 0 690 | 1 691 | 0 692 | 1 693 | 1 694 | 1 695 | 0 696 | 0 697 | 1 698 | 1 699 | 1 700 | 0 701 | 0 702 | 1 703 | 0 704 | 0 705 | 1 706 | 1 707 | 0 708 | 0 709 | 1 710 | 0 711 | 1 712 | 0 713 | 1 714 | 0 715 | 1 716 | 0 717 | 0 718 | 1 719 | 1 720 | 0 721 | 1 722 | 1 723 | 1 724 | 1 725 | 0 726 | 1 727 | 0 728 | 0 729 | 0 730 | 1 731 | 0 732 | 0 733 | 0 734 | 1 735 | 0 736 | 0 737 | 0 738 | 1 739 | 1 740 | 0 741 | 1 742 | 1 743 | 1 744 | 0 745 | 0 746 | 1 747 | 1 748 | 0 749 | 0 750 | 0 751 | 1 752 | 1 753 | 0 754 | 0 755 | 1 756 | 1 757 | 0 758 | 1 759 | 0 760 | 1 761 | 1 762 | 1 763 | 0 764 | 0 765 | 0 766 | 0 767 | 1 768 | 1 769 | 1 770 | 1 771 | 0 772 | 1 773 | 0 774 | 1 775 | 1 776 | 1 777 | 0 778 | 0 779 | 1 780 | 0 781 | 0 782 | 1 783 | 1 784 | 0 785 | 1 786 | 0 787 | 1 788 | 1 789 | 0 790 | 0 791 | 1 792 | 0 793 | 0 794 | 1 795 | 1 796 | 1 797 | 0 798 | 0 799 | 0 800 | 0 801 | 0 802 | 1 803 | 1 804 | 1 805 | 0 806 | 0 807 | 0 808 | 1 809 | 0 810 | 0 811 | 0 812 | 1 813 | 1 814 | 1 815 | 1 816 | 0 817 | 0 818 | 1 819 | 0 820 | 1 821 | 0 822 | 1 823 | 0 824 | 0 825 | 1 826 | 0 827 | 1 828 | 1 829 | 0 830 | 0 831 | 1 832 | 1 833 | 0 834 | 0 835 | 1 836 | 0 837 | 1 838 | 1 839 | 1 840 | 0 841 | 0 842 | 0 843 | 1 844 | 0 845 | 1 846 | 0 847 | 0 848 | 0 849 | 1 850 | 0 851 | 0 852 | 1 853 | 0 854 | 1 855 | 0 856 | 1 857 | 1 858 | 0 859 | 1 860 | 1 861 | 0 862 | 1 863 | 1 864 | 0 865 | 0 866 | 1 867 | 0 868 | 0 869 | 0 870 | 1 871 | 0 872 | 0 873 | 0 874 | 1 875 | 0 876 | 1 877 | 0 878 | 1 879 | 0 880 | 1 881 | 0 882 | 1 883 | 0 884 | 0 885 | 1 886 | 1 887 | 0 888 | 1 889 | 0 890 | 1 891 | 0 892 | 1 893 | 0 894 | 1 895 | 1 896 | 0 897 | 0 898 | 1 899 | 0 900 | 1 901 | 1 902 | 0 903 | 1 904 | 1 905 | 1 906 | 0 907 | 0 908 | 1 909 | 0 910 | 1 911 | 0 912 | 0 913 | 1 914 | 1 915 | 0 916 | 1 917 | 0 918 | 0 919 | 1 920 | 1 921 | 1 922 | 1 923 | 1 924 | 0 925 | 0 926 | 0 927 | 1 928 | 1 929 | 0 930 | 0 931 | 1 932 | 1 933 | 1 934 | 1 935 | 0 936 | 0 937 | 1 938 | 1 939 | 0 940 | 1 941 | 0 942 | 1 943 | 0 944 | 0 945 | 1 946 | 1 947 | 0 948 | 1 949 | 1 950 | 0 951 | 0 952 | 1 953 | 1 954 | 1 955 | 0 956 | 0 957 | 1 958 | 0 959 | 1 960 | 1 961 | 1 962 | 0 963 | 0 964 | 1 965 | 1 966 | 0 967 | 0 968 | 1 969 | 0 970 | 0 971 | 1 972 | 1 973 | 1 974 | 0 975 | 0 976 | 1 977 | 0 978 | 1 979 | 1 980 | 0 981 | 1 982 | 0 983 | 1 984 | 1 985 | 0 986 | 1 987 | 1 988 | 0 989 | 0 990 | 0 991 | 1 992 | 0 993 | 0 994 | 1 995 | 1 996 | 0 997 | 0 998 | 1 999 | 0 1000 | 0 1001 | 0 1002 | 1 1003 | 0 1004 | 1 1005 | 1 1006 | 0 1007 | 1 1008 | 0 1009 | 0 1010 | 1 1011 | 0 1012 | 1 1013 | 0 1014 | 1 1015 | 1 1016 | 0 1017 | 1 1018 | 0 1019 | 1 1020 | 0 1021 | 1 1022 | 0 1023 | 0 1024 | 1 1025 | 0 1026 | 1 1027 | 0 1028 | 1 1029 | 0 1030 | 1 1031 | 1 1032 | 0 1033 | 0 1034 | 1 1035 | 1 1036 | 1 1037 | 0 1038 | 0 1039 | 1 1040 | 0 1041 | 0 1042 | 1 1043 | 1 1044 | 0 1045 | 1 1046 | 0 1047 | 1 1048 | 0 1049 | 1 1050 | 0 1051 | 0 1052 | 1 1053 | 1 1054 | 1 1055 | 1 1056 | 0 1057 | 0 1058 | 1 1059 | 0 1060 | 1 1061 | 0 1062 | 1 1063 | 1 1064 | 0 1065 | 1 1066 | 0 1067 | 1 1068 | 0 1069 | 1 1070 | 0 1071 | 0 1072 | 1 1073 | 1 1074 | 0 1075 | 1 1076 | 0 1077 | 1 1078 | 0 1079 | 1 1080 | 1 1081 | 0 1082 | 1 1083 | 1 1084 | 0 1085 | 0 1086 | 0 1087 | 1 1088 | 0 1089 | 1 1090 | 1 1091 | 0 1092 | 0 1093 | 0 1094 | 0 1095 | 0 1096 | 1 1097 | 0 1098 | 1 1099 | 1 1100 | 0 1101 | 1 1102 | 0 1103 | 1 1104 | 1 1105 | 0 1106 | 0 1107 | 1 1108 | 1 1109 | 0 1110 | 0 1111 | 0 1112 | 0 1113 | 1 1114 | 0 1115 | 1 1116 | 0 1117 | 1 1118 | 1 1119 | 1 1120 | 0 1121 | 0 1122 | 0 1123 | 1 1124 | 1 1125 | 0 1126 | 0 1127 | 1 1128 | 1 1129 | 0 1130 | 1 1131 | 0 1132 | 0 1133 | 0 1134 | 0 1135 | 1 1136 | 0 1137 | 1 1138 | 1 1139 | 0 1140 | 1 1141 | 0 1142 | 0 1143 | 0 1144 | 1 1145 | 1 1146 | 0 1147 | 1 1148 | 1 1149 | 0 1150 | 1 1151 | 0 1152 | 0 1153 | 1 1154 | 1 1155 | 0 1156 | 0 1157 | 1 1158 | 0 1159 | 1 1160 | 0 1161 | 0 1162 | 0 1163 | 1 1164 | 1 1165 | 0 1166 | 1 1167 | 0 1168 | 1 1169 | 1 1170 | 1 1171 | 0 1172 | 0 1173 | 1 1174 | 1 1175 | 0 1176 | 1 1177 | 1 1178 | 0 1179 | 0 1180 | 1 1181 | 0 1182 | 0 1183 | 1 1184 | 0 1185 | 0 1186 | 1 1187 | 1 1188 | 1 1189 | 0 1190 | 1 1191 | 1 1192 | 0 1193 | 0 1194 | 1 1195 | 0 1196 | 1 1197 | 1 1198 | 1 1199 | 0 1200 | 0 1201 | 1 1202 | 0 1203 | 0 1204 | 1 1205 | 1 1206 | 0 1207 | 0 1208 | 1 1209 | 0 1210 | 1 1211 | 1 1212 | 1 1213 | 0 1214 | 0 1215 | 1 1216 | 0 1217 | 0 1218 | 1 1219 | 0 1220 | 1 1221 | 1 1222 | 0 1223 | 1 1224 | 1 1225 | 0 1226 | 1 1227 | 0 1228 | 0 1229 | 1 1230 | 0 1231 | 1 1232 | 0 1233 | 1 1234 | 0 1235 | 0 1236 | 1 1237 | 1 1238 | 1 1239 | 0 1240 | 0 1241 | 0 1242 | 0 1243 | 1 1244 | 1 1245 | 0 1246 | 1 1247 | 1 1248 | 1 1249 | 0 1250 | 1 1251 | 0 1252 | 1 1253 | 0 1254 | 0 1255 | 1 1256 | 0 1257 | 1 1258 | 0 1259 | 1 1260 | 0 1261 | 0 1262 | 0 1263 | 1 1264 | 1 1265 | 1 1266 | 0 1267 | 0 1268 | 1 1269 | 0 1270 | 1 1271 | 1 1272 | 0 1273 | 0 1274 | 0 1275 | 1 1276 | 1 1277 | 1 1278 | 0 1279 | 1 1280 | 0 1281 | 0 1282 | 1 1283 | 0 1284 | 1 1285 | 1 1286 | 0 1287 | 1 1288 | 0 1289 | 0 1290 | 0 1291 | 1 1292 | 1 1293 | 1 1294 | 1 1295 | 0 1296 | 0 1297 | 0 1298 | 1 1299 | 0 1300 | 1 1301 | 1 1302 | 1 1303 | 0 1304 | 0 1305 | 0 1306 | 1 1307 | 1 1308 | 0 1309 | 1 1310 | 0 1311 | 1 1312 | 0 1313 | 0 1314 | 1 1315 | 1 1316 | 0 1317 | 0 1318 | 1 1319 | 1 1320 | 0 1321 | 1 1322 | 1 1323 | 0 1324 | 0 1325 | 0 1326 | 1 1327 | 1 1328 | 0 1329 | 1 1330 | 0 1331 | 0 1332 | 0 1333 | 1 1334 | 1 1335 | 1 1336 | 0 1337 | 1 1338 | 1 1339 | 0 1340 | 1 1341 | 0 1342 | 0 1343 | 1 1344 | 0 1345 | 0 1346 | 1 1347 | 1 1348 | 1 1349 | 0 1350 | 1 1351 | 0 1352 | 0 1353 | 1 1354 | 0 1355 | 0 1356 | 1 1357 | 1 1358 | 0 1359 | 1 1360 | 0 1361 | 0 1362 | 1 1363 | 0 1364 | 1 1365 | 1 1366 | 1 1367 | 0 1368 | 0 1369 | 0 1370 | 1 1371 | 1 1372 | 0 1373 | 0 1374 | 1 1375 | 1 1376 | 0 1377 | 1 1378 | 1 1379 | 0 1380 | 1 1381 | 0 1382 | 0 1383 | 1 1384 | 1 1385 | 1 1386 | 0 1387 | 1 1388 | 1 1389 | 1 1390 | 0 1391 | 0 1392 | 1 1393 | 0 1394 | 0 1395 | 0 1396 | 1 1397 | 0 1398 | 1 1399 | 0 1400 | 0 1401 | 1 1402 | 0 1403 | 0 1404 | 1 1405 | 0 1406 | 1 1407 | 0 1408 | 1 1409 | 0 1410 | 1 1411 | 1 1412 | 0 1413 | 1 1414 | 0 1415 | 0 1416 | 0 1417 | 1 1418 | 0 1419 | 1 1420 | 1 1421 | 0 1422 | 1 1423 | 1 1424 | 0 1425 | 0 1426 | 1 1427 | 0 1428 | 1 1429 | 0 1430 | 0 1431 | 0 1432 | 1 1433 | 1 1434 | 1 1435 | 1 1436 | 0 1437 | 0 1438 | 1 1439 | 0 1440 | 1 1441 | 1 1442 | 0 1443 | 0 1444 | 0 1445 | 1 1446 | 0 1447 | 1 1448 | 1 1449 | 0 1450 | 1 1451 | 0 1452 | 1 1453 | 1 1454 | 0 1455 | 0 1456 | 1 1457 | 1 1458 | 0 1459 | 0 1460 | 1 1461 | 1 1462 | 0 1463 | 1 1464 | 1 1465 | 0 1466 | 0 1467 | 1 1468 | 0 1469 | 0 1470 | 0 1471 | 1 1472 | 1 1473 | 0 1474 | 1 1475 | 1 1476 | 0 1477 | 1 1478 | 0 1479 | 0 1480 | 0 1481 | 1 1482 | 1 1483 | 0 1484 | 1 1485 | 0 1486 | 1 1487 | 0 1488 | 0 1489 | 1 1490 | 0 1491 | 1 1492 | 0 1493 | 1 1494 | 0 1495 | 1 1496 | 0 1497 | 0 1498 | 0 1499 | 1 1500 | 1 1501 | 0 1502 | 1 1503 | 0 1504 | 1 1505 | 1 1506 | 1 1507 | 0 1508 | 0 1509 | 0 1510 | 1 1511 | 0 1512 | 0 1513 | 1 1514 | 0 1515 | 1 1516 | 1 1517 | 1 1518 | 0 1519 | 0 1520 | 0 1521 | 0 1522 | 0 1523 | 0 1524 | 1 1525 | 1 1526 | 1 1527 | 0 1528 | 1 1529 | 0 1530 | 1 1531 | 0 1532 | 0 1533 | 0 1534 | 0 1535 | 1 1536 | 1 1537 | 1 1538 | 0 1539 | 0 1540 | 1 1541 | 0 1542 | 1 1543 | 1 1544 | 0 1545 | 1 1546 | 1 1547 | 0 1548 | 1 1549 | 1 1550 | 0 1551 | 0 1552 | 1 1553 | 1 1554 | 0 1555 | 0 1556 | 0 1557 | 1 1558 | 1 1559 | 1 1560 | 0 1561 | 0 1562 | 1 1563 | 0 1564 | 1 1565 | 1 1566 | 1 1567 | 1 1568 | 0 1569 | 1 1570 | 0 1571 | 0 1572 | 1 1573 | 0 1574 | 1 1575 | 1 1576 | 1 1577 | 1 1578 | 1 1579 | 0 1580 | 0 1581 | 0 1582 | 1 1583 | 1 1584 | 0 1585 | 0 1586 | 1 1587 | 0 1588 | 1 1589 | 1 1590 | 1 1591 | 1 1592 | 0 1593 | 1 1594 | 0 1595 | 1 1596 | 0 1597 | 0 1598 | 0 1599 | 1 1600 | 1 1601 | 1 1602 | 0 1603 | 1 1604 | 1 1605 | 0 1606 | 0 1607 | 1 1608 | 1 1609 | 1 1610 | 0 1611 | 0 1612 | 0 1613 | 0 1614 | 0 1615 | 0 1616 | 1 1617 | 0 1618 | 1 1619 | 0 1620 | 1 1621 | 1 1622 | 0 1623 | 0 1624 | 1 1625 | 0 1626 | 0 1627 | 1 1628 | 1 1629 | 1 1630 | 0 1631 | 0 1632 | 0 1633 | 1 1634 | 1 1635 | 0 1636 | 1 1637 | 1 1638 | 1 1639 | 0 1640 | 0 1641 | 0 1642 | 1 1643 | 0 1644 | 1 1645 | 1 1646 | 1 1647 | 0 1648 | 1 1649 | 0 1650 | 0 1651 | 1 1652 | 1 1653 | 1 1654 | 0 1655 | 1 1656 | 0 1657 | 0 1658 | 0 1659 | 1 1660 | 1 1661 | 0 1662 | 1 1663 | 0 1664 | 0 1665 | 1 1666 | 0 1667 | 0 1668 | 0 1669 | 1 1670 | 1 1671 | 0 1672 | 0 1673 | 1 1674 | 1 1675 | 1 1676 | 0 1677 | 1 1678 | 1 1679 | 0 1680 | 1 1681 | 1 1682 | 1 1683 | 0 1684 | 0 1685 | 0 1686 | 1 1687 | 1 1688 | 1 1689 | 0 1690 | 1 1691 | 1 1692 | 0 1693 | 0 1694 | 1 1695 | 0 1696 | 0 1697 | 1 1698 | 0 1699 | 1 1700 | 0 1701 | 1 1702 | 1 1703 | 1 1704 | 1 1705 | 0 1706 | 0 1707 | 0 1708 | 0 1709 | 1 1710 | 0 1711 | 0 1712 | 1 1713 | 0 1714 | 1 1715 | 0 1716 | 1 1717 | 0 1718 | 1 1719 | 1 1720 | 0 1721 | 0 1722 | 1 1723 | 0 1724 | 0 1725 | 1 1726 | 1 1727 | 0 1728 | 0 1729 | 1 1730 | 0 1731 | 1 1732 | 0 1733 | 1 1734 | 1 1735 | 0 1736 | 0 1737 | 1 1738 | 1 1739 | 0 1740 | 0 1741 | 0 1742 | 1 1743 | 0 1744 | 1 1745 | 1 1746 | 1 1747 | 0 1748 | 0 1749 | 0 1750 | 0 1751 | 1 1752 | 1 1753 | 0 1754 | 1 1755 | 1 1756 | 0 1757 | 1 1758 | 0 1759 | 0 1760 | 1 1761 | 0 1762 | 1 1763 | 0 1764 | 1 1765 | 0 1766 | 1 1767 | 1 1768 | 0 1769 | 1 1770 | 0 1771 | 0 1772 | 1 1773 | 1 1774 | 1 1775 | 0 1776 | 1 1777 | 0 1778 | 1 1779 | 0 1780 | 0 1781 | 0 1782 | 0 1783 | 1 1784 | 1 1785 | 1 1786 | 1 1787 | 1 1788 | 0 1789 | 1 1790 | 1 1791 | 0 1792 | 0 1793 | 0 1794 | 1 1795 | 0 1796 | 1 1797 | 0 1798 | 0 1799 | 0 1800 | 1 1801 | 1 1802 | 0 1803 | 1 1804 | 0 1805 | 0 1806 | 1 1807 | 1 1808 | 0 1809 | 1 1810 | 0 1811 | 0 1812 | 1 1813 | 1 1814 | 0 1815 | 0 1816 | 1 1817 | 0 1818 | 1 1819 | 0 1820 | 1 1821 | 1 1822 | 0 1823 | 1 1824 | 0 1825 | 1 1826 | 0 1827 | 1 1828 | 0 1829 | 0 1830 | 0 1831 | 1 1832 | 0 1833 | 0 1834 | 1 1835 | 1 1836 | 0 1837 | 1 1838 | 0 1839 | 0 1840 | 1 1841 | 0 1842 | 0 1843 | 1 1844 | 1 1845 | 1 1846 | 0 1847 | 0 1848 | 1 1849 | 1 1850 | 0 1851 | 0 1852 | 1 1853 | 0 1854 | 1 1855 | 0 1856 | 0 1857 | 0 1858 | 1 1859 | 0 1860 | 0 1861 | 1 1862 | 1 1863 | 1 1864 | 1 1865 | 1 1866 | 0 1867 | 0 1868 | 0 1869 | 1 1870 | 1 1871 | 0 1872 | 1 1873 | 0 1874 | 1 1875 | 1 1876 | 0 1877 | 1 1878 | 0 1879 | 1 1880 | 0 1881 | 1 1882 | 1 1883 | 0 1884 | 1 1885 | 1 1886 | 0 1887 | 0 1888 | 0 1889 | 1 1890 | 1 1891 | 0 1892 | 1 1893 | 1 1894 | 1 1895 | 1 1896 | 0 1897 | 0 1898 | 0 1899 | 0 1900 | 0 1901 | 1 1902 | 1 1903 | 1 1904 | 0 1905 | 0 1906 | 1 1907 | 1 1908 | 0 1909 | 0 1910 | 1 1911 | 0 1912 | 1 1913 | 1 1914 | 0 1915 | 0 1916 | 0 1917 | 1 1918 | 1 1919 | 0 1920 | 1 1921 | 1 1922 | 0 1923 | 1 1924 | 0 1925 | 1 1926 | 1 1927 | 0 1928 | 0 1929 | 1 1930 | 1 1931 | 0 1932 | 1 1933 | 0 1934 | 1 1935 | 0 1936 | 1 1937 | 1 1938 | 0 1939 | 1 1940 | 0 1941 | 1 1942 | 1 1943 | 0 1944 | 0 1945 | 0 1946 | 1 1947 | 1 1948 | 0 1949 | 0 1950 | 0 1951 | 1 1952 | 1 1953 | 0 1954 | 0 1955 | 1 1956 | 0 1957 | 1 1958 | 0 1959 | 1 1960 | 0 1961 | 1 1962 | 1 1963 | 0 1964 | 0 1965 | 1 1966 | 0 1967 | 0 1968 | 0 1969 | 1 1970 | 1 1971 | 1 1972 | 1 1973 | 0 1974 | 0 1975 | 0 1976 | 1 1977 | 0 1978 | 1 1979 | 0 1980 | 1 1981 | 0 1982 | 0 1983 | 1 1984 | 0 1985 | 1 1986 | 0 1987 | 1 1988 | 0 1989 | 1 1990 | 0 1991 | 0 1992 | 0 1993 | 1 1994 | 0 1995 | 1 1996 | 1 1997 | 0 1998 | 1 1999 | 0 2000 | 1 2001 | 0 2002 | 1 2003 | 0 2004 | 1 2005 | 0 2006 | 1 2007 | 0 2008 | 1 2009 | 1 2010 | 0 2011 | 0 2012 | 0 2013 | 1 2014 | 0 2015 | 0 2016 | 1 2017 | 1 2018 | 1 2019 | 0 2020 | 1 2021 | 0 2022 | 1 2023 | 1 2024 | 0 2025 | 1 2026 | 0 2027 | 1 2028 | 0 2029 | 0 2030 | 0 2031 | 0 2032 | 1 2033 | 1 2034 | 0 2035 | 1 2036 | 1 2037 | 0 2038 | 0 2039 | 0 2040 | 0 2041 | 1 2042 | 1 2043 | 1 2044 | 0 2045 | 0 2046 | 0 2047 | 1 2048 | 0 2049 | 0 2050 | 1 2051 | 0 2052 | 1 2053 | 0 2054 | 1 2055 | 0 2056 | 1 2057 | 1 2058 | 0 2059 | 1 2060 | 1 2061 | 0 2062 | 0 2063 | 0 2064 | 1 2065 | 0 2066 | 0 2067 | 1 2068 | 0 2069 | 1 2070 | 0 2071 | 1 2072 | 1 2073 | 0 2074 | 1 2075 | 0 2076 | 0 2077 | 1 2078 | 0 2079 | 1 2080 | 1 2081 | 1 2082 | 0 2083 | 0 2084 | 0 2085 | 1 2086 | 0 2087 | 0 2088 | 1 2089 | 1 2090 | 1 2091 | 1 2092 | 0 2093 | 1 2094 | 0 2095 | 1 2096 | 1 2097 | 0 2098 | 1 2099 | 0 2100 | 0 2101 | 1 2102 | 1 2103 | 1 2104 | 1 2105 | 1 2106 | 0 2107 | 0 2108 | 0 2109 | 1 2110 | 0 2111 | 1 2112 | 1 2113 | 0 2114 | 1 2115 | 1 2116 | 1 2117 | 0 2118 | 0 2119 | 1 2120 | 0 2121 | 1 2122 | 0 2123 | 1 2124 | 0 2125 | 1 2126 | 1 2127 | 0 2128 | 0 2129 | 1 2130 | 0 2131 | 1 2132 | 0 2133 | 1 2134 | 0 2135 | 1 2136 | 1 2137 | 1 2138 | 0 2139 | 1 2140 | 0 2141 | 0 2142 | 1 2143 | 1 2144 | 0 2145 | 0 2146 | 0 2147 | 1 2148 | 1 2149 | 0 2150 | 1 2151 | 0 2152 | 1 2153 | 0 2154 | 0 2155 | 1 2156 | 1 2157 | 0 2158 | 0 2159 | 1 2160 | 0 2161 | 0 2162 | 1 2163 | 0 2164 | 1 2165 | 0 2166 | 1 2167 | 0 2168 | 0 2169 | 1 2170 | 0 2171 | 1 2172 | 0 2173 | 1 2174 | 0 2175 | 1 2176 | 1 2177 | 1 2178 | 0 2179 | 0 2180 | 1 2181 | 0 2182 | 0 2183 | 1 2184 | 0 2185 | 1 2186 | 0 2187 | 0 2188 | 0 2189 | 1 2190 | 1 2191 | 0 2192 | 1 2193 | 0 2194 | 0 2195 | 1 2196 | 1 2197 | 1 2198 | 1 2199 | 0 2200 | 0 2201 | 0 2202 | 1 2203 | 1 2204 | 0 2205 | 0 2206 | 1 2207 | 0 2208 | 1 2209 | 1 2210 | 0 2211 | 1 2212 | 1 2213 | 0 2214 | 0 2215 | 1 2216 | 0 2217 | 1 2218 | 0 2219 | 1 2220 | 1 2221 | 0 2222 | 1 2223 | 1 2224 | 0 2225 | 1 2226 | 0 2227 | 1 2228 | 1 2229 | 0 2230 | 1 2231 | 1 2232 | 0 2233 | 0 2234 | 1 2235 | 1 2236 | 0 2237 | 0 2238 | 0 2239 | 1 2240 | 1 2241 | 0 2242 | 0 2243 | 1 2244 | 0 2245 | 0 2246 | 1 2247 | 1 2248 | 0 2249 | 1 2250 | 0 2251 | 1 2252 | 1 2253 | 0 2254 | 0 2255 | 0 2256 | 1 2257 | 1 2258 | 0 2259 | 1 2260 | 0 2261 | 1 2262 | 0 2263 | 0 2264 | 0 2265 | 1 2266 | 1 2267 | 0 2268 | 1 2269 | 0 2270 | 0 2271 | 0 2272 | 0 2273 | 1 2274 | 1 2275 | 0 2276 | 1 2277 | 1 2278 | 1 2279 | 0 2280 | 1 2281 | 1 2282 | 0 2283 | 0 2284 | 0 2285 | 0 2286 | 1 2287 | 0 2288 | 1 2289 | 1 2290 | 0 2291 | 0 2292 | 1 2293 | 0 2294 | 1 2295 | 0 2296 | 1 2297 | 0 2298 | 1 2299 | 1 2300 | 0 2301 | 0 2302 | 0 2303 | 0 2304 | 1 2305 | 1 2306 | 1 2307 | 0 2308 | 1 2309 | 0 2310 | 0 2311 | 1 2312 | 1 2313 | 0 2314 | 0 2315 | 0 2316 | 1 2317 | 1 2318 | 0 2319 | 1 2320 | 1 2321 | 1 2322 | 0 2323 | 1 2324 | 1 2325 | 0 2326 | 0 2327 | 1 2328 | 1 2329 | 0 2330 | 0 2331 | 0 2332 | 1 2333 | 1 2334 | 1 2335 | 0 2336 | 1 2337 | 1 2338 | 0 2339 | 0 2340 | 0 2341 | 0 2342 | 1 2343 | 1 2344 | 1 2345 | 1 2346 | 0 2347 | 1 2348 | 0 2349 | 0 2350 | 0 2351 | 1 2352 | 0 2353 | 1 2354 | 1 2355 | 1 2356 | 0 2357 | 0 2358 | 1 2359 | 1 2360 | 0 2361 | 0 2362 | 1 2363 | 1 2364 | 0 2365 | 1 2366 | 0 2367 | 0 2368 | 1 2369 | 0 2370 | 1 2371 | 1 2372 | 0 2373 | 0 2374 | 0 2375 | 1 2376 | 0 2377 | 0 2378 | 1 2379 | 1 2380 | 0 2381 | 0 2382 | 1 2383 | 1 2384 | 0 2385 | 1 2386 | 0 2387 | 1 2388 | 0 2389 | 0 2390 | 1 2391 | 1 2392 | 1 2393 | 0 2394 | 1 2395 | 0 2396 | 1 2397 | 0 2398 | 1 2399 | 0 2400 | 0 2401 | 0 2402 | 1 2403 | 0 2404 | 0 2405 | 1 2406 | 1 2407 | 1 2408 | 0 2409 | 1 2410 | 0 2411 | 0 2412 | 0 2413 | 0 2414 | 1 2415 | 1 2416 | 0 2417 | 1 2418 | 1 2419 | 0 2420 | 0 2421 | 0 2422 | 1 2423 | 1 2424 | 1 2425 | 0 2426 | 1 2427 | 0 2428 | 0 2429 | 1 2430 | 1 2431 | 1 2432 | 0 2433 | 0 2434 | 0 2435 | 1 2436 | 1 2437 | 1 2438 | 1 2439 | 0 2440 | 1 2441 | 0 2442 | 0 2443 | 1 2444 | 0 2445 | 0 2446 | 1 2447 | 0 2448 | 1 2449 | 0 2450 | 1 2451 | 1 2452 | 0 2453 | 0 2454 | 0 2455 | 1 2456 | 1 2457 | 0 2458 | 0 2459 | 1 2460 | 0 2461 | 0 2462 | 1 2463 | 0 2464 | 1 2465 | 1 2466 | 0 2467 | 1 2468 | 0 2469 | 1 2470 | 0 2471 | 1 2472 | 0 2473 | 1 2474 | 0 2475 | 0 2476 | 1 2477 | 1 2478 | 1 2479 | 0 2480 | 1 2481 | 0 2482 | 0 2483 | 0 2484 | 0 2485 | 1 2486 | 1 2487 | 0 2488 | 1 2489 | 0 2490 | 1 2491 | 0 2492 | 1 2493 | 1 2494 | 0 2495 | 0 2496 | 1 2497 | 0 2498 | 1 2499 | 0 2500 | 0 2501 | 1 2502 | 1 2503 | 0 2504 | 0 2505 | 1 2506 | 0 2507 | 0 2508 | 1 2509 | 0 2510 | 1 2511 | 0 2512 | 0 2513 | 1 2514 | 0 2515 | 0 2516 | 1 2517 | 1 2518 | 1 2519 | 1 2520 | 0 2521 | 1 2522 | 0 2523 | 0 2524 | 0 2525 | 0 2526 | 1 2527 | 1 2528 | 1 2529 | 1 2530 | 0 2531 | 1 2532 | 0 2533 | 1 2534 | 0 2535 | 0 2536 | 1 2537 | 1 2538 | 1 2539 | 0 2540 | 1 2541 | 0 2542 | 0 2543 | 1 2544 | 1 2545 | 0 2546 | 1 2547 | 1 2548 | 0 2549 | 1 2550 | 1 2551 | 0 2552 | 0 2553 | 1 2554 | 0 2555 | 1 2556 | 0 2557 | 1 2558 | 0 2559 | 0 2560 | 1 2561 | 0 2562 | 0 2563 | 1 2564 | 1 2565 | 1 2566 | 0 2567 | 0 2568 | 1 2569 | 1 2570 | 1 2571 | 0 2572 | 0 2573 | 1 2574 | 1 2575 | 0 2576 | 1 2577 | 0 2578 | 1 2579 | 0 2580 | 0 2581 | 1 2582 | 0 2583 | 0 2584 | 1 2585 | 1 2586 | 1 2587 | 1 2588 | 0 2589 | 0 2590 | 0 2591 | 0 2592 | 1 2593 | 0 2594 | 1 2595 | 0 2596 | 1 2597 | 0 2598 | 0 2599 | 1 2600 | 1 2601 | 1 2602 | 0 2603 | 0 2604 | 1 2605 | 0 2606 | 1 2607 | 0 2608 | 1 2609 | 1 2610 | 0 2611 | 0 2612 | 1 2613 | 0 2614 | 1 2615 | 0 2616 | 1 2617 | 0 2618 | 0 2619 | 1 2620 | 0 2621 | 1 2622 | 0 2623 | 1 2624 | 1 2625 | 1 2626 | 0 2627 | 1 2628 | 0 2629 | 0 2630 | 1 2631 | 0 2632 | 1 2633 | 0 2634 | 1 2635 | 0 2636 | 1 2637 | 0 2638 | 0 2639 | 1 2640 | 1 2641 | 0 2642 | 1 2643 | 0 2644 | 1 2645 | 1 2646 | 0 2647 | 1 2648 | 0 2649 | 0 2650 | 1 2651 | 1 2652 | 0 2653 | 0 2654 | 0 2655 | 1 2656 | 1 2657 | 0 2658 | 1 2659 | 1 2660 | 1 2661 | 0 2662 | 1 2663 | 0 2664 | 0 2665 | 1 2666 | 1 2667 | 0 2668 | 0 2669 | 1 2670 | 1 2671 | 0 2672 | 1 2673 | 0 2674 | 0 2675 | 1 2676 | 1 2677 | 0 2678 | 1 2679 | 0 2680 | 0 2681 | 1 2682 | 0 2683 | 0 2684 | 1 2685 | 1 2686 | 1 2687 | 0 2688 | 0 2689 | 1 2690 | 0 2691 | 0 2692 | 1 2693 | 0 2694 | 1 2695 | 1 2696 | 0 2697 | 1 2698 | 0 2699 | 0 2700 | 1 2701 | 0 2702 | 0 2703 | 1 2704 | 0 2705 | 0 2706 | 1 2707 | 0 2708 | 1 2709 | 1 2710 | 0 2711 | 0 2712 | 1 2713 | 0 2714 | 0 2715 | 1 2716 | 1 2717 | 0 2718 | 0 2719 | 0 2720 | 1 2721 | 1 2722 | 0 2723 | 0 2724 | 1 2725 | 1 2726 | 1 2727 | 0 2728 | 0 2729 | 1 2730 | 1 2731 | 0 2732 | 1 2733 | 1 2734 | 0 2735 | 0 2736 | 1 2737 | 0 2738 | 1 2739 | 0 2740 | 1 2741 | 1 2742 | 0 2743 | 1 2744 | 0 2745 | 0 2746 | 1 2747 | 1 2748 | 0 2749 | 1 2750 | 0 2751 | 1 2752 | 0 2753 | 0 2754 | 1 2755 | 1 2756 | 0 2757 | 0 2758 | 1 2759 | 1 2760 | 1 2761 | 0 2762 | 1 2763 | 1 2764 | 0 2765 | 1 2766 | 0 2767 | 1 2768 | 1 2769 | 1 2770 | 0 2771 | 0 2772 | 1 2773 | 1 2774 | 0 2775 | 0 2776 | 0 2777 | 1 2778 | 0 2779 | 1 2780 | 1 2781 | 1 2782 | 0 2783 | 0 2784 | 1 2785 | 0 2786 | 1 2787 | 0 2788 | 1 2789 | 0 2790 | 1 2791 | 1 2792 | 0 2793 | 0 2794 | 0 2795 | 1 2796 | 0 2797 | 1 2798 | 1 2799 | 0 2800 | 1 2801 | 1 2802 | 1 2803 | 0 2804 | 0 2805 | 1 2806 | 0 2807 | 0 2808 | 1 2809 | 0 2810 | 1 2811 | 0 2812 | 0 2813 | 1 2814 | 1 2815 | 1 2816 | 1 2817 | 0 2818 | 1 2819 | 0 2820 | 0 2821 | 0 2822 | 1 2823 | 1 2824 | 0 2825 | 1 2826 | 1 2827 | 1 2828 | 0 2829 | 1 2830 | 0 2831 | 0 2832 | 1 2833 | 1 2834 | 0 2835 | 1 2836 | 0 2837 | 1 2838 | 0 2839 | 0 2840 | 1 2841 | 1 2842 | 0 2843 | 1 2844 | 0 2845 | 1 2846 | 1 2847 | 1 2848 | 0 2849 | 0 2850 | 1 2851 | 1 2852 | 0 2853 | 0 2854 | 1 2855 | 1 2856 | 0 2857 | 0 2858 | 1 2859 | 0 2860 | 0 2861 | 1 2862 | 1 2863 | 0 2864 | 0 2865 | 0 2866 | 1 2867 | 1 2868 | 0 2869 | 1 2870 | 0 2871 | 0 2872 | 1 2873 | 0 2874 | 0 2875 | 1 2876 | 0 2877 | 0 2878 | 1 2879 | 0 2880 | 1 2881 | 1 2882 | 0 2883 | 0 2884 | 0 2885 | 1 2886 | 0 2887 | 0 2888 | 1 2889 | 1 2890 | 0 2891 | 0 2892 | 0 2893 | 1 2894 | 1 2895 | 0 2896 | 0 2897 | 1 2898 | 1 2899 | 1 2900 | 1 2901 | 0 2902 | 0 2903 | 0 2904 | 1 2905 | 1 2906 | 0 2907 | 0 2908 | 0 2909 | 1 2910 | 1 2911 | 0 2912 | 1 2913 | 1 2914 | 0 2915 | 1 2916 | 0 2917 | 1 2918 | 0 2919 | 0 2920 | 0 2921 | 1 2922 | 1 2923 | 0 2924 | 1 2925 | 0 2926 | 1 2927 | 0 2928 | 1 2929 | 1 2930 | 0 2931 | 1 2932 | 0 2933 | 0 2934 | 0 2935 | 0 2936 | 1 2937 | 1 2938 | 0 2939 | 0 2940 | 0 2941 | 1 2942 | 1 2943 | 1 2944 | 0 2945 | 1 2946 | 0 2947 | 0 2948 | 1 2949 | 1 2950 | 0 2951 | 0 2952 | 1 2953 | 1 2954 | 0 2955 | 1 2956 | 0 2957 | 1 2958 | 0 2959 | 1 2960 | 1 2961 | 0 2962 | 1 2963 | 1 2964 | 0 2965 | 0 2966 | 1 2967 | 0 2968 | 1 2969 | 0 2970 | 1 2971 | 1 2972 | 0 2973 | 1 2974 | 0 2975 | 1 2976 | 0 2977 | 1 2978 | 0 2979 | 1 2980 | 1 2981 | 0 2982 | 1 2983 | 0 2984 | 0 2985 | 1 2986 | 1 2987 | 0 2988 | 1 2989 | 1 2990 | 0 2991 | 1 2992 | 0 2993 | 0 2994 | 1 2995 | 1 2996 | 1 2997 | 0 2998 | 1 2999 | 0 3000 | 1 3001 | 1 3002 | 0 3003 | 1 3004 | 0 3005 | 0 3006 | 0 3007 | 1 3008 | 1 3009 | 0 3010 | 0 3011 | 1 3012 | 1 3013 | 0 3014 | 0 3015 | 1 3016 | 1 3017 | 0 3018 | 1 3019 | 0 3020 | 0 3021 | 1 3022 | 1 3023 | 0 3024 | 1 3025 | 0 3026 | 1 3027 | 0 3028 | 0 3029 | 1 3030 | 1 3031 | 0 3032 | 1 3033 | 1 3034 | 1 3035 | 0 3036 | 0 3037 | 0 3038 | 1 3039 | 0 3040 | 1 3041 | 0 3042 | 0 3043 | 1 3044 | 1 3045 | 1 3046 | 1 3047 | 0 3048 | 0 3049 | 0 3050 | 0 3051 | 0 3052 | 1 3053 | 1 3054 | 0 3055 | 0 3056 | 1 3057 | 1 3058 | 0 3059 | 1 3060 | 0 3061 | 0 3062 | 1 3063 | 0 3064 | 1 3065 | 0 3066 | 1 3067 | 0 3068 | 0 3069 | 0 3070 | 1 3071 | 0 3072 | 1 3073 | 1 3074 | 0 3075 | 1 3076 | 1 3077 | 1 3078 | 0 3079 | 0 3080 | 0 3081 | 1 3082 | 1 3083 | 0 3084 | 1 3085 | 1 3086 | 0 3087 | 1 3088 | 1 3089 | 0 3090 | 0 3091 | 1 3092 | 0 3093 | 1 3094 | 0 3095 | 0 3096 | 0 3097 | 1 3098 | 0 3099 | 1 3100 | 0 3101 | 1 3102 | 1 3103 | 1 3104 | 1 3105 | 0 3106 | 1 3107 | 1 3108 | 1 3109 | 1 3110 | 0 3111 | 0 3112 | 0 3113 | 0 3114 | 1 3115 | 0 3116 | 1 3117 | 1 3118 | 1 3119 | 1 3120 | 0 3121 | 0 3122 | 1 3123 | 1 3124 | 0 3125 | 1 3126 | 0 3127 | 1 3128 | 0 3129 | 0 3130 | 1 3131 | 0 3132 | 1 3133 | 0 3134 | 1 3135 | 0 3136 | 1 3137 | 1 3138 | 0 3139 | 0 3140 | 0 3141 | 1 3142 | 1 3143 | 0 3144 | 1 3145 | 0 3146 | 0 3147 | 0 3148 | 1 3149 | 0 3150 | 0 3151 | 1 3152 | 0 3153 | 1 3154 | 1 3155 | 0 3156 | 1 3157 | 0 3158 | 1 3159 | 0 3160 | 1 3161 | 1 3162 | 0 3163 | 1 3164 | 0 3165 | 0 3166 | 1 3167 | 1 3168 | 0 3169 | 0 3170 | 1 3171 | 1 3172 | 0 3173 | 1 3174 | 0 3175 | 0 3176 | 1 3177 | 0 3178 | 1 3179 | 1 3180 | 1 3181 | 0 3182 | 0 3183 | 1 3184 | 0 3185 | 1 3186 | 1 3187 | 0 3188 | 0 3189 | 1 3190 | 0 3191 | 1 3192 | 1 3193 | 0 3194 | 0 3195 | 0 3196 | 1 3197 | 1 3198 | 0 3199 | 1 3200 | 0 3201 | 0 3202 | 1 3203 | 0 3204 | 1 3205 | 1 3206 | 0 3207 | 1 3208 | 0 3209 | 0 3210 | 1 3211 | 1 3212 | 0 3213 | 1 3214 | 0 3215 | 1 3216 | 0 3217 | 1 3218 | 1 3219 | 0 3220 | 1 3221 | 1 3222 | 0 3223 | 0 3224 | 0 3225 | 0 3226 | 1 3227 | 1 3228 | 1 3229 | 1 3230 | 0 3231 | 0 3232 | 1 3233 | 1 3234 | 0 3235 | 0 3236 | 0 3237 | 1 3238 | 1 3239 | 0 3240 | 1 3241 | 0 3242 | 0 3243 | 1 3244 | 1 3245 | 1 3246 | 0 3247 | 1 3248 | 0 3249 | 0 3250 | 0 3251 | 1 3252 | 0 3253 | 1 3254 | 1 3255 | 0 3256 | 0 3257 | 1 3258 | 0 3259 | 0 3260 | 0 3261 | 1 3262 | 1 3263 | 0 3264 | 0 3265 | 1 3266 | 0 3267 | 0 3268 | 0 3269 | 1 3270 | 0 3271 | 1 3272 | 1 3273 | 1 3274 | 0 3275 | 0 3276 | 0 3277 | 1 3278 | 0 3279 | 0 3280 | 1 3281 | 1 3282 | 1 3283 | 0 3284 | 0 3285 | 0 3286 | 1 3287 | 0 3288 | 1 3289 | 0 3290 | 1 3291 | 0 3292 | 1 3293 | 1 3294 | 0 3295 | 1 3296 | 1 3297 | 0 3298 | 1 3299 | 0 3300 | 1 3301 | 0 3302 | 1 3303 | 1 3304 | 0 3305 | 0 3306 | 1 3307 | 0 3308 | 1 3309 | 0 3310 | 0 3311 | 0 3312 | 1 3313 | 0 3314 | 1 3315 | 1 3316 | 1 3317 | 1 3318 | 1 3319 | 1 3320 | 0 3321 | 1 3322 | 0 3323 | 1 3324 | 1 3325 | 0 3326 | 0 3327 | 1 3328 | 1 3329 | 0 3330 | 1 3331 | 1 3332 | 0 3333 | 0 3334 | 0 3335 | 1 3336 | 1 3337 | 0 3338 | 0 3339 | 1 3340 | 1 3341 | 0 3342 | 1 3343 | 1 3344 | 1 3345 | 0 3346 | 0 3347 | 0 3348 | 0 3349 | 1 3350 | 1 3351 | 0 3352 | 1 3353 | 0 3354 | 1 3355 | 0 3356 | 0 3357 | 1 3358 | 1 3359 | 0 3360 | 1 3361 | 1 3362 | 0 3363 | 1 3364 | 0 3365 | 0 3366 | 0 3367 | 0 3368 | 1 3369 | 1 3370 | 1 3371 | 0 3372 | 0 3373 | 1 3374 | 0 3375 | 1 3376 | 0 3377 | 1 3378 | 1 3379 | 0 3380 | 0 3381 | 0 3382 | 1 3383 | 1 3384 | 0 3385 | 1 3386 | 0 3387 | 0 3388 | 0 3389 | 1 3390 | 1 3391 | 0 3392 | 1 3393 | 0 3394 | 0 3395 | 1 3396 | 1 3397 | 0 3398 | 1 3399 | 0 3400 | 1 3401 | 1 3402 | 1 3403 | 0 3404 | 0 3405 | 0 3406 | 1 3407 | 1 3408 | 0 3409 | 0 3410 | 0 3411 | 1 3412 | 1 3413 | 0 3414 | 1 3415 | 0 3416 | 0 3417 | 1 3418 | 0 3419 | 0 3420 | 1 3421 | 0 3422 | 1 3423 | 0 3424 | 1 3425 | 1 3426 | 0 3427 | 1 3428 | 1 3429 | 0 3430 | 0 3431 | 1 3432 | 0 3433 | 1 3434 | 0 3435 | 0 3436 | 1 3437 | 1 3438 | 1 3439 | 0 3440 | 1 3441 | 0 3442 | 0 3443 | 0 3444 | 1 3445 | 0 3446 | 1 3447 | 1 3448 | 0 3449 | 0 3450 | 1 3451 | 1 3452 | 1 3453 | 0 3454 | 0 3455 | 1 3456 | 0 3457 | 0 3458 | 1 3459 | 0 3460 | 0 3461 | 1 3462 | 1 3463 | 1 3464 | 0 3465 | 0 3466 | 0 3467 | 1 3468 | 1 3469 | 1 3470 | 1 3471 | 0 3472 | 1 3473 | 0 3474 | 1 3475 | 1 3476 | 0 3477 | 0 3478 | 0 3479 | 0 3480 | 1 3481 | 1 3482 | 1 3483 | 1 3484 | 0 3485 | 0 3486 | 0 3487 | 1 3488 | 1 3489 | 0 3490 | 0 3491 | 1 3492 | 1 3493 | 0 3494 | 1 3495 | 0 3496 | 1 3497 | 0 3498 | 0 3499 | 1 3500 | 1 3501 | 0 3502 | 0 3503 | 1 3504 | 1 3505 | 1 3506 | 0 3507 | 0 3508 | 1 3509 | 1 3510 | 1 3511 | 0 3512 | 1 3513 | 0 3514 | 1 3515 | 1 3516 | 0 3517 | 0 3518 | 1 3519 | 1 3520 | 0 3521 | 0 3522 | 1 3523 | 1 3524 | 1 3525 | 0 3526 | 0 3527 | 0 3528 | 1 3529 | 0 3530 | 1 3531 | 1 3532 | 0 3533 | 1 3534 | 0 3535 | 1 3536 | 0 3537 | 0 3538 | 0 3539 | 1 3540 | 1 3541 | 0 3542 | 1 3543 | 0 3544 | 0 3545 | 1 3546 | 1 3547 | 0 3548 | 0 3549 | 1 3550 | 0 3551 | 0 3552 | 1 3553 | 0 3554 | 0 3555 | 1 3556 | 0 3557 | 0 3558 | 1 3559 | 1 3560 | 0 3561 | 0 3562 | 1 3563 | 1 3564 | 0 3565 | 1 3566 | 0 3567 | 0 3568 | 1 3569 | 0 3570 | 1 3571 | 1 3572 | 0 3573 | 0 3574 | 1 3575 | 1 3576 | 0 3577 | 0 3578 | 0 3579 | 1 3580 | 0 3581 | 1 3582 | 0 3583 | 0 3584 | 1 3585 | 1 3586 | 0 3587 | 1 3588 | 1 3589 | 0 3590 | 1 3591 | 0 3592 | 0 3593 | 1 3594 | 0 3595 | 1 3596 | 1 3597 | 1 3598 | 0 3599 | 0 3600 | 0 3601 | 1 3602 | 0 3603 | 1 3604 | 1 3605 | 0 3606 | 0 3607 | 0 3608 | 1 3609 | 1 3610 | 0 3611 | 0 3612 | 1 3613 | 1 3614 | 0 3615 | 0 3616 | 1 3617 | 0 3618 | 1 3619 | 1 3620 | 1 3621 | 0 3622 | 0 3623 | 0 3624 | 0 3625 | 1 3626 | 0 3627 | 1 3628 | 0 3629 | 0 3630 | 1 3631 | 0 3632 | 1 3633 | 0 3634 | 0 3635 | 1 3636 | 1 3637 | 0 3638 | 1 3639 | 0 3640 | 1 3641 | 0 3642 | 1 3643 | 0 3644 | 0 3645 | 1 3646 | 1 3647 | 0 3648 | 1 3649 | 0 3650 | 0 3651 | 1 3652 | 0 3653 | 1 3654 | 1 3655 | 0 3656 | 1 3657 | 0 3658 | 1 3659 | 0 3660 | 1 3661 | 0 3662 | 1 3663 | 0 3664 | 1 3665 | 0 3666 | 1 3667 | 0 3668 | 0 3669 | 1 3670 | 1 3671 | 1 3672 | 0 3673 | 1 3674 | 0 3675 | 1 3676 | 0 3677 | 0 3678 | 0 3679 | 1 3680 | 0 3681 | 1 3682 | 1 3683 | 1 3684 | 0 3685 | 0 3686 | 1 3687 | 1 3688 | 0 3689 | 1 3690 | 0 3691 | 1 3692 | 1 3693 | 0 3694 | 0 3695 | 1 3696 | 1 3697 | 0 3698 | 0 3699 | 1 3700 | 0 3701 | 1 3702 | 1 3703 | 0 3704 | 1 3705 | 0 3706 | 0 3707 | 1 3708 | 0 3709 | 1 3710 | 1 3711 | 1 3712 | 0 3713 | 1 3714 | 0 3715 | 0 3716 | 0 3717 | 1 3718 | 1 3719 | 0 3720 | 1 3721 | 0 3722 | 1 3723 | 1 3724 | 0 3725 | 1 3726 | 1 3727 | 0 3728 | 0 3729 | 1 3730 | 1 3731 | 0 3732 | 1 3733 | 0 3734 | 0 3735 | 0 3736 | 1 3737 | 1 3738 | 0 3739 | 1 3740 | 0 3741 | 1 3742 | 0 3743 | 0 3744 | 1 3745 | 1 3746 | 0 3747 | 0 3748 | 1 3749 | 1 3750 | 0 3751 | 0 3752 | 1 3753 | 1 3754 | 0 3755 | 1 3756 | 0 3757 | 0 3758 | 1 3759 | 1 3760 | 0 3761 | 1 3762 | 0 3763 | 1 3764 | 1 3765 | 0 3766 | 0 3767 | 1 3768 | 0 3769 | 0 3770 | 0 3771 | 1 3772 | 0 3773 | 1 3774 | 0 3775 | 1 3776 | 0 3777 | 1 3778 | 1 3779 | 0 3780 | 1 3781 | 0 3782 | 1 3783 | 1 3784 | 0 3785 | 1 3786 | 0 3787 | 1 3788 | 0 3789 | 1 3790 | 1 3791 | 0 3792 | 0 3793 | 1 3794 | 1 3795 | 1 3796 | 0 3797 | 1 3798 | 0 3799 | 1 3800 | 0 3801 | 1 3802 | 1 3803 | 0 3804 | 1 3805 | 0 3806 | 0 3807 | 1 3808 | 1 3809 | 0 3810 | 1 3811 | 0 3812 | 1 3813 | 1 3814 | 0 3815 | 0 3816 | 0 3817 | 0 3818 | 1 3819 | 0 3820 | 1 3821 | 0 3822 | 0 3823 | 0 3824 | 1 3825 | 1 3826 | 1 3827 | 0 3828 | 0 3829 | 1 3830 | 0 3831 | 1 3832 | 0 3833 | 0 3834 | 0 3835 | 1 3836 | 1 3837 | 0 3838 | 0 3839 | 1 3840 | 1 3841 | 1 3842 | 0 3843 | 0 3844 | 0 3845 | 1 3846 | 0 3847 | 1 3848 | 1 3849 | 0 3850 | 0 3851 | 1 3852 | 1 3853 | 0 3854 | 1 3855 | 0 3856 | 0 3857 | 1 3858 | 1 3859 | 0 3860 | 1 3861 | 1 3862 | 0 3863 | 0 3864 | 0 3865 | 1 3866 | 1 3867 | 0 3868 | 0 3869 | 0 3870 | 1 3871 | 1 3872 | 1 3873 | 1 3874 | 0 3875 | 0 3876 | 1 3877 | 0 3878 | 1 3879 | 0 3880 | 0 3881 | 1 3882 | 0 3883 | 0 3884 | 0 3885 | 1 3886 | 1 3887 | 0 3888 | 1 3889 | 1 3890 | 1 3891 | 0 3892 | 0 3893 | 1 3894 | 1 3895 | 1 3896 | 0 3897 | 0 3898 | 0 3899 | 1 3900 | 0 3901 | 1 3902 | 0 3903 | 1 3904 | 0 3905 | 0 3906 | 1 3907 | 1 3908 | 1 3909 | 1 3910 | 1 3911 | 1 3912 | 0 3913 | 1 3914 | 1 3915 | 0 3916 | 0 3917 | 0 3918 | 1 3919 | 1 3920 | 0 3921 | 1 3922 | 1 3923 | 0 3924 | 0 3925 | 1 3926 | 0 3927 | 1 3928 | 1 3929 | 0 3930 | 0 3931 | 1 3932 | 0 3933 | 1 3934 | 0 3935 | 1 3936 | 0 3937 | 1 3938 | 0 3939 | 0 3940 | 1 3941 | 1 3942 | 0 3943 | 0 3944 | 0 3945 | 1 3946 | 0 3947 | 1 3948 | 1 3949 | 1 3950 | 1 3951 | 0 3952 | 0 3953 | 1 3954 | 0 3955 | 1 3956 | 0 3957 | 1 3958 | 1 3959 | 0 3960 | 1 3961 | 1 3962 | 0 3963 | 0 3964 | 0 3965 | 0 3966 | 1 3967 | 0 3968 | 1 3969 | 1 3970 | 1 3971 | 0 3972 | 0 3973 | 0 3974 | 1 3975 | 0 3976 | 1 3977 | 1 3978 | 0 3979 | 1 3980 | 0 3981 | 0 3982 | 0 3983 | 1 3984 | 0 3985 | 0 3986 | 1 3987 | 0 3988 | 1 3989 | 1 3990 | 0 3991 | 0 3992 | 0 3993 | 1 3994 | 0 3995 | 1 3996 | 0 3997 | 0 3998 | 1 3999 | 0 4000 | 1 4001 | -------------------------------------------------------------------------------- /pretrained/LunarLander-v2/ppo/policy.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/pretrained/LunarLander-v2/ppo/policy.ckpt -------------------------------------------------------------------------------- /report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jatinarora2702/gail-pytorch/d66c6a9bff115e38c62672e0b2a175654794193f/report.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | dataclasses 4 | torch 5 | wandb 6 | gym 7 | matplotlib 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name="gail", version="1.0", packages=find_packages()) 4 | --------------------------------------------------------------------------------