├── algos ├── algo_lib │ ├── __init__.py │ ├── atari.py │ ├── a3c.py │ ├── common.py │ └── player.py ├── .gitignore ├── README.md ├── ini │ ├── a3c_breakout_0.ini │ ├── a3c_breakout_1.ini │ └── a3c_pong_0.ini ├── a3c_atari_play.py ├── elite.py ├── others │ └── p.py ├── a3c_atari.py ├── a3c_async.py ├── a3c.py ├── dqn.py └── pg.py ├── misc └── nn_plus │ ├── lib │ ├── __init__.py │ ├── model.py │ └── common.py │ ├── .gitignore │ ├── README.md │ └── train_pong.py ├── rl_lib ├── rl_lib │ ├── __init__.py │ └── wrappers.py └── setup.py ├── articles └── 01_rubic │ ├── tests │ ├── __init__.py │ └── libcube │ │ ├── __init__.py │ │ └── cubes │ │ ├── __init__.py │ │ ├── test_cube2x2.py │ │ └── test_cube3x3.py │ ├── models │ ├── cube3x3 │ │ ├── paper.dat │ │ ├── zero-goal.dat │ │ ├── paper-d200-t1 │ │ │ └── best_3.3371e-02.dat │ │ └── zero-goal-d200-decay=200 │ │ │ └── best_2.1692e-02.dat │ ├── .gitattributes │ └── cube2x2 │ │ ├── .gitattributes │ │ ├── best_1.8184e-01.dat │ │ └── t2-zero-goal-best_1.4547e-02.dat │ ├── .gitignore │ ├── requirements.txt │ ├── libcube │ ├── cubes │ │ ├── __init__.py │ │ ├── _common.py │ │ ├── _env.py │ │ ├── cube2x2.py │ │ └── cube3x3.py │ ├── conf.py │ ├── model.py │ └── mcts.py │ ├── README.md │ ├── ini │ ├── cube2x2-paper-d200.ini │ ├── README.md │ ├── cube3x3-paper-d200.ini │ ├── cube2x2-zero-goal-d200.ini │ ├── cube3x3-zero-goal-d200.ini │ ├── cube3x3-paper-d20.ini │ ├── cube3x3-zero-goal-d20.ini │ └── cube3x3-zero-goal-d20-noweight.ini │ ├── run_tests.sh │ ├── csvs │ ├── README.md │ └── c3x3 │ │ ├── c3-zg-d20-noweight-no-decay=5.501e-1.csv │ │ ├── c3-zg-d20-noweight.csv │ │ ├── c3-zg-d20-noweight-no-decay=5.61e-1.csv │ │ └── c3-zg-d20-noweight-no-decay=7.29e-1.csv │ ├── cubes_tests │ ├── cube2x2_d3.txt │ ├── cube3x3_d3.txt │ ├── cube3x3_d3_norepeat.txt │ ├── cube2x2_d4.txt │ ├── cube2x2_d5.txt │ ├── cube2x2_d6.txt │ └── cube3x3_d10.txt │ ├── gen_cubes.py │ ├── train_debug.py │ ├── docs │ └── Notes.md │ └── train.py ├── gym-submit ├── .gitignore ├── setup.py └── gym-submit.py ├── ptan └── README.md ├── .gitignore ├── requirements.txt └── gym_bugs └── atari_race.py /algos/algo_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /misc/nn_plus/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rl_lib/rl_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /misc/nn_plus/.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | -------------------------------------------------------------------------------- /articles/01_rubic/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /articles/01_rubic/tests/libcube/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /articles/01_rubic/tests/libcube/cubes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /misc/nn_plus/README.md: -------------------------------------------------------------------------------- 1 | Experiments on NoisyNets+ 2 | -------------------------------------------------------------------------------- /algos/.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | logs-a3c 3 | *.log 4 | *.txt 5 | -------------------------------------------------------------------------------- /gym-submit/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | build 3 | dist 4 | -------------------------------------------------------------------------------- /articles/01_rubic/models/cube3x3/paper.dat: -------------------------------------------------------------------------------- 1 | paper-d200-t1/best_3.3371e-02.dat -------------------------------------------------------------------------------- /ptan/README.md: -------------------------------------------------------------------------------- 1 | This project was moved to https://github.com/Shmuma/ptan 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.hdf5 3 | *.egg-info 4 | __pycache__ 5 | *.pyc 6 | res* 7 | -------------------------------------------------------------------------------- /articles/01_rubic/models/.gitattributes: -------------------------------------------------------------------------------- 1 | *.dat filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /articles/01_rubic/models/cube3x3/zero-goal.dat: -------------------------------------------------------------------------------- 1 | zero-goal-d200-decay=200/best_2.1692e-02.dat -------------------------------------------------------------------------------- /articles/01_rubic/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | runs 3 | saves 4 | *.png 5 | .ipynb_checkpoints 6 | *.log 7 | -------------------------------------------------------------------------------- /articles/01_rubic/requirements.txt: -------------------------------------------------------------------------------- 1 | nose 2 | seaborn 3 | torch 4 | numpy 5 | tqdm 6 | tensorboard-pytorch 7 | -------------------------------------------------------------------------------- /articles/01_rubic/models/cube2x2/.gitattributes: -------------------------------------------------------------------------------- 1 | best_1.8184e-01.dat filter=lfs diff=lfs merge=lfs -text 2 | *.dat filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/cubes/__init__.py: -------------------------------------------------------------------------------- 1 | from ._env import CubeEnv, get, names 2 | from . import cube3x3 3 | from . import cube2x2 4 | 5 | __all__ = ('CubeEnv', 'get', 'names') 6 | -------------------------------------------------------------------------------- /articles/01_rubic/models/cube2x2/best_1.8184e-01.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c479735965ac356008d10857aa0232ed458f0bc3bc383e5de95acf3365770cc5 3 | size 45146173 4 | -------------------------------------------------------------------------------- /articles/01_rubic/models/cube2x2/t2-zero-goal-best_1.4547e-02.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eeea764b512e6cc8a0c430816cbac33cacf0c59cd4998ffb10a4f07e6cc91272 3 | size 45146173 4 | -------------------------------------------------------------------------------- /articles/01_rubic/models/cube3x3/paper-d200-t1/best_3.3371e-02.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7982179f44d9579d5f8371e249069c68a132eb1c8381d273c45243a3ee1a10f2 3 | size 49864767 4 | -------------------------------------------------------------------------------- /articles/01_rubic/models/cube3x3/zero-goal-d200-decay=200/best_2.1692e-02.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e002a5cc866d18a8fcff5f53606fa95df0e68f8995a9e3d67a12ddcb7d703cec 3 | size 49864767 4 | -------------------------------------------------------------------------------- /articles/01_rubic/README.md: -------------------------------------------------------------------------------- 1 | Code for article about Rubic's cube solution with RL: [Reinforcement learning to solve Rubik's cube](https://medium.com/datadriveninvestor/reinforcement-learning-to-solve-rubiks-cube-and-other-complex-problems-106424cf26ff) 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.7.3 2 | h5py==2.6.0 3 | Keras==1.2.1 4 | numpy==1.12.0 5 | protobuf==3.2.0 6 | pyglet==1.2.4 7 | PyOpenGL==3.1.0 8 | PyYAML==3.12 9 | requests==2.20.0 10 | scipy==0.18.1 11 | six==1.10.0 12 | tensorflow==0.12.1 13 | Theano==0.8.2 14 | tqdm==4.11.2 15 | -------------------------------------------------------------------------------- /rl_lib/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='rl-lib', 5 | author="Max Lapan", 6 | author_email="max.lapan@gmail.com", 7 | license='GPL-v3', 8 | version='0.1', 9 | description="Common RL libraries", 10 | packages=["rl_lib"] 11 | ) 12 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/cube2x2-paper-d200.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | cube_type=cube2x2 3 | run_name=paper 4 | 5 | [train] 6 | cuda=True 7 | lr=1e-5 8 | batch_size=10000 9 | scramble_depth=200 10 | report_batches=10 11 | checkpoint_batches=100 12 | lr_decay=True 13 | lr_decay_gamma=0.95 14 | lr_decay_batches=1000 15 | -------------------------------------------------------------------------------- /gym-submit/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='gym-submit', 5 | author="Max Lapan", 6 | author_email="max.lapan@gmail.com", 7 | license='GPL-v3', 8 | version='0.1', 9 | description="Tool to submit solution to OpenAI Gym", 10 | instal_requires=[ 11 | 'gym', 12 | ], 13 | scripts=["gym-submit.py"], 14 | ) 15 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/README.md: -------------------------------------------------------------------------------- 1 | Configuration files with training/testing settings. 2 | 3 | # cube2x2-paper-d200 4 | Method from the paper applied to 2x2 cube with scrambling depth 200 during the training. 5 | 6 | Best policy is achieved after 8k batches (3.5 hours on 1080Ti), after 10k batches training diverges. 7 | 8 | # cube2x2-zero-goal-d200 9 | The same as in paper, but target value for goal states set to zero, which helps convergence a lot 10 | -------------------------------------------------------------------------------- /algos/README.md: -------------------------------------------------------------------------------- 1 | # RL algorithms 2 | 3 | This dir contains implementation of various RL methods 4 | 5 | ## Asyncronous Advantage Actor-Critic (A3C) 6 | 7 | * a3c.py: minimalistic implementation, applicable to simple gym environments, like CartPole or MountainCar 8 | * a3c_atari.py: synchronous version with convolution nets 9 | * a3c_async.py: latest version with convolutions and async subprocesses. 10 | 11 | 12 | ## Other methods 13 | 14 | * dqn.py: Q-iteration 15 | * elite.py: variant of PG with examples filtering 16 | * pg.py: policy gradient 17 | -------------------------------------------------------------------------------- /articles/01_rubic/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ./solver.py -e cube2x2 -m saves/cube2x2-paper-d200-t1/best_2.9156e-02.dat --max-steps 30000 --cuda -o c2x2-paper-d200-t1-v2.csv & 4 | ./solver.py -e cube2x2 -m saves/cube2x2-zero-goal-d200-t1/best_1.4547e-02.dat --max-steps 30000 --cuda -o c2x2-zero-goal-d200-t1-v2.csv & 5 | ./solver.py -e cube3x3 --cuda --max-steps 30000 -m saves/cube3x3-paper-d200-t1/best_3.3371e-02.dat -o c3x3-paper-d200-t1-v2.csv & 6 | ./solver.py -e cube3x3 --cuda --max-steps 30000 -m saves/cube3x3-zero-goal-d200-t1/best_2.2160e-02.dat -o c3x3-zero-goal-d200-t1-v2.csv & 7 | ./solver.py -e cube3x3 --cuda --max-steps 30000 -m saves/cube3x3-zero-goal-d200-no-decay/best_2.1798e-02.dat -o c3x3-zero-goal-d200-no-decay-v2.csv & 8 | -------------------------------------------------------------------------------- /algos/ini/a3c_breakout_0.ini: -------------------------------------------------------------------------------- 1 | [game] 2 | ; gym environment to be created 3 | env = Breakout-v0 4 | ; how many frames we'll track 5 | history = 2 6 | ; scaled image X 7 | image_x = 84 8 | ; scaled image Y 9 | image_y = 84 10 | ; limit of steps 11 | max_steps = 40000 12 | 13 | [a3c] 14 | entropy_beta = 0.01 15 | ; how many frames will be used to estimate total reward 16 | reward_steps = 5 17 | ; discount factor 18 | gamma = 0.99 19 | 20 | [swarm] 21 | ; how many parallel proceses to start 22 | swarms = 3 23 | ; how many environments to play in parallel in each process 24 | swarm_size = 16 25 | 26 | [training] 27 | batch_size = 128 28 | learning_rate = 0.001 29 | ; clip norm for gradient. Disabled if not present 30 | grad_clip_norm = 0.1 31 | 32 | -------------------------------------------------------------------------------- /algos/ini/a3c_breakout_1.ini: -------------------------------------------------------------------------------- 1 | [game] 2 | ; gym environment to be created 3 | env = Breakout-v0 4 | ; how many frames we'll track 5 | history = 2 6 | ; scaled image X 7 | image_x = 84 8 | ; scaled image Y 9 | image_y = 84 10 | ; limit of steps 11 | max_steps = 40000 12 | 13 | [a3c] 14 | entropy_beta = 0.001 15 | ; how many frames will be used to estimate total reward 16 | reward_steps = 10 17 | ; discount factor 18 | gamma = 0.99 19 | 20 | [swarm] 21 | ; how many parallel proceses to start 22 | swarms = 3 23 | ; how many environments to play in parallel in each process 24 | swarm_size = 16 25 | 26 | [training] 27 | batch_size = 128 28 | learning_rate = 0.0001 29 | ; clip norm for gradient. Disabled if not present 30 | grad_clip_norm = 0.1 31 | 32 | -------------------------------------------------------------------------------- /algos/ini/a3c_pong_0.ini: -------------------------------------------------------------------------------- 1 | [game] 2 | ; gym environment to be created 3 | env = PongNoFrameskip-v4 4 | ; how many frames we'll track 5 | history = 4 6 | ; scaled image X 7 | image_x = 84 8 | ; scaled image Y 9 | image_y = 84 10 | ; limit of steps 11 | max_steps = 40000 12 | 13 | [a3c] 14 | entropy_beta = 0.01 15 | ; how many frames will be used to estimate total reward 16 | reward_steps = 5 17 | ; discount factor 18 | gamma = 0.99 19 | 20 | [swarm] 21 | ; how many parallel proceses to start 22 | swarms = 3 23 | ; how many environments to play in parallel in each process 24 | swarm_size = 16 25 | 26 | [training] 27 | batch_size = 128 28 | learning_rate = 0.001 29 | ; clip norm for gradient. Disabled if not present 30 | grad_clip_norm = 0.1 31 | 32 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/cube3x3-paper-d200.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | cube_type=cube3x3 3 | run_name=paper 4 | 5 | [train] 6 | ; how to calculate target values, default is 'paper' 7 | value_targets_method=paper 8 | ; limit of batches to train (train iterations) 9 | ;max_batches=4000 10 | ; use cuda 11 | cuda=True 12 | ; learning rate 13 | lr=1e-5 14 | ; count of cubes in single batch 15 | batch_size=10000 16 | ; how deeply to scramble cube 17 | scramble_depth=200 18 | ; how frequently to report training progress 19 | report_batches=10 20 | ; how frequently to save model (if commented out, won't be saved) 21 | ;checkpoint_batches=100 22 | ; enables LR decay 23 | lr_decay=True 24 | ; LR decay gamma (if enabled) 25 | lr_decay_gamma=0.95 26 | ; interval between decays 27 | lr_decay_batches=1000 28 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/cube2x2-zero-goal-d200.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | cube_type=cube2x2 3 | run_name=zero-goal 4 | 5 | [train] 6 | ; how to calculate target values, default is 'paper' 7 | value_targets_method=zero_goal_value 8 | ; limit of batches to train (train iterations) 9 | max_batches=4000 10 | ; use cuda 11 | cuda=True 12 | ; learning rate 13 | lr=1e-5 14 | ; count of cubes in single batch 15 | batch_size=10000 16 | ; how deeply to scramble cube 17 | scramble_depth=200 18 | ; how frequently to report training progress 19 | report_batches=10 20 | ; how frequently to save model (if commented out, won't be saved) 21 | ;checkpoint_batches=100 22 | ; enables LR decay 23 | lr_decay=True 24 | ; LR decay gamma (if enabled) 25 | lr_decay_gamma=0.95 26 | ; interval between decays 27 | lr_decay_batches=100 28 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/cube3x3-zero-goal-d200.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | cube_type=cube3x3 3 | run_name=zero-goal 4 | 5 | [train] 6 | ; how to calculate target values, default is 'paper' 7 | value_targets_method=zero_goal_value 8 | ; limit of batches to train (train iterations) 9 | max_batches=40000 10 | ; use cuda 11 | cuda=True 12 | ; learning rate 13 | lr=1e-5 14 | ; count of cubes in single batch 15 | batch_size=10000 16 | ; how deeply to scramble cube 17 | scramble_depth=200 18 | ; how frequently to report training progress 19 | report_batches=10 20 | ; how frequently to save model (if commented out, won't be saved) 21 | ;checkpoint_batches=100 22 | ; enables LR decay 23 | lr_decay=True 24 | ; LR decay gamma (if enabled) 25 | lr_decay_gamma=0.95 26 | ; interval between decays 27 | lr_decay_batches=200 28 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/cubes/_common.py: -------------------------------------------------------------------------------- 1 | def _permute(t, m, is_inv=False): 2 | """ 3 | Perform permutation of tuple according to mapping m 4 | """ 5 | r = list(t) 6 | for from_idx, to_idx in m: 7 | if is_inv: 8 | r[from_idx] = t[to_idx] 9 | else: 10 | r[to_idx] = t[from_idx] 11 | return r 12 | 13 | 14 | def _rotate(corner_ort, corners): 15 | """ 16 | Rotate given corners 120 degrees 17 | """ 18 | r = list(corner_ort) 19 | for c, angle in corners: 20 | r[c] = (r[c] + angle) % 3 21 | return r 22 | 23 | 24 | # orient corner cubelet 25 | def _map_orient(cols, orient_id): 26 | if orient_id == 0: 27 | return cols 28 | elif orient_id == 1: 29 | return cols[2], cols[0], cols[1] 30 | else: 31 | return cols[1], cols[2], cols[0] 32 | 33 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/cube3x3-paper-d20.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | cube_type=cube3x3 3 | run_name=paper 4 | 5 | [train] 6 | ; how to calculate target values, default is 'paper' 7 | value_targets_method=paper 8 | ; limit of batches to train (train iterations) 9 | max_batches=100000 10 | ; use cuda 11 | cuda=True 12 | ; learning rate 13 | lr=1e-5 14 | ; count of cubes in single batch 15 | batch_size=10000 16 | ; batches to keep in scramble buffer 17 | scramble_buffer_batches=10 18 | ; after how many iterations push fresh batch into the scramble buffer 19 | push_scramble_buffer_iters=100 20 | ; how deeply to scramble cube 21 | scramble_depth=20 22 | ; how frequently to report training progress 23 | report_batches=10 24 | ; how frequently to save model (if commented out, won't be saved) 25 | ;checkpoint_batches=100 26 | ; enables LR decay 27 | lr_decay=True 28 | ; LR decay gamma (if enabled) 29 | lr_decay_gamma=0.95 30 | ; interval between decays 31 | lr_decay_batches=1000 32 | -------------------------------------------------------------------------------- /articles/01_rubic/csvs/README.md: -------------------------------------------------------------------------------- 1 | Description of produced test results 2 | 3 | # First results 4 | 5 | Test results from first models (paper versus zero-goal method). Solve tool run for 30k MCTS searches 6 | (but due to bug, actual amount of steps in some tests was much lower). 7 | 8 | ```` 9 | c2x2-paper-d200-t1.csv 10 | c2x2-zero-goal-d200-t1.csv 11 | c3x3-paper-d200-t1.csv 12 | c3x3-zero-goal-d200-no-decay.csv 13 | c3x3-zero-goal-d200-t1.csv 14 | ```` 15 | 16 | Analysis of the results are in notebook 17 | https://github.com/Shmuma/rl/blob/master/articles/01_rubic/nbs/01_paper-vs-zero_goal.ipynb 18 | 19 | # Fix of wrong steps 20 | 21 | Fixed with https://github.com/Shmuma/rl/commit/793aebc81b7bf323a8db930e8224521700383af5#diff-b9a7f0478383b0f6ad54ae87c8769b03 22 | 23 | ```` 24 | c2x2-paper-d200-t1-v2.csv 25 | c2x2-zero-goal-d200-t1-v2.csv 26 | c3x3-paper-d200-t1-v2.csv 27 | c3x3-zero-goal-d200-no-decay-v2.csv 28 | c3x3-zero-goal-d200-t1-v2.csv 29 | ```` 30 | 31 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/cube3x3-zero-goal-d20.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | cube_type=cube3x3 3 | run_name=zero-goal 4 | 5 | [train] 6 | ; how to calculate target values, default is 'paper' 7 | value_targets_method=zero_goal_value 8 | ; limit of batches to train (train iterations) 9 | max_batches=100000 10 | ; use cuda 11 | cuda=True 12 | ; learning rate 13 | lr=1e-4 14 | ; count of cubes in single batch 15 | batch_size=10000 16 | ; batches to keep in scramble buffer 17 | scramble_buffer_batches=10 18 | ; after how many iterations push fresh batch into the scramble buffer 19 | push_scramble_buffer_iters=100 20 | ; how deeply to scramble cube 21 | scramble_depth=20 22 | ; how frequently to report training progress 23 | report_batches=10 24 | ; how frequently to save model (if commented out, won't be saved) 25 | ;checkpoint_batches=100 26 | ; enables LR decay 27 | lr_decay=True 28 | ; LR decay gamma (if enabled) 29 | lr_decay_gamma=0.95 30 | ; interval between decays 31 | lr_decay_batches=1000 32 | -------------------------------------------------------------------------------- /articles/01_rubic/ini/cube3x3-zero-goal-d20-noweight.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | cube_type=cube3x3 3 | run_name=zero-goal-noweight 4 | 5 | [train] 6 | ; how to calculate target values, default is 'paper' 7 | value_targets_method=zero_goal_value 8 | ; limit of batches to train (train iterations) 9 | max_batches=100000 10 | ; use cuda 11 | cuda=True 12 | ; learning rate 13 | lr=1e-5 14 | ; count of cubes in single batch 15 | batch_size=10000 16 | ; batches to keep in scramble buffer 17 | scramble_buffer_batches=10 18 | ; after how many iterations push fresh batch into the scramble buffer 19 | push_scramble_buffer_iters=100 20 | ; how deeply to scramble cube 21 | scramble_depth=20 22 | ; how frequently to report training progress 23 | report_batches=10 24 | ; how frequently to save model (if commented out, won't be saved) 25 | checkpoint_batches=1000 26 | ; enables LR decay 27 | lr_decay=False 28 | ; LR decay gamma (if enabled) 29 | lr_decay_gamma=0.95 30 | ; interval between decays 31 | lr_decay_batches=1000 32 | ; perform weighting of training samples inverse by scramble depth, default=True 33 | weight_samples=False 34 | -------------------------------------------------------------------------------- /rl_lib/rl_lib/wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | def HistoryWrapper(steps): 5 | class HistoryWrapper(gym.Wrapper): 6 | """ 7 | Track history of observations for given amount of steps 8 | Initial steps are zero-filled 9 | """ 10 | def __init__(self, env): 11 | super(HistoryWrapper, self).__init__(env) 12 | self.steps = steps 13 | self.history = self._make_history() 14 | 15 | def _make_history(self): 16 | return [np.zeros(shape=self.env.observation_space.shape) for _ in range(steps)] 17 | 18 | def _step(self, action): 19 | obs, reward, done, info = self.env.step(action) 20 | self.history.pop(0) 21 | self.history.append(obs) 22 | return np.array(self.history), reward, done, info 23 | 24 | def _reset(self): 25 | self.history = self._make_history() 26 | self.history.pop(0) 27 | self.history.append(self.env.reset()) 28 | return np.array(self.history) 29 | 30 | return HistoryWrapper 31 | -------------------------------------------------------------------------------- /articles/01_rubic/cubes_tests/cube2x2_d3.txt: -------------------------------------------------------------------------------- 1 | 10,1,0 2 | 11,4,3 3 | 3,2,11 4 | 1,10,11 5 | 8,1,9 6 | 6,1,3 7 | 3,8,9 8 | 0,8,3 9 | 11,10,11 10 | 8,6,3 11 | 7,9,4 12 | 0,2,11 13 | 6,5,4 14 | 2,3,5 15 | 1,1,6 16 | 1,5,5 17 | 9,4,0 18 | 11,7,8 19 | 1,6,1 20 | 8,4,9 21 | 5,9,11 22 | 1,0,10 23 | 3,4,1 24 | 3,1,6 25 | 4,7,10 26 | 5,2,5 27 | 5,3,10 28 | 4,11,10 29 | 10,1,9 30 | 10,2,11 31 | 3,2,7 32 | 6,4,11 33 | 8,3,10 34 | 5,0,3 35 | 0,5,6 36 | 4,1,3 37 | 9,11,3 38 | 10,7,6 39 | 10,7,2 40 | 4,2,3 41 | 11,8,8 42 | 4,11,9 43 | 6,9,6 44 | 5,3,2 45 | 8,7,0 46 | 1,2,10 47 | 2,10,6 48 | 9,1,6 49 | 6,9,7 50 | 8,4,8 51 | 0,10,11 52 | 1,10,8 53 | 4,5,1 54 | 4,6,2 55 | 7,0,11 56 | 11,4,8 57 | 2,1,10 58 | 4,8,9 59 | 3,2,5 60 | 2,0,9 61 | 5,7,0 62 | 1,5,4 63 | 3,0,3 64 | 9,1,1 65 | 11,7,8 66 | 2,2,10 67 | 7,8,4 68 | 8,9,6 69 | 3,8,11 70 | 11,3,11 71 | 4,6,10 72 | 10,5,7 73 | 8,7,3 74 | 3,1,5 75 | 0,9,8 76 | 3,3,0 77 | 1,11,10 78 | 0,3,1 79 | 0,5,1 80 | 8,3,4 81 | 10,7,3 82 | 8,11,9 83 | 9,7,3 84 | 7,6,3 85 | 1,1,10 86 | 6,5,6 87 | 6,7,11 88 | 0,10,10 89 | 10,1,0 90 | 6,11,1 91 | 3,3,3 92 | 8,7,2 93 | 6,2,4 94 | 7,3,1 95 | 7,8,1 96 | 0,10,8 97 | 0,1,3 98 | 2,6,7 99 | 7,3,6 100 | 0,2,6 101 | -------------------------------------------------------------------------------- /articles/01_rubic/cubes_tests/cube3x3_d3.txt: -------------------------------------------------------------------------------- 1 | 10,1,0 2 | 11,4,3 3 | 3,2,11 4 | 1,10,11 5 | 8,1,9 6 | 6,0,0 7 | 1,3,3 8 | 8,9,0 9 | 8,3,11 10 | 10,11,8 11 | 6,3,7 12 | 9,4,0 13 | 2,11,6 14 | 5,4,2 15 | 3,5,1 16 | 1,6,1 17 | 5,5,9 18 | 4,0,11 19 | 7,8,1 20 | 6,1,8 21 | 4,10,9 22 | 5,9,3 23 | 11,1,0 24 | 10,3,4 25 | 1,3,1 26 | 6,4,7 27 | 10,5,2 28 | 5,5,3 29 | 10,4,11 30 | 10,10,1 31 | 9,10,2 32 | 8,11,3 33 | 2,7,6 34 | 4,10,11 35 | 8,3,10 36 | 5,0,3 37 | 0,5,6 38 | 4,1,3 39 | 9,11,5 40 | 3,10,7 41 | 6,10,7 42 | 2,4,2 43 | 3,11,8 44 | 8,4,11 45 | 9,6,9 46 | 6,5,3 47 | 2,8,7 48 | 1,0,1 49 | 2,10,2 50 | 10,6,9 51 | 1,6,6 52 | 9,7,8 53 | 4,8,0 54 | 10,11,1 55 | 10,8,4 56 | 10,5,1 57 | 4,6,2 58 | 7,0,11 59 | 11,4,8 60 | 2,8,1 61 | 10,4,10 62 | 8,9,3 63 | 2,5,2 64 | 8,8,0 65 | 9,5,7 66 | 0,1,5 67 | 4,3,0 68 | 3,9,1 69 | 1,11,7 70 | 1,8,2 71 | 2,10,7 72 | 8,2,4 73 | 8,9,6 74 | 3,8,11 75 | 11,3,11 76 | 4,6,10 77 | 10,5,7 78 | 8,7,1 79 | 3,3,1 80 | 5,0,9 81 | 8,3,9 82 | 3,0,1 83 | 11,10,0 84 | 3,1,0 85 | 5,1,8 86 | 3,4,10 87 | 7,3,8 88 | 2,11,9 89 | 9,7,3 90 | 7,6,3 91 | 1,1,10 92 | 6,5,6 93 | 6,7,11 94 | 0,10,10 95 | 10,1,0 96 | 6,11,5 97 | 1,3,3 98 | 3,8,7 99 | 2,6,2 100 | 4,7,3 101 | -------------------------------------------------------------------------------- /articles/01_rubic/cubes_tests/cube3x3_d3_norepeat.txt: -------------------------------------------------------------------------------- 1 | 10,1,0 2 | 11,4,3 3 | 3,2,11 4 | 1,10,11 5 | 8,1,9 6 | 6,1,3 7 | 3,8,9 8 | 0,8,3 9 | 11,10,11 10 | 8,6,3 11 | 7,9,4 12 | 0,2,11 13 | 6,5,4 14 | 2,3,5 15 | 1,1,6 16 | 1,5,5 17 | 9,4,0 18 | 11,7,8 19 | 1,6,1 20 | 8,4,9 21 | 5,9,11 22 | 1,0,10 23 | 3,4,1 24 | 3,1,6 25 | 4,7,10 26 | 5,2,5 27 | 5,3,10 28 | 4,11,10 29 | 10,1,9 30 | 10,2,11 31 | 3,2,7 32 | 6,4,11 33 | 8,3,10 34 | 5,0,3 35 | 0,5,6 36 | 4,1,3 37 | 9,11,3 38 | 10,7,6 39 | 10,7,2 40 | 4,2,3 41 | 11,8,8 42 | 4,11,9 43 | 6,9,6 44 | 5,3,2 45 | 8,7,0 46 | 1,2,10 47 | 2,10,6 48 | 9,1,6 49 | 6,9,7 50 | 8,4,8 51 | 0,10,11 52 | 1,10,8 53 | 4,5,1 54 | 4,6,2 55 | 7,0,11 56 | 11,4,8 57 | 2,1,10 58 | 4,8,9 59 | 3,2,5 60 | 2,0,9 61 | 5,7,0 62 | 1,5,4 63 | 3,0,3 64 | 9,1,1 65 | 11,7,8 66 | 2,2,10 67 | 7,8,4 68 | 8,9,6 69 | 3,8,11 70 | 11,3,11 71 | 4,6,10 72 | 10,5,7 73 | 8,7,3 74 | 3,1,5 75 | 0,9,8 76 | 3,3,0 77 | 1,11,10 78 | 0,3,1 79 | 0,5,1 80 | 8,3,4 81 | 10,7,3 82 | 8,11,9 83 | 9,7,3 84 | 7,6,3 85 | 1,1,10 86 | 6,5,6 87 | 6,7,11 88 | 0,10,10 89 | 10,1,0 90 | 6,11,1 91 | 3,3,3 92 | 8,7,2 93 | 6,2,4 94 | 7,3,1 95 | 7,8,1 96 | 0,10,8 97 | 0,1,3 98 | 2,6,7 99 | 7,3,6 100 | 0,2,6 101 | -------------------------------------------------------------------------------- /gym_bugs/atari_race.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | 4 | from rl_lib.wrappers import HistoryWrapper 5 | 6 | ENV_NAME = "Breakout-v0" 7 | ENV_COUNT = 50 8 | 9 | if __name__ == "__main__": 10 | envs = {} 11 | HistoryWrapper(4)(gym.make(ENV_NAME)) 12 | 13 | for idx in range(ENV_COUNT): 14 | e = HistoryWrapper(4)(gym.make(ENV_NAME)) 15 | envs[idx] = { 16 | 'done': False, 17 | 'steps': 0, 18 | 'reward': 0.0, 19 | 'state': e.reset(), 20 | 'env': e 21 | } 22 | 23 | # play randomly 24 | while not all(map(lambda e: e['done'], envs.values())): 25 | for idx, e in envs.items(): 26 | if e['done']: 27 | continue 28 | e['steps'] += 1 29 | e['state'], r, done, _ = e['env'].step(e['env'].action_space.sample()) 30 | e['reward'] += r 31 | e['done'] = done 32 | if done: 33 | print("Env %d done after %d steps with reward %s" % (idx, e['steps'], e['reward'])) 34 | if e['steps'] < 100: 35 | sys.exit(0) 36 | e['steps'] = 0 37 | e['reward'] = 0.0 38 | e['done'] = False 39 | e['state'] = e['env'].reset() 40 | 41 | pass 42 | -------------------------------------------------------------------------------- /articles/01_rubic/cubes_tests/cube2x2_d4.txt: -------------------------------------------------------------------------------- 1 | 10,1,0,11 2 | 4,3,3,2 3 | 11,1,10,11 4 | 8,1,9,6 5 | 0,0,1,3 6 | 3,8,9,0 7 | 8,3,11,10 8 | 11,8,6,3 9 | 7,9,4,0 10 | 2,11,6,5 11 | 4,2,3,5 12 | 1,1,6,1 13 | 5,5,9,4 14 | 0,11,7,8 15 | 1,6,1,8 16 | 4,9,5,9 17 | 3,11,1,0 18 | 10,3,4,1 19 | 3,1,6,4 20 | 7,10,5,2 21 | 5,5,3,10 22 | 4,11,10,10 23 | 1,9,10,2 24 | 8,11,3,2 25 | 7,6,4,11 26 | 8,3,10,5 27 | 0,3,0,5 28 | 6,4,1,3 29 | 9,11,3,10 30 | 7,6,10,7 31 | 2,4,2,3 32 | 11,8,8,4 33 | 11,9,6,9 34 | 6,5,3,2 35 | 8,7,0,1 36 | 2,10,2,10 37 | 6,9,1,6 38 | 6,9,7,8 39 | 4,8,0,10 40 | 11,1,10,8 41 | 4,5,1,4 42 | 6,2,7,0 43 | 11,11,4,8 44 | 2,1,10,10 45 | 8,9,2,5 46 | 2,0,9,5 47 | 7,0,1,5 48 | 4,3,0,3 49 | 9,1,1,11 50 | 7,8,10,7 51 | 8,4,8,9 52 | 6,3,8,11 53 | 11,3,11,4 54 | 6,10,10,5 55 | 7,8,7,3 56 | 3,1,5,0 57 | 9,8,3,3 58 | 0,1,11,10 59 | 0,3,1,0 60 | 5,1,8,3 61 | 4,7,3,8 62 | 2,11,9,9 63 | 7,3,7,6 64 | 3,1,1,10 65 | 6,5,6,6 66 | 7,11,0,10 67 | 10,10,1,0 68 | 6,11,1,3 69 | 3,3,8,7 70 | 2,6,2,4 71 | 7,3,1,8 72 | 1,0,10,8 73 | 0,1,3,2 74 | 6,7,7,3 75 | 6,2,6,6 76 | 4,7,4,6 77 | 11,11,8,10 78 | 11,7,2,3 79 | 4,3,0,9 80 | 11,8,0,11 81 | 5,0,0,9 82 | 7,8,8,0 83 | 8,1,2,1 84 | 9,1,10,3 85 | 6,1,9,9 86 | 9,0,9,1 87 | 6,10,9,9 88 | 8,5,4,3 89 | 10,11,3,4 90 | 6,2,10,10 91 | 4,7,5,1 92 | 0,7,9,9 93 | 1,1,8,3 94 | 8,4,2,5 95 | 1,3,5,4 96 | 2,7,8,11 97 | 4,9,10,8 98 | 0,10,8,4 99 | 10,1,2,4 100 | 1,1,11,8 101 | -------------------------------------------------------------------------------- /articles/01_rubic/gen_cubes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Tool to generate test set for solver 4 | """ 5 | import argparse 6 | import random 7 | 8 | from libcube import cubes 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("-e", "--env", required=True, help="Type of env to train, supported types=%s" % cubes.names()) 14 | parser.add_argument("-n", "--number", type=int, default=10, help="Amount of scramble rounds, default=10") 15 | parser.add_argument("-d", "--depth", type=int, default=100, help="Scramble depth, default=10") 16 | parser.add_argument("--seed", type=int, default=42, help="Seed to use, if zero, no seed used. default=42") 17 | parser.add_argument("-o", "--output", required=True, help="Output file to produce") 18 | args = parser.parse_args() 19 | 20 | if args.seed: 21 | random.seed(args.seed) 22 | 23 | cube_env = cubes.get(args.env) 24 | assert isinstance(cube_env, cubes.CubeEnv) 25 | 26 | with open(args.output, "w+t", encoding="utf-8") as fd_out: 27 | for _ in range(args.number): 28 | s = cube_env.initial_state 29 | path = [] 30 | prev_a = None 31 | for _ in range(args.depth): 32 | a = cube_env.sample_action(prev_action=prev_a) 33 | path.append(a.value) 34 | s = cube_env.transform(s, a) 35 | prev_a = a 36 | fd_out.write(",".join(map(str, path)) + "\n") 37 | -------------------------------------------------------------------------------- /articles/01_rubic/cubes_tests/cube2x2_d5.txt: -------------------------------------------------------------------------------- 1 | 10,1,0,11,4 2 | 3,3,2,11,1 3 | 10,11,8,1,9 4 | 6,1,3,3,8 5 | 9,0,8,3,11 6 | 10,11,8,6,3 7 | 7,9,4,0,2 8 | 11,6,5,4,2 9 | 3,5,1,1,6 10 | 1,5,5,9,4 11 | 0,11,7,8,1 12 | 6,1,8,4,9 13 | 5,9,11,1,0 14 | 10,3,4,1,3 15 | 1,6,4,7,10 16 | 5,2,5,5,3 17 | 10,11,10,10,1 18 | 9,10,2,11,3 19 | 2,7,6,4,11 20 | 8,3,10,5,0 21 | 3,0,5,6,4 22 | 1,3,11,3,10 23 | 7,6,10,7,2 24 | 4,2,3,11,8 25 | 8,4,11,9,6 26 | 9,6,5,3,2 27 | 8,7,0,1,2 28 | 10,2,10,6,9 29 | 1,6,6,9,7 30 | 8,4,8,0,10 31 | 11,1,10,8,4 32 | 10,5,1,4,6 33 | 2,7,0,11,11 34 | 4,8,8,1,10 35 | 4,8,9,2,5 36 | 2,0,9,5,7 37 | 0,1,5,4,3 38 | 0,3,1,1,11 39 | 7,8,10,7,8 40 | 2,4,8,9,6 41 | 3,8,11,11,3 42 | 11,4,6,10,10 43 | 5,7,8,7,3 44 | 3,1,5,0,9 45 | 8,3,3,0,1 46 | 11,10,0,3,1 47 | 0,5,1,8,3 48 | 4,7,3,8,11 49 | 9,9,7,3,7 50 | 6,3,1,1,10 51 | 6,5,6,6,7 52 | 11,0,10,10,10 53 | 1,0,11,1,3 54 | 3,3,8,7,2 55 | 6,2,4,7,3 56 | 1,8,1,0,10 57 | 8,0,1,3,2 58 | 6,7,7,3,6 59 | 0,2,6,6,4 60 | 7,4,6,11,11 61 | 8,10,11,7,2 62 | 3,4,3,0,9 63 | 11,8,0,11,0 64 | 0,9,7,8,8 65 | 2,0,8,1,2 66 | 1,9,1,10,3 67 | 6,1,9,9,9 68 | 0,9,1,6,10 69 | 9,9,8,5,4 70 | 3,10,11,3,4 71 | 6,2,10,10,7 72 | 5,1,0,7,9 73 | 9,1,1,8,3 74 | 8,4,2,5,1 75 | 3,5,4,2,7 76 | 8,11,4,9,10 77 | 8,0,10,8,4 78 | 10,1,2,4,1 79 | 1,11,8,4,4 80 | 9,11,3,10,10 81 | 4,8,7,4,0 82 | 1,10,6,4,0 83 | 0,5,2,10,2 84 | 11,7,8,11,6 85 | 8,0,1,1,11 86 | 2,0,5,9,8 87 | 2,6,2,0,4 88 | 5,0,5,3,10 89 | 3,10,1,5,8 90 | 6,9,11,2,3 91 | 2,2,6,2,11 92 | 5,6,10,11,3 93 | 4,2,11,1,6 94 | 0,7,3,3,7 95 | 5,4,3,3,0 96 | 10,3,6,5,4 97 | 1,4,5,10,8 98 | 6,10,8,5,0 99 | 1,4,2,9,4 100 | 0,1,9,6,5 101 | -------------------------------------------------------------------------------- /articles/01_rubic/cubes_tests/cube2x2_d6.txt: -------------------------------------------------------------------------------- 1 | 10,1,0,11,4,3 2 | 3,2,11,1,10,11 3 | 8,1,9,6,1,3 4 | 3,8,9,0,8,3 5 | 11,10,11,8,6,3 6 | 7,9,4,0,2,11 7 | 6,5,4,2,3,5 8 | 1,1,6,1,5,5 9 | 9,4,0,11,7,8 10 | 1,6,1,8,4,9 11 | 5,9,11,1,0,10 12 | 3,4,1,3,1,6 13 | 4,7,10,5,2,5 14 | 5,3,10,11,10,10 15 | 1,9,10,2,11,3 16 | 2,7,6,4,11,8 17 | 3,10,5,0,3,0 18 | 5,6,4,1,3,11 19 | 5,3,10,7,6,10 20 | 7,2,4,2,3,11 21 | 8,8,4,11,9,6 22 | 9,6,5,3,2,7 23 | 1,0,1,2,10,2 24 | 10,6,9,1,6,6 25 | 9,7,8,4,8,0 26 | 10,11,1,10,8,4 27 | 10,5,1,4,6,2 28 | 7,0,11,11,4,8 29 | 2,1,10,10,8,9 30 | 3,2,5,2,0,9 31 | 5,7,0,1,5,4 32 | 3,0,3,1,1,11 33 | 7,8,10,7,8,4 34 | 8,9,6,3,8,11 35 | 11,3,11,4,6,10 36 | 10,5,7,8,7,3 37 | 3,1,5,0,9,8 38 | 3,3,0,1,11,10 39 | 0,3,1,0,5,1 40 | 8,3,4,7,3,8 41 | 2,11,9,9,7,3 42 | 7,6,3,1,1,10 43 | 6,5,6,6,7,11 44 | 0,10,10,10,1,0 45 | 6,11,1,3,3,3 46 | 8,7,2,6,2,4 47 | 7,3,1,8,1,0 48 | 10,8,0,1,3,2 49 | 6,7,7,3,6,2 50 | 6,6,4,7,4,6 51 | 11,11,8,10,11,7 52 | 2,3,4,3,0,9 53 | 11,8,0,11,0,0 54 | 9,7,8,8,0,8 55 | 1,2,1,9,1,10 56 | 3,6,1,9,9,9 57 | 0,9,1,6,10,9 58 | 9,8,5,4,3,10 59 | 11,3,4,6,2,10 60 | 10,7,5,1,0,7 61 | 9,9,1,1,8,3 62 | 8,4,2,5,1,3 63 | 5,4,2,7,8,11 64 | 4,9,10,8,0,10 65 | 8,4,1,2,4,1 66 | 1,11,8,4,4,9 67 | 3,11,3,10,10,8 68 | 7,4,0,1,10,6 69 | 4,0,0,5,2,10 70 | 4,2,11,7,8,11 71 | 6,8,0,1,1,11 72 | 2,0,5,9,8,6 73 | 2,0,4,5,0,5 74 | 3,10,3,10,1,5 75 | 8,6,9,11,2,3 76 | 2,2,6,2,11,6 77 | 10,11,3,4,2,11 78 | 1,6,7,3,3,7 79 | 5,4,3,3,0,10 80 | 3,6,5,4,1,4 81 | 5,10,8,6,10,8 82 | 5,0,1,4,2,9 83 | 4,0,1,9,6,5 84 | 11,6,9,8,1,6 85 | 9,4,0,11,6,8 86 | 8,10,11,11,11,10 87 | 3,5,6,1,10,5 88 | 9,5,10,1,11,4 89 | 8,4,6,5,6,11 90 | 4,8,3,6,10,6 91 | 10,11,2,9,9,4 92 | 6,8,0,4,4,3 93 | 6,9,9,10,5,7 94 | 7,7,10,3,8,7 95 | 11,2,10,1,4,8 96 | 10,10,9,5,1,3 97 | 10,3,3,2,0,0 98 | 3,7,9,1,6,10 99 | 9,11,11,6,7,6 100 | 3,2,10,11,0,1 101 | -------------------------------------------------------------------------------- /gym-submit/gym-submit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import configparser 5 | import argparse 6 | import gym 7 | 8 | 9 | ENV_VAR_NAME = 'OPENAI_GYM_KEY' 10 | CONF_FILE_NAME = '~/.config/gym-submit.conf' 11 | CONF_SECTION_NAME = 'gym-submit' 12 | CONF_VALUE_NAME = 'Key' 13 | 14 | 15 | def look_for_key(): 16 | env_key = os.environ.get(ENV_VAR_NAME) 17 | if env_key is not None: 18 | return env_key 19 | 20 | conf_path = os.path.expanduser(CONF_FILE_NAME) 21 | if os.path.exists(conf_path): 22 | conf = configparser.ConfigParser() 23 | conf.read(conf_path) 24 | if CONF_SECTION_NAME in conf.sections(): 25 | key = conf[CONF_SECTION_NAME].get(CONF_VALUE_NAME) 26 | if key: 27 | return key 28 | 29 | return None 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("dirs", nargs='+', help="Directories to submit") 34 | parser.add_argument("-k", "--key", help="Submission key. If not provided, we'll check {env_name} " 35 | "and config {conf_name}".format(env_name=ENV_VAR_NAME, 36 | conf_name=CONF_FILE_NAME)) 37 | args = parser.parse_args() 38 | 39 | if args.key is not None: 40 | key = args.key 41 | else: 42 | key = look_for_key() 43 | 44 | # if nothing have found, complain about it 45 | if key is None: 46 | print("""No OpenAI Gym key was provided. You can specify it: 47 | 1. as -k argument, 48 | 2. with {env_name} environment variable, 49 | 3. put in file {conf_name} under section '{section_name}' and '{value_name}' value, like in example: 50 | 51 | [{section_name}] 52 | {value_name}=YOUR_KEY 53 | """.format(env_name=ENV_VAR_NAME, conf_name=CONF_FILE_NAME, section_name=CONF_SECTION_NAME, 54 | value_name=CONF_VALUE_NAME)) 55 | sys.exit(-1) 56 | 57 | for dir in args.dirs: 58 | gym.upload(dir, api_key=key) 59 | -------------------------------------------------------------------------------- /algos/a3c_atari_play.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import numpy as np 4 | 5 | from algo_lib.common import make_env, HistoryWrapper 6 | from algo_lib.atari import HISTORY_STEPS, net_input, RescaleWrapper 7 | from algo_lib.a3c import make_run_model 8 | from algo_lib.player import softmax 9 | 10 | try: 11 | from keras.utils.visualize_util import plot 12 | except ImportError: 13 | plot = None 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("-r", "--model", required=True, help="Model file to read") 19 | parser.add_argument("-e", "--env", required=True, help="Environment to use") 20 | parser.add_argument("-m", "--monitor", help="Enable monitor and write to directory, default=disabled") 21 | parser.add_argument("--iters", type=int, default=100, help="Episodes to play, default=100") 22 | parser.add_argument("-v", "--verbose", action="store_true", default=False, help="Show individual episode results") 23 | parser.add_argument("--netimg", action='store_true', default=False, help="Save image of network") 24 | args = parser.parse_args() 25 | 26 | env_wrappers = (HistoryWrapper(HISTORY_STEPS), RescaleWrapper()) 27 | env = make_env(args.env, args.monitor, wrappers=env_wrappers) 28 | state_shape = env.observation_space.shape 29 | n_actions = env.action_space.n 30 | 31 | input_t, conv_out_t = net_input() 32 | model = make_run_model(input_t, conv_out_t, n_actions) 33 | model.summary() 34 | model.load_weights(args.model) 35 | 36 | if plot is not None and args.netimg: 37 | plot(model, to_file="net.png", show_layer_names=False, show_shapes=False) 38 | 39 | rewards = [] 40 | steps = [] 41 | 42 | for iter in range(args.iters): 43 | state = env.reset() 44 | sum_reward = 0.0 45 | step = 0 46 | while True: 47 | probs, value = model.predict_on_batch([ 48 | np.array([state]), 49 | ]) 50 | probs, value = probs[0], value[0][0] 51 | # take action 52 | action = np.random.choice(len(probs), p=softmax(probs)) 53 | state, reward, done, _ = env.step(action) 54 | step += 1 55 | sum_reward += reward 56 | if done: 57 | if args.verbose: 58 | print("Episode %d done in %d steps, reward %f" % (iter, step, sum_reward)) 59 | break 60 | rewards.append(sum_reward) 61 | steps.append(step) 62 | print("Done %d episodes, mean reward %.3f, mean steps %.2f" % (args.iters, np.mean(rewards), np.mean(steps))) 63 | pass 64 | -------------------------------------------------------------------------------- /articles/01_rubic/cubes_tests/cube3x3_d10.txt: -------------------------------------------------------------------------------- 1 | 10,1,0,11,4,3,3,2,11,1 2 | 10,11,8,1,9,6,0,0,1,3 3 | 3,8,9,0,8,3,11,10,11,8 4 | 6,3,7,9,4,0,2,11,6,5 5 | 4,2,3,5,1,1,6,1,5,5 6 | 9,4,0,11,7,8,1,6,1,8 7 | 4,10,9,5,9,3,11,1,0,10 8 | 3,4,1,3,1,6,4,7,10,5 9 | 2,5,5,3,10,4,11,10,10,1 10 | 9,10,2,8,11,3,2,7,6,4 11 | 10,11,8,3,10,5,0,3,0,5 12 | 6,4,1,3,9,11,5,3,10,7 13 | 6,10,7,2,4,2,3,11,8,8 14 | 4,11,9,6,9,6,5,3,2,8 15 | 7,1,0,1,2,10,2,10,6,9 16 | 1,6,6,9,7,8,4,8,0,10 17 | 11,1,10,8,4,10,5,1,4,6 18 | 2,7,0,11,11,4,8,2,8,1 19 | 10,4,10,8,9,3,2,5,2,8 20 | 8,0,9,5,7,0,1,5,4,3 21 | 0,3,9,1,1,11,7,1,8,2 22 | 2,10,7,8,2,4,8,9,6,3 23 | 8,11,11,3,11,4,6,10,10,5 24 | 7,8,7,1,3,3,1,5,0,9 25 | 8,3,9,3,0,1,11,10,0,3 26 | 1,0,5,1,8,3,4,10,7,3 27 | 8,2,11,9,9,7,3,7,6,3 28 | 1,1,10,6,5,6,6,7,11,0 29 | 10,10,10,1,0,6,11,5,1,3 30 | 3,3,8,7,2,6,2,4,7,3 31 | 1,7,8,1,0,10,8,0,1,3 32 | 2,6,7,7,3,6,0,2,6,0 33 | 6,4,7,4,6,11,11,8,10,11 34 | 7,2,3,4,3,0,9,11,8,0 35 | 11,5,0,0,9,7,8,8,2,0 36 | 8,1,2,1,9,1,10,3,6,1 37 | 9,3,9,9,0,9,1,6,10,9 38 | 9,8,5,4,3,10,11,5,3,4 39 | 6,2,10,10,4,7,5,1,0,7 40 | 9,9,1,1,8,3,8,4,2,5 41 | 1,3,5,4,2,7,8,11,4,9 42 | 10,8,0,10,8,4,10,1,2,4 43 | 1,1,11,8,2,4,4,9,3,11 44 | 5,3,10,10,4,8,7,4,0,1 45 | 10,6,4,0,0,5,2,10,4,2 46 | 11,7,8,11,6,8,0,1,1,11 47 | 2,8,0,5,9,8,2,6,2,0 48 | 4,5,0,5,3,10,3,10,1,5 49 | 8,6,9,11,2,3,2,2,6,0 50 | 2,11,5,6,10,11,3,4,2,11 51 | 1,6,0,7,3,3,7,5,4,3 52 | 3,0,10,3,6,5,4,1,4,5 53 | 10,8,6,10,8,5,0,1,4,2 54 | 9,4,0,1,9,6,5,11,5,6 55 | 9,8,1,6,9,3,4,0,11,6 56 | 0,8,8,10,11,11,11,10,3,5 57 | 6,1,10,5,9,5,10,1,11,4 58 | 8,4,10,6,5,6,11,4,8,2 59 | 3,6,10,6,10,11,2,9,9,4 60 | 6,8,0,4,4,3,6,9,9,10 61 | 5,7,7,7,10,3,8,7,11,2 62 | 10,1,4,8,10,10,9,5,1,3 63 | 10,4,3,3,2,0,0,3,7,9 64 | 1,7,6,10,9,3,11,11,6,7 65 | 6,3,2,10,11,0,1,6,3,2 66 | 11,8,7,0,8,3,1,7,2,7 67 | 10,8,8,9,5,7,9,11,8,6 68 | 8,7,2,11,7,7,4,3,10,4 69 | 8,7,10,3,4,7,1,11,4,3 70 | 4,5,5,8,1,2,2,3,6,11 71 | 2,11,3,1,6,6,5,8,7,6 72 | 0,3,6,6,9,11,0,9,6,7 73 | 0,5,4,6,6,8,11,11,8,9 74 | 3,7,3,4,6,7,0,6,5,10 75 | 10,6,11,2,7,2,9,8,0,6 76 | 9,9,10,0,1,10,6,2,7,2 77 | 0,4,6,5,3,7,5,5,6,4 78 | 6,4,1,7,0,11,8,0,5,3 79 | 10,1,10,0,0,3,3,0,9,2 80 | 3,2,7,10,1,9,3,7,11,4 81 | 5,2,9,9,11,11,1,2,4,1 82 | 9,0,4,9,10,6,6,11,3,1 83 | 9,11,10,3,1,11,4,10,9,1 84 | 9,0,5,8,6,10,5,1,8,10 85 | 5,0,6,7,1,6,5,10,7,11 86 | 2,6,2,11,8,10,4,9,8,7 87 | 7,6,11,9,4,5,3,1,4,7 88 | 3,7,9,9,10,6,5,0,7,5 89 | 2,7,3,5,4,5,4,9,11,4 90 | 8,0,8,3,1,3,11,6,7,8 91 | 3,11,7,10,11,7,7,0,1,4 92 | 3,6,11,3,4,10,9,5,7,8 93 | 8,5,6,11,8,5,5,11,7,4 94 | 4,4,3,1,11,3,5,1,11,8 95 | 11,2,3,3,11,7,4,11,9,8 96 | 9,4,1,3,4,3,5,2,4,0 97 | 11,8,2,4,0,0,8,4,11,2 98 | 10,7,1,0,9,4,7,7,7,5 99 | 2,0,4,7,1,1,6,7,1,9 100 | 10,10,0,2,2,9,4,1,3,1 101 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/conf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import configparser 3 | 4 | 5 | class Config: 6 | """ 7 | Configuration for train/test/solve 8 | """ 9 | log = logging.getLogger("Config") 10 | 11 | def __init__(self, file_name): 12 | self.data = configparser.ConfigParser() 13 | self.log.info("Reading config file %s", file_name) 14 | if not self.data.read(file_name): 15 | raise ValueError("Config file %s not found" % file_name) 16 | 17 | # sections acessors 18 | @property 19 | def sect_general(self): 20 | return self.data['general'] 21 | 22 | @property 23 | def sect_train(self): 24 | return self.data['train'] 25 | 26 | # general section 27 | @property 28 | def cube_type(self): 29 | return self.sect_general['cube_type'] 30 | 31 | @property 32 | def run_name(self): 33 | return self.sect_general['run_name'] 34 | 35 | # train section 36 | @property 37 | def train_scramble_depth(self): 38 | return self.sect_train.getint('scramble_depth') 39 | 40 | @property 41 | def train_cuda(self): 42 | return self.sect_train.getboolean('cuda', fallback=False) 43 | 44 | @property 45 | def train_learning_rate(self): 46 | return self.sect_train.getfloat('lr') 47 | 48 | @property 49 | def train_batch_size(self): 50 | return self.sect_train.getint('batch_size') 51 | 52 | @property 53 | def train_report_batches(self): 54 | return self.sect_train.getint('report_batches') 55 | 56 | @property 57 | def train_checkpoint_batches(self): 58 | return self.sect_train.getint('checkpoint_batches') 59 | 60 | @property 61 | def train_lr_decay_enabled(self): 62 | return self.sect_train.getboolean('lr_decay', fallback=False) 63 | 64 | @property 65 | def train_lr_decay_batches(self): 66 | return self.sect_train.getint('lr_decay_batches') 67 | 68 | @property 69 | def train_lr_decay_gamma(self): 70 | return self.sect_train.getfloat('lr_decay_gamma', fallback=1.0) 71 | 72 | @property 73 | def train_value_targets_method(self): 74 | return self.sect_train.get('value_targets_method', fallback='paper') 75 | 76 | @property 77 | def train_max_batches(self): 78 | return self.sect_train.getint('max_batches') 79 | 80 | @property 81 | def scramble_buffer_batches(self): 82 | return self.sect_train.getint("scramble_buffer_batches", 10) 83 | 84 | @property 85 | def push_scramble_buffer_iters(self): 86 | return self.sect_train.getint('push_scramble_buffer_iters', 100) 87 | 88 | @property 89 | def weight_samples(self): 90 | return self.sect_train.getboolean('weight_samples', True) 91 | 92 | # higher-level functions 93 | def train_name(self, suffix=None): 94 | res = "%s-%s-d%d" % (self.cube_type, self.run_name, self.train_scramble_depth) 95 | if suffix is not None: 96 | res += "-" + suffix 97 | return res 98 | -------------------------------------------------------------------------------- /algos/algo_lib/atari.py: -------------------------------------------------------------------------------- 1 | # Atari-specific options for environments 2 | import gym 3 | import gym.spaces 4 | from keras.layers import Input, Flatten, Conv2D, MaxPooling2D, Dense 5 | import numpy as np 6 | import cv2 7 | 8 | from . import common 9 | 10 | 11 | class AtariEnvFactory: 12 | def __init__(self, config): 13 | self.config = config 14 | self.common_factory = common.EnvFactory(config) 15 | 16 | def __call__(self): 17 | env = self.common_factory() 18 | return RescaleWrapper(self.config)(env) 19 | 20 | 21 | class RescaleWrapper: 22 | def __init__(self, config): 23 | self.config = config 24 | 25 | class _RescaleWrapper(gym.Wrapper): 26 | """ 27 | Track history of observations for given amount of steps 28 | Initial steps are zero-filled 29 | """ 30 | def __init__(self, config, env): 31 | super(RescaleWrapper._RescaleWrapper, self).__init__(env) 32 | self.shape = config.image_shape 33 | self.observation_space = self._make_observation_space(env.observation_space, self.shape) 34 | 35 | def _step(self, action): 36 | obs, reward, done, info = self.env.step(action) 37 | return self._preprocess(obs), reward, done, info 38 | 39 | def _reset(self): 40 | return self._preprocess(self.env.reset()) 41 | 42 | @staticmethod 43 | def _make_observation_space(orig_space, target_shape): 44 | assert isinstance(orig_space, gym.spaces.Box) 45 | shape = target_shape + (orig_space.shape[0] * orig_space.shape[-1], ) 46 | low = np.ones(shape) * orig_space.low.min() 47 | high = np.ones(shape) * orig_space.high.max() 48 | return gym.spaces.Box(low, high) 49 | 50 | def _preprocess(self, state): 51 | """ 52 | Convert input from atari game + history buffer to shape expected by net_input function. 53 | :param state: input state 54 | :return: 55 | """ 56 | state = np.transpose(state, (1, 2, 3, 0)) 57 | state = np.reshape(state, (state.shape[0], state.shape[1], state.shape[2] * state.shape[3])) 58 | 59 | state = state.astype(np.float32) 60 | res = cv2.resize(state, self.shape) 61 | res /= 255 62 | return res 63 | 64 | def __call__(self, env): 65 | return self._RescaleWrapper(self.config, env) 66 | 67 | 68 | def net_input(env): 69 | """ 70 | Create input part of the network with optional prescaling. 71 | :return: input_tensor, output_tensor 72 | """ 73 | in_t = Input(shape=env.observation_space.shape, name='input') 74 | out_t = Conv2D(32, 5, 5, activation='relu', border_mode='same')(in_t) 75 | out_t = MaxPooling2D((2, 2))(out_t) 76 | out_t = Conv2D(32, 5, 5, activation='relu', border_mode='same')(out_t) 77 | out_t = MaxPooling2D((2, 2))(out_t) 78 | out_t = Conv2D(64, 4, 4, activation='relu', border_mode='same')(out_t) 79 | out_t = MaxPooling2D((2, 2))(out_t) 80 | out_t = Conv2D(64, 3, 3, activation='relu', border_mode='same')(out_t) 81 | out_t = Flatten(name='flat')(out_t) 82 | out_t = Dense(512, name='l1', activation='relu')(out_t) 83 | 84 | return in_t, out_t 85 | 86 | 87 | -------------------------------------------------------------------------------- /articles/01_rubic/train_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Ad-hoc utility to analyze trained model and various training process details 4 | """ 5 | import argparse 6 | import logging 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import seaborn as sns 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | from libcube import cubes 15 | from libcube import model 16 | 17 | 18 | log = logging.getLogger("train_debug") 19 | 20 | 21 | # How many data to generate for plots 22 | MAX_DEPTH = 10 23 | ROUND_COUNTS = 100 24 | # debug params 25 | #MAX_DEPTH = 5 26 | #ROUND_COUNTS = 2 27 | 28 | 29 | def gen_states(cube_env, max_depth, round_counts): 30 | """ 31 | Generate random states of various scramble depth 32 | :param cube_env: CubeEnv instance 33 | :return: list of list of (state, correct_action_index) pairs 34 | """ 35 | assert isinstance(cube_env, cubes.CubeEnv) 36 | 37 | result = [[] for _ in range(max_depth)] 38 | for _ in range(round_counts): 39 | data = cube_env.scramble_cube(max_depth, return_inverse=True) 40 | for depth, state, inv_action in data: 41 | result[depth-1].append((state, inv_action.value)) 42 | return result 43 | 44 | 45 | if __name__ == "__main__": 46 | sns.set() 47 | 48 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("-e", "--env", required=True, help="Type of env to train, supported types=%s" % cubes.names()) 51 | parser.add_argument("-m", "--model", required=True, help="Model file to load") 52 | parser.add_argument("-o", "--output", required=True, help="Output prefix for plots") 53 | args = parser.parse_args() 54 | 55 | cube_env = cubes.get(args.env) 56 | log.info("Selected cube: %s", cube_env) 57 | net = model.Net(cube_env.encoded_shape, len(cube_env.action_enum)) 58 | net.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage)) 59 | net.eval() 60 | log.info("Network loaded from %s", args.model) 61 | 62 | # model.make_train_data(cube_env, net, device='cpu', batch_size=10, scramble_depth=2, shuffle=False) 63 | 64 | states_by_depth = gen_states(cube_env, max_depth=MAX_DEPTH, round_counts=ROUND_COUNTS) 65 | # for idx, states in enumerate(states_by_depth): 66 | # log.info("%d: %s", idx, states) 67 | 68 | # flatten returned data 69 | data = [] 70 | for depth, states in enumerate(states_by_depth): 71 | for s, inv_action in states: 72 | data.append((depth+1, s, inv_action)) 73 | depths, states, inv_actions = map(list, zip(*data)) 74 | 75 | # process states with net 76 | enc_states = model.encode_states(cube_env, states) 77 | enc_states_t = torch.tensor(enc_states) 78 | policy_t, value_t = net(enc_states_t) 79 | value_t = value_t.squeeze(-1) 80 | value = value_t.cpu().detach().numpy() 81 | policy = F.softmax(policy_t, dim=1).cpu().detach().numpy() 82 | 83 | # plot value per depth of scramble 84 | plot = sns.lineplot(depths, value) 85 | plot.set_title("Values per depths") 86 | plot.get_figure().savefig(args.output + "-vals_vs_depths.png") 87 | 88 | # plot action match 89 | plt.clf() 90 | actions = np.argmax(policy, axis=1) 91 | actions_match = (actions == inv_actions).astype(np.int8) 92 | plot = sns.lineplot(depths, actions_match) 93 | plot.set_title("Actions accuracy per depths") 94 | plot.get_figure().savefig(args.output + "-acts_vs_depths.png") 95 | 96 | pass 97 | -------------------------------------------------------------------------------- /algos/elite.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Multi-layer perceptron inspired by this: https://gym.openai.com/evaluations/eval_P4KyYPwIQdSg6EqvHgYjiw 3 | # https://gist.githubusercontent.com/anonymous/d829ec2f8bda088ac897aa2055dcd3a8/raw/d3fcdfdcc9038bf24385589e94939dcd3c198349/crossentropy_method.py 4 | import gym 5 | import argparse 6 | from gym import wrappers 7 | import numpy as np 8 | 9 | from keras.models import Sequential 10 | from keras.layers import Dense, Activation 11 | from keras.utils.np_utils import to_categorical 12 | from keras.optimizers import Adagrad, RMSprop 13 | 14 | 15 | BATCH_SIZE = 16 16 | 17 | 18 | def make_model(state_shape, actions_n): 19 | m = Sequential() 20 | m.add(Dense(40, input_shape=state_shape, activation='relu')) 21 | m.add(Dense(40)) 22 | m.add(Dense(actions_n)) 23 | m.add(Activation('softmax')) 24 | return m 25 | 26 | 27 | def generate_session(env, model, n_actions, limit=None): 28 | states = [] 29 | actions = [] 30 | s = env.reset() 31 | total_reward = 0 32 | 33 | while True: 34 | probs = model.predict_proba(np.array([s]), verbose=0)[0] 35 | 36 | action = np.random.choice(n_actions, p=probs) 37 | new_s, reward, done, _ = env.step(action) 38 | states.append(s) 39 | actions.append(action) 40 | total_reward += reward 41 | s = new_s 42 | if done: 43 | break 44 | if limit is not None and len(actions) >= limit: 45 | break 46 | 47 | return states, actions, total_reward 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("-r", "--read", help="Read model weight from file, default=None") 53 | parser.add_argument("-m", "--monitor", help="Enable monitor and save data into provided dir, default=disabled") 54 | parser.add_argument("-e", "--env", default="CartPole-v1", help="Environment to test on, default=CartPole-v1") 55 | parser.add_argument("-l", "--limit", default=500, type=int, help="Limit of steps per episode") 56 | parser.add_argument("--iters", type=int, default=100, help="How many learning iterations to do, default=100") 57 | args = parser.parse_args() 58 | 59 | env = gym.make(args.env) 60 | if args.monitor: 61 | env = wrappers.Monitor(env, args.monitor) 62 | state_shape = env.observation_space.shape 63 | n_actions = env.action_space.n 64 | 65 | m = make_model(state_shape, n_actions) 66 | m.summary() 67 | m.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy') 68 | 69 | if args.read: 70 | m.load_weights(args.read) 71 | 72 | for idx in range(args.iters): 73 | batch = [generate_session(env, m, n_actions, limit=args.limit) for _ in range(BATCH_SIZE)] 74 | b_states, b_actions, b_rewards = map(np.array, zip(*batch)) 75 | 76 | threshold = np.percentile(b_rewards, 50) 77 | 78 | elite_states = b_states[b_rewards > threshold] 79 | elite_actions = b_actions[b_rewards > threshold] 80 | 81 | if len(elite_states) > 0: 82 | elite_states, elite_actions = map(np.concatenate, [elite_states, elite_actions]) 83 | oh_actions = to_categorical(elite_actions, nb_classes=n_actions) 84 | m.fit(elite_states, oh_actions, verbose=0, nb_epoch=50) 85 | print("%d: mean reward = %.5f\tthreshold = %.1f" % (idx, np.mean(b_rewards), threshold)) 86 | # m.save_weights("t0-iter=%03d-thr=%.2f.hdf5" % (idx, threshold)) 87 | else: 88 | print("%d: no improvement\tmean reward = %.5f\tthreshold = %.1f" % (idx, np.mean(b_rewards), threshold)) 89 | 90 | pass 91 | -------------------------------------------------------------------------------- /algos/algo_lib/a3c.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Dense, Input, Lambda, BatchNormalization 2 | from keras.models import Model 3 | import keras.backend as K 4 | import tensorflow as tf 5 | 6 | 7 | def net_prediction(input_t, n_actions): 8 | """ 9 | Make prediction part of A3C network 10 | :param input_t: flattened input from previous layer 11 | :return: policy_tensor and value_tensor 12 | """ 13 | value_t = Dense(1, name='value')(input_t) 14 | policy_t = Dense(n_actions, name='policy')(input_t) 15 | 16 | return policy_t, value_t 17 | 18 | 19 | def create_policy_loss(policy_t, value_t, n_actions): 20 | """ 21 | Policy loss 22 | :param policy_t: policy tensor from prediction part 23 | :param value_t: value tensor from prediction part 24 | :param n_actions: count of actions in space 25 | :param entropy_beta: entropy loss scaling factor 26 | :return: action_t, advantage_t, policy_loss_t 27 | """ 28 | action_t = Input(batch_shape=(None, 1), name='action', dtype='int32') 29 | reward_t = Input(batch_shape=(None, 1), name="reward") 30 | 31 | def policy_loss_func(args): 32 | p_t, v_t, act_t, rew_t = args 33 | log_p_t = tf.nn.log_softmax(p_t) 34 | oh_t = K.one_hot(act_t, n_actions) 35 | oh_t = K.squeeze(oh_t, 1) 36 | p_oh_t = K.sum(log_p_t * oh_t, axis=-1, keepdims=True) 37 | adv_t = (rew_t - K.stop_gradient(v_t)) 38 | tf.summary.scalar("advantage_mean", K.mean(adv_t)) 39 | tf.summary.scalar("advantage_rms", K.sqrt(K.mean(K.square(adv_t)))) 40 | 41 | res_t = -adv_t * p_oh_t 42 | tf.summary.scalar("loss_policy_mean", K.mean(res_t)) 43 | tf.summary.scalar("loss_policy_rms", K.sqrt(K.mean(K.square(res_t)))) 44 | return res_t 45 | 46 | loss_args = [policy_t, value_t, action_t, reward_t] 47 | policy_loss_t = Lambda(policy_loss_func, output_shape=(1,), name='policy_loss')(loss_args) 48 | 49 | tf.summary.scalar("value_mean", K.mean(value_t)) 50 | tf.summary.scalar("reward_mean", K.mean(reward_t)) 51 | 52 | return action_t, reward_t, policy_loss_t 53 | 54 | 55 | def create_value_loss(value_t, reward_t): 56 | value_loss_func = lambda args: K.mean(K.square(args[0] - args[1]), axis=-1, keepdims=True) 57 | value_loss_t = Lambda(value_loss_func, name="value_loss", output_shape=(1,))([reward_t, value_t]) 58 | return value_loss_t 59 | 60 | 61 | def create_entropy_loss(policy_t, beta): 62 | def entropy_loss_func(p_t): 63 | log_p_t = tf.nn.log_softmax(p_t) 64 | sigm_p_t = K.softmax(p_t) 65 | entropy_t = beta * K.sum(sigm_p_t * log_p_t, axis=-1, keepdims=True) 66 | return entropy_t 67 | 68 | entropy_loss_t = Lambda(entropy_loss_func, name="entropy_loss", output_shape=(1,))(policy_t) 69 | return entropy_loss_t 70 | 71 | 72 | def make_run_model(input_t, conv_output_t, n_actions): 73 | policy_t, value_t = net_prediction(conv_output_t, n_actions) 74 | return Model(input=input_t, output=[policy_t, value_t]) 75 | 76 | 77 | def make_train_model(input_t, conv_output_t, n_actions, entropy_beta=0.01): 78 | policy_t, value_t = net_prediction(conv_output_t, n_actions) 79 | action_t, reward_t, policy_loss_t = create_policy_loss(policy_t, value_t, n_actions) 80 | 81 | value_loss_t = create_value_loss(value_t=value_t, reward_t=reward_t) 82 | entropy_loss_t = create_entropy_loss(policy_t, entropy_beta) 83 | 84 | tf.summary.scalar("loss_value", K.mean(value_loss_t)) 85 | tf.summary.scalar("loss_entropy", K.mean(entropy_loss_t)) 86 | 87 | return Model(input=[input_t, action_t, reward_t], output=[policy_loss_t, entropy_loss_t, value_loss_t]) 88 | -------------------------------------------------------------------------------- /misc/nn_plus/train_pong.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import gym 3 | import ptan 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | 12 | from tensorboardX import SummaryWriter 13 | 14 | from lib import model, common 15 | 16 | 17 | class NoisyDQN(nn.Module): 18 | def __init__(self, input_shape, n_actions): 19 | super(NoisyDQN, self).__init__() 20 | 21 | self.conv = nn.Sequential( 22 | nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), 23 | nn.ReLU(), 24 | nn.Conv2d(32, 64, kernel_size=4, stride=2), 25 | nn.ReLU(), 26 | nn.Conv2d(64, 64, kernel_size=3, stride=1), 27 | nn.ReLU() 28 | ) 29 | 30 | conv_out_size = self._get_conv_out(input_shape) 31 | self.noisy_out = model.NoisyLinearExt(512, n_actions) 32 | 33 | self.fc = nn.Sequential( 34 | nn.Linear(conv_out_size, 512), 35 | nn.ReLU(), 36 | ) 37 | 38 | self.sigma_layers = nn.Sequential( 39 | nn.Linear(conv_out_size, 512), 40 | nn.ReLU(), 41 | nn.Linear(512, 1) 42 | ) 43 | 44 | def _get_conv_out(self, shape): 45 | o = self.conv(Variable(torch.zeros(1, *shape))) 46 | return int(np.prod(o.size())) 47 | 48 | def forward(self, x): 49 | fx = x.float() / 256 50 | conv_out = self.conv(fx).view(fx.size()[0], -1) 51 | sigma = self.sigma_layers(conv_out) 52 | fc_out = self.fc(conv_out) 53 | out = self.noisy_out(fc_out, sigma=sigma) 54 | return out 55 | 56 | 57 | if __name__ == "__main__": 58 | params = common.HYPERPARAMS['breakout'] 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda") 61 | args = parser.parse_args() 62 | 63 | env = gym.make(params['env_name']) 64 | env = ptan.common.wrappers.wrap_dqn(env) 65 | 66 | writer = SummaryWriter(comment="-" + params['run_name'] + "-noisy-plus-1") 67 | net = NoisyDQN(env.observation_space.shape, env.action_space.n) 68 | if args.cuda: 69 | net.cuda() 70 | 71 | tgt_net = ptan.agent.TargetNet(net) 72 | agent = ptan.agent.DQNAgent(net, ptan.actions.ArgmaxActionSelector(), cuda=args.cuda) 73 | 74 | exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=1) 75 | buffer = ptan.experience.ExperienceReplayBuffer(exp_source, buffer_size=params['replay_size']) 76 | optimizer = optim.Adam(net.parameters(), lr=params['learning_rate']) 77 | 78 | frame_idx = 0 79 | 80 | with common.RewardTracker(writer, params['stop_reward']) as reward_tracker: 81 | while True: 82 | frame_idx += 1 83 | buffer.populate(1) 84 | 85 | new_rewards = exp_source.pop_total_rewards() 86 | if new_rewards: 87 | if reward_tracker.reward(new_rewards[0], frame_idx): 88 | break 89 | 90 | if len(buffer) < params['replay_initial']: 91 | continue 92 | 93 | optimizer.zero_grad() 94 | batch = buffer.sample(params['batch_size']) 95 | loss_v = common.calc_loss_dqn(batch, net, tgt_net.target_model, gamma=params['gamma'], cuda=args.cuda) 96 | loss_v.backward() 97 | optimizer.step() 98 | 99 | if frame_idx % params['target_net_sync'] == 0: 100 | tgt_net.sync() 101 | -------------------------------------------------------------------------------- /misc/nn_plus/lib/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import numpy as np 8 | 9 | 10 | class NoisyLinear(nn.Linear): 11 | def __init__(self, in_features, out_features, sigma_init=0.017, bias=True): 12 | super(NoisyLinear, self).__init__(in_features, out_features, bias=bias) 13 | self.sigma_weight = nn.Parameter(torch.Tensor(out_features, in_features).fill_(sigma_init)) 14 | self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features)) 15 | if bias: 16 | self.sigma_bias = nn.Parameter(torch.Tensor(out_features).fill_(sigma_init)) 17 | self.register_buffer("epsilon_bias", torch.zeros(out_features)) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | std = math.sqrt(3 / self.in_features) 22 | nn.init.uniform(self.weight, -std, std) 23 | nn.init.uniform(self.bias, -std, std) 24 | 25 | def forward(self, input): 26 | torch.randn(self.epsilon_weight.size(), out=self.epsilon_weight) 27 | bias = self.bias 28 | if bias is not None: 29 | torch.randn(self.epsilon_bias.size(), out=self.epsilon_bias) 30 | bias = bias + self.sigma_bias * Variable(self.epsilon_bias) 31 | return F.linear(input, self.weight + self.sigma_weight * Variable(self.epsilon_weight), bias) 32 | 33 | 34 | class NoisyFactorizedLinear(nn.Linear): 35 | """ 36 | NoisyNet layer with factorized gaussian noise 37 | 38 | N.B. nn.Linear already initializes weight and bias to 39 | """ 40 | def __init__(self, in_features, out_features, sigma_zero=0.4, bias=True): 41 | super(NoisyFactorizedLinear, self).__init__(in_features, out_features, bias=bias) 42 | sigma_init = sigma_zero / math.sqrt(in_features) 43 | self.sigma_weight = nn.Parameter(torch.Tensor(out_features, in_features).fill_(sigma_init)) 44 | self.register_buffer("epsilon_input", torch.zeros(1, in_features)) 45 | self.register_buffer("epsilon_output", torch.zeros(out_features, 1)) 46 | if bias: 47 | self.sigma_bias = nn.Parameter(torch.Tensor(out_features).fill_(sigma_init)) 48 | 49 | def forward(self, input): 50 | torch.randn(self.epsilon_input.size(), out=self.epsilon_input) 51 | torch.randn(self.epsilon_output.size(), out=self.epsilon_output) 52 | 53 | func = lambda x: torch.sign(x) * torch.sqrt(torch.abs(x)) 54 | eps_in = func(self.epsilon_input) 55 | eps_out = func(self.epsilon_output) 56 | 57 | bias = self.bias 58 | if bias is not None: 59 | bias = bias + self.sigma_bias * Variable(eps_out.t()) 60 | noise_v = Variable(torch.mul(eps_in, eps_out)) 61 | return F.linear(input, self.weight + self.sigma_weight * noise_v, bias) 62 | 63 | 64 | class NoisyLinearExt(nn.Linear): 65 | """ 66 | Noisy layer with externally-controllable sigma 67 | """ 68 | def __init__(self, in_features, out_features, bias=True): 69 | super(NoisyLinearExt, self).__init__(in_features, out_features, bias=bias) 70 | self.rand_buf = None 71 | 72 | def forward(self, input, sigma=None): 73 | res = F.linear(input, self.weight, self.bias) 74 | if sigma is None: 75 | return res 76 | if self.rand_buf is None or self.rand_buf.size() != res.size(): 77 | self.rand_buf = torch.FloatTensor(res.size()) 78 | if input.is_cuda: 79 | self.rand_buf = self.rand_buf.cuda() 80 | torch.randn(self.rand_buf.size(), out=self.rand_buf) 81 | # print(m.size(), res.size()) 82 | return res + torch.mul(sigma, Variable(self.rand_buf)) 83 | -------------------------------------------------------------------------------- /algos/others/p.py: -------------------------------------------------------------------------------- 1 | """ 2 | idea(and code) from Karpathy's PG Pong 3 | 4 | Karpathy's PG Pong code : https://gist.github.com/karpathy/a4166c7fe253700972fcbc77e4ea32c5 5 | 6 | Karpathy's PG blog post : http://karpathy.github.io/2016/05/31/rl/ 7 | 8 | https://gist.github.com/zzing0907/de3665f9f7bbe9329b283da90d72049e#file-cartpole_pg-py 9 | """ 10 | import numpy as np 11 | import pickle 12 | import gym, gym.wrappers 13 | 14 | H = 10 15 | learning_rate = 2e-3 16 | gamma = 0.99 17 | decay_rate = 0.99 18 | score_queue_size = 100 19 | resume = False 20 | D = 3 21 | 22 | if resume: 23 | model = pickle.load(open('save.p', 'rb')) 24 | else: 25 | model = {} 26 | model['W1'] = np.random.randn(H, D) / np.sqrt(D) 27 | model['W2'] = np.random.randn(H) / np.sqrt(H) 28 | 29 | grad_buffer = {k: np.zeros_like(v) for k, v in model.items()} 30 | rmsprop_cache = {k: np.zeros_like(v) for k, v in model.items()} 31 | 32 | 33 | def sigmoid(x): 34 | return 1.0 / (1.0 + np.exp(-x)) 35 | 36 | 37 | def prepro(I): 38 | return I[1:] 39 | 40 | 41 | def discount_rewards(r): 42 | discounted_r = np.zeros_like(r) 43 | running_add = 0 44 | for t in reversed(range(0, r.size)): 45 | running_add = running_add * gamma + r[t] 46 | discounted_r[t] = running_add 47 | 48 | return discounted_r 49 | 50 | 51 | def policy_forward(x): 52 | h = np.dot(model['W1'], x) 53 | h = sigmoid(h) 54 | logp = np.dot(model['W2'], h) 55 | p = sigmoid(logp) 56 | return p, h 57 | 58 | 59 | def policy_backward(eph, epdlogp, epx): 60 | global grad_buffer 61 | dW2 = np.dot(eph.T, epdlogp).ravel() 62 | dh = np.outer(epdlogp, model['W2']) 63 | eph_dot = eph * (1 - eph) 64 | dW1 = dh * eph_dot 65 | dW1 = np.dot(dW1.T, epx) 66 | 67 | for k in model: grad_buffer[k] += {'W1': dW1, 'W2': dW2}[k] 68 | 69 | 70 | env = gym.make('CartPole-v0') 71 | env = gym.wrappers.Monitor(env, "res-1") 72 | #env.monitor.start('CartPole', force=True) 73 | observation = env.reset() 74 | reward_sum, episode_num = 0, 0 75 | xs, hs, dlogps, drs = [], [], [], [] 76 | score_queue = [] 77 | 78 | while True: 79 | 80 | x = prepro(observation) 81 | 82 | act_prob, h = policy_forward(x) 83 | 84 | if np.mean(score_queue) > 180: 85 | action = 1 if 0.5 < act_prob else 0 86 | else: 87 | action = 1 if np.random.uniform() < act_prob else 0 88 | 89 | xs.append(x) 90 | hs.append(h) 91 | y = action 92 | dlogps.append(y - act_prob) 93 | 94 | observation, reward, done, info = env.step(action) 95 | reward_sum += reward 96 | 97 | drs.append(reward) 98 | 99 | if done: 100 | episode_num += 1 101 | 102 | if episode_num > score_queue_size: 103 | score_queue.append(reward_sum) 104 | score_queue.pop(0) 105 | else: 106 | score_queue.append(reward_sum) 107 | 108 | print("episode : " + str(episode_num) + ", reward : " + str(reward_sum) + ", reward_mean : " + str( 109 | np.mean(score_queue))) 110 | 111 | if np.mean(score_queue) >= 200: 112 | print("CartPole solved!!!!!") 113 | break 114 | 115 | epx = np.vstack(xs) 116 | eph = np.vstack(hs) 117 | epdlogp = np.vstack(dlogps) 118 | epr = np.vstack(drs) 119 | xs, hs, dlogps, drs = [], [], [], [] 120 | 121 | discounted_epr = discount_rewards(epr) 122 | discounted_epr -= np.mean(discounted_epr) 123 | discounted_epr /= np.std(discounted_epr) 124 | 125 | epdlogp *= discounted_epr 126 | 127 | policy_backward(eph, epdlogp, epx) 128 | for k, v in model.items(): 129 | g = grad_buffer[k] 130 | rmsprop_cache[k] = decay_rate * rmsprop_cache[k] + (1 - decay_rate) * g ** 2 131 | model[k] += learning_rate * g / (np.sqrt(rmsprop_cache[k]) + 1e-5) 132 | grad_buffer[k] = np.zeros_like(v) 133 | 134 | if episode_num % 1000 == 0: pickle.dump(model, open('Cart.p', 'wb')) 135 | 136 | reward_sum = 0 137 | observation = env.reset() 138 | 139 | #env.monitor.close() 140 | -------------------------------------------------------------------------------- /algos/a3c_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Quick-n-dirty implementation of Advantage Actor-Critic method from https://arxiv.org/abs/1602.01783 3 | import os 4 | import argparse 5 | import logging 6 | import time 7 | import numpy as np 8 | 9 | from keras.optimizers import Adam 10 | from keras import backend as K 11 | import tensorflow as tf 12 | 13 | from algo_lib.common import make_env, summarize_gradients, summary_value 14 | from algo_lib import common 15 | from algo_lib import atari 16 | from algo_lib.a3c import make_run_model, make_train_model 17 | from algo_lib.player import Player, generate_batches 18 | 19 | HISTORY_STEPS = 4 20 | 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | 24 | PLAYERS_COUNT = 50 25 | BATCH_SIZE = 128 26 | 27 | SUMMARY_EVERY_BATCH = 10 28 | SYNC_MODEL_EVERY_BATCH = 1 29 | SAVE_MODEL_EVERY_BATCH = 3000 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("-i", "--ini", required=True, help="Ini file with configuration") 35 | parser.add_argument("-n", "--name", required=True, help="Run name") 36 | args = parser.parse_args() 37 | 38 | # limit GPU memory 39 | config = tf.ConfigProto() 40 | config.gpu_options.per_process_gpu_memory_fraction = 0.2 41 | K.set_session(tf.Session(config=config)) 42 | 43 | config = common.Configuration(args.ini) 44 | env_factory = atari.AtariEnvFactory(config) 45 | 46 | env = env_factory() 47 | state_shape = env.observation_space.shape 48 | n_actions = env.action_space.n 49 | logger.info("Created environment %s, state: %s, actions: %s", config.env_name, state_shape, n_actions) 50 | 51 | tr_input_t, tr_conv_out_t = atari.net_input(env) 52 | value_policy_model = make_train_model(tr_input_t, tr_conv_out_t, n_actions) 53 | 54 | r_input_t, r_conv_out_t = atari.net_input(env) 55 | run_model = make_run_model(r_input_t, r_conv_out_t, n_actions) 56 | 57 | value_policy_model.summary() 58 | 59 | loss_dict = { 60 | 'policy_loss': lambda y_true, y_pred: y_pred, 61 | 'value_loss': lambda y_true, y_pred: y_pred, 62 | 'entropy_loss': lambda y_true, y_pred: y_pred, 63 | } 64 | 65 | value_policy_model.compile(optimizer=Adam(lr=0.001, epsilon=1e-3, clipnorm=0.1), loss=loss_dict) 66 | 67 | # keras summary magic 68 | summary_writer = tf.summary.FileWriter("logs/" + args.name) 69 | summarize_gradients(value_policy_model) 70 | value_policy_model.metrics_names.append("value_summary") 71 | value_policy_model.metrics_tensors.append(tf.summary.merge_all()) 72 | 73 | players = [Player(env_factory(), reward_steps=config.a3c_steps, 74 | gamma=config.a3c_gamma, max_steps=config.max_steps, player_index=idx) 75 | for idx in range(PLAYERS_COUNT)] 76 | 77 | bench_samples = 0 78 | bench_ts = time.time() 79 | 80 | for iter_idx, x_batch in enumerate(generate_batches(run_model, players, BATCH_SIZE)): 81 | y_stub = np.zeros(len(x_batch[0])) 82 | l = value_policy_model.train_on_batch(x_batch, [y_stub]*3) 83 | bench_samples += BATCH_SIZE 84 | 85 | if iter_idx % SUMMARY_EVERY_BATCH == 0: 86 | l_dict = dict(zip(value_policy_model.metrics_names, l)) 87 | done_rewards = Player.gather_done_rewards(*players) 88 | 89 | if done_rewards: 90 | summary_value("reward_episode_mean", np.mean(done_rewards), summary_writer, iter_idx) 91 | summary_value("reward_episode_max", np.max(done_rewards), summary_writer, iter_idx) 92 | 93 | summary_value("speed", bench_samples / (time.time() - bench_ts), summary_writer, iter_idx) 94 | summary_value("reward_batch", np.mean(x_batch[2]), summary_writer, iter_idx) 95 | summary_value("loss_full", l_dict['loss'], summary_writer, iter_idx) 96 | summary_writer.add_summary(l_dict['value_summary'], global_step=iter_idx) 97 | summary_writer.flush() 98 | bench_samples = 0 99 | bench_ts = time.time() 100 | 101 | if iter_idx % SYNC_MODEL_EVERY_BATCH == 0: 102 | run_model.set_weights(value_policy_model.get_weights()) 103 | # logger.info("Models synchronized, iter %d", iter_idx) 104 | 105 | if iter_idx % SAVE_MODEL_EVERY_BATCH == 0 and iter_idx > 0: 106 | value_policy_model.save(os.path.join("logs", args.name, "model-%06d.h5" % iter_idx)) 107 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/cubes/_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic cube env representation and registry 3 | """ 4 | import logging 5 | import random 6 | 7 | log = logging.getLogger("cube.env") 8 | _registry = {} 9 | 10 | 11 | class CubeEnv: 12 | def __init__(self, name, state_type, initial_state, is_goal_pred, 13 | action_enum, transform_func, inverse_action_func, 14 | render_func, encoded_shape, encode_func): 15 | self.name = name 16 | self._state_type = state_type 17 | self.initial_state = initial_state 18 | self._is_goal_pred = is_goal_pred 19 | self.action_enum = action_enum 20 | self._transform_func = transform_func 21 | self._inverse_action_func = inverse_action_func 22 | self._render_func = render_func 23 | self.encoded_shape = encoded_shape 24 | self._encode_func = encode_func 25 | 26 | def __repr__(self): 27 | return "CubeEnv(%r)" % self.name 28 | 29 | # wrapper functions 30 | def is_goal(self, state): 31 | assert isinstance(state, self._state_type) 32 | return self._is_goal_pred(state) 33 | 34 | def transform(self, state, action): 35 | assert isinstance(state, self._state_type) 36 | assert isinstance(action, self.action_enum) 37 | return self._transform_func(state, action) 38 | 39 | def inverse_action(self, action): 40 | return self._inverse_action_func(action) 41 | 42 | def render(self, state): 43 | assert isinstance(state, self._state_type) 44 | return self._render_func(state) 45 | 46 | def encode_inplace(self, target, state): 47 | assert isinstance(state, self._state_type) 48 | return self._encode_func(target, state) 49 | 50 | # Utility functions 51 | def sample_action(self, prev_action=None): 52 | while True: 53 | res = self.action_enum(random.randrange(len(self.action_enum))) 54 | if prev_action is None or self.inverse_action(res) != prev_action: 55 | return res 56 | 57 | def scramble(self, actions): 58 | s = self.initial_state 59 | for action in actions: 60 | s = self.transform(s, action) 61 | return s 62 | 63 | def is_state(self, state): 64 | return isinstance(state, self._state_type) 65 | 66 | def scramble_cube(self, scrambles_count, return_inverse=False, include_initial=False): 67 | """ 68 | Generate sequence of random cube scrambles 69 | :param scrambles_count: count of scrambles to perform 70 | :param return_inverse: if True, inverse action is returned 71 | :return: list of tuples (depth, state[, inverse_action]) 72 | """ 73 | assert isinstance(scrambles_count, int) 74 | assert scrambles_count > 0 75 | 76 | state = self.initial_state 77 | result = [] 78 | if include_initial: 79 | assert not return_inverse 80 | result.append((1, state)) 81 | prev_action = None 82 | for depth in range(scrambles_count): 83 | action = self.sample_action(prev_action=prev_action) 84 | state = self.transform(state, action) 85 | prev_action = action 86 | if return_inverse: 87 | inv_action = self.inverse_action(action) 88 | res = (depth+1, state, inv_action) 89 | else: 90 | res = (depth+1, state) 91 | result.append(res) 92 | return result 93 | 94 | def explore_state(self, state): 95 | """ 96 | Expand cube state by applying every action to it 97 | :param state: state to explore 98 | :return: tuple of two lists: [states reachable], [flag that state is initial] 99 | """ 100 | res_states, res_flags = [], [] 101 | for action in self.action_enum: 102 | new_state = self.transform(state, action) 103 | is_init = self.is_goal(new_state) 104 | res_states.append(new_state) 105 | res_flags.append(is_init) 106 | return res_states, res_flags 107 | 108 | 109 | def register(cube_env): 110 | assert isinstance(cube_env, CubeEnv) 111 | global _registry 112 | 113 | if cube_env.name in _registry: 114 | log.warning("Cube environment %s is already registered, ignored", cube_env) 115 | else: 116 | _registry[cube_env.name] = cube_env 117 | 118 | 119 | def get(name): 120 | assert isinstance(name, str) 121 | return _registry.get(name) 122 | 123 | 124 | def names(): 125 | return list(sorted(_registry.keys())) 126 | -------------------------------------------------------------------------------- /algos/a3c_async.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import argparse 4 | import logging 5 | import numpy as np 6 | 7 | import time 8 | import datetime 9 | import tensorflow as tf 10 | import multiprocessing as mp 11 | 12 | from keras.optimizers import Adam 13 | 14 | from algo_lib import common 15 | from algo_lib import atari 16 | from algo_lib import player 17 | 18 | from algo_lib.a3c import make_train_model, make_run_model 19 | 20 | logger = logging.getLogger() 21 | logger.setLevel(logging.INFO) 22 | 23 | SUMMARY_EVERY_BATCH = 100 24 | SYNC_MODEL_EVERY_BATCH = 1 25 | SAVE_MODEL_EVERY_BATCH = 3000 26 | 27 | 28 | if __name__ == "__main__": 29 | # work-around for TF multiprocessing problems 30 | mp.set_start_method('spawn') 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("-r", "--read", help="Model file name to read") 34 | parser.add_argument("-n", "--name", required=True, help="Run name") 35 | parser.add_argument("-i", "--ini", required=True, help="Ini file with configuration") 36 | args = parser.parse_args() 37 | 38 | config = common.Configuration(args.ini) 39 | 40 | env_factory = atari.AtariEnvFactory(config) 41 | 42 | env = env_factory() 43 | state_shape = env.observation_space.shape 44 | n_actions = env.action_space.n 45 | logger.info("Created environment %s, state: %s, actions: %s", config.env_name, state_shape, n_actions) 46 | 47 | input_t, conv_out_t = atari.net_input(env) 48 | value_policy_model = make_train_model(input_t, conv_out_t, n_actions, entropy_beta=config.a3c_beta) 49 | value_policy_model.summary() 50 | run_model = make_run_model(input_t, conv_out_t, n_actions) 51 | run_model.summary() 52 | 53 | loss_dict = { 54 | 'policy_loss': lambda y_true, y_pred: y_pred, 55 | 'value_loss': lambda y_true, y_pred: y_pred, 56 | 'entropy_loss': lambda y_true, y_pred: y_pred, 57 | } 58 | 59 | optimizer = Adam(lr=config.learning_rate, epsilon=1e-3, clipnorm=config.gradient_clip_norm) 60 | value_policy_model.compile(optimizer=optimizer, loss=loss_dict) 61 | 62 | # keras summary magic 63 | summary_writer = tf.summary.FileWriter("logs/" + args.name) 64 | common.summarize_gradients(value_policy_model) 65 | value_policy_model.metrics_names.append("value_summary") 66 | value_policy_model.metrics_tensors.append(tf.summary.merge_all()) 67 | 68 | if args.read: 69 | logger.info("Loading model from %s", args.read) 70 | value_policy_model.load_weights(args.read) 71 | 72 | tweaker = common.ParamsTweaker() 73 | tweaker.add("lr", optimizer.lr) 74 | 75 | players = player.AsyncPlayersSwarm(config, env_factory, run_model) 76 | iter_idx = 0 77 | bench_samples = 0 78 | bench_ts = time.time() 79 | 80 | while True: 81 | if iter_idx % SYNC_MODEL_EVERY_BATCH == 0: 82 | players.push_model_weights(value_policy_model.get_weights()) 83 | 84 | iter_idx += 1 85 | batch_ts = time.time() 86 | x_batch = players.get_batch() 87 | # stub for y 88 | y_stub = np.zeros(len(x_batch[0])) 89 | 90 | l = value_policy_model.train_on_batch(x_batch, [y_stub]*3) 91 | bench_samples += config.batch_size 92 | 93 | if iter_idx % SUMMARY_EVERY_BATCH == 0: 94 | l_dict = dict(zip(value_policy_model.metrics_names, l)) 95 | done_rewards = players.get_done_rewards() 96 | 97 | if done_rewards: 98 | common.summary_value("reward_episode_mean", np.mean(done_rewards), summary_writer, iter_idx) 99 | common.summary_value("reward_episode_max", np.max(done_rewards), summary_writer, iter_idx) 100 | common.summary_value("reward_episode_min", np.min(done_rewards), summary_writer, iter_idx) 101 | 102 | # summary_value("rewards_norm_mean", np.mean(y_batch[0]), summary_writer, iter_idx) 103 | common.summary_value("speed", bench_samples / (time.time() - bench_ts), summary_writer, iter_idx) 104 | common.summary_value("loss", l_dict['loss'], summary_writer, iter_idx) 105 | summary_writer.add_summary(l_dict['value_summary'], global_step=iter_idx) 106 | summary_writer.flush() 107 | bench_samples = 0 108 | logger.info("Iter %d: speed %s per batch", iter_idx, 109 | datetime.timedelta(seconds=(time.time() - bench_ts)/SUMMARY_EVERY_BATCH)) 110 | bench_ts = time.time() 111 | 112 | if iter_idx % SAVE_MODEL_EVERY_BATCH == 0: 113 | value_policy_model.save(os.path.join("logs", args.name, "model-%06d.h5" % iter_idx)) 114 | 115 | tweaker.check() 116 | -------------------------------------------------------------------------------- /algos/a3c.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Quick-n-dirty implementation of Advantage Actor-Critic method from https://arxiv.org/abs/1602.01783 3 | import uuid 4 | import os 5 | import argparse 6 | import logging 7 | import numpy as np 8 | import pickle 9 | 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | import tensorflow as tf 14 | from keras.layers import Input, Dense, Flatten 15 | from keras.optimizers import Adam 16 | 17 | from algo_lib import common# import make_env, summarize_gradients, summary_value, HistoryWrapper 18 | from algo_lib import a3c 19 | from algo_lib.player import Player, generate_batches 20 | 21 | HISTORY_STEPS = 4 22 | SIMPLE_L1_SIZE = 50 23 | SIMPLE_L2_SIZE = 50 24 | 25 | SUMMARY_EVERY_BATCH = 100 26 | SAVE_MODEL_EVERY_BATCH = 3000 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("-e", "--env", default="CartPole-v0", help="Environment name to use") 32 | parser.add_argument("-m", "--monitor", help="Enable monitor and save data into provided dir, default=disabled") 33 | parser.add_argument("--gamma", type=float, default=1.0, help="Gamma for reward discount, default=1.0") 34 | parser.add_argument("-n", "--name", required=True, help="Run name") 35 | args = parser.parse_args() 36 | 37 | env_wrappers = (common.HistoryWrapper(HISTORY_STEPS),) 38 | env = common.make_env(args.env, args.monitor, wrappers=env_wrappers) 39 | state_shape = env.observation_space.shape 40 | n_actions = env.action_space.n 41 | 42 | logger.info("Created environment %s, state: %s, actions: %s", args.env, state_shape, n_actions) 43 | 44 | in_t = Input(shape=state_shape, name='input') 45 | fl_t = Flatten(name='flat')(in_t) 46 | l1_t = Dense(SIMPLE_L1_SIZE, activation='relu', name='in_l1')(fl_t) 47 | out_t = Dense(SIMPLE_L2_SIZE, activation='relu', name='in_l2')(l1_t) 48 | 49 | run_model = a3c.make_run_model(in_t, out_t, n_actions) 50 | value_policy_model = a3c.make_train_model(in_t, out_t, n_actions, entropy_beta=0.01) 51 | value_policy_model.summary() 52 | 53 | loss_dict = { 54 | 'policy_loss': lambda y_true, y_pred: y_pred, 55 | 'value_loss': lambda y_true, y_pred: y_pred, 56 | 'entropy_loss': lambda y_true, y_pred: y_pred, 57 | } 58 | value_policy_model.compile(optimizer=Adam(lr=0.0005, clipnorm=0.1), loss=loss_dict) 59 | 60 | # keras summary magic 61 | summary_writer = tf.summary.FileWriter("logs-a3c/" + args.name) 62 | common.summarize_gradients(value_policy_model) 63 | value_policy_model.metrics_names.append("value_summary") 64 | value_policy_model.metrics_tensors.append(tf.summary.merge_all()) 65 | 66 | if args.env.startswith("MountainCar"): 67 | reward_hook = lambda reward, done, step: int(done)*10.0 68 | else: 69 | reward_hook = None 70 | 71 | players = [ 72 | Player(common.make_env(args.env, args.monitor, wrappers=env_wrappers), reward_steps=20, gamma=0.999, 73 | max_steps=40000, player_index=idx, reward_hook=reward_hook) 74 | for idx in range(10) 75 | ] 76 | 77 | for iter_idx, x_batch in enumerate(generate_batches(run_model, players, 128)): 78 | y_stub = np.zeros(len(x_batch[0])) 79 | pre_weights = value_policy_model.get_weights() 80 | l = value_policy_model.train_on_batch(x_batch, [y_stub]*3) 81 | post_weights = value_policy_model.get_weights() 82 | 83 | # logger.info("Iteration %d, loss: %s", iter_idx, l[:-1]) 84 | if np.isnan(l[:-1]).any(): 85 | break 86 | 87 | if iter_idx % SUMMARY_EVERY_BATCH == 0: 88 | l_dict = dict(zip(value_policy_model.metrics_names, l)) 89 | done_rewards = Player.gather_done_rewards(*players) 90 | 91 | if done_rewards: 92 | common.summary_value("reward_episode_mean", np.mean(done_rewards), summary_writer, iter_idx) 93 | common.summary_value("reward_episode_max", np.max(done_rewards), summary_writer, iter_idx) 94 | common.summary_value("reward_episode_min", np.min(done_rewards), summary_writer, iter_idx) 95 | 96 | common.summary_value("reward_batch", np.mean(x_batch[2]), summary_writer, iter_idx) 97 | common.summary_value("loss", l_dict['loss'], summary_writer, iter_idx) 98 | summary_writer.add_summary(l_dict['value_summary'], global_step=iter_idx) 99 | summary_writer.flush() 100 | 101 | if iter_idx % SAVE_MODEL_EVERY_BATCH == 0: 102 | value_policy_model.save(os.path.join("logs-a3c", args.name, "model-%06d.h5" % iter_idx)) 103 | 104 | if iter_idx % 20 == 0: 105 | run_model.set_weights(value_policy_model.get_weights()) 106 | pass 107 | -------------------------------------------------------------------------------- /misc/nn_plus/lib/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | 9 | HYPERPARAMS = { 10 | 'pong': { 11 | 'env_name': "PongNoFrameskip-v4", 12 | 'stop_reward': 18.0, 13 | 'run_name': 'pong', 14 | 'replay_size': 10000, 15 | 'replay_initial': 10000, 16 | 'target_net_sync': 1000, 17 | 'epsilon_frames': 10**5, 18 | 'epsilon_start': 1.0, 19 | 'epsilon_final': 0.02, 20 | 'learning_rate': 0.0001, 21 | 'gamma': 0.99, 22 | 'batch_size': 32 23 | }, 24 | 'breakout': { 25 | 'env_name': "BreakoutNoFrameskip-v4", 26 | 'stop_reward': 500.0, 27 | 'run_name': 'breakout', 28 | 'replay_size': 10 ** 6, 29 | 'replay_initial': 50000, 30 | 'target_net_sync': 10000, 31 | 'epsilon_frames': 10 ** 6, 32 | 'epsilon_start': 1.0, 33 | 'epsilon_final': 0.1, 34 | 'learning_rate': 0.00025, 35 | 'gamma': 0.99, 36 | 'batch_size': 32 37 | }, 38 | 'invaders': { 39 | 'env_name': "SpaceInvadersNoFrameskip-v4", 40 | 'stop_reward': 500.0, 41 | 'run_name': 'breakout', 42 | 'replay_size': 10 ** 6, 43 | 'replay_initial': 50000, 44 | 'target_net_sync': 10000, 45 | 'epsilon_frames': 10 ** 6, 46 | 'epsilon_start': 1.0, 47 | 'epsilon_final': 0.1, 48 | 'learning_rate': 0.00025, 49 | 'gamma': 0.99, 50 | 'batch_size': 32 51 | }, 52 | } 53 | 54 | 55 | def unpack_batch(batch): 56 | states, actions, rewards, dones, last_states = [], [], [], [], [] 57 | for exp in batch: 58 | state = np.array(exp.state, copy=False) 59 | states.append(state) 60 | actions.append(exp.action) 61 | rewards.append(exp.reward) 62 | dones.append(exp.last_state is None) 63 | if exp.last_state is None: 64 | last_states.append(state) # the result will be masked anyway 65 | else: 66 | last_states.append(np.array(exp.last_state, copy=False)) 67 | return np.array(states, copy=False), np.array(actions), np.array(rewards, dtype=np.float32), \ 68 | np.array(dones, dtype=np.uint8), np.array(last_states, copy=False) 69 | 70 | 71 | def calc_loss_dqn(batch, net, tgt_net, gamma, cuda=False): 72 | states, actions, rewards, dones, next_states = unpack_batch(batch) 73 | 74 | states_v = Variable(torch.from_numpy(states)) 75 | next_states_v = Variable(torch.from_numpy(next_states), volatile=True) 76 | actions_v = Variable(torch.from_numpy(actions)) 77 | rewards_v = Variable(torch.from_numpy(rewards)) 78 | done_mask = torch.ByteTensor(dones) 79 | if cuda: 80 | states_v = states_v.cuda() 81 | next_states_v = next_states_v.cuda() 82 | actions_v = actions_v.cuda() 83 | rewards_v = rewards_v.cuda() 84 | done_mask = done_mask.cuda() 85 | 86 | state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1) 87 | next_state_values = tgt_net(next_states_v).max(1)[0] 88 | next_state_values[done_mask] = 0.0 89 | next_state_values.volatile = False 90 | 91 | expected_state_action_values = next_state_values * gamma + rewards_v 92 | return nn.MSELoss()(state_action_values, expected_state_action_values) 93 | 94 | 95 | class RewardTracker: 96 | def __init__(self, writer, stop_reward): 97 | self.writer = writer 98 | self.stop_reward = stop_reward 99 | 100 | def __enter__(self): 101 | self.ts = time.time() 102 | self.ts_frame = 0 103 | self.total_rewards = [] 104 | return self 105 | 106 | def __exit__(self, *args): 107 | self.writer.close() 108 | 109 | def reward(self, reward, frame, epsilon=None): 110 | self.total_rewards.append(reward) 111 | speed = (frame - self.ts_frame) / (time.time() - self.ts) 112 | self.ts_frame = frame 113 | self.ts = time.time() 114 | mean_reward = np.mean(self.total_rewards[-100:]) 115 | epsilon_str = "" if epsilon is None else ", eps %.2f" % epsilon 116 | print("%d: done %d games, mean reward %.3f, speed %.2f f/s%s" % ( 117 | frame, len(self.total_rewards), mean_reward, speed, epsilon_str 118 | )) 119 | sys.stdout.flush() 120 | if epsilon is not None: 121 | self.writer.add_scalar("epsilon", epsilon, frame) 122 | self.writer.add_scalar("speed", speed, frame) 123 | self.writer.add_scalar("reward_100", mean_reward, frame) 124 | self.writer.add_scalar("reward", reward, frame) 125 | if mean_reward > self.stop_reward: 126 | print("Solved in %d frames!" % frame) 127 | return True 128 | return False 129 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/cubes/cube2x2.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import collections 3 | 4 | from . import _env 5 | from . import _common 6 | 7 | State = collections.namedtuple("State", field_names=['corner_pos', 'corner_ort']) 8 | RenderedState = collections.namedtuple("RenderedState", field_names=['top', 'front', 'left', 9 | 'right', 'back', 'bottom']) 10 | 11 | initial_state = State(corner_pos=tuple(range(8)), corner_ort=tuple([0]*8)) 12 | 13 | 14 | def is_initial(state): 15 | assert isinstance(state, State) 16 | return state.corner_pos == initial_state.corner_pos and \ 17 | state.corner_ort == initial_state.corner_ort 18 | 19 | 20 | # available actions. Capital actions denote clockwise rotation 21 | class Action(enum.Enum): 22 | R = 0 23 | L = 1 24 | T = 2 25 | D = 3 26 | F = 4 27 | B = 5 28 | r = 6 29 | l = 7 30 | t = 8 31 | d = 9 32 | f = 10 33 | b = 11 34 | 35 | 36 | _inverse_action = { 37 | Action.R: Action.r, 38 | Action.r: Action.R, 39 | Action.L: Action.l, 40 | Action.l: Action.L, 41 | Action.T: Action.t, 42 | Action.t: Action.T, 43 | Action.D: Action.d, 44 | Action.d: Action.D, 45 | Action.F: Action.f, 46 | Action.f: Action.F, 47 | Action.B: Action.b, 48 | Action.b: Action.B 49 | } 50 | 51 | 52 | def inverse_action(action): 53 | assert isinstance(action, Action) 54 | return _inverse_action[action] 55 | 56 | 57 | _transform_map = { 58 | Action.R: [ 59 | ((1, 2), (2, 6), (6, 5), (5, 1)), # corner map 60 | ((1, 2), (2, 1), (5, 1), (6, 2)), # corner rotate 61 | ], 62 | Action.L: [ 63 | ((3, 0), (7, 3), (0, 4), (4, 7)), 64 | ((0, 1), (3, 2), (4, 2), (7, 1)), 65 | ], 66 | Action.T: [ 67 | ((0, 3), (1, 0), (2, 1), (3, 2)), 68 | (), 69 | ], 70 | Action.D: [ 71 | ((4, 5), (5, 6), (6, 7), (7, 4)), 72 | (), 73 | ], 74 | Action.F: [ 75 | ((0, 1), (1, 5), (5, 4), (4, 0)), 76 | ((0, 2), (1, 1), (4, 1), (5, 2)), 77 | ], 78 | Action.B: [ 79 | ((2, 3), (3, 7), (7, 6), (6, 2)), 80 | ((2, 2), (3, 1), (6, 1), (7, 2)), 81 | ] 82 | } 83 | 84 | 85 | def transform(state, action): 86 | assert isinstance(state, State) 87 | assert isinstance(action, Action) 88 | global _transform_map 89 | 90 | is_inv = action not in _transform_map 91 | if is_inv: 92 | action = inverse_action(action) 93 | c_map, c_rot = _transform_map[action] 94 | corner_pos = _common._permute(state.corner_pos, c_map, is_inv) 95 | corner_ort = _common._permute(state.corner_ort, c_map, is_inv) 96 | corner_ort = _common._rotate(corner_ort, c_rot) 97 | return State(corner_pos=tuple(corner_pos), corner_ort=tuple(corner_ort)) 98 | 99 | 100 | # create initial sides in the right order 101 | def _init_sides(): 102 | return [ 103 | [None for _ in range(4)] 104 | for _ in range(6) # top, left, back, front, right, bottom 105 | ] 106 | 107 | 108 | # corner cubelets colors (clockwise from main label). Order of cubelets are first top, 109 | # in counter-clockwise, started from front left 110 | corner_colors = ( 111 | ('W', 'R', 'G'), ('W', 'B', 'R'), ('W', 'O', 'B'), ('W', 'G', 'O'), 112 | ('Y', 'G', 'R'), ('Y', 'R', 'B'), ('Y', 'B', 'O'), ('Y', 'O', 'G') 113 | ) 114 | 115 | 116 | # map every 3-side cubelet to their projection on sides 117 | # sides are indexed in the order of _init_sides() function result 118 | corner_maps = ( 119 | # top layer 120 | ((0, 2), (3, 0), (1, 1)), 121 | ((0, 3), (4, 0), (3, 1)), 122 | ((0, 1), (2, 0), (4, 1)), 123 | ((0, 0), (1, 0), (2, 1)), 124 | # bottom layer 125 | ((5, 0), (1, 3), (3, 2)), 126 | ((5, 1), (3, 3), (4, 2)), 127 | ((5, 3), (4, 3), (2, 2)), 128 | ((5, 2), (2, 3), (1, 2)) 129 | ) 130 | 131 | 132 | # render state into human readable form 133 | def render(state): 134 | assert isinstance(state, State) 135 | global corner_colors, corner_maps 136 | 137 | sides = _init_sides() 138 | 139 | for corner, orient, maps in zip(state.corner_pos, state.corner_ort, corner_maps): 140 | cols = corner_colors[corner] 141 | cols = _common._map_orient(cols, orient) 142 | for (arr_idx, index), col in zip(maps, cols): 143 | sides[arr_idx][index] = col 144 | 145 | return RenderedState(top=sides[0], left=sides[1], back=sides[2], front=sides[3], 146 | right=sides[4], bottom=sides[5]) 147 | 148 | 149 | encoded_shape = (8, 24) 150 | 151 | 152 | def encode_inplace(target, state): 153 | """ 154 | Encode cude into existig zeroed numpy array 155 | Follows encoding described in paper https://arxiv.org/abs/1805.07470 156 | :param target: numpy array 157 | :param state: state to be encoded 158 | """ 159 | assert isinstance(state, State) 160 | 161 | # handle corner cubelets: find their permuted position 162 | for corner_idx in range(8): 163 | perm_pos = state.corner_pos.index(corner_idx) 164 | corn_ort = state.corner_ort[perm_pos] 165 | target[corner_idx, perm_pos * 3 + corn_ort] = 1 166 | 167 | 168 | # register env 169 | _env.register(_env.CubeEnv(name="cube2x2", state_type=State, initial_state=initial_state, 170 | is_goal_pred=is_initial, action_enum=Action, 171 | transform_func=transform, inverse_action_func=inverse_action, 172 | render_func=render, encoded_shape=encoded_shape, encode_func=encode_inplace)) 173 | -------------------------------------------------------------------------------- /articles/01_rubic/docs/Notes.md: -------------------------------------------------------------------------------- 1 | # 2018-11-08 2 | 3 | MCTS performance drops with increase of c: 4 | ```` 5 | (art_01_cube) shmuma@gpu:~/work/rl/articles/01_rubic$ ./solver.py -e cube2x2 -m saves/cube2x2-zero-goal-d200-t1/best_1.4547e-02.dat --max-steps 1000 --cuda -r 20 6 | 2018-11-08 06:33:56,195 INFO Using environment CubeEnv('cube2x2') 7 | 2018-11-08 06:33:58,169 INFO Network loaded from saves/cube2x2-zero-goal-d200-t1/best_1.4547e-02.dat 8 | 2018-11-08 06:33:58,169 INFO Got task [10, 1, 0, 11, 4, 3, 3, 2, 11, 1, 10, 11, 8, 1, 9, 6, 1, 3, 3, 8], solving... 9 | 2018-11-08 06:34:01,330 INFO Maximum amount of steps has reached, cube wasn't solved. Did 1001 searches, speed 316.77 searches/s 10 | ```` 11 | 12 | * c=10k: 316 searches/s 13 | * c=100k: 58 searches/s 14 | * c=1m: 4.94 searches/s 15 | 16 | Root tree state is the same for 10k and 100k. 17 | 18 | Mean search depth: 1k: 57, 10k: 129.7, 100k: 861 19 | 20 | Conclusion: 21 | Larger C makes tree taller by exploring less options around, but delving deeper into the search space. 22 | This leads to longer search paths which take more and more time to back up. 23 | It is likely that my C value is too large and I just need to speed up MCTS. 24 | 25 | TODO: 26 | * measure depths of resulting tree 27 | * analyze the length of solution (both naive and BFS) 28 | * check effect of C on those parameters 29 | 30 | Depths with 1000 steps limit: 31 | * c=1m: {'min': 1, 'max': 16, 'mean': 7.849963731321631, 'leaves': 34465} 32 | * c=100k: {'min': 1, 'max': 17, 'mean': 9.103493774652236, 'leaves': 71241} 33 | * c=10k: {'min': 1, 'max': 18, 'mean': 10.28626504647809, 'leaves': 70033} 34 | * c=1k: {'min': 1, 'max': 18, 'mean': 9.942448384493218, 'leaves': 76818} 35 | * c=100: {'min': 1, 'max': 14, 'mean': 8.938883245826121, 'leaves': 69899} 36 | * c=10: {'min': 1, 'max': 13, 'mean': 8.59500956472128, 'leaves': 59594} 37 | 38 | Depths with 10000 steps limit: 39 | * c=10k: {'min': 1, 'max': 27, 'mean': 15.374430775030191, 'leaves': 1289253} 40 | * c=1k: {'min': 1, 'max': 26, 'mean': 14.057022074409328, 'leaves': 1004874} 41 | * c=100: {'min': 1, 'max': 19, 'mean': 12.376234716455224, 'leaves': 1113616} 42 | * c=10: {'min': 1, 'max': 19, 'mean': 11.707333613164712, 'leaves': 886248} 43 | 44 | 45 | # Weird case in search 46 | Looks like virtual loss needs tuning as well 47 | 48 | ```` 49 | (art_01_cube) shmuma@gpu:~/work/rl/articles/01_rubic$ ./solver.py -e cube2x2 -m saves/cube2x2-zero-goal-d200-t1/best_1.4547e-02.dat --max-steps 10000 --cuda -r 10 --seed 41 50 | 2018-11-08 14:43:34,360 INFO Using environment CubeEnv('cube2x2') 51 | 2018-11-08 14:43:36,328 INFO Network loaded from saves/cube2x2-zero-goal-d200-t1/best_1.4547e-02.dat 52 | 2018-11-08 14:43:36,329 INFO Got task [6, 5, 3, 2, 6, 9, 11, 4, 8, 4], solving... 53 | 2018-11-08 14:43:59,362 INFO On step 8544 we found goal state, unroll. Speed 370.94 searches/s 54 | 2018-11-08 14:43:59,627 INFO Tree depths: {'max': 22, 'mean': 11.557521172600728, 'leaves': 604673} 55 | 2018-11-08 14:43:59,627 INFO Solutions: naive [10, 0, 6, 3, 9, 6, 0, 2, 8, 4, 0, 10, 8, 2, 4, 6, 1, 0, 5, 6, 0, 9, 3, 11, 6, 7, 3, 9, 2, 8, 6, 0, 9, 3, 8, 2, 4, 10, 11, 5, 7, 1, 5, 11, 10, 1, 7, 11, 5, 9, 3, 8, 2, 5, 11, 10, 4, 7, 1, 0, 2, 6, 0, 8, 8, 2, 0, 0, 3, 9, 6, 6, 11, 5, 4, 6, 10, 0, 6, 8, 2, 4, 4, 1, 10, 10, 4, 4, 4, 4, 10, 10, 1, 7, 0, 6, 7, 10, 3, 9, 0, 10, 1, 6, 4, 7, 10, 4, 6, 0, 2, 8, 8, 2, 7, 1, 4, 10, 1, 10, 3, 7, 7, 1, 9, 3, 1, 4, 7, 7, 10, 3, 0, 6, 9, 4, 1, 7, 3, 9, 8, 2, 0, 6, 4, 10, 7, 1, 11, 5, 9, 3, 2, 8, 5, 11, 6, 0, 10, 6, 3] (161) 56 | ```` 57 | 58 | # 2018-11-16 59 | ## Experiment with lower c, but more steps 60 | 61 | With decrease of C, solve ratio drops (with fixed amount of steps). But lower C is generally faster. 62 | Maybe more steps will increase the solve ratio and will fit the same time frame? 63 | 64 | Experiments: 65 | * t3.1-c2x2-mcts-c=1000.csv 66 | * t3.1-c2x2-mcts-c=100.csv 67 | * t3.1-c2x2-mcts-c=100-steps=60k.csv 68 | * t3.1-c2x2-mcts-c=100-steps=100k.csv 69 | * t3.1-c2x2-mcts-c=10.csv 70 | * t4-c2x2-mcts-c=10-steps=100k.csv 71 | * t4-c2x2-mcts-c=10-steps=200k.csv 72 | * t4-c2x2-mcts-c=10-steps=500k.csv 73 | 74 | Charts are in 04_mcts_C-extra-data.ipynb 75 | 76 | Conclusion: c=100 is optimal, more steps solve all cubes (checked up to depth 50) 77 | 78 | ## Experiment with batched search 79 | 80 | Main question: what increase in speed we've got and did it decreased quality of search? 81 | 82 | Experiments: 83 | * t4-c2x2-c=100-steps=100k.csv: batch=1 84 | * t4-c2x2-c=100-steps=100k-b10.csv: batch=10 85 | * t4-c2x2-c=100-steps=100k-b100.csv: batch=100 86 | 87 | Charts are in 05_batch_search.ipynb 88 | 89 | With larger batch, solve ration drops. Speed increases, but not proportionally - b=100 has speed increase 2-3 times in 90 | terms of raw steps. 91 | 92 | Maybe, we need to tune virtual loss as well. Do an experiment on it. 93 | 94 | ## Experiment with models with different loss 95 | 96 | Main question: does lower loss mean better solve ratio? 97 | 98 | Setup: 99 | cube 2x2, c=100, steps=100k, batch=1, models: 100 | * t2-zero-goal-best_1.4547e-02.dat 101 | * best_3.0742e-02.dat 102 | * best_6.0737e-02.dat 103 | * best_1.0366e-01.dat 104 | 105 | Results: 106 | * t4-c2x2-mcts-c=100-steps=100k.csv 107 | * t5-c2x2-3.0742e-02.csv 108 | * t5-c2x2-6.0737e-02.csv 109 | * t5-c2x2-1.0366e-01.csv 110 | 111 | ~~Started, waiting for results~~ 112 | **2018-11-19**: run done, notebook is in nbs/06_compare_models 113 | Conclusion: with lower loss, solve ratio increases significantly. 114 | 115 | ## Experiment with different virtual loss 116 | 117 | Setup: 118 | cube 2x2, c=100, steps=100k, batch=10 119 | * nu=100 (default) 120 | * nu=10 121 | * nu=1 122 | * nu=1000 123 | 124 | Results: 125 | * t4-c2x2-mcts-c=100-steps=100k-b10.csv 126 | * t6-c2x2-nu=10.csv 127 | * t6-c2x2-nu=1.csv 128 | * t6-c2x2-nu=1000.csv 129 | 130 | ## Final check -- compare best paper solution with best zero-goal 131 | 132 | Results: 133 | * t4-c2x2-mcts-c=100-steps=100k.csv 134 | * t7-best-paper-1.8184e-1.csv 135 | -------------------------------------------------------------------------------- /algos/dqn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # n-step Q-learning 3 | import argparse 4 | import logging 5 | 6 | import numpy as np 7 | 8 | from rl_lib.wrappers import HistoryWrapper 9 | 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | import gym, gym.wrappers 14 | 15 | from keras.models import Model 16 | from keras.layers import Input, Dense, Flatten 17 | from keras.optimizers import Adagrad, RMSprop 18 | 19 | HISTORY_STEPS = 2 20 | SIMPLE_L1_SIZE = 50 21 | SIMPLE_L2_SIZE = 50 22 | 23 | 24 | def make_env(env_name, monitor_dir): 25 | env = HistoryWrapper(HISTORY_STEPS)(gym.make(env_name)) 26 | if monitor_dir: 27 | env = gym.wrappers.Monitor(env, monitor_dir) 28 | return env 29 | 30 | 31 | def make_model(state_shape, n_actions): 32 | in_t = Input(shape=(HISTORY_STEPS,) + state_shape, name='input') 33 | fl_t = Flatten(name='flat')(in_t) 34 | l1_t = Dense(SIMPLE_L1_SIZE, activation='relu', name='l1')(fl_t) 35 | l2_t = Dense(SIMPLE_L2_SIZE, activation='relu', name='l2')(l1_t) 36 | value_t = Dense(n_actions, name='value')(l2_t) 37 | 38 | return Model(input=in_t, output=value_t) 39 | 40 | 41 | def create_batch(iter_no, env, run_model, num_episodes, n_steps, steps_limit=1000, gamma=1.0, tau=0.20): 42 | """ 43 | Play given amount of episodes and prepare data to train on 44 | :param env: Environment instance 45 | :param run_model: Model to take actions 46 | :param num_episodes: count of episodes to run 47 | :param n_steps: boolean, do we use n-steps DQN or 1-step 48 | :return: batch in format required by model 49 | """ 50 | samples = [] 51 | rewards = [] 52 | 53 | for _ in range(num_episodes): 54 | state = env.reset() 55 | step = 0 56 | sum_reward = 0.0 57 | episode = [] 58 | while True: 59 | # chose action to take 60 | q_value = run_model.predict_on_batch(np.array([state]))[0] 61 | if np.random.random() < tau: 62 | action = np.random.randint(0, len(q_value)) 63 | else: 64 | action = np.argmax(q_value) 65 | next_state, reward, done, _ = env.step(action) 66 | episode.append((state, q_value, action, reward)) 67 | sum_reward = reward + gamma * sum_reward 68 | 69 | state = next_state 70 | step += 1 71 | 72 | # if episode is done, last_q is None 73 | if done: 74 | last_q = None 75 | rewards.append(sum_reward) 76 | break 77 | # otherwise, we'll need last_q as estimation of total reward 78 | elif steps_limit is not None and steps_limit == step: 79 | last_q = run_model.predict_on_batch(np.array([state]))[0] 80 | rewards.append(sum_reward) 81 | break 82 | 83 | # R_sum is used only in n-steps DQN and holds discounted reward for all episode 84 | if last_q is None: 85 | R_sum = 0.0 86 | else: 87 | R_sum = max(last_q) 88 | # now we need to unroll our episode backward to generate training samples 89 | for state, q_value, action, reward in reversed(episode): 90 | # get approximated target reward for this state 91 | R_sum = R_sum*gamma + reward 92 | if n_steps: 93 | R = R_sum 94 | else: 95 | R = reward 96 | if last_q is not None: 97 | R += gamma * max(last_q) 98 | target_q = np.copy(q_value) 99 | target_q[action] = R 100 | samples.append((state, target_q)) 101 | last_q = q_value 102 | 103 | logger.info("%d: Have %d samples, mean final reward: %.3f, max: %.3f", 104 | iter_no, len(samples), np.mean(rewards), np.max(rewards)) 105 | # convert data to train format 106 | np.random.shuffle(samples) 107 | return list(map(np.array, zip(*samples))) 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("-e", "--env", default="CartPole-v0", help="Environment name to use") 113 | parser.add_argument("-m", "--monitor", help="Enable monitor and save data into provided dir, default=disabled") 114 | parser.add_argument("-t", "--tau", type=float, default=0.2, help="Ratio of random steps, default=0.2") 115 | parser.add_argument("-i", "--iters", type=int, default=100, help="Count if iterations to take, default=100") 116 | parser.add_argument("--n-steps", action='store_true', default=False, 117 | help="Enable n-step DQN, default=1-step") 118 | args = parser.parse_args() 119 | 120 | env = make_env(args.env, args.monitor) 121 | state_shape = env.observation_space.shape 122 | n_actions = env.action_space.n 123 | 124 | logger.info("Created environment %s, state: %s, actions: %s", args.env, state_shape, n_actions) 125 | 126 | model = make_model(state_shape, n_actions) 127 | model.summary() 128 | 129 | model.compile(optimizer=Adagrad(), loss='mse') 130 | 131 | # test run, to check correctness 132 | if args.monitor is None: 133 | st = env.reset() 134 | r = model.predict_on_batch([ 135 | np.array([st]) 136 | ]) 137 | print(r) 138 | 139 | epoch_limit = 10 140 | step_limit = 300 141 | if args.monitor is not None: 142 | step_limit = None 143 | 144 | for iter in range(args.iters): 145 | batch, target_y = create_batch(iter, env, model, n_steps=args.n_steps, tau=args.tau, 146 | num_episodes=20, steps_limit=step_limit) 147 | # iterate until our losses decreased 10 times or epoches limit exceeded 148 | start_loss = None 149 | loss = None 150 | converged = False 151 | for epoch in range(epoch_limit): 152 | p_h = model.fit(batch, target_y, verbose=0, batch_size=128) 153 | loss = np.min(p_h.history['loss']) 154 | 155 | if start_loss is None: 156 | start_loss = np.max(p_h.history['loss']) 157 | else: 158 | if start_loss / loss > 1.5: 159 | break 160 | pass 161 | -------------------------------------------------------------------------------- /algos/algo_lib/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import configparser 3 | import gym 4 | import gym.spaces 5 | import gym.wrappers 6 | import numpy as np 7 | import logging as log 8 | 9 | import keras.backend as K 10 | import tensorflow as tf 11 | import collections 12 | 13 | 14 | def HistoryWrapper(steps): 15 | class _HistoryWrapper(gym.Wrapper): 16 | """ 17 | Track history of observations for given amount of steps 18 | Initial steps are zero-filled 19 | """ 20 | def __init__(self, env): 21 | super(_HistoryWrapper, self).__init__(env) 22 | self.steps = steps 23 | self.history = self._make_history() 24 | self.observation_space = self._make_observation_space(steps, env.observation_space) 25 | 26 | @staticmethod 27 | def _make_observation_space(steps, orig_obs): 28 | assert isinstance(orig_obs, gym.spaces.Box) 29 | low = np.repeat(np.expand_dims(orig_obs.low, 0), steps, axis=0) 30 | high = np.repeat(np.expand_dims(orig_obs.high, 0), steps, axis=0) 31 | return gym.spaces.Box(low, high) 32 | 33 | def _make_history(self, last_item = None): 34 | size = self.steps if last_item is None else self.steps-1 35 | res = collections.deque([np.zeros(shape=self.env.observation_space.shape)] * size) 36 | if last_item is not None: 37 | res.append(last_item) 38 | return res 39 | 40 | def _step(self, action): 41 | obs, reward, done, info = self.env.step(action) 42 | self.history.popleft() 43 | self.history.append(obs) 44 | return self.history, reward, done, info 45 | 46 | def _reset(self): 47 | self.history = self._make_history(last_item=self.env.reset()) 48 | return self.history 49 | 50 | return _HistoryWrapper 51 | 52 | 53 | def make_env(env_name, monitor_dir=None, wrappers=()): 54 | """ 55 | Make gym environment with optional monitor 56 | :param env_name: name of the environment to create 57 | :param monitor_dir: optional directory to save monitor results 58 | :param wrappers: list of optional Wrapper object instances 59 | :return: environment object 60 | """ 61 | env = gym.make(env_name) 62 | for wrapper in wrappers: 63 | env = wrapper(env) 64 | if monitor_dir: 65 | env = gym.wrappers.Monitor(env, monitor_dir) 66 | return env 67 | 68 | 69 | def summarize_gradients(model): 70 | """ 71 | Add summaries of gradients 72 | :param model: compiled keras model 73 | """ 74 | gradients = model.optimizer.get_gradients(model.total_loss, model._collected_trainable_weights) 75 | for var, grad in zip(model._collected_trainable_weights, gradients): 76 | n = var.name.split(':', maxsplit=1)[0] 77 | tf.summary.scalar("gradrms_" + n, K.sqrt(K.mean(K.square(grad)))) 78 | 79 | 80 | def summary_value(name, value, writer, step_no): 81 | """ 82 | Add given actual value to summary writer 83 | :param name: name of value to add 84 | :param value: scalar value 85 | :param writer: SummaryWriter instance 86 | :param step_no: global step index 87 | """ 88 | summ = tf.Summary() 89 | summ_value = summ.value.add() 90 | summ_value.simple_value = value 91 | summ_value.tag = name 92 | writer.add_summary(summ, global_step=step_no) 93 | 94 | 95 | class ParamsTweaker: 96 | logger = log.getLogger("ParamsTweaker") 97 | 98 | def __init__(self, file_name="tweak_params.txt"): 99 | self.file_name = file_name 100 | self.params = {} 101 | 102 | def add(self, name, var): 103 | self.params[name] = var 104 | 105 | def check(self): 106 | if not os.path.exists(self.file_name): 107 | return 108 | 109 | self.logger.info("Tweak file detected: %s", self.file_name) 110 | with open(self.file_name, "rt", encoding='utf-8') as fd: 111 | for idx, l in enumerate(fd): 112 | name, val = list(map(str.strip, l.split('=', maxsplit=2))) 113 | var = self.params.get(name) 114 | if not var: 115 | self.logger.info("Unknown param '%s' found in file at line %d, ignored", name, idx+1) 116 | continue 117 | self.logger.info("Param %s <-- %s", name, val) 118 | K.set_value(var, float(val)) 119 | os.remove(self.file_name) 120 | pass 121 | 122 | 123 | class Configuration: 124 | def __init__(self, file_name): 125 | self.file_name = file_name 126 | self.config = configparser.ConfigParser() 127 | if not self.config.read(file_name): 128 | raise FileNotFoundError(file_name) 129 | 130 | @property 131 | def env_name(self): 132 | return self.config.get('game', 'env') 133 | 134 | @property 135 | def history(self): 136 | return self.config.getint('game', 'history', fallback=1) 137 | 138 | @property 139 | def image_shape(self): 140 | x = self.config.getint('game', 'image_x') 141 | y = self.config.getint('game', 'image_y') 142 | if x is not None and y is not None: 143 | return (x, y) 144 | return None 145 | 146 | @property 147 | def max_steps(self): 148 | return self.config.getint('game', 'max_steps') 149 | 150 | @property 151 | def a3c_beta(self): 152 | return self.config.getfloat('a3c', 'entropy_beta') 153 | 154 | @property 155 | def a3c_steps(self): 156 | return self.config.getint('a3c', 'reward_steps') 157 | 158 | @property 159 | def a3c_gamma(self): 160 | return self.config.getfloat('a3c', 'gamma') 161 | 162 | @property 163 | def batch_size(self): 164 | return self.config.getint('training', 'batch_size') 165 | 166 | @property 167 | def learning_rate(self): 168 | return self.config.getfloat('training', 'learning_rate') 169 | 170 | @property 171 | def gradient_clip_norm(self): 172 | return self.config.getfloat('training', 'grad_clip_norm') 173 | 174 | @property 175 | def swarms_count(self): 176 | return self.config.getint('swarm', 'swarms') 177 | 178 | @property 179 | def swarm_size(self): 180 | return self.config.getint('swarm', 'swarm_size') 181 | 182 | 183 | class EnvFactory: 184 | def __init__(self, config): 185 | self.config = config 186 | 187 | def __call__(self): 188 | env = gym.make(self.config.env_name) 189 | history = self.config.history 190 | if history > 1: 191 | env = HistoryWrapper(history)(env) 192 | return env 193 | -------------------------------------------------------------------------------- /articles/01_rubic/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import time 4 | import argparse 5 | import logging 6 | import numpy as np 7 | import collections 8 | 9 | import torch 10 | import torch.optim as optim 11 | import torch.optim.lr_scheduler as scheduler 12 | import torch.nn.functional as F 13 | 14 | from tensorboardX import SummaryWriter 15 | 16 | from libcube import cubes 17 | from libcube import model 18 | from libcube import conf 19 | 20 | log = logging.getLogger("train") 21 | 22 | 23 | if __name__ == "__main__": 24 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("-i", "--ini", required=True, help="Ini file to use for this run") 27 | parser.add_argument("-n", "--name", required=True, help="Name of the run") 28 | args = parser.parse_args() 29 | config = conf.Config(args.ini) 30 | device = torch.device("cuda" if config.train_cuda else "cpu") 31 | 32 | name = config.train_name(suffix=args.name) 33 | writer = SummaryWriter(comment="-" + name) 34 | save_path = os.path.join("saves", name) 35 | os.makedirs(save_path) 36 | 37 | cube_env = cubes.get(config.cube_type) 38 | assert isinstance(cube_env, cubes.CubeEnv) 39 | log.info("Selected cube: %s", cube_env) 40 | value_targets_method = model.ValueTargetsMethod(config.train_value_targets_method) 41 | 42 | net = model.Net(cube_env.encoded_shape, len(cube_env.action_enum)).to(device) 43 | print(net) 44 | opt = optim.Adam(net.parameters(), lr=config.train_learning_rate) 45 | sched = scheduler.StepLR(opt, 1, gamma=config.train_lr_decay_gamma) if config.train_lr_decay_enabled else None 46 | 47 | step_idx = 0 48 | buf_policy_loss, buf_value_loss, buf_loss = [], [], [] 49 | buf_policy_loss_raw, buf_value_loss_raw, buf_loss_raw = [], [], [] 50 | buf_mean_values = [] 51 | ts = time.time() 52 | best_loss = None 53 | 54 | log.info("Generate scramble buffer...") 55 | scramble_buf = collections.deque(maxlen=config.scramble_buffer_batches*config.train_batch_size) 56 | scramble_buf.extend(model.make_scramble_buffer(cube_env, config.train_batch_size*2, config.train_scramble_depth)) 57 | log.info("Generated buffer of size %d", len(scramble_buf)) 58 | 59 | while True: 60 | if config.train_lr_decay_enabled and step_idx % config.train_lr_decay_batches == 0: 61 | sched.step() 62 | log.info("LR decrease to %s", sched.get_lr()[0]) 63 | writer.add_scalar("lr", sched.get_lr()[0], step_idx) 64 | 65 | step_idx += 1 66 | x_t, weights_t, y_policy_t, y_value_t = model.sample_batch( 67 | scramble_buf, net, device, config.train_batch_size, value_targets_method) 68 | 69 | opt.zero_grad() 70 | policy_out_t, value_out_t = net(x_t) 71 | value_out_t = value_out_t.squeeze(-1) 72 | value_loss_t = (value_out_t - y_value_t)**2 73 | value_loss_raw_t = value_loss_t.mean() 74 | if config.weight_samples: 75 | value_loss_t *= weights_t 76 | value_loss_t = value_loss_t.mean() 77 | policy_loss_t = F.cross_entropy(policy_out_t, y_policy_t, reduction='none') 78 | policy_loss_raw_t = policy_loss_t.mean() 79 | if config.weight_samples: 80 | policy_loss_t *= weights_t 81 | policy_loss_t = policy_loss_t.mean() 82 | loss_raw_t = policy_loss_raw_t + value_loss_raw_t 83 | loss_t = value_loss_t + policy_loss_t 84 | loss_t.backward() 85 | opt.step() 86 | 87 | # save data 88 | buf_mean_values.append(value_out_t.mean().item()) 89 | buf_policy_loss.append(policy_loss_t.item()) 90 | buf_value_loss.append(value_loss_t.item()) 91 | buf_loss.append(loss_t.item()) 92 | buf_loss_raw.append(loss_raw_t.item()) 93 | buf_value_loss_raw.append(value_loss_raw_t.item()) 94 | buf_policy_loss_raw.append(policy_loss_raw_t.item()) 95 | 96 | if config.train_report_batches is not None and step_idx % config.train_report_batches == 0: 97 | m_policy_loss = np.mean(buf_policy_loss) 98 | m_value_loss = np.mean(buf_value_loss) 99 | m_loss = np.mean(buf_loss) 100 | buf_value_loss.clear() 101 | buf_policy_loss.clear() 102 | buf_loss.clear() 103 | 104 | m_policy_loss_raw = np.mean(buf_policy_loss_raw) 105 | m_value_loss_raw = np.mean(buf_value_loss_raw) 106 | m_loss_raw = np.mean(buf_loss_raw) 107 | buf_value_loss_raw.clear() 108 | buf_policy_loss_raw.clear() 109 | buf_loss_raw.clear() 110 | 111 | m_values = np.mean(buf_mean_values) 112 | buf_mean_values.clear() 113 | 114 | dt = time.time() - ts 115 | ts = time.time() 116 | speed = config.train_batch_size * config.train_report_batches / dt 117 | log.info("%d: p_loss=%.3e, v_loss=%.3e, loss=%.3e, speed=%.1f cubes/s", 118 | step_idx, m_policy_loss, m_value_loss, m_loss, speed) 119 | sum_train_data = 0.0 120 | sum_opt = 0.0 121 | writer.add_scalar("loss_policy", m_policy_loss, step_idx) 122 | writer.add_scalar("loss_value", m_value_loss, step_idx) 123 | writer.add_scalar("loss", m_loss, step_idx) 124 | writer.add_scalar("loss_policy_raw", m_policy_loss_raw, step_idx) 125 | writer.add_scalar("loss_value_raw", m_value_loss_raw, step_idx) 126 | writer.add_scalar("loss_raw", m_loss_raw, step_idx) 127 | writer.add_scalar("values", m_values, step_idx) 128 | writer.add_scalar("speed", speed, step_idx) 129 | 130 | if best_loss is None: 131 | best_loss = m_loss 132 | elif best_loss > m_loss: 133 | name = os.path.join(save_path, "best_%.4e.dat" % m_loss) 134 | torch.save(net.state_dict(), name) 135 | best_loss = m_loss 136 | 137 | if step_idx % config.push_scramble_buffer_iters == 0: 138 | scramble_buf.extend(model.make_scramble_buffer(cube_env, config.train_batch_size, 139 | config.train_scramble_depth)) 140 | log.info("Pushed new data in scramble buffer, new size = %d", len(scramble_buf)) 141 | 142 | if config.train_checkpoint_batches is not None and step_idx % config.train_checkpoint_batches == 0: 143 | name = os.path.join(save_path, "chpt_%06d.dat" % step_idx) 144 | torch.save(net.state_dict(), name) 145 | 146 | if config.train_max_batches is not None and config.train_max_batches <= step_idx: 147 | log.info("Limit of train batches reached, exiting") 148 | break 149 | 150 | writer.close() 151 | -------------------------------------------------------------------------------- /algos/algo_lib/player.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import multiprocessing as mp 5 | import tensorflow as tf 6 | import queue 7 | from keras.models import model_from_json 8 | 9 | 10 | def softmax(x): 11 | e_x = np.exp(x - np.max(x)) 12 | return e_x / e_x.sum() 13 | 14 | 15 | class Player: 16 | """ 17 | Simple syncronous pool of players 18 | """ 19 | def __init__(self, env, reward_steps, gamma, max_steps, player_index, reward_hook=None): 20 | self.env = env 21 | self.reward_steps = reward_steps 22 | self.gamma = gamma 23 | self.reward_hook = reward_hook 24 | 25 | self.state = env.reset() 26 | 27 | self.memory = [] 28 | self.episode_reward = 0.0 29 | self.step_index = 0 30 | self.max_steps = max_steps 31 | self.player_index = player_index 32 | 33 | self.done_rewards = [] 34 | 35 | @classmethod 36 | def step_players(cls, model, players): 37 | """ 38 | Do one step for list of players 39 | :param model: model to use for predictions 40 | :param players: player instances 41 | :return: list of samples 42 | """ 43 | input = np.array([ 44 | p.state for p in players 45 | ]) 46 | probs, values = model.predict_on_batch(input) 47 | result = [] 48 | 49 | for idx, player in enumerate(players): 50 | prb = softmax(probs[idx]) 51 | action = np.random.choice(len(prb), p=prb) 52 | result.extend(player.step(action, values[idx][0])) 53 | return result 54 | 55 | def step(self, action, value): 56 | result = [] 57 | new_state, reward, done, _ = self.env.step(action) 58 | self.episode_reward += reward 59 | if self.reward_hook is not None: 60 | reward = self.reward_hook(reward=reward, done=done, step=self.step_index) 61 | self.memory.append((self.state, action, reward, value)) 62 | self.state = new_state 63 | self.step_index += 1 64 | 65 | if done or self.step_index > self.max_steps: 66 | self.state = self.env.reset() 67 | logging.info("%3d: Episode done @ step %5d, sum reward %d", 68 | self.player_index, self.step_index, int(self.episode_reward)) 69 | self.done_rewards.append(self.episode_reward) 70 | self.episode_reward = 0.0 71 | self.step_index = 0 72 | result.extend(self._memory_to_samples(is_done=done)) 73 | elif len(self.memory) == self.reward_steps + 1: 74 | result.extend(self._memory_to_samples(is_done=False)) 75 | return result 76 | 77 | def _memory_to_samples(self, is_done): 78 | """ 79 | From existing memory, generate samples 80 | :param is_done: is episode done 81 | :return: list of training samples 82 | """ 83 | result = [] 84 | sum_r, last_item = 0.0, None 85 | 86 | if not is_done: 87 | last_item = self.memory.pop() 88 | sum_r = last_item[-1] 89 | 90 | for state, action, reward, value in reversed(self.memory): 91 | sum_r = reward + sum_r * self.gamma 92 | result.append((state, action, sum_r)) 93 | 94 | self.memory = [] if is_done else [last_item] 95 | return result 96 | 97 | @classmethod 98 | def gather_done_rewards(cls, *players): 99 | """ 100 | Collect rewards from list of players 101 | :param players: list of players 102 | :return: list of steps, list of rewards of done episodes 103 | """ 104 | res = [] 105 | for p in players: 106 | res.extend(p.done_rewards) 107 | p.done_rewards = [] 108 | return res 109 | 110 | 111 | def generate_batches(model, players, batch_size): 112 | samples = [] 113 | 114 | while True: 115 | samples.extend(Player.step_players(model, players)) 116 | while len(samples) >= batch_size: 117 | states, actions, rewards = list(map(np.array, zip(*samples[:batch_size]))) 118 | yield [states, actions, rewards] 119 | samples = samples[batch_size:] 120 | 121 | 122 | class AsyncPlayersSwarm: 123 | def __init__(self, config, env_factory, model): 124 | self.config = config 125 | self.batch_size = config.batch_size 126 | self.samples_queue = mp.Queue(maxsize=self.batch_size * 10) 127 | self.done_rewards_queue = mp.Queue() 128 | self.control_queues = [] 129 | self.processes = [] 130 | for _ in range(config.swarms_count): 131 | ctrl_queue = mp.Queue() 132 | self.control_queues.append(ctrl_queue) 133 | args = (config, env_factory, model.to_json(), ctrl_queue, self.samples_queue, self.done_rewards_queue) 134 | proc = mp.Process(target=AsyncPlayersSwarm.player, args=args) 135 | self.processes.append(proc) 136 | proc.start() 137 | 138 | def push_model_weights(self, weights): 139 | for q in self.control_queues: 140 | q.put(weights) 141 | 142 | def get_batch(self): 143 | batch = [] 144 | while len(batch) < self.batch_size: 145 | batch.append(self.samples_queue.get()) 146 | states, actions, rewards = list(map(np.array, zip(*batch))) 147 | return [states, actions, rewards] 148 | 149 | def get_done_rewards(self): 150 | res = [] 151 | try: 152 | while True: 153 | res.append(self.done_rewards_queue.get_nowait()) 154 | except queue.Empty: 155 | pass 156 | return res 157 | 158 | @classmethod 159 | def player(cls, config, env_factory, model_json, ctrl_queue, out_queue, done_rewards_queue): 160 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 161 | with tf.device("/cpu:0"): 162 | model = model_from_json(model_json) 163 | players = [Player(env_factory(), config.a3c_steps, config.a3c_gamma, config.max_steps, idx) 164 | for idx in range(config.swarm_size)] 165 | # input_t, conv_out_t = atari.net_input(players[0].env) 166 | # n_actions = players[0].env.action_space.n 167 | # model = make_run_model(input_t, conv_out_t, n_actions) 168 | while True: 169 | # check ctrl queue for new model 170 | if not ctrl_queue.empty(): 171 | weights = ctrl_queue.get() 172 | # stop requested 173 | if weights is None: 174 | break 175 | model.set_weights(weights) 176 | 177 | for sample in Player.step_players(model, players): 178 | out_queue.put(sample) 179 | for rw in Player.gather_done_rewards(*players): 180 | done_rewards_queue.put(rw) 181 | -------------------------------------------------------------------------------- /algos/pg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Stochastic policy gradient: http://karpathy.github.io/2016/05/31/rl/ 3 | import argparse 4 | import logging 5 | 6 | import numpy as np 7 | 8 | from rl_lib.wrappers import HistoryWrapper 9 | 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | import gym, gym.wrappers 14 | 15 | from keras.models import Model 16 | from keras.layers import Input, Dense, Flatten, Lambda 17 | from keras.optimizers import Adagrad, RMSprop 18 | from keras import backend as K 19 | 20 | HISTORY_STEPS = 1 21 | SIMPLE_L1_SIZE = 50 22 | SIMPLE_L2_SIZE = 50 23 | 24 | 25 | def make_env(env_name, monitor_dir): 26 | env = HistoryWrapper(HISTORY_STEPS)(gym.make(env_name)) 27 | if monitor_dir: 28 | env = gym.wrappers.Monitor(env, monitor_dir) 29 | return env 30 | 31 | 32 | def make_model(state_shape, n_actions): 33 | in_t = Input(shape=(HISTORY_STEPS,) + state_shape, name='input') 34 | action_t = Input(shape=(1,), dtype='int32', name='action') 35 | advantage_t = Input(shape=(1,), name='advantage') 36 | 37 | fl_t = Flatten(name='flat')(in_t) 38 | l1_t = Dense(SIMPLE_L1_SIZE, activation='relu', name='l1')(fl_t) 39 | l2_t = Dense(SIMPLE_L2_SIZE, activation='relu', name='l2')(l1_t) 40 | policy_t = Dense(n_actions, name='policy', activation='softmax')(l2_t) 41 | 42 | def loss_func(args): 43 | p_t, act_t, adv_t = args 44 | oh_t = K.one_hot(act_t, n_actions) 45 | oh_t = K.squeeze(oh_t, 1) 46 | p_oh_t = K.log(1e-6 + K.sum(oh_t * p_t, axis=-1, keepdims=True)) 47 | res_t = adv_t * p_oh_t 48 | return -res_t 49 | 50 | loss_t = Lambda(loss_func, output_shape=(1,), name='loss')([policy_t, action_t, advantage_t]) 51 | 52 | return Model(input=[in_t, action_t, advantage_t], output=[policy_t, loss_t]) 53 | 54 | 55 | def create_batch(iter_no, env, run_model, num_episodes, steps_limit=1000, gamma=1.0, tau=0.20, min_samples=None): 56 | """ 57 | Play given amount of episodes and prepare data to train on 58 | :param env: Environment instance 59 | :param run_model: Model to take actions 60 | :param num_episodes: count of episodes to run 61 | :return: batch in format required by model 62 | """ 63 | samples = [] 64 | rewards = [] 65 | 66 | episodes_counter = 0 67 | while True: 68 | state = env.reset() 69 | step = 0 70 | sum_reward = 0.0 71 | episode = [] 72 | loc_rewards = [] 73 | while True: 74 | # chose action to take 75 | probs = run_model.predict_on_batch([ 76 | np.array([state]), 77 | np.array([0]), 78 | np.array([0.0]) 79 | ])[0][0] 80 | if np.random.random() < tau: 81 | action = np.random.randint(0, len(probs)) 82 | probs = np.ones_like(probs) 83 | probs /= np.sum(probs) 84 | else: 85 | action = np.random.choice(len(probs), p=probs) 86 | next_state, reward, done, _ = env.step(action) 87 | episode.append((state, probs, action)) 88 | loc_rewards.append(reward) 89 | sum_reward = reward + gamma * sum_reward 90 | state = next_state 91 | step += 1 92 | 93 | if done or (steps_limit is not None and steps_limit == step): 94 | rewards.append(sum_reward) 95 | break 96 | 97 | # create reversed reward 98 | sum_reward = 0.0 99 | rev_rewards = [] 100 | for r in reversed(loc_rewards): 101 | sum_reward = sum_reward * gamma + r 102 | rev_rewards.append(sum_reward) 103 | rev_rewards = np.copy(rev_rewards) 104 | rev_rewards -= np.mean(rev_rewards) 105 | rev_rewards /= np.std(rev_rewards) 106 | 107 | # generate samples from episode 108 | for reward, (state, probs, action) in zip(rev_rewards, reversed(episode)): 109 | samples.append((state, action, reward)) 110 | episodes_counter += 1 111 | 112 | if min_samples is None: 113 | if episodes_counter == num_episodes: 114 | break 115 | elif len(samples) >= min_samples and episodes_counter >= num_episodes: 116 | break 117 | 118 | logger.info("%d: Have %d samples from %d episodes, mean final reward: %.3f, max: %.3f", 119 | iter_no, len(samples), episodes_counter, np.mean(rewards), np.max(rewards)) 120 | # convert data to train format 121 | np.random.shuffle(samples) 122 | return list(map(np.array, zip(*samples))) 123 | 124 | 125 | def create_fake_target(n_actions, batch_len): 126 | return [ 127 | np.array([[0.0] * n_actions] * batch_len), 128 | np.array([0.0] * batch_len) 129 | ] 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument("-e", "--env", default="CartPole-v0", help="Environment name to use") 135 | parser.add_argument("-m", "--monitor", help="Enable monitor and save data into provided dir, default=disabled") 136 | parser.add_argument("-t", "--tau", type=float, default=0.2, help="Ratio of random steps, default=0.2") 137 | parser.add_argument("-i", "--iters", type=int, default=100, help="Count if iterations to take, default=100") 138 | args = parser.parse_args() 139 | 140 | env = make_env(args.env, args.monitor) 141 | state_shape = env.observation_space.shape 142 | n_actions = env.action_space.n 143 | 144 | logger.info("Created environment %s, state: %s, actions: %s", args.env, state_shape, n_actions) 145 | 146 | model = make_model(state_shape, n_actions) 147 | model.summary() 148 | 149 | loss_dict = { 150 | # our model already outputs loss, so just take it as-is 151 | 'loss': lambda y_true, y_pred: y_pred, 152 | # this will make zero gradients contribution 153 | 'policy': lambda y_true, y_pred: y_true, 154 | } 155 | 156 | model.compile(optimizer=Adagrad(), loss=loss_dict) 157 | 158 | # gradient check 159 | if False: 160 | batch, action, advantage = create_batch(0, env, model, tau=0, num_episodes=1, steps_limit=10, min_samples=None) 161 | r = model.predict_on_batch([batch, action, advantage]) 162 | fake_out = create_fake_target(n_actions, len(batch)) 163 | l = model.train_on_batch([batch, action, advantage], fake_out) 164 | r2 = model.predict_on_batch([batch, action, advantage]) 165 | logger.info("Test fit, mean loss: %s -> %s", np.mean(r[1]), np.mean(r2[1])) 166 | 167 | step_limit = 300 168 | if args.monitor is not None: 169 | step_limit = None 170 | 171 | for iter in range(args.iters): 172 | batch, action, advantage = create_batch(iter, env, model, tau=args.tau, num_episodes=10, 173 | steps_limit=step_limit, min_samples=500) 174 | fake_out = create_fake_target(n_actions, len(batch)) 175 | l = model.train_on_batch([batch, action, advantage], fake_out) 176 | #logger.info("Loss: %s", l[0]) 177 | pass 178 | -------------------------------------------------------------------------------- /articles/01_rubic/csvs/c3x3/c3-zg-d20-noweight-no-decay=5.501e-1.csv: -------------------------------------------------------------------------------- 1 | start_dt,stop_dt,duration,depth,scramble,is_solved,solve_steps,sol_len_naive,sol_len_bfs,tree_depth_max,tree_depth_mean 2 | 2019-01-08T08:28:54.421080,2019-01-08T08:28:54.424637,0.003557,1,10,1,1,1,1,1,1.0 3 | 2019-01-08T08:28:54.424757,2019-01-08T08:28:54.425990,0.001233,1,1,1,1,1,1,1,1.0 4 | 2019-01-08T08:28:54.426094,2019-01-08T08:28:54.427316,0.001222,1,0,1,1,1,1,1,1.0 5 | 2019-01-08T08:28:54.427419,2019-01-08T08:28:54.428638,0.001219,1,11,1,1,1,1,1,1.0 6 | 2019-01-08T08:28:54.428743,2019-01-08T08:28:54.429962,0.001219,1,4,1,1,1,1,1,1.0 7 | 2019-01-08T08:28:54.430514,2019-01-08T08:28:54.434303,0.003789,2,3 3,1,3,2,2,2,1.6875 8 | 2019-01-08T08:28:54.434580,2019-01-08T08:28:54.438282,0.003702,2,11 1,1,3,2,2,2,1.6875 9 | 2019-01-08T08:28:54.438475,2019-01-08T08:28:54.442157,0.003682,2,11 8,1,3,2,2,2,1.6875 10 | 2019-01-08T08:28:54.442424,2019-01-08T08:28:54.446456,0.004032,2,9 6,1,2,2,2,2,1.5 11 | 2019-01-08T08:28:54.446644,2019-01-08T08:28:54.449056,0.002412,2,0 1,1,3,2,2,2,1.6875 12 | 2019-01-08T08:28:54.449729,2019-01-08T08:28:54.450597,0.000868,3,3 8 9,1,1,1,1,1,1.0 13 | 2019-01-08T08:28:54.450702,2019-01-08T08:28:54.454821,0.004119,3,0 8 3,1,5,3,3,3,2.230769230769231 14 | 2019-01-08T08:28:54.455185,2019-01-08T08:28:54.459302,0.004117,3,11 8 6,1,5,3,3,3,2.230769230769231 15 | 2019-01-08T08:28:54.459658,2019-01-08T08:28:54.463764,0.004106,3,9 4 0,1,5,3,3,3,2.230769230769231 16 | 2019-01-08T08:28:54.464120,2019-01-08T08:28:54.468213,0.004093,3,6 5 4,1,5,3,3,3,2.230769230769231 17 | 2019-01-08T08:28:54.468978,2019-01-08T08:28:54.474907,0.005929,4,5 1 1 6,1,7,4,4,4,2.75 18 | 2019-01-08T08:28:54.475501,2019-01-08T08:28:54.481349,0.005848,4,9 4 0 11,1,7,4,4,4,2.75 19 | 2019-01-08T08:28:54.481942,2019-01-08T08:28:54.486974,0.005032,4,6 1 8 4,1,6,4,4,4,2.8548387096774195 20 | 2019-01-08T08:28:54.487408,2019-01-08T08:28:54.494913,0.007505,4,9 11 1 0,1,9,4,4,4,2.975 21 | 2019-01-08T08:28:54.495761,2019-01-08T08:28:54.501612,0.005851,4,1 6 4 7,1,7,4,4,4,2.75 22 | 2019-01-08T08:28:54.502598,2019-01-08T08:28:54.531406,0.028808,5,5 5 3 10 11,1,29,15,5,7,5.107925801011804 23 | 2019-01-08T08:28:54.533474,2019-01-08T08:28:54.541945,0.008471,5,10 11 8 3 10,1,10,5,5,5,3.375 24 | 2019-01-08T08:28:54.542847,2019-01-08T08:28:54.549634,0.006787,5,6 4 1 3 11,1,8,5,5,5,3.402439024390244 25 | 2019-01-08T08:28:54.550305,2019-01-08T08:28:54.570914,0.020609,5,6 10 7 2 4,1,21,5,5,5,3.547169811320755 26 | 2019-01-08T08:28:54.572579,2019-01-08T08:28:54.580226,0.007647,5,5 3 2 7 0,1,9,5,5,5,3.66025641025641 27 | 2019-01-08T08:28:54.581756,2019-01-08T08:28:54.669988,0.088232,6,10 6 9 1 6 6,1,74,50,6,9,6.165 28 | 2019-01-08T08:28:54.674038,2019-01-08T08:28:54.683616,0.009578,6,8 8 0 9 5 7,1,11,6,6,6,3.6785714285714284 29 | 2019-01-08T08:28:54.684582,2019-01-08T08:28:59.667883,4.983301,6,0 3 1 1 11 7,1,1644,34,10,17,9.305150631681244 30 | 2019-01-08T08:28:59.838534,2019-01-08T08:29:02.231358,2.392824,6,3 2 1 2 0 7,1,730,300,6,15,8.885909543922928 31 | 2019-01-08T08:29:02.252519,2019-01-08T08:29:02.259834,0.007315,6,2 4 7 4 11 1,1,10,6,6,6,3.9785714285714286 32 | 2019-01-08T08:29:02.261416,2019-01-08T08:29:05.705523,3.444107,7,11 10 6 8 5 1 6,1,1306,135,7,15,9.60963029631681 33 | 2019-01-08T08:29:05.760336,2019-01-08T08:29:05.768457,0.008121,7,3 0 5 1 11 10 5,1,9,5,5,5,3.260869565217391 34 | 2019-01-08T08:29:05.769273,2019-01-08T08:29:05.777006,0.007733,7,7 11 2 11 7 7 9,1,14,7,7,7,4.59047619047619 35 | 2019-01-08T08:29:05.778645,2019-01-08T08:29:05.846827,0.068182,7,8 11 9 8 6 1 4,1,104,7,7,12,7.352678571428571 36 | 2019-01-08T08:29:05.853553,2019-01-08T08:29:05.912739,0.059186,7,4 11 0 5 1 5 3,1,79,37,7,11,7.134691195795007 37 | 2019-01-08T08:29:05.919962,2019-01-08T08:29:06.152056,0.232094,8,4 5 0 10 5 4 1 5,1,221,6,6,14,10.142761729838037 38 | 2019-01-08T08:29:06.158885,2019-01-08T08:29:07.155494,0.996609,8,7 10 5 10 5 1 8 8,1,761,10,8,19,13.074936838716367 39 | 2019-01-08T08:29:07.279780,2019-01-08T08:29:07.337167,0.057387,8,1 5 7 10 11 7 3 3,1,68,40,8,10,6.552910052910053 40 | 2019-01-08T08:29:07.346629,2019-01-08T08:29:35.231284,27.884655,8,10 7 7 10 3 4 5 4,1,5083,424,12,20,12.115751153805146 41 | 2019-01-08T08:29:36.067374,2019-01-08T08:34:17.645057,281.577683,8,1 1 10 3 0 5 5 10,1,30257,8,8,26,16.28422838822429 42 | 2019-01-08T08:34:19.440161,2019-01-08T08:35:05.026055,45.585894,9,11 6 8 7 10 8 5 9 9,1,21361,9,9,27,15.596400754642922 43 | 2019-01-08T08:35:06.617234,2019-01-08T08:45:06.700297,600.083063,9,7 0 3 0 8 10 8 0 8,0,90115,-1,-1,32,18.815614329315796 44 | 2019-01-08T08:45:18.905195,2019-01-08T08:45:39.503956,20.598761,9,9 10 2 5 5 8 3 0 0,1,7187,497,9,24,14.719080658315688 45 | 2019-01-08T08:45:40.257448,2019-01-08T08:55:40.292897,600.035449,9,9 2 2 3 5 6 11 2 2,0,55613,-1,-1,26,14.354985218258578 46 | 2019-01-08T08:55:42.061797,2019-01-08T09:01:37.950623,355.888826,9,8 9 1 5 9 6 5 4 5,1,52112,9,9,32,19.87793362161064 47 | 2019-01-08T09:01:46.129826,2019-01-08T09:11:46.337487,600.207661,10,7 5 6 2 6 9 1 4 2 11,0,160416,-1,-1,41,23.824312527745455 48 | 2019-01-08T09:13:06.757816,2019-01-08T09:23:07.514951,600.757135,10,0 0 1 8 4 2 10 10 1 1,0,49431,-1,-1,28,16.43874448816716 49 | 2019-01-08T09:23:13.688648,2019-01-08T09:33:13.903465,600.214817,10,4 1 8 7 4 8 7 7 5 10,0,163435,-1,-1,35,19.601319282013876 50 | 2019-01-08T09:33:26.297088,2019-01-08T09:43:27.089217,600.792129,10,4 5 1 3 5 8 11 1 6 8,0,100775,-1,-1,33,19.297588776527807 51 | 2019-01-08T09:43:30.206149,2019-01-08T09:53:30.634290,600.428141,10,3 11 3 8 4 6 3 8 5 1,0,123365,-1,-1,35,20.788868672140758 52 | 2019-01-08T09:53:49.853546,2019-01-08T09:55:38.630723,108.777177,11,10 5 9 5 6 9 2 11 10 10 10,1,47851,9,9,30,19.22284663476597 53 | 2019-01-08T09:55:46.105435,2019-01-08T10:01:01.189792,315.084357,11,1 10 9 7 5 1 11 4 1 0 2,1,30873,393,11,33,19.65320113615534 54 | 2019-01-08T10:01:03.575722,2019-01-08T10:11:03.700703,600.124981,11,3 5 5 4 5 7 7 3 6 4 3,0,75300,-1,-1,31,18.782011027092818 55 | 2019-01-08T10:11:07.938569,2019-01-08T10:21:08.255516,600.316947,11,5 5 2 4 7 0 11 1 1 6 10,0,112687,-1,-1,31,16.92932849415105 56 | 2019-01-08T10:21:14.037032,2019-01-08T10:31:14.542200,600.505168,11,5 1 11 2 10 5 7 4 8 10 1,0,69301,-1,-1,32,19.501281511182288 57 | 2019-01-08T10:31:17.111130,2019-01-08T10:41:17.411955,600.300825,12,1 2 6 10 11 2 3 2 5 4 5 7,0,107018,-1,-1,34,19.67145124999614 58 | 2019-01-08T10:41:24.710406,2019-01-08T10:41:27.270346,2.55994,12,2 9 1 1 3 2 2 9 7 4 3 3,1,1838,38,10,22,12.393955335111913 59 | 2019-01-08T10:41:27.435633,2019-01-08T10:51:27.441101,600.005468,12,2 10 10 5 3 6 2 2 0 7 3 6,0,128835,-1,-1,33,20.792059014325897 60 | 2019-01-08T10:51:49.966913,2019-01-08T11:01:50.597983,600.63107,12,4 1 9 10 2 9 10 9 9 4 5 6,0,90493,-1,-1,34,20.557958848297993 61 | 2019-01-08T11:01:59.923270,2019-01-08T11:12:00.315070,600.3918,12,11 6 8 7 9 11 8 1 0 9 11 4,0,168991,-1,-1,36,21.452424147530355 62 | 2019-01-08T11:12:14.636385,2019-01-08T11:22:15.499135,600.86275,13,8 5 8 0 1 5 7 8 9 4 9 6 6,0,149831,-1,-1,31,17.157095620648587 63 | 2019-01-08T11:22:21.570155,2019-01-08T11:32:22.273320,600.703165,13,0 0 0 11 10 0 10 10 5 10 8 0 11,0,98208,-1,-1,34,20.93405447819751 64 | 2019-01-08T11:32:31.892007,2019-01-08T11:42:32.329119,600.437112,13,5 6 3 8 6 5 6 5 5 4 9 9 6,0,163028,-1,-1,35,19.647722747698264 65 | 2019-01-08T11:42:51.468257,2019-01-08T11:52:52.282630,600.814373,13,2 10 2 4 7 6 4 4 3 2 2 1 2,0,168858,-1,-1,33,19.62627900942845 66 | 2019-01-08T11:53:03.882629,2019-01-08T12:03:05.171077,601.288448,13,1 6 11 3 4 4 2 5 1 9 10 1 8,0,146697,-1,-1,38,21.091279248597374 67 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/model.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import time 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from . import cubes 10 | 11 | 12 | class Net(nn.Module): 13 | def __init__(self, input_shape, actions_count): 14 | super(Net, self).__init__() 15 | 16 | self.input_size = int(np.prod(input_shape)) 17 | self.body = nn.Sequential( 18 | nn.Linear(self.input_size, 4096), 19 | nn.ELU(), 20 | nn.Linear(4096, 2048), 21 | nn.ELU() 22 | ) 23 | self.policy = nn.Sequential( 24 | nn.Linear(2048, 512), 25 | nn.ELU(), 26 | nn.Linear(512, actions_count) 27 | ) 28 | self.value = nn.Sequential( 29 | nn.Linear(2048, 512), 30 | nn.ELU(), 31 | nn.Linear(512, 1) 32 | ) 33 | 34 | def forward(self, batch, value_only=False): 35 | x = batch.view((-1, self.input_size)) 36 | body_out = self.body(x) 37 | value_out = self.value(body_out) 38 | if value_only: 39 | return value_out 40 | policy_out = self.policy(body_out) 41 | return policy_out, value_out 42 | 43 | 44 | def encode_states(cube_env, states): 45 | assert isinstance(cube_env, cubes.CubeEnv) 46 | assert isinstance(states, (list, tuple)) 47 | 48 | # states could be list of lists or just list of states 49 | if isinstance(states[0], list): 50 | encoded = np.zeros((len(states), len(states[0])) + cube_env.encoded_shape, dtype=np.float32) 51 | 52 | for i, st_list in enumerate(states): 53 | for j, state in enumerate(st_list): 54 | cube_env.encode_inplace(encoded[i, j], state) 55 | else: 56 | encoded = np.zeros((len(states), ) + cube_env.encoded_shape, dtype=np.float32) 57 | for i, state in enumerate(states): 58 | cube_env.encode_inplace(encoded[i], state) 59 | 60 | return encoded 61 | 62 | 63 | class ValueTargetsMethod(enum.Enum): 64 | # method from the paper 65 | Paper = 'paper' 66 | # paper, but value of goal state equals zero 67 | ZeroGoalValue = 'zero_goal_value' 68 | 69 | 70 | def make_scramble_buffer(cube_env, buf_size, scramble_depth): 71 | """ 72 | Create data buffer with scramble states and explored substates 73 | :param cube_env: env to use 74 | :param buf_size: how many states to generate 75 | :param scramble_depth: how deep to scramble 76 | :return: list of tuples 77 | """ 78 | result = [] 79 | data = [] 80 | rounds = buf_size // scramble_depth 81 | for _ in range(rounds): 82 | data.extend(cube_env.scramble_cube(scramble_depth, include_initial=True)) 83 | 84 | # explore each state 85 | for depth, s in data: 86 | states, goals = cube_env.explore_state(s) 87 | enc_s = encode_states(cube_env, [s]) 88 | enc_states = encode_states(cube_env, states) 89 | result.append((enc_s, depth, cube_env.is_goal(s), enc_states, goals)) 90 | return result 91 | 92 | 93 | def sample_batch(scramble_buffer, net, device, batch_size, value_targets): 94 | """ 95 | Sample batch of given size from scramble buffer produced by make_scramble_buffer 96 | :param scramble_buffer: scramble buffer 97 | :param net: network to use to calculate targets 98 | :param device: device to move values 99 | :param batch_size: size of batch to generate 100 | :param value_targets: targets 101 | :return: tensors 102 | """ 103 | data = random.sample(scramble_buffer, batch_size) 104 | states, depths, is_goals, explored_states, explored_goals = zip(*data) 105 | 106 | # handle explored states 107 | explored_states = np.stack(explored_states) 108 | shape = explored_states.shape 109 | explored_states_t = torch.tensor(explored_states).to(device) 110 | explored_states_t = explored_states_t.view(shape[0]*shape[1], *shape[2:]) # shape: (states*actions, encoded_shape) 111 | value_t = net(explored_states_t, value_only=True) 112 | value_t = value_t.squeeze(-1).view(shape[0], shape[1]) # shape: (states, actions) 113 | if value_targets == ValueTargetsMethod.Paper: 114 | # add reward to the values 115 | goals_mask_t = torch.tensor(explored_goals, dtype=torch.int8).to(device) 116 | goals_mask_t += goals_mask_t - 1 # has 1 at final states and -1 elsewhere 117 | value_t += goals_mask_t.type(dtype=torch.float32) 118 | # find target value and target policy 119 | max_val_t, max_act_t = value_t.max(dim=1) 120 | elif value_targets == ValueTargetsMethod.ZeroGoalValue: 121 | value_t -= 1.0 122 | max_val_t, max_act_t = value_t.max(dim=1) 123 | goal_indices = np.nonzero(is_goals) 124 | max_val_t[goal_indices] = 0.0 125 | max_act_t[goal_indices] = 0 126 | else: 127 | assert False, "Unsupported method of value targets" 128 | 129 | # train input 130 | enc_input = np.stack(states) 131 | enc_input_t = torch.tensor(enc_input).to(device) 132 | depths_t = torch.tensor(depths, dtype=torch.float32).to(device) 133 | weights_t = 1/depths_t 134 | return enc_input_t.detach(), weights_t.detach(), max_act_t.detach(), max_val_t.detach() 135 | 136 | 137 | def make_train_data(cube_env, net, device, batch_size, scramble_depth, shuffle=False, 138 | value_targets=ValueTargetsMethod.Paper): 139 | assert isinstance(cube_env, cubes.CubeEnv) 140 | assert isinstance(value_targets, ValueTargetsMethod) 141 | 142 | # scramble cube states and their depths 143 | data = [] 144 | rounds = batch_size // scramble_depth 145 | for _ in range(rounds): 146 | data.extend(cube_env.scramble_cube(scramble_depth, include_initial=True)) 147 | if shuffle: 148 | random.shuffle(data) 149 | cube_depths, cube_states = zip(*data) 150 | 151 | # explore each state by doing 1-step BFS search and keep a mask of goal states (for reward calculation) 152 | explored_states, explored_goals = [], [] 153 | goal_indices = [] 154 | for idx, s in enumerate(cube_states): 155 | states, goals = cube_env.explore_state(s) 156 | explored_states.append(states) 157 | explored_goals.append(goals) 158 | if cube_env.is_goal(s): 159 | goal_indices.append(idx) 160 | 161 | # obtain network's values for all explored states 162 | enc_explored = encode_states(cube_env, explored_states) # shape: (states, actions, encoded_shape) 163 | 164 | shape = enc_explored.shape 165 | enc_explored_t = torch.tensor(enc_explored).to(device) 166 | enc_explored_t = enc_explored_t.view(shape[0]*shape[1], *shape[2:]) # shape: (states*actions, encoded_shape) 167 | value_t = net(enc_explored_t, value_only=True) 168 | value_t = value_t.squeeze(-1).view(shape[0], shape[1]) # shape: (states, actions) 169 | if value_targets == ValueTargetsMethod.Paper: 170 | # add reward to the values 171 | goals_mask_t = torch.tensor(explored_goals, dtype=torch.int8).to(device) 172 | goals_mask_t += goals_mask_t - 1 # has 1 at final states and -1 elsewhere 173 | value_t += goals_mask_t.type(dtype=torch.float32) 174 | # find target value and target policy 175 | max_val_t, max_act_t = value_t.max(dim=1) 176 | elif value_targets == ValueTargetsMethod.ZeroGoalValue: 177 | value_t -= 1.0 178 | max_val_t, max_act_t = value_t.max(dim=1) 179 | max_val_t[goal_indices] = 0.0 180 | max_act_t[goal_indices] = 0 181 | else: 182 | assert False, "Unsupported method of value targets" 183 | 184 | # create train input 185 | enc_input = encode_states(cube_env, cube_states) 186 | enc_input_t = torch.tensor(enc_input).to(device) 187 | cube_depths_t = torch.tensor(cube_depths, dtype=torch.float32).to(device) 188 | weights_t = 1/cube_depths_t 189 | return enc_input_t.detach(), weights_t.detach(), max_act_t.detach(), max_val_t.detach() 190 | -------------------------------------------------------------------------------- /articles/01_rubic/csvs/c3x3/c3-zg-d20-noweight.csv: -------------------------------------------------------------------------------- 1 | start_dt,stop_dt,duration,depth,scramble,is_solved,solve_steps,sol_len_naive,sol_len_bfs,tree_depth_max,tree_depth_mean 2 | 2019-01-06T07:34:49.025535,2019-01-06T07:34:49.028813,0.003278,1,10,1,1,1,1,1,1.0 3 | 2019-01-06T07:34:49.028945,2019-01-06T07:34:49.029800,0.000855,1,1,1,1,1,1,1,1.0 4 | 2019-01-06T07:34:49.029917,2019-01-06T07:34:49.030748,0.000831,1,0,1,1,1,1,1,1.0 5 | 2019-01-06T07:34:49.030854,2019-01-06T07:34:49.031681,0.000827,1,11,1,1,1,1,1,1.0 6 | 2019-01-06T07:34:49.031790,2019-01-06T07:34:49.032615,0.000825,1,4,1,1,1,1,1,1.0 7 | 2019-01-06T07:34:49.033197,2019-01-06T07:34:49.035896,0.002699,2,3 3,1,3,2,2,2,1.6875 8 | 2019-01-06T07:34:49.036178,2019-01-06T07:34:49.038753,0.002575,2,11 1,1,3,2,2,2,1.6875 9 | 2019-01-06T07:34:49.038957,2019-01-06T07:34:49.041521,0.002564,2,11 8,1,3,2,2,2,1.6875 10 | 2019-01-06T07:34:49.041798,2019-01-06T07:34:49.043501,0.001703,2,9 6,1,2,2,2,2,1.5 11 | 2019-01-06T07:34:49.043697,2019-01-06T07:34:49.046234,0.002537,2,0 1,1,3,2,2,2,1.6875 12 | 2019-01-06T07:34:49.046943,2019-01-06T07:34:49.047898,0.000955,3,3 8 9,1,1,1,1,1,1.0 13 | 2019-01-06T07:34:49.048012,2019-01-06T07:34:49.053227,0.005215,3,0 8 3,1,6,3,3,3,2.3142857142857145 14 | 2019-01-06T07:34:49.053690,2019-01-06T07:34:49.058083,0.004393,3,11 8 6,1,5,3,3,3,2.230769230769231 15 | 2019-01-06T07:34:49.058448,2019-01-06T07:34:49.062753,0.004305,3,9 4 0,1,5,3,3,3,2.230769230769231 16 | 2019-01-06T07:34:49.063119,2019-01-06T07:34:49.067419,0.0043,3,6 5 4,1,5,3,3,3,2.230769230769231 17 | 2019-01-06T07:34:49.068419,2019-01-06T07:34:49.076232,0.007813,4,5 1 1 6,1,9,4,4,4,2.975 18 | 2019-01-06T07:34:49.077005,2019-01-06T07:34:49.082164,0.005159,4,4 0 11 7,1,6,4,4,4,2.532258064516129 19 | 2019-01-06T07:34:49.082674,2019-01-06T07:34:49.088708,0.006034,4,1 8 4 9,1,7,4,4,4,2.611111111111111 20 | 2019-01-06T07:34:49.089223,2019-01-06T07:34:49.096098,0.006875,4,11 1 0 10,1,8,4,4,4,2.911111111111111 21 | 2019-01-06T07:34:49.096773,2019-01-06T07:34:49.102783,0.00601,4,3 1 6 4,1,7,4,4,4,2.75 22 | 2019-01-06T07:34:49.103909,2019-01-06T07:34:49.192085,0.088176,5,2 5 5 3 10,1,84,5,5,8,6.467635402906208 23 | 2019-01-06T07:34:49.196558,2019-01-06T07:34:49.205348,0.00879,5,8 8 4 11 9,1,10,5,5,5,3.5076923076923077 24 | 2019-01-06T07:34:49.206325,2019-01-06T07:34:49.210627,0.004302,5,2 7 0 1 2,1,5,3,3,3,2.0384615384615383 25 | 2019-01-06T07:34:49.211055,2019-01-06T07:34:49.221391,0.010336,5,10 6 9 1 6,1,12,5,5,5,3.4823529411764707 26 | 2019-01-06T07:34:49.222546,2019-01-06T07:34:49.238979,0.016433,5,8 0 10 11 1,1,19,5,5,5,3.596774193548387 27 | 2019-01-06T07:34:49.241292,2019-01-06T07:34:50.011383,0.770091,6,2 7 0 11 11 4,1,560,84,10,12,7.763881483738289 28 | 2019-01-06T07:34:50.109492,2019-01-06T07:44:50.121035,600.011543,6,5 2 10 2 11 7,0,105805,-1,-1,28,15.900892352749986 29 | 2019-01-06T07:44:54.545974,2019-01-06T07:54:55.011851,600.465877,6,8 9 4 1 3 1,0,120665,-1,-1,28,14.770508980456585 30 | 2019-01-06T07:55:00.700228,2019-01-06T07:55:02.172393,1.472165,6,6 1 8 7 0 0,1,652,94,6,13,7.958104004554944 31 | 2019-01-06T07:55:02.201829,2019-01-06T07:55:02.234097,0.032268,6,3 11 3 7 6 1,1,45,4,4,7,5.070707070707071 32 | 2019-01-06T07:55:02.235852,2019-01-06T07:55:02.262483,0.026631,7,4 6 5 10 7 0 5,1,44,23,7,8,5.805486284289277 33 | 2019-01-06T07:55:02.267075,2019-01-06T07:55:02.269623,0.002548,7,2 4 5 10 11 10 2,1,5,3,3,3,2.230769230769231 34 | 2019-01-06T07:55:02.269985,2019-01-06T07:57:25.776701,143.506716,7,4 6 9 10 3 4 4,1,24379,307,7,21,11.808345533197985 35 | 2019-01-06T07:57:26.713655,2019-01-06T07:57:26.934170,0.220515,7,8 0 2 4 6 3 3,1,122,49,7,8,5.374570446735395 36 | 2019-01-06T07:57:26.948845,2019-01-06T07:57:26.954971,0.006126,7,3 8 9 11 9 8 11,1,11,5,5,5,3.6214285714285714 37 | 2019-01-06T07:57:26.956341,2019-01-06T08:07:26.967006,600.010665,8,0 4 3 11 3 1 5 9,0,118423,-1,-1,29,15.895919949506569 38 | 2019-01-06T08:07:32.723817,2019-01-06T08:17:33.250822,600.527005,8,10 6 5 9 7 5 0 0,0,248445,-1,-1,31,17.783428141415904 39 | 2019-01-06T08:17:46.266327,2019-01-06T08:27:47.625819,601.359492,8,5 4 2 9 4 0 2 10,0,194114,-1,-1,29,17.612143130264197 40 | 2019-01-06T08:28:27.417914,2019-01-06T08:28:28.432729,1.014815,8,4 3 3 3 6 6 11 1,1,47,6,6,10,6.624020887728459 41 | 2019-01-06T08:28:28.436034,2019-01-06T08:38:28.441233,600.005199,8,0 3 5 9 8 1 3 2,0,124913,-1,-1,28,15.891528379865527 42 | 2019-01-06T08:38:33.975364,2019-01-06T08:48:34.575263,600.599899,9,9 8 11 0 0 3 2 6 9,0,229432,-1,-1,33,16.520290779491642 43 | 2019-01-06T08:48:47.303654,2019-01-06T08:58:48.626792,601.323138,9,8 1 5 7 11 10 1 11 10,0,106030,-1,-1,29,15.850959022393413 44 | 2019-01-06T08:58:57.309543,2019-01-06T09:08:57.798890,600.489347,9,7 7 9 10 1 8 10 3 3,0,213492,-1,-1,29,17.10003110137338 45 | 2019-01-06T09:09:10.824940,2019-01-06T09:09:24.983087,14.158147,9,2 0 10 1 6 10 3 5 7,1,8791,105,9,21,12.526243079176568 46 | 2019-01-06T09:09:25.703448,2019-01-06T09:12:35.098575,189.395127,9,1 1 3 2 5 0 9 2 2,1,104993,9,9,29,14.920753261631077 47 | 2019-01-06T09:12:43.625518,2019-01-06T09:16:06.184186,202.558668,10,1 8 9 9 2 5 2 9 8 1,1,48649,162,6,26,13.85702721107708 48 | 2019-01-06T09:16:07.440988,2019-01-06T09:26:07.640386,600.199398,10,7 9 4 4 1 0 5 7 8 3,0,269636,-1,-1,30,18.970058872867043 49 | 2019-01-06T09:26:40.146385,2019-01-06T09:26:47.186041,7.039656,10,0 5 3 6 8 5 3 8 11 10,1,2667,134,10,19,12.261158572505318 50 | 2019-01-06T09:26:47.810821,2019-01-06T09:36:47.819528,600.008707,10,1 11 2 6 9 1 9 4 2 0,0,391970,-1,-1,29,15.694122107418114 51 | 2019-01-06T09:37:09.860857,2019-01-06T09:47:12.868729,603.007872,10,11 8 5 9 10 10 11 11 8 3,0,264003,-1,-1,32,18.315199699193567 52 | 2019-01-06T09:48:01.797306,2019-01-06T09:58:03.328172,601.530866,11,8 8 11 11 7 3 0 5 2 9 8,0,157198,-1,-1,30,16.9572262888426 53 | 2019-01-06T09:58:11.614824,2019-01-06T10:08:12.455077,600.840253,11,9 9 5 2 6 10 9 4 1 2 11,0,332227,-1,-1,29,18.693299620367835 54 | 2019-01-06T10:09:29.769998,2019-01-06T10:19:32.099144,602.329146,11,9 10 10 11 7 0 7 8 5 4 8,0,230695,-1,-1,30,16.93926787788448 55 | 2019-01-06T10:20:04.890378,2019-01-06T10:30:06.332124,601.441746,11,5 3 4 4 0 8 7 10 3 0 7,0,326845,-1,-1,32,19.573864291174655 56 | 2019-01-06T10:30:49.529535,2019-01-06T10:40:51.890872,602.361337,11,10 6 4 8 10 2 1 10 10 9 4,0,140177,-1,-1,30,15.91579442956684 57 | 2019-01-06T10:40:57.880298,2019-01-06T10:50:58.599116,600.718818,12,0 11 8 1 5 10 2 7 3 3 1 4,0,178686,-1,-1,27,15.015617122753136 58 | 2019-01-06T10:51:04.732280,2019-01-06T11:01:05.748438,601.016158,12,6 9 4 4 3 10 6 7 10 7 5 10,0,146164,-1,-1,27,15.703117961449946 59 | 2019-01-06T11:01:14.173672,2019-01-06T11:11:14.944878,600.771206,12,2 5 3 7 3 3 0 11 2 7 3 4,0,238071,-1,-1,28,16.134807211863887 60 | 2019-01-06T11:11:24.349314,2019-01-06T11:21:25.847806,601.498492,12,0 4 0 8 6 11 9 7 9 4 6 11,0,268457,-1,-1,32,16.45096353943884 61 | 2019-01-06T11:21:37.959357,2019-01-06T11:31:39.766598,601.807241,12,5 5 10 7 5 7 8 0 11 10 8 6,0,272934,-1,-1,32,18.585873667750235 62 | 2019-01-06T11:32:26.178062,2019-01-06T11:42:28.019802,601.84174,13,3 2 9 2 4 4 9 2 1 5 7 0 11,0,269683,-1,-1,34,18.17250209483416 63 | 2019-01-06T11:43:22.178838,2019-01-06T11:53:24.033502,601.854664,13,0 2 7 7 0 7 2 11 1 3 6 7 4,0,117526,-1,-1,30,17.096282475770987 64 | 2019-01-06T11:53:30.702613,2019-01-06T12:03:31.298465,600.595852,13,6 9 10 7 10 11 8 0 0 3 11 0 7,0,331877,-1,-1,31,17.964230303468458 65 | 2019-01-06T12:04:20.319673,2019-01-06T12:14:22.829660,602.509987,13,10 7 11 11 2 11 11 4 9 9 11 1 9,0,171938,-1,-1,32,20.117808296099835 66 | 2019-01-06T12:14:54.924657,2019-01-06T12:24:55.908453,600.983796,13,11 3 4 5 3 8 1 11 9 6 3 11 8,0,251041,-1,-1,34,18.75409834596751 67 | 2019-01-06T12:25:30.469588,2019-01-06T12:35:32.125919,601.656331,14,0 11 11 9 4 3 5 8 0 4 3 8 5 4,0,221351,-1,-1,35,18.46652292615473 68 | 2019-01-06T12:35:55.201700,2019-01-06T12:45:56.594318,601.392618,14,1 10 2 9 11 10 0 5 0 1 1 8 3 6,0,123002,-1,-1,29,18.092552609238965 69 | 2019-01-06T12:46:07.777955,2019-01-06T12:56:08.417951,600.639996,14,3 2 3 1 11 4 7 2 9 11 11 2 11 3,0,286564,-1,-1,31,17.43819731519242 70 | 2019-01-06T12:56:26.087101,2019-01-06T13:06:28.101441,602.01434,14,11 0 5 0 9 5 8 1 8 1 1 3 11 10,0,246544,-1,-1,31,18.174439245724056 71 | 2019-01-06T13:06:48.284900,2019-01-06T13:16:49.897424,601.612524,14,10 7 4 4 2 11 6 7 9 5 7 0 5 6,0,257149,-1,-1,32,16.568676996250606 72 | 2019-01-06T13:17:07.549802,2019-01-06T13:27:09.273427,601.723625,15,11 3 10 7 3 6 7 7 5 3 7 10 3 10 0,0,163748,-1,-1,29,17.718979832617723 73 | 2019-01-06T13:27:19.680887,2019-01-06T13:37:20.600300,600.919413,15,1 3 0 7 8 5 10 7 3 1 6 5 8 5 9,0,193028,-1,-1,34,19.935600953037273 74 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/cubes/cube3x3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classic cube 3x3 3 | """ 4 | import enum 5 | import collections 6 | 7 | from . import _env 8 | from . import _common 9 | 10 | # environment API 11 | State = collections.namedtuple("State", field_names=['corner_pos', 'side_pos', 'corner_ort', 'side_ort']) 12 | 13 | # rendered state -- list of colors of every side 14 | RenderedState = collections.namedtuple("RenderedState", field_names=['top', 'front', 'left', 'right', 'back', 'bottom']) 15 | 16 | # initial (solved state) 17 | initial_state = State(corner_pos=tuple(range(8)), side_pos=tuple(range(12)), corner_ort=tuple([0]*8), side_ort=tuple([0]*12)) 18 | 19 | 20 | def is_initial(state): 21 | """ 22 | Checks that this state is initial state 23 | :param state: State instance 24 | :return: True if state match initial, False otherwise 25 | """ 26 | return state.corner_pos == initial_state.corner_pos and \ 27 | state.side_pos == initial_state.side_pos and \ 28 | state.corner_ort == initial_state.corner_ort and \ 29 | state.side_ort == initial_state.side_ort 30 | 31 | 32 | # available actions. Capital actions denote clockwise rotation 33 | class Action(enum.Enum): 34 | R = 0 35 | L = 1 36 | T = 2 37 | D = 3 38 | F = 4 39 | B = 5 40 | r = 6 41 | l = 7 42 | t = 8 43 | d = 9 44 | f = 10 45 | b = 11 46 | 47 | 48 | _inverse_action = { 49 | Action.R: Action.r, 50 | Action.r: Action.R, 51 | Action.L: Action.l, 52 | Action.l: Action.L, 53 | Action.T: Action.t, 54 | Action.t: Action.T, 55 | Action.D: Action.d, 56 | Action.d: Action.D, 57 | Action.F: Action.f, 58 | Action.f: Action.F, 59 | Action.B: Action.b, 60 | Action.b: Action.B 61 | } 62 | 63 | 64 | def inverse_action(action): 65 | assert isinstance(action, Action) 66 | return _inverse_action[action] 67 | 68 | 69 | def _flip(side_ort, sides): 70 | return [ 71 | o if idx not in sides else 1-o 72 | for idx, o in enumerate(side_ort) 73 | ] 74 | 75 | 76 | _transform_map = { 77 | Action.R: [ 78 | ((1, 2), (2, 6), (6, 5), (5, 1)), # corner map 79 | ((1, 6), (6, 9), (9, 5), (5, 1)), # side map 80 | ((1, 2), (2, 1), (5, 1), (6, 2)), # corner rotate 81 | () # side flip 82 | ], 83 | Action.L: [ 84 | ((3, 0), (7, 3), (0, 4), (4, 7)), 85 | ((7, 3), (3, 4), (11, 7), (4, 11)), 86 | ((0, 1), (3, 2), (4, 2), (7, 1)), 87 | () 88 | ], 89 | Action.T: [ 90 | ((0, 3), (1, 0), (2, 1), (3, 2)), 91 | ((0, 3), (1, 0), (2, 1), (3, 2)), 92 | (), 93 | () 94 | ], 95 | Action.D: [ 96 | ((4, 5), (5, 6), (6, 7), (7, 4)), 97 | ((8, 9), (9, 10), (10, 11), (11, 8)), 98 | (), 99 | () 100 | ], 101 | Action.F: [ 102 | ((0, 1), (1, 5), (5, 4), (4, 0)), 103 | ((0, 5), (4, 0), (5, 8), (8, 4)), 104 | ((0, 2), (1, 1), (4, 1), (5, 2)), 105 | (0, 4, 5, 8) 106 | ], 107 | Action.B: [ 108 | ((2, 3), (3, 7), (7, 6), (6, 2)), 109 | ((2, 7), (6, 2), (7, 10), (10, 6)), 110 | ((2, 2), (3, 1), (6, 1), (7, 2)), 111 | (2, 6, 7, 10) 112 | ] 113 | } 114 | 115 | 116 | def transform(state, action): 117 | assert isinstance(state, State) 118 | assert isinstance(action, Action) 119 | global _transform_map 120 | 121 | is_inv = action not in _transform_map 122 | if is_inv: 123 | action = inverse_action(action) 124 | c_map, s_map, c_rot, s_flp = _transform_map[action] 125 | corner_pos = _common._permute(state.corner_pos, c_map, is_inv) 126 | corner_ort = _common._permute(state.corner_ort, c_map, is_inv) 127 | corner_ort = _common._rotate(corner_ort, c_rot) 128 | side_pos = _common._permute(state.side_pos, s_map, is_inv) 129 | side_ort = state.side_ort 130 | if s_flp: 131 | side_ort = _common._permute(side_ort, s_map, is_inv) 132 | side_ort = _flip(side_ort, s_flp) 133 | return State(corner_pos=tuple(corner_pos), corner_ort=tuple(corner_ort), 134 | side_pos=tuple(side_pos), side_ort=tuple(side_ort)) 135 | 136 | 137 | # make initial state of rendered side 138 | def _init_side(color): 139 | return [color if idx == 4 else None for idx in range(9)] 140 | 141 | 142 | # create initial sides in the right order 143 | def _init_sides(): 144 | return [ 145 | _init_side('W'), # top 146 | _init_side('G'), # left 147 | _init_side('O'), # back 148 | _init_side('R'), # front 149 | _init_side('B'), # right 150 | _init_side('Y') # bottom 151 | ] 152 | 153 | 154 | # corner cubelets colors (clockwise from main label). Order of cubelets are first top, 155 | # in counter-clockwise, started from front left 156 | corner_colors = ( 157 | ('W', 'R', 'G'), ('W', 'B', 'R'), ('W', 'O', 'B'), ('W', 'G', 'O'), 158 | ('Y', 'G', 'R'), ('Y', 'R', 'B'), ('Y', 'B', 'O'), ('Y', 'O', 'G') 159 | ) 160 | 161 | side_colors = ( 162 | ('W', 'R'), ('W', 'B'), ('W', 'O'), ('W', 'G'), 163 | ('R', 'G'), ('R', 'B'), ('O', 'B'), ('O', 'G'), 164 | ('Y', 'R'), ('Y', 'B'), ('Y', 'O'), ('Y', 'G') 165 | ) 166 | 167 | 168 | # map every 3-side cubelet to their projection on sides 169 | # sides are indexed in the order of _init_sides() function result 170 | corner_maps = ( 171 | # top layer 172 | ((0, 6), (3, 0), (1, 2)), 173 | ((0, 8), (4, 0), (3, 2)), 174 | ((0, 2), (2, 0), (4, 2)), 175 | ((0, 0), (1, 0), (2, 2)), 176 | # bottom layer 177 | ((5, 0), (1, 8), (3, 6)), 178 | ((5, 2), (3, 8), (4, 6)), 179 | ((5, 8), (4, 8), (2, 6)), 180 | ((5, 6), (2, 8), (1, 6)) 181 | ) 182 | 183 | # map every 2-side cubelet to their projection on sides 184 | side_maps = ( 185 | # top layer 186 | ((0, 7), (3, 1)), 187 | ((0, 5), (4, 1)), 188 | ((0, 1), (2, 1)), 189 | ((0, 3), (1, 1)), 190 | # middle layer 191 | ((3, 3), (1, 5)), 192 | ((3, 5), (4, 3)), 193 | ((2, 3), (4, 5)), 194 | ((2, 5), (1, 3)), 195 | # bottom layer 196 | ((5, 1), (3, 7)), 197 | ((5, 5), (4, 7)), 198 | ((5, 7), (2, 7)), 199 | ((5, 3), (1, 7)) 200 | ) 201 | 202 | 203 | # render state into human readable form 204 | def render(state): 205 | assert isinstance(state, State) 206 | global corner_colors, corner_maps, side_colors, side_maps 207 | 208 | sides = _init_sides() 209 | 210 | for corner, orient, maps in zip(state.corner_pos, state.corner_ort, corner_maps): 211 | cols = corner_colors[corner] 212 | cols = _common._map_orient(cols, orient) 213 | for (arr_idx, index), col in zip(maps, cols): 214 | sides[arr_idx][index] = col 215 | 216 | for side, orient, maps in zip(state.side_pos, state.side_ort, side_maps): 217 | cols = side_colors[side] 218 | cols = cols if orient == 0 else (cols[1], cols[0]) 219 | for (arr_idx, index), col in zip(maps, cols): 220 | sides[arr_idx][index] = col 221 | 222 | return RenderedState(top=sides[0], left=sides[1], back=sides[2], front=sides[3], 223 | right=sides[4], bottom=sides[5]) 224 | 225 | 226 | # shape of encoded cube state 227 | encoded_shape = (20, 24) 228 | 229 | 230 | def encode_inplace(target, state): 231 | """ 232 | Encode cude into existig zeroed numpy array 233 | Follows encoding described in paper https://arxiv.org/abs/1805.07470 234 | :param target: numpy array 235 | :param state: state to be encoded 236 | """ 237 | assert isinstance(state, State) 238 | 239 | # handle corner cubelets: find their permuted position 240 | for corner_idx in range(8): 241 | perm_pos = state.corner_pos.index(corner_idx) 242 | corn_ort = state.corner_ort[perm_pos] 243 | target[corner_idx, perm_pos * 3 + corn_ort] = 1 244 | 245 | # handle side cubelets 246 | for side_idx in range(12): 247 | perm_pos = state.side_pos.index(side_idx) 248 | side_ort = state.side_ort[perm_pos] 249 | target[8 + side_idx, perm_pos * 2 + side_ort] = 1 250 | 251 | 252 | # register env 253 | _env.register(_env.CubeEnv(name="cube3x3", state_type=State, initial_state=initial_state, 254 | is_goal_pred=is_initial, action_enum=Action, 255 | transform_func=transform, inverse_action_func=inverse_action, 256 | render_func=render, encoded_shape=encoded_shape, encode_func=encode_inplace)) 257 | -------------------------------------------------------------------------------- /articles/01_rubic/csvs/c3x3/c3-zg-d20-noweight-no-decay=5.61e-1.csv: -------------------------------------------------------------------------------- 1 | start_dt,stop_dt,duration,depth,scramble,is_solved,solve_steps,sol_len_naive,sol_len_bfs,tree_depth_max,tree_depth_mean 2 | 2019-01-07T14:44:09.558439,2019-01-07T14:44:09.582932,0.024493,1,10,1,1,1,1,1,1.0 3 | 2019-01-07T14:44:09.583052,2019-01-07T14:44:09.583543,0.000491,1,1,1,1,1,1,1,1.0 4 | 2019-01-07T14:44:09.583647,2019-01-07T14:44:09.584115,0.000468,1,0,1,1,1,1,1,1.0 5 | 2019-01-07T14:44:09.584216,2019-01-07T14:44:09.584675,0.000459,1,11,1,1,1,1,1,1.0 6 | 2019-01-07T14:44:09.584779,2019-01-07T14:44:09.585237,0.000458,1,4,1,1,1,1,1,1.0 7 | 2019-01-07T14:44:09.585752,2019-01-07T14:44:09.603070,0.017318,2,3 3,1,3,2,2,2,1.6875 8 | 2019-01-07T14:44:09.603345,2019-01-07T14:44:09.604791,0.001446,2,11 1,1,3,2,2,2,1.6875 9 | 2019-01-07T14:44:09.604982,2019-01-07T14:44:09.606413,0.001431,2,11 8,1,3,2,2,2,1.6875 10 | 2019-01-07T14:44:09.606675,2019-01-07T14:44:09.607618,0.000943,2,9 6,1,2,2,2,2,1.5 11 | 2019-01-07T14:44:09.607800,2019-01-07T14:44:09.609209,0.001409,2,0 1,1,3,2,2,2,1.6875 12 | 2019-01-07T14:44:09.609883,2019-01-07T14:44:09.610418,0.000535,3,3 8 9,1,1,1,1,1,1.0 13 | 2019-01-07T14:44:09.610523,2019-01-07T14:44:09.612957,0.002434,3,0 8 3,1,5,3,3,3,2.230769230769231 14 | 2019-01-07T14:44:09.613316,2019-01-07T14:44:09.615768,0.002452,3,11 8 6,1,5,3,3,3,2.230769230769231 15 | 2019-01-07T14:44:09.616119,2019-01-07T14:44:09.619293,0.003174,3,9 4 0,1,5,3,3,3,2.230769230769231 16 | 2019-01-07T14:44:09.619643,2019-01-07T14:44:09.622066,0.002423,3,6 5 4,1,5,3,3,3,2.230769230769231 17 | 2019-01-07T14:44:09.622913,2019-01-07T14:44:09.628242,0.005329,4,5 1 1 6,1,10,4,4,4,3.231707317073171 18 | 2019-01-07T14:44:09.629109,2019-01-07T14:44:09.632098,0.002989,4,4 0 11 7,1,6,4,4,4,2.532258064516129 19 | 2019-01-07T14:44:09.632604,2019-01-07T14:44:09.636155,0.003551,4,1 8 4 9,1,7,4,4,4,2.611111111111111 20 | 2019-01-07T14:44:09.636691,2019-01-07T14:44:09.640405,0.003714,4,11 1 0 10,1,7,4,4,4,2.7625 21 | 2019-01-07T14:44:09.641008,2019-01-07T14:44:09.645792,0.004784,4,3 1 6 4,1,8,4,4,4,2.911111111111111 22 | 2019-01-07T14:44:09.646854,2019-01-07T14:44:09.651303,0.004449,5,2 5 5 3 10,1,8,5,5,5,3.5555555555555554 23 | 2019-01-07T14:44:09.651914,2019-01-07T14:44:09.656700,0.004786,5,1 9 10 2 11,1,9,5,5,5,3.260869565217391 24 | 2019-01-07T14:44:09.657389,2019-01-07T14:44:09.662894,0.005505,5,4 11 8 3 10,1,9,5,5,5,3.309090909090909 25 | 2019-01-07T14:44:09.663741,2019-01-07T14:44:09.668053,0.004312,5,6 4 1 3 11,1,8,5,5,5,3.402439024390244 26 | 2019-01-07T14:44:09.668730,2019-01-07T14:44:09.674954,0.006224,5,6 10 7 2 4,1,12,5,5,5,2.9076923076923076 27 | 2019-01-07T14:44:09.676565,2019-01-07T14:44:11.664350,1.987785,6,11 9 6 9 6 5,1,1299,6,6,18,10.872542933995447 28 | 2019-01-07T14:44:11.690601,2019-01-07T14:44:11.698328,0.007727,6,0 9 2 3 2 7,1,7,4,4,4,2.75 29 | 2019-01-07T14:44:11.698905,2019-01-07T14:44:11.702759,0.003854,6,3 7 11 4 5 2,1,7,4,4,4,2.7875 30 | 2019-01-07T14:44:11.703371,2019-01-07T14:44:11.709321,0.00595,6,11 1 2 4 1 9,1,11,6,6,6,3.767857142857143 31 | 2019-01-07T14:44:11.710259,2019-01-07T14:44:11.731925,0.021666,6,6 11 3 1 9 11,1,37,6,6,7,4.916387959866221 32 | 2019-01-07T14:44:11.735611,2019-01-07T14:44:11.977758,0.242147,7,0 7 6 5 10 7 11,1,183,83,5,10,6.227197346600332 33 | 2019-01-07T14:44:11.986667,2019-01-07T14:44:11.993433,0.006766,7,3 3 11 7 4 11 9,1,12,7,7,7,4.6722222222222225 34 | 2019-01-07T14:44:11.994673,2019-01-07T14:44:15.489236,3.494563,7,3 5 2 4 0 11 8,1,1444,161,7,18,10.803833491434629 35 | 2019-01-07T14:44:15.521995,2019-01-07T14:44:15.534007,0.012012,7,9 10 6 1 2 11 10,1,15,7,7,7,4.761111111111111 36 | 2019-01-07T14:44:15.535294,2019-01-07T14:44:15.541103,0.005809,7,5 2 5 4 11 9 9,1,11,5,5,5,3.532934131736527 37 | 2019-01-07T14:44:15.542690,2019-01-07T14:44:15.686961,0.144271,8,0 10 1 4 7 10 6 7,1,136,76,8,12,9.986732489972232 38 | 2019-01-07T14:44:15.700454,2019-01-07T14:44:15.767975,0.067521,8,9 7 11 1 6 1 3 6,1,80,32,8,11,7.826822157434402 39 | 2019-01-07T14:44:15.775367,2019-01-07T14:44:15.807690,0.032323,8,9 10 10 9 5 10 5 2,1,57,8,8,8,5.178489702517163 40 | 2019-01-07T14:44:15.814579,2019-01-07T14:44:18.037154,2.222575,8,9 8 11 0 7 9 10 3,1,2540,8,8,21,12.633602672733732 41 | 2019-01-07T14:44:18.223678,2019-01-07T14:44:18.254370,0.030692,8,5 4 0 3 3 11 1 11,1,39,8,8,8,5.794117647058823 42 | 2019-01-07T14:44:18.258822,2019-01-07T14:53:18.798303,540.539481,9,5 8 7 7 11 7 5 3 5,1,52238,201,9,27,16.868602552838905 43 | 2019-01-07T14:53:21.892641,2019-01-07T15:03:22.107635,600.214994,9,1 8 8 4 9 6 8 0 4,0,69482,-1,-1,32,19.01852813211099 44 | 2019-01-07T15:03:26.193389,2019-01-07T15:03:26.495540,0.302151,9,9 4 1 2 11 3 11 4 5,1,13,7,7,7,4.196969696969697 45 | 2019-01-07T15:03:26.496736,2019-01-07T15:03:48.521456,22.02472,9,11 8 0 11 2 4 5 5 7,1,5747,15,15,21,13.055312081370891 46 | 2019-01-07T15:03:49.539340,2019-01-07T15:03:51.724195,2.184855,9,2 11 9 11 3 3 10 11 2,1,1187,79,9,18,11.100944787972002 47 | 2019-01-07T15:03:51.812089,2019-01-07T15:04:11.190050,19.377961,10,5 9 11 2 10 7 2 9 1 6,1,6325,10,10,24,16.67808122770022 48 | 2019-01-07T15:04:11.857487,2019-01-07T15:04:22.288398,10.430911,10,11 3 11 8 11 1 6 7 4 11,1,4649,8,8,22,13.455165196414635 49 | 2019-01-07T15:04:22.463115,2019-01-07T15:14:22.483708,600.020593,10,5 0 4 3 2 10 6 5 9 0,0,113456,-1,-1,31,18.086052982114165 50 | 2019-01-07T15:14:27.407146,2019-01-07T15:14:27.916715,0.509569,10,9 9 9 11 3 5 4 11 2 7,1,11,6,6,6,3.767857142857143 51 | 2019-01-07T15:14:27.917772,2019-01-07T15:24:27.922582,600.00481,10,7 10 7 0 7 7 2 3 10 0,0,60742,-1,-1,31,17.406939955774792 52 | 2019-01-07T15:24:29.578398,2019-01-07T15:34:29.827992,600.249594,11,10 8 4 1 8 5 4 2 1 1 8,0,102164,-1,-1,36,20.13788480311567 53 | 2019-01-07T15:34:41.246524,2019-01-07T15:44:41.701566,600.455042,11,10 6 9 8 5 3 6 6 11 10 2,0,82964,-1,-1,31,19.168137290213494 54 | 2019-01-07T15:44:46.589449,2019-01-07T15:54:46.956211,600.366762,11,8 7 9 10 11 8 3 8 6 9 0,0,80964,-1,-1,31,18.333843464593148 55 | 2019-01-07T15:54:51.278641,2019-01-07T15:57:49.381791,178.10315,11,0 7 0 10 6 8 11 8 11 4 5,1,27981,9,9,28,16.668082578863586 56 | 2019-01-07T15:57:51.160560,2019-01-07T16:07:51.267429,600.106869,11,6 2 0 1 3 5 10 9 0 7 4,0,84394,-1,-1,31,18.70516663046059 57 | 2019-01-07T16:07:57.452784,2019-01-07T16:07:57.844678,0.391894,12,7 7 0 1 10 5 6 6 6 6 5 5,1,9,4,4,4,2.7818181818181817 58 | 2019-01-07T16:07:57.845628,2019-01-07T16:17:57.850272,600.004644,12,9 11 7 9 1 4 4 9 8 11 7 0,0,147240,-1,-1,36,22.402116717238616 59 | 2019-01-07T16:18:16.362751,2019-01-07T16:28:17.071256,600.708505,12,6 7 9 7 0 8 4 9 8 8 10 9,0,94443,-1,-1,36,20.584762322648757 60 | 2019-01-07T16:28:32.024796,2019-01-07T16:38:32.439563,600.414767,12,0 0 8 10 8 11 11 2 1 3 6 10,0,136188,-1,-1,36,19.255684157852855 61 | 2019-01-07T16:38:47.411074,2019-01-07T16:38:48.209588,0.798514,12,10 7 9 7 7 0 7 6 9 10 2 3,1,163,44,8,11,7.816482582837723 62 | 2019-01-07T16:38:48.227105,2019-01-07T16:48:48.228412,600.001307,13,3 10 0 3 8 11 4 3 8 8 10 3 4,0,78332,-1,-1,34,18.776139217901104 63 | 2019-01-07T16:48:52.781335,2019-01-07T16:58:22.898102,570.116767,13,5 10 11 2 0 5 4 0 5 10 10 11 4,1,71502,7,7,31,16.202572545728657 64 | 2019-01-07T16:58:25.259768,2019-01-07T17:08:25.560076,600.300308,13,5 3 8 10 1 6 9 7 0 0 10 1 2,0,120696,-1,-1,34,20.4408104090139 65 | 2019-01-07T17:08:38.677557,2019-01-07T17:18:39.251250,600.573693,13,5 6 9 11 3 1 1 11 8 10 9 11 0,0,162422,-1,-1,33,18.96892935573633 66 | 2019-01-07T17:18:50.158980,2019-01-07T17:28:50.970918,600.811938,13,7 11 8 6 10 0 8 6 9 9 6 9 6,0,81892,-1,-1,36,21.129603342709775 67 | 2019-01-07T17:28:57.046836,2019-01-07T17:38:57.460947,600.414111,14,9 6 2 10 7 9 9 1 4 8 5 8 6 9,0,121576,-1,-1,30,18.275118930648812 68 | 2019-01-07T17:39:03.893906,2019-01-07T17:49:04.520704,600.626798,14,1 9 11 0 9 5 9 5 9 11 10 1 6 6,0,79036,-1,-1,35,19.351324738506445 69 | 2019-01-07T17:49:09.831021,2019-01-07T17:59:10.184197,600.353176,14,4 9 11 10 3 2 7 10 9 0 8 10 0 8,0,97135,-1,-1,34,19.69256232154972 70 | 2019-01-07T17:59:17.689028,2019-01-07T18:09:18.141573,600.452545,14,11 4 6 9 5 7 2 2 9 5 7 2 6 7,0,143079,-1,-1,34,19.440330490473848 71 | 2019-01-07T18:09:28.087935,2019-01-07T18:19:28.803745,600.71581,14,11 7 11 0 10 3 11 1 10 7 6 10 6 8,0,167107,-1,-1,32,18.534979663612063 72 | 2019-01-07T18:19:54.199761,2019-01-07T18:29:55.070337,600.870576,15,9 7 10 1 10 7 10 9 1 9 8 9 8 11 9,0,122919,-1,-1,34,21.93279873853761 73 | 2019-01-07T18:30:23.798214,2019-01-07T18:40:24.386446,600.588232,15,10 11 3 11 8 7 7 0 7 8 11 2 6 6 2,0,153655,-1,-1,34,19.249943400358564 74 | 2019-01-07T18:40:40.889187,2019-01-07T18:50:41.677443,600.788256,15,4 8 3 11 11 2 9 4 11 8 5 2 11 1 11,0,143622,-1,-1,38,22.705381560985707 75 | 2019-01-07T18:51:08.917242,2019-01-07T19:01:09.640506,600.723264,15,1 5 5 1 9 6 6 2 6 8 4 6 9 0 1,0,143012,-1,-1,33,19.98241963504499 76 | 2019-01-07T19:01:23.791643,2019-01-07T19:11:24.502050,600.710407,15,8 4 6 10 8 1 5 4 0 5 1 4 1 8 7,0,172071,-1,-1,35,20.817842064917873 77 | 2019-01-07T19:11:53.288295,2019-01-07T19:21:54.197683,600.909388,16,4 2 6 5 1 0 10 5 7 7 9 10 8 6 10 3,0,160646,-1,-1,34,18.869670068329217 78 | -------------------------------------------------------------------------------- /articles/01_rubic/libcube/mcts.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import collections 4 | 5 | from . import cubes 6 | from . import model 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | class MCTS: 13 | """ 14 | Monte Carlo Tree Search state and method 15 | """ 16 | def __init__(self, cube_env, state, net, exploration_c=100, virt_loss_nu=100.0, device="cpu"): 17 | assert isinstance(cube_env, cubes.CubeEnv) 18 | assert cube_env.is_state(state) 19 | 20 | self.cube_env = cube_env 21 | self.root_state = state 22 | self.net = net 23 | self.exploration_c = exploration_c 24 | self.virt_loss_nu = virt_loss_nu 25 | self.device = device 26 | 27 | # Tree state 28 | shape = (len(cube_env.action_enum), ) 29 | # correspond to N_s(a) in the paper 30 | self.act_counts = collections.defaultdict(lambda: np.zeros(shape, dtype=np.uint32)) 31 | # correspond to W_s(a) 32 | self.val_maxes = collections.defaultdict(lambda: np.zeros(shape, dtype=np.float32)) 33 | # correspond to P_s(a) 34 | self.prob_actions = {} 35 | # correspond to L_s(a) 36 | self.virt_loss = collections.defaultdict(lambda: np.zeros(shape, dtype=np.float32)) 37 | # TODO: check speed and memory of edge-less version 38 | self.edges = {} 39 | 40 | def __len__(self): 41 | return len(self.edges) 42 | 43 | def __repr__(self): 44 | return "MCTS(states=%d)" % len(self.edges) 45 | 46 | def dump_root(self): 47 | print("Root state:") 48 | self.dump_state(self.root_state) 49 | # states, _ = cubes.explore_state(self.cube_env, self.root_state) 50 | # for idx, s in enumerate(states): 51 | # print("") 52 | # print("State %d" % idx) 53 | # self.dump_state(s) 54 | 55 | def dump_state(self, s): 56 | print("") 57 | print("act_counts: %s" % ", ".join(map(lambda v: "%8d" % v, self.act_counts[s].tolist()))) 58 | print("probs: %s" % ", ".join(map(lambda v: "%.2e" % v, self.prob_actions[s].tolist()))) 59 | print("val_maxes: %s" % ", ".join(map(lambda v: "%.2e" % v, self.val_maxes[s].tolist()))) 60 | 61 | act_counts = self.act_counts[s] 62 | N_sqrt = np.sqrt(np.sum(act_counts)) 63 | u = self.exploration_c * N_sqrt / (act_counts + 1) 64 | print("u: %s" % ", ".join(map(lambda v: "%.2e" % v, u.tolist()))) 65 | u *= self.prob_actions[s] 66 | print("u*prob: %s" % ", ".join(map(lambda v: "%.2e" % v, u.tolist()))) 67 | q = self.val_maxes[s] - self.virt_loss[s] 68 | print("q: %s" % ", ".join(map(lambda v: "%.2e" % v, q.tolist()))) 69 | fin = u + q 70 | print("u*prob + q: %s" % ", ".join(map(lambda v: "%.2e" % v, fin.tolist()))) 71 | act = np.argmax(fin, axis=0) 72 | print("Action: %s" % act) 73 | 74 | def search(self): 75 | s, path_actions, path_states = self._search_leaf() 76 | 77 | child_states, child_goal = self.cube_env.explore_state(s) 78 | self.edges[s] = child_states 79 | 80 | value = self._expand_leaves([s])[0] 81 | self._backup_leaf(path_states, path_actions, value) 82 | 83 | if np.any(child_goal): 84 | path_actions.append(np.argmax(child_goal)) 85 | return path_actions 86 | return None 87 | 88 | def _search_leaf(self): 89 | """ 90 | Starting the root state, find path to the leaf node 91 | :return: tuple: (state, path_actions, path_states) 92 | """ 93 | s = self.root_state 94 | path_actions = [] 95 | path_states = [] 96 | 97 | # walking down the tree 98 | while True: 99 | next_states = self.edges.get(s) 100 | if next_states is None: 101 | break 102 | 103 | act_counts = self.act_counts[s] 104 | N_sqrt = np.sqrt(np.sum(act_counts)) 105 | if N_sqrt < 1e-6: 106 | act = random.randrange(len(self.cube_env.action_enum)) 107 | else: 108 | u = self.exploration_c * N_sqrt / (act_counts + 1) 109 | u *= self.prob_actions[s] 110 | q = self.val_maxes[s] - self.virt_loss[s] 111 | act = np.argmax(u + q) 112 | self.virt_loss[s][act] += self.virt_loss_nu 113 | path_actions.append(act) 114 | path_states.append(s) 115 | s = next_states[act] 116 | return s, path_actions, path_states 117 | 118 | def _expand_leaves(self, leaf_states): 119 | """ 120 | From list of states expand them using the network 121 | :param leaf_states: list of states 122 | :return: list of state values 123 | """ 124 | policies, values = self.evaluate_states(leaf_states) 125 | for s, p in zip(leaf_states, policies): 126 | self.prob_actions[s] = p 127 | return values 128 | 129 | def _backup_leaf(self, states, actions, value): 130 | """ 131 | Update tree state after reaching and expanding the leaf node 132 | :param states: path of states (without final leaf state) 133 | :param actions: path of actions 134 | :param value: value of leaf node 135 | """ 136 | for path_s, path_a in zip(states, actions): 137 | self.act_counts[path_s][path_a] += 1 138 | w = self.val_maxes[path_s] 139 | w[path_a] = max(w[path_a], value) 140 | self.virt_loss[path_s][path_a] -= self.virt_loss_nu 141 | 142 | def search_batch(self, batch_size): 143 | """ 144 | Perform a batches search to increase efficiency. 145 | :param batch_size: size of search batch 146 | :return: path to solution or None if not found 147 | """ 148 | batch_size = min(batch_size, len(self) + 1) 149 | batch_states, batch_actions, batch_paths = [], [], [] 150 | for _ in range(batch_size): 151 | s, path_acts, path_s = self._search_leaf() 152 | batch_states.append(s) 153 | batch_actions.append(path_acts) 154 | batch_paths.append(path_s) 155 | 156 | for s, path_actions in zip(batch_states, batch_actions): 157 | child, goals = self.cube_env.explore_state(s) 158 | self.edges[s] = child 159 | if np.any(goals): 160 | return path_actions + [np.argmax(goals)] 161 | 162 | values = self._expand_leaves(batch_states) 163 | for value, path_states, path_actions in zip(values, batch_paths, batch_actions): 164 | self._backup_leaf(path_states, path_actions, value) 165 | return None 166 | 167 | def evaluate_states(self, states): 168 | """ 169 | Ask network to return policy and values 170 | :param net: 171 | :param states: 172 | :return: 173 | """ 174 | enc_states = model.encode_states(self.cube_env, states) 175 | enc_states_t = torch.tensor(enc_states).to(self.device) 176 | policy_t, value_t = self.net(enc_states_t) 177 | policy_t = F.softmax(policy_t, dim=1) 178 | return policy_t.detach().cpu().numpy(), value_t.squeeze(-1).detach().cpu().numpy() 179 | 180 | def eval_states_values(self, states): 181 | enc_states = model.encode_states(self.cube_env, states) 182 | enc_states_t = torch.tensor(enc_states).to(self.device) 183 | value_t = self.net(enc_states_t, value_only=True) 184 | return value_t.detach().cpu().numpy() 185 | 186 | def get_depth_stats(self): 187 | """ 188 | Calculate minimum, maximum, and mean depth of children in the tree 189 | :return: dict with stats 190 | """ 191 | max_depth = 0 192 | sum_depth = 0 193 | leaves_count = 0 194 | q = collections.deque([(self.root_state, 0)]) 195 | met = set() 196 | 197 | while q: 198 | s, depth = q.popleft() 199 | met.add(s) 200 | for ss in self.edges[s]: 201 | if ss not in self.edges: 202 | max_depth = max(max_depth, depth+1) 203 | sum_depth += depth+1 204 | leaves_count += 1 205 | elif ss not in met: 206 | q.append((ss, depth+1)) 207 | return { 208 | 'max': max_depth, 209 | 'mean': sum_depth / leaves_count, 210 | 'leaves': leaves_count 211 | } 212 | 213 | def dump_solution(self, solution): 214 | assert isinstance(solution, list) 215 | 216 | s = self.root_state 217 | r = self.cube_env.render(s) 218 | print(r) 219 | for aidx in solution: 220 | a = self.cube_env.action_enum(aidx) 221 | print(a, aidx) 222 | s = self.cube_env.transform(s, a) 223 | r = self.cube_env.render(s) 224 | print(r) 225 | 226 | def find_solution(self): 227 | queue = collections.deque([(self.root_state, [])]) 228 | seen = set() 229 | 230 | while queue: 231 | s, path = queue.popleft() 232 | seen.add(s) 233 | c_states, c_goals = self.cube_env.explore_state(s) 234 | for a_idx, (c_state, c_goal) in enumerate(zip(c_states, c_goals)): 235 | p = path + [a_idx] 236 | if c_goal: 237 | return p 238 | if c_state in seen or c_state not in self.edges: 239 | continue 240 | queue.append((c_state, p)) 241 | 242 | 243 | -------------------------------------------------------------------------------- /articles/01_rubic/csvs/c3x3/c3-zg-d20-noweight-no-decay=7.29e-1.csv: -------------------------------------------------------------------------------- 1 | start_dt,stop_dt,duration,depth,scramble,is_solved,solve_steps,sol_len_naive,sol_len_bfs,tree_depth_max,tree_depth_mean 2 | 2019-01-06T14:09:16.552747,2019-01-06T14:09:16.556388,0.003641,1,10,1,1,1,1,1,1.0 3 | 2019-01-06T14:09:16.556505,2019-01-06T14:09:16.557775,0.00127,1,1,1,1,1,1,1,1.0 4 | 2019-01-06T14:09:16.557882,2019-01-06T14:09:16.559134,0.001252,1,0,1,1,1,1,1,1.0 5 | 2019-01-06T14:09:16.559234,2019-01-06T14:09:16.560486,0.001252,1,11,1,1,1,1,1,1.0 6 | 2019-01-06T14:09:16.560587,2019-01-06T14:09:16.561850,0.001263,1,4,1,1,1,1,1,1.0 7 | 2019-01-06T14:09:16.562391,2019-01-06T14:09:16.566287,0.003896,2,3 3,1,3,2,2,2,1.6875 8 | 2019-01-06T14:09:16.566554,2019-01-06T14:09:16.570355,0.003801,2,11 1,1,3,2,2,2,1.6875 9 | 2019-01-06T14:09:16.570543,2019-01-06T14:09:16.574331,0.003788,2,11 8,1,3,2,2,2,1.6875 10 | 2019-01-06T14:09:16.574590,2019-01-06T14:09:16.577116,0.002526,2,9 6,1,2,2,2,2,1.5 11 | 2019-01-06T14:09:16.577296,2019-01-06T14:09:16.581111,0.003815,2,0 1,1,3,2,2,2,1.6875 12 | 2019-01-06T14:09:16.581816,2019-01-06T14:09:16.583161,0.001345,3,3 8 9,1,1,1,1,1,1.0 13 | 2019-01-06T14:09:16.583268,2019-01-06T14:09:16.589829,0.006561,3,0 8 3,1,5,3,3,3,2.230769230769231 14 | 2019-01-06T14:09:16.590204,2019-01-06T14:09:16.596763,0.006559,3,11 8 6,1,5,3,3,3,2.230769230769231 15 | 2019-01-06T14:09:16.597126,2019-01-06T14:09:16.603598,0.006472,3,9 4 0,1,5,3,3,3,2.230769230769231 16 | 2019-01-06T14:09:16.603957,2019-01-06T14:09:16.613108,0.009151,3,6 5 4,1,7,3,3,3,2.4125 17 | 2019-01-06T14:09:16.614026,2019-01-06T14:09:16.617896,0.00387,4,1 1 6 1,1,3,2,2,2,1.6875 18 | 2019-01-06T14:09:16.618082,2019-01-06T14:09:16.627136,0.009054,4,5 9 4 0,1,7,4,4,4,2.75 19 | 2019-01-06T14:09:16.627638,2019-01-06T14:09:16.639319,0.011681,4,1 6 1 8,1,9,4,4,4,3.03 20 | 2019-01-06T14:09:16.639908,2019-01-06T14:09:16.651579,0.011671,4,9 11 1 0,1,9,4,4,4,2.975 21 | 2019-01-06T14:09:16.652389,2019-01-06T14:09:16.661435,0.009046,4,1 6 4 7,1,7,4,4,4,2.75 22 | 2019-01-06T14:09:16.663770,2019-01-06T14:09:16.675613,0.011843,5,5 5 3 10 11,1,9,5,5,5,3.5642857142857145 23 | 2019-01-06T14:09:16.676580,2019-01-06T14:09:16.690900,0.01432,5,10 2 11 3 2,1,11,5,5,5,3.38125 24 | 2019-01-06T14:09:16.692098,2019-01-06T14:09:16.705241,0.013143,5,8 3 10 5 0,1,10,5,5,5,3.5076923076923077 25 | 2019-01-06T14:09:16.706136,2019-01-06T14:09:16.718102,0.011966,5,4 1 3 11 3,1,9,5,5,5,3.260869565217391 26 | 2019-01-06T14:09:16.718778,2019-01-06T14:09:16.732084,0.013306,5,7 2 4 2 3,1,10,5,5,5,3.533333333333333 27 | 2019-01-06T14:09:16.733558,2019-01-06T14:09:16.911121,0.177563,6,9 6 9 6 5 3,1,117,6,6,12,8.273816314888762 28 | 2019-01-06T14:09:16.915739,2019-01-06T14:09:16.925316,0.009577,6,9 5 7 0 1 5,1,7,4,4,4,2.75 29 | 2019-01-06T14:09:16.925913,2019-01-06T14:09:16.971490,0.045577,6,9 1 1 11 7 8,1,33,8,8,11,6.27710843373494 30 | 2019-01-06T14:09:16.973543,2019-01-06T14:09:16.997381,0.023838,6,11 4 6 10 10 5,1,18,6,6,6,4.45 31 | 2019-01-06T14:09:16.999458,2019-01-06T14:09:17.009837,0.010379,6,5 0 9 8 3 3,1,8,4,4,4,2.963636363636364 32 | 2019-01-06T14:09:17.017249,2019-01-06T14:09:17.121869,0.10462,7,10 0 3 1 0 5 1,1,78,7,7,9,6.035183349851338 33 | 2019-01-06T14:09:17.134077,2019-01-06T14:09:17.151273,0.017196,7,3 3 3 8 7 2 6,1,13,5,5,5,3.2 34 | 2019-01-06T14:09:17.152471,2019-01-06T14:09:17.207169,0.054698,7,8 1 0 10 8 0 1,1,41,7,7,7,5.096818810511756 35 | 2019-01-06T14:09:17.212023,2019-01-06T14:09:17.356197,0.144174,7,11 7 2 3 4 3 0,1,105,7,7,10,6.455197132616488 36 | 2019-01-06T14:09:17.363460,2019-01-06T14:09:17.374085,0.010625,7,5 1 0 7 9 9 1,1,8,5,5,5,3.158536585365854 37 | 2019-01-06T14:09:17.375150,2019-01-06T14:09:19.349742,1.974592,8,4 2 5 1 3 5 4 2,1,1998,8,8,17,10.442109123737517 38 | 2019-01-06T14:09:19.533959,2019-01-06T14:09:21.350351,1.816392,8,2 1 5 9 4 6 11 4,1,1325,8,8,19,12.895317816269719 39 | 2019-01-06T14:09:21.415372,2019-01-06T14:09:21.432305,0.016933,8,4 2 3 5 7 3 3 2,1,25,8,8,8,5.0177304964539005 40 | 2019-01-06T14:09:21.436662,2019-01-06T14:09:21.479209,0.042547,8,2 9 6 2 2 2 11 9,1,65,6,6,8,5.407932011331445 41 | 2019-01-06T14:09:21.483741,2019-01-06T14:11:59.523841,158.0401,8,8 6 2 3 2 4 0 10,1,44410,8,8,30,16.803861237559396 42 | 2019-01-06T14:12:04.030074,2019-01-06T14:22:04.195707,600.165633,9,2 9 0 5 2 9 1 8 10,0,274787,-1,-1,32,17.204026107709293 43 | 2019-01-06T14:22:23.404471,2019-01-06T14:22:29.638089,6.233618,9,6 1 8 6 7 9 6 7 9,1,5434,9,9,22,13.614096867246845 44 | 2019-01-06T14:22:30.065716,2019-01-06T14:22:30.605404,0.539688,9,1 2 7 3 3 6 4 7 2,1,767,9,9,16,9.257389665462542 45 | 2019-01-06T14:22:30.735416,2019-01-06T14:22:30.761385,0.025969,9,9 10 5 4 4 5 4 9 9,1,42,7,7,7,4.839080459770115 46 | 2019-01-06T14:22:30.768297,2019-01-06T14:22:32.928365,2.160068,9,10 9 2 11 2 9 7 4 6,1,3201,9,9,20,9.406191796697156 47 | 2019-01-06T14:22:33.304934,2019-01-06T14:23:19.177325,45.872391,10,8 10 1 0 9 1 1 6 2 10,1,29613,10,10,28,19.453958884997462 48 | 2019-01-06T14:23:23.360435,2019-01-06T14:23:26.199258,2.838823,10,8 5 8 7 4 7 0 3 3 5,1,3179,10,10,18,10.754414154343346 49 | 2019-01-06T14:23:26.619923,2019-01-06T14:23:27.722785,1.102862,10,11 11 4 6 11 11 2 0 7 7,1,1120,10,10,16,9.823237727772838 50 | 2019-01-06T14:23:27.998577,2019-01-06T14:33:28.014597,600.01602,10,7 5 1 6 3 0 9 11 0 8,0,88878,-1,-1,27,15.944966726318448 51 | 2019-01-06T14:33:31.632578,2019-01-06T14:43:32.039710,600.407132,10,8 4 11 2 6 1 8 10 10 6,0,419277,-1,-1,33,17.289153285215036 52 | 2019-01-06T14:44:05.621081,2019-01-06T14:44:08.934641,3.31356,11,10 1 1 1 10 5 4 11 3 7 7,1,77,45,5,10,7.4206219312602295 53 | 2019-01-06T14:44:08.936874,2019-01-06T14:44:09.497343,0.560469,11,3 3 2 9 10 11 4 8 5 1 11,1,334,177,7,11,9.061358058358332 54 | 2019-01-06T14:44:09.512882,2019-01-06T14:54:09.515693,600.002811,11,7 3 1 0 1 11 10 10 8 11 2,0,149224,-1,-1,30,18.404876548987385 55 | 2019-01-06T14:54:30.454222,2019-01-06T15:04:31.244918,600.790696,11,11 1 8 11 8 11 4 0 9 4 11,0,189869,-1,-1,33,19.821504997547162 56 | 2019-01-06T15:05:10.093751,2019-01-06T15:15:11.170104,601.076353,11,4 2 10 8 3 5 7 0 11 2 0,0,157672,-1,-1,28,17.78882512180261 57 | 2019-01-06T15:15:24.680907,2019-01-06T15:25:25.526476,600.845569,12,9 1 0 3 1 1 3 8 11 4 7 0,0,322440,-1,-1,34,19.76268595301969 58 | 2019-01-06T15:25:56.790797,2019-01-06T15:35:59.087638,602.296841,12,9 8 8 4 11 3 7 3 10 6 5 3,0,130731,-1,-1,34,17.93915392271383 59 | 2019-01-06T15:36:06.863143,2019-01-06T15:46:07.524300,600.661157,12,8 11 7 2 6 11 6 7 3 7 5 5,0,356716,-1,-1,32,17.288594039956102 60 | 2019-01-06T15:46:28.640718,2019-01-06T15:56:31.377751,602.737033,12,5 8 0 7 10 2 0 11 0 2 7 10,0,281370,-1,-1,32,18.214257456090174 61 | 2019-01-06T15:56:47.762704,2019-01-06T16:06:49.616766,601.854062,12,9 8 4 7 2 4 4 9 4 2 2 1,0,302444,-1,-1,32,18.764034775388737 62 | 2019-01-06T16:07:13.054100,2019-01-06T16:17:15.179990,602.12589,13,5 3 3 1 2 3 3 10 5 7 7 0 11,0,262322,-1,-1,34,18.14576680085131 63 | 2019-01-06T16:17:29.595574,2019-01-06T16:27:31.333910,601.738336,13,5 6 11 0 5 8 9 7 6 4 7 3 7,0,424392,-1,-1,32,16.798179279444852 64 | 2019-01-06T16:28:01.993915,2019-01-06T16:30:33.372628,151.378713,13,11 10 2 3 11 9 8 8 8 0 2 2 7,1,100649,15,15,28,15.017517507065948 65 | 2019-01-06T16:31:02.272831,2019-01-06T16:41:02.753406,600.480575,13,6 4 4 2 3 3 4 7 11 3 7 7 4,0,433371,-1,-1,29,16.651715750615637 66 | 2019-01-06T16:41:37.487185,2019-01-06T16:51:41.212517,603.725332,13,1 1 0 10 1 0 7 11 8 10 7 5 9,0,255143,-1,-1,32,17.931627082008582 67 | 2019-01-06T16:52:00.538415,2019-01-06T16:52:34.753950,34.215535,14,6 1 9 0 3 7 10 9 4 5 10 5 3 6,1,9405,146,12,24,15.988570226528985 68 | 2019-01-06T16:52:37.746065,2019-01-06T17:02:37.780319,600.034254,14,9 8 1 2 3 8 0 9 5 9 9 8 6 9,0,250171,-1,-1,32,19.255058810877472 69 | 2019-01-06T17:02:53.188684,2019-01-06T17:12:54.819317,601.630633,14,11 3 0 10 8 0 7 8 1 9 5 7 5 0,0,235395,-1,-1,37,19.60527717815236 70 | 2019-01-06T17:13:39.794081,2019-01-06T17:16:34.484647,174.690566,14,10 11 9 11 10 5 1 5 5 9 7 8 8 8,1,122402,10,10,30,17.5409955065966 71 | 2019-01-06T17:16:55.027920,2019-01-06T17:26:55.644971,600.617051,14,10 3 1 0 0 2 6 3 6 10 2 2 6 3,0,444103,-1,-1,35,17.406692138811064 72 | 2019-01-06T17:27:50.162265,2019-01-06T17:37:54.181558,604.019293,15,7 11 3 8 7 4 5 4 1 1 11 2 2 4 9,0,197462,-1,-1,29,16.260482242977105 73 | 2019-01-06T17:38:01.023271,2019-01-06T17:48:02.199973,601.176702,15,6 5 1 1 11 2 4 2 2 4 7 5 0 2 5,0,382590,-1,-1,33,19.63565055518015 74 | 2019-01-06T17:48:47.393800,2019-01-06T17:58:50.583346,603.189546,15,8 0 5 4 4 0 0 2 6 9 7 2 2 10 3,0,143903,-1,-1,30,18.127055838762413 75 | 2019-01-06T17:59:59.242073,2019-01-06T18:10:00.015164,600.773091,15,3 3 8 7 8 0 11 4 3 1 11 1 5 4 4,0,278915,-1,-1,33,17.399902736588253 76 | 2019-01-06T18:10:19.871131,2019-01-06T18:20:21.847020,601.975889,15,7 8 8 10 1 9 8 1 1 3 4 3 1 5 8,0,359569,-1,-1,33,16.867698194822697 77 | 2019-01-06T18:20:41.461789,2019-01-06T18:30:44.344106,602.882317,16,0 9 11 10 10 3 6 7 10 0 3 8 7 7 9 7,0,367019,-1,-1,33,19.639358938868003 78 | 2019-01-06T18:31:52.691697,2019-01-06T18:41:55.682982,602.991285,16,4 11 1 3 0 1 6 5 4 9 5 4 8 7 0 3,0,320230,-1,-1,32,18.21490847815826 79 | 2019-01-06T18:42:28.989401,2019-01-06T18:52:31.395108,602.405707,16,10 0 10 7 7 0 3 5 9 0 10 5 1 3 8 1,0,413090,-1,-1,32,16.77479307210613 80 | 2019-01-06T18:52:52.647275,2019-01-06T19:02:56.225685,603.57841,16,0 4 8 10 2 5 3 5 2 7 2 6 10 7 9 5,0,390117,-1,-1,35,19.640656685414083 81 | 2019-01-06T19:03:40.063774,2019-01-06T19:13:43.327787,603.264013,16,2 9 2 0 3 11 4 8 5 7 4 4 3 8 4 2,0,109440,-1,-1,30,18.542187330803504 82 | 2019-01-06T19:14:10.746391,2019-01-06T19:24:11.287342,600.540951,17,6 4 6 10 11 2 0 9 0 5 5 10 7 5 2 4 9,0,504794,-1,-1,28,16.580335207594697 83 | 2019-01-06T19:24:51.578995,2019-01-06T19:34:56.566842,604.987847,17,6 7 10 11 11 4 8 0 10 1 0 4 1 9 11 7 10,0,338284,-1,-1,34,19.235462609380917 84 | 2019-01-06T19:36:13.973676,2019-01-06T19:46:16.576125,602.602449,17,6 3 4 7 5 5 4 9 10 5 10 5 4 0 11 3 8,0,367812,-1,-1,31,17.94752277377971 85 | 2019-01-06T19:46:44.420677,2019-01-06T19:56:47.442714,603.022037,17,5 9 7 3 0 3 2 2 7 3 6 10 8 9 0 0 8,0,501793,-1,-1,32,16.73799891113393 86 | 2019-01-06T19:57:25.119296,2019-01-06T20:07:30.037665,604.918369,17,4 0 8 11 8 10 9 11 7 2 4 0 5 2 1 3 10,0,165278,-1,-1,30,17.27349818998854 87 | -------------------------------------------------------------------------------- /articles/01_rubic/tests/libcube/cubes/test_cube2x2.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import random 4 | 5 | from libcube.cubes import cube2x2 6 | 7 | 8 | class CubeRender(unittest.TestCase): 9 | def test_init_render(self): 10 | state = cube2x2.initial_state 11 | render = cube2x2.render(state) 12 | self.assertIsInstance(render, cube2x2.RenderedState) 13 | self.assertEqual(render.top, ['W'] * 4) 14 | self.assertEqual(render.back, ['O'] * 4) 15 | self.assertEqual(render.bottom, ['Y'] * 4) 16 | self.assertEqual(render.front, ['R'] * 4) 17 | self.assertEqual(render.left, ['G'] * 4) 18 | self.assertEqual(render.right, ['B'] * 4) 19 | 20 | 21 | class CubeTransforms(unittest.TestCase): 22 | def test_top(self): 23 | s = cube2x2.initial_state 24 | s = cube2x2.transform(s, cube2x2.Action.T) 25 | r = cube2x2.render(s) 26 | self.assertEqual(r.top, ['W'] * 4) 27 | self.assertEqual(r.back, ['G'] * 2 + ['O'] * 2) 28 | self.assertEqual(r.bottom, ['Y'] * 4) 29 | self.assertEqual(r.front, ['B'] * 2 + ['R'] * 2) 30 | self.assertEqual(r.left, ['R'] * 2 + ['G'] * 2) 31 | self.assertEqual(r.right, ['O'] * 2 + ['B'] * 2) 32 | 33 | def test_top_rev(self): 34 | s = cube2x2.initial_state 35 | s = cube2x2.transform(s, cube2x2.Action.t) 36 | r = cube2x2.render(s) 37 | self.assertEqual(r.top, ['W'] * 4) 38 | self.assertEqual(r.back, ['B'] * 2 + ['O'] * 2) 39 | self.assertEqual(r.bottom, ['Y'] * 4) 40 | self.assertEqual(r.front, ['G'] * 2 + ['R'] * 2) 41 | self.assertEqual(r.left, ['O'] * 2 + ['G'] * 2) 42 | self.assertEqual(r.right, ['R'] * 2 + ['B'] * 2) 43 | 44 | def test_down(self): 45 | s = cube2x2.initial_state 46 | s = cube2x2.transform(s, cube2x2.Action.D) 47 | r = cube2x2.render(s) 48 | self.assertEqual(r.back, ['O'] * 2 + ['B'] * 2) 49 | self.assertEqual(r.bottom, ['Y'] * 4) 50 | self.assertEqual(r.front, ['R'] * 2 + ['G'] * 2) 51 | self.assertEqual(r.left, ['G'] * 2 + ['O'] * 2) 52 | self.assertEqual(r.right, ['B'] * 2 + ['R'] * 2) 53 | self.assertEqual(r.top, ['W'] * 4) 54 | 55 | def test_down_rev(self): 56 | s = cube2x2.initial_state 57 | s = cube2x2.transform(s, cube2x2.Action.d) 58 | r = cube2x2.render(s) 59 | self.assertEqual(r.back, ['O'] * 2 + ['G'] * 2) 60 | self.assertEqual(r.bottom, ['Y'] * 4) 61 | self.assertEqual(r.front, ['R'] * 2 + ['B'] * 2) 62 | self.assertEqual(r.left, ['G'] * 2 + ['R'] * 2) 63 | self.assertEqual(r.right, ['B'] * 2 + ['O'] * 2) 64 | self.assertEqual(r.top, ['W'] * 4) 65 | 66 | def test_right(self): 67 | s = cube2x2.initial_state 68 | s = cube2x2.transform(s, cube2x2.Action.R) 69 | r = cube2x2.render(s) 70 | self.assertEqual(r.back, ['W', 'O'] * 2) 71 | self.assertEqual(r.bottom, ['Y', 'O'] * 2) 72 | self.assertEqual(r.front, ['R', 'Y'] * 2) 73 | self.assertEqual(r.left, ['G'] * 4) 74 | self.assertEqual(r.right, ['B'] * 4) 75 | self.assertEqual(r.top, ['W', 'R'] * 2) 76 | 77 | def test_right_rev(self): 78 | s = cube2x2.initial_state 79 | s = cube2x2.transform(s, cube2x2.Action.r) 80 | r = cube2x2.render(s) 81 | self.assertEqual(r.back, ['Y', 'O'] * 2) 82 | self.assertEqual(r.bottom, ['Y', 'R'] * 2) 83 | self.assertEqual(r.front, ['R', 'W'] * 2) 84 | self.assertEqual(r.left, ['G'] * 4) 85 | self.assertEqual(r.right, ['B'] * 4) 86 | self.assertEqual(r.top, ['W', 'O'] * 2) 87 | 88 | def test_left(self): 89 | s = cube2x2.initial_state 90 | s = cube2x2.transform(s, cube2x2.Action.L) 91 | r = cube2x2.render(s) 92 | self.assertEqual(r.back, ['O', 'Y'] * 2) 93 | self.assertEqual(r.bottom, ['R', 'Y'] * 2) 94 | self.assertEqual(r.front, ['W', 'R'] * 2) 95 | self.assertEqual(r.left, ['G'] * 4) 96 | self.assertEqual(r.right, ['B'] * 4) 97 | self.assertEqual(r.top, ['O', 'W'] * 2) 98 | 99 | def test_left_rev(self): 100 | s = cube2x2.initial_state 101 | s = cube2x2.transform(s, cube2x2.Action.l) 102 | r = cube2x2.render(s) 103 | self.assertEqual(r.back, ['O', 'W'] * 2) 104 | self.assertEqual(r.bottom, ['O', 'Y'] * 2) 105 | self.assertEqual(r.front, ['Y', 'R'] * 2) 106 | self.assertEqual(r.left, ['G'] * 4) 107 | self.assertEqual(r.right, ['B'] * 4) 108 | self.assertEqual(r.top, ['R', 'W'] * 2) 109 | 110 | def test_front(self): 111 | s = cube2x2.initial_state 112 | s = cube2x2.transform(s, cube2x2.Action.F) 113 | r = cube2x2.render(s) 114 | self.assertEqual(r.back, ['O'] * 4) 115 | self.assertEqual(r.bottom, ['B'] * 2 + ['Y'] * 2) 116 | self.assertEqual(r.front, ['R'] * 4) 117 | self.assertEqual(r.left, ['G', 'Y'] * 2) 118 | self.assertEqual(r.right, ['W', 'B'] * 2) 119 | self.assertEqual(r.top, ['W'] * 2 + ['G'] * 2) 120 | 121 | def test_front_rev(self): 122 | s = cube2x2.initial_state 123 | s = cube2x2.transform(s, cube2x2.Action.f) 124 | r = cube2x2.render(s) 125 | self.assertEqual(r.back, ['O'] * 4) 126 | self.assertEqual(r.bottom, ['G'] * 2 + ['Y'] * 2) 127 | self.assertEqual(r.front, ['R'] * 4) 128 | self.assertEqual(r.left, ['G', 'W'] * 2) 129 | self.assertEqual(r.right, ['Y', 'B'] * 2) 130 | self.assertEqual(r.top, ['W'] * 2 + ['B'] * 2) 131 | 132 | def test_back(self): 133 | s = cube2x2.initial_state 134 | s = cube2x2.transform(s, cube2x2.Action.B) 135 | r = cube2x2.render(s) 136 | self.assertEqual(r.back, ['O'] * 4) 137 | self.assertEqual(r.bottom, ['Y'] * 2 + ['G'] * 2) 138 | self.assertEqual(r.front, ['R'] * 4) 139 | self.assertEqual(r.left, ['W', 'G'] * 2) 140 | self.assertEqual(r.right, ['B', 'Y'] * 2) 141 | self.assertEqual(r.top, ['B'] * 2 + ['W'] * 2) 142 | 143 | def test_back_rev(self): 144 | s = cube2x2.initial_state 145 | s = cube2x2.transform(s, cube2x2.Action.b) 146 | r = cube2x2.render(s) 147 | self.assertEqual(r.back, ['O'] * 4) 148 | self.assertEqual(r.bottom, ['Y'] * 2 + ['B'] * 2) 149 | self.assertEqual(r.front, ['R'] * 4) 150 | self.assertEqual(r.left, ['Y', 'G'] * 2) 151 | self.assertEqual(r.right, ['B', 'W'] * 2) 152 | self.assertEqual(r.top, ['G'] * 2 + ['W'] * 2) 153 | 154 | def test_inverse_right(self): 155 | s = cube2x2.initial_state 156 | s = cube2x2.transform(s, cube2x2.Action.R) 157 | s = cube2x2.transform(s, cube2x2.Action.r) 158 | self.assertEqual(s, cube2x2.initial_state) 159 | 160 | s = cube2x2.initial_state 161 | s = cube2x2.transform(s, cube2x2.Action.r) 162 | s = cube2x2.transform(s, cube2x2.Action.R) 163 | self.assertEqual(s, cube2x2.initial_state) 164 | 165 | def test_inverse_left(self): 166 | s = cube2x2.initial_state 167 | s = cube2x2.transform(s, cube2x2.Action.L) 168 | s = cube2x2.transform(s, cube2x2.Action.l) 169 | self.assertEqual(s, cube2x2.initial_state) 170 | 171 | s = cube2x2.initial_state 172 | s = cube2x2.transform(s, cube2x2.Action.l) 173 | s = cube2x2.transform(s, cube2x2.Action.L) 174 | self.assertEqual(s, cube2x2.initial_state) 175 | 176 | def test_inverse_top(self): 177 | s = cube2x2.initial_state 178 | s = cube2x2.transform(s, cube2x2.Action.T) 179 | s = cube2x2.transform(s, cube2x2.Action.t) 180 | self.assertEqual(s, cube2x2.initial_state) 181 | 182 | s = cube2x2.initial_state 183 | s = cube2x2.transform(s, cube2x2.Action.t) 184 | s = cube2x2.transform(s, cube2x2.Action.T) 185 | self.assertEqual(s, cube2x2.initial_state) 186 | 187 | def test_inverse_down(self): 188 | s = cube2x2.initial_state 189 | s = cube2x2.transform(s, cube2x2.Action.D) 190 | s = cube2x2.transform(s, cube2x2.Action.d) 191 | self.assertEqual(s, cube2x2.initial_state) 192 | 193 | s = cube2x2.initial_state 194 | s = cube2x2.transform(s, cube2x2.Action.d) 195 | s = cube2x2.transform(s, cube2x2.Action.D) 196 | self.assertEqual(s, cube2x2.initial_state) 197 | 198 | def test_inverse_front(self): 199 | s = cube2x2.initial_state 200 | s = cube2x2.transform(s, cube2x2.Action.F) 201 | s = cube2x2.transform(s, cube2x2.Action.f) 202 | self.assertEqual(s, cube2x2.initial_state) 203 | 204 | s = cube2x2.initial_state 205 | s = cube2x2.transform(s, cube2x2.Action.f) 206 | s = cube2x2.transform(s, cube2x2.Action.F) 207 | self.assertEqual(s, cube2x2.initial_state) 208 | 209 | def test_inverse_back(self): 210 | s = cube2x2.initial_state 211 | s = cube2x2.transform(s, cube2x2.Action.B) 212 | s = cube2x2.transform(s, cube2x2.Action.b) 213 | self.assertEqual(s, cube2x2.initial_state) 214 | 215 | s = cube2x2.initial_state 216 | s = cube2x2.transform(s, cube2x2.Action.b) 217 | s = cube2x2.transform(s, cube2x2.Action.B) 218 | self.assertEqual(s, cube2x2.initial_state) 219 | 220 | def test_inverse(self): 221 | s = cube2x2.initial_state 222 | for a in cube2x2.Action: 223 | s = cube2x2.transform(s, a) 224 | r = cube2x2.render(s) 225 | s = cube2x2.transform(s, cube2x2.inverse_action(a)) 226 | r2 = cube2x2.render(s) 227 | self.assertEqual(s, cube2x2.initial_state) 228 | 229 | def test_sequence(self): 230 | acts = [cube2x2.Action.R, cube2x2.Action.t, cube2x2.Action.R, cube2x2.Action.D, cube2x2.Action.F, 231 | cube2x2.Action.d, cube2x2.Action.T, cube2x2.Action.R, cube2x2.Action.D, cube2x2.Action.F] 232 | 233 | s = cube2x2.initial_state 234 | for a in acts: 235 | s = cube2x2.transform(s, a) 236 | r = cube2x2.render(s) 237 | for a in reversed(acts): 238 | s = cube2x2.transform(s, cube2x2.inverse_action(a)) 239 | r = cube2x2.render(s) 240 | self.assertEqual(s, cube2x2.initial_state) 241 | 242 | 243 | class CubeEncoding(unittest.TestCase): 244 | def test_init(self): 245 | tgt = np.zeros(shape=cube2x2.encoded_shape) 246 | s = cube2x2.initial_state 247 | cube2x2.encode_inplace(tgt, s) 248 | 249 | def test_random(self): 250 | s = cube2x2.initial_state 251 | for _ in range(200): 252 | a = cube2x2.Action(random.randrange(len(cube2x2.Action))) 253 | s = cube2x2.transform(s, a) 254 | tgt = np.zeros(shape=cube2x2.encoded_shape) 255 | cube2x2.encode_inplace(tgt, s) 256 | self.assertEqual(tgt.sum(), 8) 257 | 258 | 259 | if __name__ == '__main__': 260 | unittest.main() 261 | -------------------------------------------------------------------------------- /articles/01_rubic/tests/libcube/cubes/test_cube3x3.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import random 4 | 5 | from libcube.cubes import cube3x3 6 | 7 | 8 | class CubeRender(unittest.TestCase): 9 | def test_init_render(self): 10 | state = cube3x3.initial_state 11 | render = cube3x3.render(state) 12 | self.assertIsInstance(render, cube3x3.RenderedState) 13 | self.assertEqual(render.top, ['W'] * 9) 14 | self.assertEqual(render.back, ['O'] * 9) 15 | self.assertEqual(render.bottom, ['Y'] * 9) 16 | self.assertEqual(render.front, ['R'] * 9) 17 | self.assertEqual(render.left, ['G'] * 9) 18 | self.assertEqual(render.right, ['B'] * 9) 19 | 20 | 21 | class CubeTransforms(unittest.TestCase): 22 | def test_top(self): 23 | s = cube3x3.initial_state 24 | s = cube3x3.transform(s, cube3x3.Action.T) 25 | r = cube3x3.render(s) 26 | self.assertEqual(r.top, ['W'] * 9) 27 | self.assertEqual(r.back, ['G'] * 3 + ['O'] * 6) 28 | self.assertEqual(r.bottom, ['Y'] * 9) 29 | self.assertEqual(r.front, ['B'] * 3 + ['R'] * 6) 30 | self.assertEqual(r.left, ['R'] * 3 + ['G'] * 6) 31 | self.assertEqual(r.right, ['O'] * 3 + ['B'] * 6) 32 | 33 | def test_top_rev(self): 34 | s = cube3x3.initial_state 35 | s = cube3x3.transform(s, cube3x3.Action.t) 36 | r = cube3x3.render(s) 37 | self.assertEqual(r.top, ['W'] * 9) 38 | self.assertEqual(r.back, ['B'] * 3 + ['O'] * 6) 39 | self.assertEqual(r.bottom, ['Y'] * 9) 40 | self.assertEqual(r.front, ['G'] * 3 + ['R'] * 6) 41 | self.assertEqual(r.left, ['O'] * 3 + ['G'] * 6) 42 | self.assertEqual(r.right, ['R'] * 3 + ['B'] * 6) 43 | 44 | def test_down(self): 45 | s = cube3x3.initial_state 46 | s = cube3x3.transform(s, cube3x3.Action.D) 47 | r = cube3x3.render(s) 48 | self.assertEqual(r.back, ['O'] * 6 + ['B'] * 3) 49 | self.assertEqual(r.bottom, ['Y'] * 9) 50 | self.assertEqual(r.front, ['R'] * 6 + ['G'] * 3) 51 | self.assertEqual(r.left, ['G'] * 6 + ['O'] * 3) 52 | self.assertEqual(r.right, ['B'] * 6 + ['R'] * 3) 53 | self.assertEqual(r.top, ['W'] * 9) 54 | 55 | def test_down_rev(self): 56 | s = cube3x3.initial_state 57 | s = cube3x3.transform(s, cube3x3.Action.d) 58 | r = cube3x3.render(s) 59 | self.assertEqual(r.back, ['O'] * 6 + ['G'] * 3) 60 | self.assertEqual(r.bottom, ['Y'] * 9) 61 | self.assertEqual(r.front, ['R'] * 6 + ['B'] * 3) 62 | self.assertEqual(r.left, ['G'] * 6 + ['R'] * 3) 63 | self.assertEqual(r.right, ['B'] * 6 + ['O'] * 3) 64 | self.assertEqual(r.top, ['W'] * 9) 65 | 66 | def test_right(self): 67 | s = cube3x3.initial_state 68 | s = cube3x3.transform(s, cube3x3.Action.R) 69 | r = cube3x3.render(s) 70 | self.assertEqual(r.back, ['W', 'O', 'O'] * 3) 71 | self.assertEqual(r.bottom, ['Y', 'Y', 'O'] * 3) 72 | self.assertEqual(r.front, ['R', 'R', 'Y'] * 3) 73 | self.assertEqual(r.left, ['G'] * 9) 74 | self.assertEqual(r.right, ['B'] * 9) 75 | self.assertEqual(r.top, ['W', 'W', 'R'] * 3) 76 | 77 | def test_right_rev(self): 78 | s = cube3x3.initial_state 79 | s = cube3x3.transform(s, cube3x3.Action.r) 80 | r = cube3x3.render(s) 81 | self.assertEqual(r.back, ['Y', 'O', 'O'] * 3) 82 | self.assertEqual(r.bottom, ['Y', 'Y', 'R'] * 3) 83 | self.assertEqual(r.front, ['R', 'R', 'W'] * 3) 84 | self.assertEqual(r.left, ['G'] * 9) 85 | self.assertEqual(r.right, ['B'] * 9) 86 | self.assertEqual(r.top, ['W', 'W', 'O'] * 3) 87 | 88 | def test_left(self): 89 | s = cube3x3.initial_state 90 | s = cube3x3.transform(s, cube3x3.Action.L) 91 | r = cube3x3.render(s) 92 | self.assertEqual(r.back, ['O', 'O', 'Y'] * 3) 93 | self.assertEqual(r.bottom, ['R', 'Y', 'Y'] * 3) 94 | self.assertEqual(r.front, ['W', 'R', 'R'] * 3) 95 | self.assertEqual(r.left, ['G'] * 9) 96 | self.assertEqual(r.right, ['B'] * 9) 97 | self.assertEqual(r.top, ['O', 'W', 'W'] * 3) 98 | 99 | def test_left_rev(self): 100 | s = cube3x3.initial_state 101 | s = cube3x3.transform(s, cube3x3.Action.l) 102 | r = cube3x3.render(s) 103 | self.assertEqual(r.back, ['O', 'O', 'W'] * 3) 104 | self.assertEqual(r.bottom, ['O', 'Y', 'Y'] * 3) 105 | self.assertEqual(r.front, ['Y', 'R', 'R'] * 3) 106 | self.assertEqual(r.left, ['G'] * 9) 107 | self.assertEqual(r.right, ['B'] * 9) 108 | self.assertEqual(r.top, ['R', 'W', 'W'] * 3) 109 | 110 | def test_front(self): 111 | s = cube3x3.initial_state 112 | s = cube3x3.transform(s, cube3x3.Action.F) 113 | r = cube3x3.render(s) 114 | self.assertEqual(r.back, ['O'] * 9) 115 | self.assertEqual(r.bottom, ['B'] * 3 + ['Y'] * 6) 116 | self.assertEqual(r.front, ['R'] * 9) 117 | self.assertEqual(r.left, ['G', 'G', 'Y'] * 3) 118 | self.assertEqual(r.right, ['W', 'B', 'B'] * 3) 119 | self.assertEqual(r.top, ['W'] * 6 + ['G'] * 3) 120 | 121 | def test_front_rev(self): 122 | s = cube3x3.initial_state 123 | s = cube3x3.transform(s, cube3x3.Action.f) 124 | r = cube3x3.render(s) 125 | self.assertEqual(r.back, ['O'] * 9) 126 | self.assertEqual(r.bottom, ['G'] * 3 + ['Y'] * 6) 127 | self.assertEqual(r.front, ['R'] * 9) 128 | self.assertEqual(r.left, ['G', 'G', 'W'] * 3) 129 | self.assertEqual(r.right, ['Y', 'B', 'B'] * 3) 130 | self.assertEqual(r.top, ['W'] * 6 + ['B'] * 3) 131 | 132 | def test_back(self): 133 | s = cube3x3.initial_state 134 | s = cube3x3.transform(s, cube3x3.Action.B) 135 | r = cube3x3.render(s) 136 | self.assertEqual(r.back, ['O'] * 9) 137 | self.assertEqual(r.bottom, ['Y'] * 6 + ['G'] * 3) 138 | self.assertEqual(r.front, ['R'] * 9) 139 | self.assertEqual(r.left, ['W', 'G', 'G'] * 3) 140 | self.assertEqual(r.right, ['B', 'B', 'Y'] * 3) 141 | self.assertEqual(r.top, ['B'] * 3 + ['W'] * 6) 142 | 143 | def test_back_rev(self): 144 | s = cube3x3.initial_state 145 | s = cube3x3.transform(s, cube3x3.Action.b) 146 | r = cube3x3.render(s) 147 | self.assertEqual(r.back, ['O'] * 9) 148 | self.assertEqual(r.bottom, ['Y'] * 6 + ['B'] * 3) 149 | self.assertEqual(r.front, ['R'] * 9) 150 | self.assertEqual(r.left, ['Y', 'G', 'G'] * 3) 151 | self.assertEqual(r.right, ['B', 'B', 'W'] * 3) 152 | self.assertEqual(r.top, ['G'] * 3 + ['W'] * 6) 153 | 154 | def test_inverse_right(self): 155 | s = cube3x3.initial_state 156 | s = cube3x3.transform(s, cube3x3.Action.R) 157 | s = cube3x3.transform(s, cube3x3.Action.r) 158 | self.assertEqual(s, cube3x3.initial_state) 159 | 160 | s = cube3x3.initial_state 161 | s = cube3x3.transform(s, cube3x3.Action.r) 162 | s = cube3x3.transform(s, cube3x3.Action.R) 163 | self.assertEqual(s, cube3x3.initial_state) 164 | 165 | def test_inverse_left(self): 166 | s = cube3x3.initial_state 167 | s = cube3x3.transform(s, cube3x3.Action.L) 168 | s = cube3x3.transform(s, cube3x3.Action.l) 169 | self.assertEqual(s, cube3x3.initial_state) 170 | 171 | s = cube3x3.initial_state 172 | s = cube3x3.transform(s, cube3x3.Action.l) 173 | s = cube3x3.transform(s, cube3x3.Action.L) 174 | self.assertEqual(s, cube3x3.initial_state) 175 | 176 | def test_inverse_top(self): 177 | s = cube3x3.initial_state 178 | s = cube3x3.transform(s, cube3x3.Action.T) 179 | s = cube3x3.transform(s, cube3x3.Action.t) 180 | self.assertEqual(s, cube3x3.initial_state) 181 | 182 | s = cube3x3.initial_state 183 | s = cube3x3.transform(s, cube3x3.Action.t) 184 | s = cube3x3.transform(s, cube3x3.Action.T) 185 | self.assertEqual(s, cube3x3.initial_state) 186 | 187 | def test_inverse_down(self): 188 | s = cube3x3.initial_state 189 | s = cube3x3.transform(s, cube3x3.Action.D) 190 | s = cube3x3.transform(s, cube3x3.Action.d) 191 | self.assertEqual(s, cube3x3.initial_state) 192 | 193 | s = cube3x3.initial_state 194 | s = cube3x3.transform(s, cube3x3.Action.d) 195 | s = cube3x3.transform(s, cube3x3.Action.D) 196 | self.assertEqual(s, cube3x3.initial_state) 197 | 198 | def test_inverse_front(self): 199 | s = cube3x3.initial_state 200 | s = cube3x3.transform(s, cube3x3.Action.F) 201 | s = cube3x3.transform(s, cube3x3.Action.f) 202 | self.assertEqual(s, cube3x3.initial_state) 203 | 204 | s = cube3x3.initial_state 205 | s = cube3x3.transform(s, cube3x3.Action.f) 206 | s = cube3x3.transform(s, cube3x3.Action.F) 207 | self.assertEqual(s, cube3x3.initial_state) 208 | 209 | def test_inverse_back(self): 210 | s = cube3x3.initial_state 211 | s = cube3x3.transform(s, cube3x3.Action.B) 212 | s = cube3x3.transform(s, cube3x3.Action.b) 213 | self.assertEqual(s, cube3x3.initial_state) 214 | 215 | s = cube3x3.initial_state 216 | s = cube3x3.transform(s, cube3x3.Action.b) 217 | s = cube3x3.transform(s, cube3x3.Action.B) 218 | self.assertEqual(s, cube3x3.initial_state) 219 | 220 | def test_inverse(self): 221 | s = cube3x3.initial_state 222 | for a in cube3x3.Action: 223 | s = cube3x3.transform(s, a) 224 | r = cube3x3.render(s) 225 | s = cube3x3.transform(s, cube3x3.inverse_action(a)) 226 | r2 = cube3x3.render(s) 227 | self.assertEqual(s, cube3x3.initial_state) 228 | 229 | def test_sequence(self): 230 | acts = [cube3x3.Action.R, cube3x3.Action.t, cube3x3.Action.R, cube3x3.Action.D, cube3x3.Action.F, 231 | cube3x3.Action.d, cube3x3.Action.T, cube3x3.Action.R, cube3x3.Action.D, cube3x3.Action.F] 232 | 233 | s = cube3x3.initial_state 234 | for a in acts: 235 | s = cube3x3.transform(s, a) 236 | r = cube3x3.render(s) 237 | for a in reversed(acts): 238 | s = cube3x3.transform(s, cube3x3.inverse_action(a)) 239 | r = cube3x3.render(s) 240 | self.assertEqual(s, cube3x3.initial_state) 241 | 242 | 243 | class CubeEncoding(unittest.TestCase): 244 | def test_init(self): 245 | tgt = np.zeros(shape=cube3x3.encoded_shape) 246 | s = cube3x3.initial_state 247 | cube3x3.encode_inplace(tgt, s) 248 | 249 | def test_random(self): 250 | s = cube3x3.initial_state 251 | for _ in range(200): 252 | a = cube3x3.Action(random.randrange(len(cube3x3.Action))) 253 | s = cube3x3.transform(s, a) 254 | tgt = np.zeros(shape=cube3x3.encoded_shape) 255 | cube3x3.encode_inplace(tgt, s) 256 | self.assertEqual(tgt.sum(), 20) 257 | 258 | 259 | if __name__ == '__main__': 260 | unittest.main() 261 | --------------------------------------------------------------------------------